In [1]:
class UnetArhitecture(nn.Module):
    def __init__(self):
        super().__init__()
        
        #Encoder
        self.e1 = nn.Sequential(
            nn.Conv2d(3,64,3,padding=1),nn.BatchNorm2d(64),nn.ReLU(),
            nn.Conv2d(64,64,3,padding=1),nn.BatchNorm2d(64),nn.ReLU()
        )
        
        self.max_pooling_e1 = nn.MaxPool2d(2) 
        
        self.e2 = nn.Sequential(
            nn.Conv2d(64,128,3,padding=1),nn.BatchNorm2d(128),nn.ReLU(),
            nn.Conv2d(128,128,3,padding=1),nn.BatchNorm2d(128),nn.ReLU()
        )
        
        self.max_pooling_e2 = nn.MaxPool2d(2) 
        
        self.e3 = nn.Sequential(
            nn.Conv2d(128,256,3,padding=1),nn.BatchNorm2d(256),nn.ReLU(),
            nn.Conv2d(256,256,3,padding=1),nn.BatchNorm2d(256),nn.ReLU()
        )
        
        #BottleNeck
        self.max_pooling_e3 = nn.MaxPool2d(2)
        self.bottle_neck = nn.Sequential(
            nn.Conv2d(256,512,3,padding=1),nn.BatchNorm2d(512),nn.ReLU(),
            nn.Conv2d(512,512,3,padding=1),nn.BatchNorm2d(512),nn.ReLU(),
            nn.ConvTranspose2d(512, 256, 2, stride=2)
        )
    
        self.d3 = nn.Sequential(
            nn.Conv2d(512,256,3,padding=1),nn.BatchNorm2d(256),nn.ReLU(), #at this part concat with e3 (256 from e3 + 256 from bottle_neck)
            nn.Conv2d(256,256,3,padding=1),nn.BatchNorm2d(256),nn.ReLU(),
            nn.ConvTranspose2d(256, 128, 2, stride=2)
        )
        
        self.d2 = nn.Sequential(
            nn.Conv2d(256,128,3,padding=1),nn.BatchNorm2d(128),nn.ReLU(), #at this part concat with e2 (128 from e2 + 128 from d3)
            nn.Conv2d(128,128,3,padding=1),nn.BatchNorm2d(128),nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 2, stride=2)
        )
        
        self.d1 = nn.Sequential(
            nn.Conv2d(128,64,3,padding=1),nn.BatchNorm2d(64),nn.ReLU(), #at this part concat with e1 (64 from e1 + 64 from d2)
            nn.Conv2d(64,64,3,padding=1),nn.BatchNorm2d(64),nn.ReLU(),
            nn.Conv2d(64, 1, 1)
        )
    
    def forward(self, x):
        #print("Input:", x.shape)
        e1 = self.e1(x)
        #print("e1:", e1.shape)
        e2 = self.e2(self.max_pooling_e1(e1))
        #print("e2:", e2.shape)
        e3 = self.e3(self.max_pooling_e2(e2))
        #print("e3:", e3.shape)
        bottle_neck = self.bottle_neck(self.max_pooling_e3(e3))
        #print("bottleneck:", bottle_neck.shape)
        d3 = self.d3(torch.cat([e3, bottle_neck], dim=1))
        #print("d3:", d3.shape)
        d2 = self.d2(torch.cat([e2, d3], dim=1))
        #print("d2:", d2.shape)
        d1 = self.d1(torch.cat([e1, d2], dim=1))
        #print("d1 / Output:", d1.shape)
        return d1


NameError: name 'nn' is not defined

In [None]:
!pip install -q monai

In [None]:
!pip install segmentation-models-pytorch


In [None]:
import os
import sys
import cv2
import random
import numpy as np
import torch
import torch.nn as nn
from torch import optim
from torchvision import transforms
import torchvision.transforms.functional as TF
import torch.nn.functional as F
from monai.losses import DiceLoss
from monai.losses import HausdorffDTLoss
from monai.metrics import compute_average_surface_distance
from PIL import Image 
import matplotlib.pyplot as plt
import torch.optim as optim
import seaborn as sns
from torch.utils.data import Dataset
from segmentation_models_pytorch.losses import FocalLoss
from sklearn.metrics import classification_report, confusion_matrix


class CellDataSet:
    def __init__(self, batch_size=1):
        self.batch_size = batch_size
        self.path = "/kaggle/input/celldetection/ds1"
        self.test_path = os.path.join(self.path, "test")
        self.train_path = os.path.join(self.path, "train")
        self.validation_path = os.path.join(self.path, "validation")

        self.img_path = "img/cls"
        self.bin_mask_path = "bin_mask/cls"
        self.multi_mask_path = "mult_mask/cls"

        self.train = os.listdir(os.path.join(self.train_path, self.img_path))
        self.test = os.listdir(os.path.join(self.test_path, self.img_path))

        self.current_index = 0
        self.indices = random.sample(self.train, len(self.train))


    def load_sample(self, base_path, img_id, grayscale = False, scharr = False):
        try:
            image = Image.open(os.path.join(base_path, self.img_path, img_id))
            if not grayscale:
                image = image.convert("RGB")
                image = torch.tensor(np.array(image), dtype=torch.float32).permute(2, 0, 1) / 255.0
            else:
                image = image.convert("L")
                image_np = np.array(image)
                if scharr:
                    # Apply Scharr edge detection
                    scharr_x = cv2.Scharr(image_np, cv2.CV_32F, 1, 0)
                    scharr_y = cv2.Scharr(image_np, cv2.CV_32F, 0, 1)
                    magnitude = cv2.magnitude(scharr_x, scharr_y)
                    magnitude = cv2.normalize(magnitude, None, 0, 1, cv2.NORM_MINMAX)
                    image = torch.tensor(magnitude, dtype=torch.float32).unsqueeze(0)
                else:
                    image = torch.tensor(image_np, dtype=torch.float32).unsqueeze(0) / 255.0
            bin_mask = Image.open(os.path.join(base_path, self.bin_mask_path, img_id)).convert("L")
            mult_mask = Image.open(os.path.join(base_path, self.multi_mask_path, img_id))
            bin_mask = torch.tensor(np.array(bin_mask), dtype=torch.float32).unsqueeze(0) / 255.0
            mult_mask = torch.tensor(np.array(mult_mask), dtype=torch.long)

            return image, bin_mask, mult_mask
        except Exception as e:
            print(f"[ERROR] Failed to load sample '{img_id}' from {base_path}: {e}")
            return None, None, None

    def get_train_index(self, index, grayscale = False, scharr = False):
        img_id = self.train[index]
        return self.load_sample(self.train_path, img_id, grayscale = grayscale)

    def get_test_index(self, index, grayscale = False, scharr = False):
        img_id = self.test[index]
        return self.load_sample(self.test_path, img_id,  grayscale = grayscale, scharr = scharr)

    def get_validation_by_name(self, name, grayscale = False, scharr = False):
        return self.load_sample(self.validation_path, name,  grayscale = grayscale, scharr = scharr)

    def apply_augmentations(self, sample):
        image, bin_mask, mult_mask = sample
        choice = random.choice(["hflip", "vflip", "rotate"])
        if choice == "hflip":
            image = TF.hflip(image)
            bin_mask = TF.hflip(bin_mask)
            mult_mask = TF.hflip(mult_mask)
        elif choice == "vflip":
            image = TF.vflip(image)
            bin_mask = TF.vflip(bin_mask)
            mult_mask = TF.vflip(mult_mask)
        elif choice == "rotate":
            angle = random.uniform(-15, 15)
            image = TF.rotate(image, angle, fill=0)
            bin_mask = TF.rotate(bin_mask, angle, fill=0)
            if mult_mask.ndim == 2:
                mult_mask = mult_mask.unsqueeze(0)  # [1, H, W]
            mult_mask = TF.rotate(mult_mask, angle, fill=0)
            mult_mask = mult_mask.squeeze(0)

        return image, bin_mask, mult_mask



        

In [None]:
import gc
def train_model(model, dataset, num_epochs=10, device="cuda", grayscale=False,  scharr = False,
                patience=10,  scale_factor=1, save_path="best_model.pth"):
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    

    def compute_loss(output, bin_mask):
        #focal_loss = FocalLoss('binary') 
        #bce = torch.nn.BCEWithLogitsLoss()
        dice_loss = DiceLoss(sigmoid=True)
        #hausdorff_loss= HausdorffDTLoss(sigmoid=True)
        #haus = hausdorff_loss(output, bin_mask) 
        return dice_loss(output,bin_mask) #0.5 * focal_loss(output,bin_mask) +  + 0.00002 * haus
           
    val_img_dir = os.path.join(dataset.validation_path, dataset.img_path)
    val_filenames = sorted(os.listdir(val_img_dir))
    best_val_loss = float("inf")
    epochs_without_improvement = 0

    for epoch in range(num_epochs):
        print(f"\n=== Epoch {epoch+1}/{num_epochs} ===")
        model.train()
        running_loss = 0.0
        batch_count = 0

        for idx in range(len(dataset.train)):
            # --- Original sample ---
            image, bin_mask, _ = dataset.get_train_index(idx, grayscale=grayscale, scharr=scharr)
            if image is None:
                continue

            orig_size = image.shape[-2:]  # Save original size

            if scale_factor != 1:
                image = F.interpolate(image.unsqueeze(0), scale_factor=1/scale_factor, mode='bilinear', align_corners=False)
            else:
                image = image.unsqueeze(0)
            
            bin_mask = bin_mask.unsqueeze(0)

            image = image.to(device)
            bin_mask = bin_mask.to(device)

            outputs = model(image)

            if scale_factor != 1:
                outputs = F.interpolate(outputs, size=orig_size, mode='bilinear', align_corners=False)
            
            loss = compute_loss(outputs, bin_mask)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            batch_count += 1

            # --- Augmented sample ---
            aug_image, aug_bin_mask, _ = dataset.apply_augmentations(
                dataset.get_train_index(idx, grayscale=grayscale, scharr=scharr)
            )

            orig_size = aug_image.shape[-2:]  # Save original size

            if scale_factor != 1:
                aug_image = F.interpolate(aug_image.unsqueeze(0), scale_factor=1/scale_factor, mode='bilinear', align_corners=False)
            else:
                aug_image = aug_image.unsqueeze(0)

            aug_bin_mask = aug_bin_mask.unsqueeze(0).to(device)

            aug_image = aug_image.to(device)

            outputs = model(aug_image)
            if scale_factor != 1:
                outputs = F.interpolate(outputs, size=orig_size, mode='bilinear', align_corners=False)
            loss = compute_loss(outputs, aug_bin_mask)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            batch_count += 1

            #if batch_count % 5 == 0:
            #    print(f"  [Train] Batch {batch_count} - Loss: {loss.item():.4f}")

        avg_train_loss = running_loss / batch_count
        print(f"[Train] Epoch {epoch+1} avg loss: {avg_train_loss:.4f}")
        gc.collect()
        # === Validation ===
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for name in val_filenames:
                image, bin_mask, _ = dataset.get_validation_by_name(name, grayscale=grayscale)
                if image is None:
                    continue

                orig_size = image.shape[-2:]

                if scale_factor != 1:
                    image = F.interpolate(image.unsqueeze(0), scale_factor=1/scale_factor, mode='bilinear', align_corners=False)
                else:
                    image = image.unsqueeze(0)
                bin_mask = bin_mask.unsqueeze(0)
                
                image = image.to(device)
                bin_mask = bin_mask.to(device)

                output = model(image)

                if scale_factor != 1:
                    output = F.interpolate(output, size=orig_size, mode='bilinear', align_corners=False)
                
                loss = compute_loss(output, bin_mask)
                val_loss += loss.item()

        avg_val_loss = val_loss / len(val_filenames)
        print(f"[Validation] Epoch {epoch+1} avg loss: {avg_val_loss:.4f}")

        # === Early Stopping ===
        if avg_val_loss < best_val_loss:
            print("✅ Validation loss improved — saving model.")
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), save_path)
            epochs_without_improvement = 0
        else:
            epochs_without_improvement += 1
            print(f"No improvement ({epochs_without_improvement}/{patience} patience).")
            if epochs_without_improvement >= patience:
                print("🛑 Early stopping triggered.")
                break


In [None]:
#print("Before Dataset: {:.2f} MB".format(torch.cuda.memory_allocated() / 1024**2))
dataset = CellDataSet(batch_size=2)
#print("Before Unet: {:.2f} MB".format(torch.cuda.memory_allocated() / 1024**2))
model = UnetArhitecture()
#print("By start Memory Allocated: {:.2f} MB".format(torch.cuda.memory_allocated() / 1024**2))
train_model(model, dataset, num_epochs=20)

In [None]:
from monai.metrics import DiceMetric, HausdorffDistanceMetric, SurfaceDistanceMetric
from monai.transforms import AsDiscrete
import torch

def test_detector(dataset, model_class, model_path, grayscale=False, scharr=True, scale_factor=1, device="cuda"):
    model = model_class().to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()

    # MONAI metric utilities
    dice_metric = DiceMetric(include_background=False, reduction="mean")
    hausdorff = HausdorffDistanceMetric(include_background=False, percentile=95)
    surface_dice = SurfaceDistanceMetric(include_background=False, symmetric=True)

    post_pred = AsDiscrete(threshold=0.5)
    post_label = AsDiscrete(threshold=0.5)

    all_preds = []
    all_labels = []

    print("🧪 Running detection on test dataset...")
    with torch.no_grad():
        for i in range(len(dataset.test)):
            sample = dataset.get_test_index(i, grayscale=grayscale, scharr=scharr)
            if sample[0] is None:
                continue

            image, bin_mask, multi_mask = sample

            orig_size = image.shape[-2:]  # Save original size

            if scale_factor != 1:
                image = F.interpolate(image.unsqueeze(0), scale_factor=1/scale_factor, mode='bilinear', align_corners=False)
            else:
                image = image.unsqueeze(0)
            bin_mask = bin_mask.unsqueeze(0).to(device)
            image = image.to(device)
            
            pred = model(image)

            if scale_factor != 1:
                pred = F.interpolate(pred, size=orig_size, mode='bilinear', align_corners=False)
            
            pred_bin = post_pred(torch.sigmoid(pred))
            label_bin = post_label(bin_mask)

            all_preds.append(pred_bin)
            all_labels.append(label_bin)

            del image, bin_mask, pred
            torch.cuda.empty_cache()

    preds_tensor = torch.cat(all_preds, dim=0)
    labels_tensor = torch.cat(all_labels, dim=0)

    # Compute metrics
    dice_val = dice_metric(preds_tensor, labels_tensor).mean().item()
    hausdorff_val = hausdorff(preds_tensor, labels_tensor).mean().item()
    surface_dice_val = surface_dice(preds_tensor, labels_tensor).mean().item()

    print(f"\n📊 Metrics:")
    print(f"🎯 Dice Score: {dice_val:.4f}")
    print(f"📏 Hausdorff Distance (95th percentile): {hausdorff_val:.4f}")
    print(f"🌊 Normalized Surface Dice: {surface_dice_val:.4f}")


In [None]:
def show_test(dataset, model_class, model_path, grayscale=False, scale_factor=1, device="cuda"):
    model = model_class().to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()

    y_true, y_pred = [], []
    
    with torch.no_grad():
        for i in range(len(dataset.test)):
            sample = dataset.get_test_index(i, grayscale=grayscale)

            image, bin_mask, multi_mask = sample
            orig_size = image.shape[-2:]  # Save original size

            if scale_factor != 1:
                image = F.interpolate(image.unsqueeze(0), scale_factor=1/scale_factor, mode='bilinear', align_corners=False)
            else:
                image = image.unsqueeze(0)
            image = image.to(device)
            pred = model(image)

            if scale_factor != 1:
                pred = F.interpolate(pred, size=orig_size, mode='bilinear', align_corners=False)

            pred_bin = (torch.sigmoid(pred).squeeze() > 0.5).cpu().numpy().astype(np.uint8)
            true_bin = bin_mask.squeeze().cpu().numpy().astype(np.uint8)

            plt.figure(figsize=(10, 5))

            # Predicted mask
            plt.subplot(1, 2, 1)
            plt.imshow(pred_bin, cmap='gray')
            plt.title('Predicted Binary Mask')
            plt.axis('off')
            
            # Ground truth mask
            plt.subplot(1, 2, 2)
            plt.imshow(true_bin, cmap='gray')
            plt.title('Ground Truth Binary Mask')
            plt.axis('off')
            
            plt.tight_layout()
            plt.show()


In [None]:


# === CONFIG ===
train_dir = "/kaggle/input/celldetection/ds1/train"
detector_model_path = "/kaggle/working/best_model.pth"
epochs = 25

# === LOAD DATASET ===
dataset = CellDataset(train_dir)


show_test(dataset, model_class=UnetArhitecture, model_path=detector_model_path)

In [None]:
class SmallerUnet(nn.Module):
    def __init__(self):
        super().__init__()

        # Encoder
        self.e1 = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(),
            nn.Conv2d(32, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU()
        )
        self.max_pooling_e1 = nn.MaxPool2d(2)

        self.e2 = nn.Sequential(
            nn.Conv2d(32, 48, 3, padding=1), nn.BatchNorm2d(48), nn.ReLU(),
            nn.Conv2d(48, 48, 3, padding=1), nn.BatchNorm2d(48), nn.ReLU()
        )
        self.max_pooling_e2 = nn.MaxPool2d(2)

        self.e3 = nn.Sequential(
            nn.Conv2d(48, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU()
        )
        self.max_pooling_e3 = nn.MaxPool2d(2)

        # Bottleneck
        self.bottle_neck = nn.Sequential(
            nn.Conv2d(64, 96, 3, padding=1), nn.BatchNorm2d(96), nn.ReLU(),
            nn.Conv2d(96, 96, 3, padding=1), nn.BatchNorm2d(96), nn.ReLU(),
            nn.ConvTranspose2d(96, 64, 2, stride=2)
        )

        # Decoder
        self.d3 = nn.Sequential(
            nn.Conv2d(64 + 64, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(),
            nn.ConvTranspose2d(64, 48, 2, stride=2)
        )

        self.d2 = nn.Sequential(
            nn.Conv2d(48 + 48, 48, 3, padding=1), nn.BatchNorm2d(48), nn.ReLU(),
            nn.Conv2d(48, 48, 3, padding=1), nn.BatchNorm2d(48), nn.ReLU(),
            nn.ConvTranspose2d(48, 32, 2, stride=2)
        )

        self.d1 = nn.Sequential(
            nn.Conv2d(32 + 32, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(),
            nn.Conv2d(32, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(),
            nn.Conv2d(32, 1, 1)
        )

    def forward(self, x):
        e1 = self.e1(x)
        e2 = self.e2(self.max_pooling_e1(e1))
        e3 = self.e3(self.max_pooling_e2(e2))
        bottle = self.bottle_neck(self.max_pooling_e3(e3))

        d3 = self.d3(torch.cat([e3, bottle], dim=1))
        d2 = self.d2(torch.cat([e2, d3], dim=1))
        d1 = self.d1(torch.cat([e1, d2], dim=1))

        return d1


In [None]:
#print("Before Dataset: {:.2f} MB".format(torch.cuda.memory_allocated() / 1024**2))
dataset = CellDataSet(batch_size=2)
#print("Before Unet: {:.2f} MB".format(torch.cuda.memory_allocated() / 1024**2))
model = SmallerUnet()
#print("By start Memory Allocated: {:.2f} MB".format(torch.cuda.memory_allocated() / 1024**2))
train_model(model, dataset, num_epochs=40, save_path = "smaller_unet.pth")

In [None]:
def test_detector(dataset, model_path):
    device="cuda"
    model = SmallerUnet().to(device)
    model.load_state_dict(torch.load(model_path))
    model.eval()

    y_true, y_pred = [], []

    with torch.no_grad():
        for i in range(len(dataset)):
            image, bin_mask, _ = dataset[i]
            image = image.unsqueeze(0).to(device)
            pred = model(image)

            pred_bin = (torch.sigmoid(pred).squeeze() > 0.5).cpu().numpy().astype(np.uint8)
            true_bin = bin_mask.squeeze().cpu().numpy().astype(np.uint8)

            y_true.extend(true_bin.flatten())
            y_pred.extend(pred_bin.flatten())

            del image, bin_mask, pred
            torch.cuda.empty_cache()

    print("\n📊 Detection Report:")
    print(classification_report(y_true, y_pred, target_names=["Background", "Cell"], zero_division=0))

    cm = confusion_matrix(y_true, y_pred)
    sns.heatmap(cm, annot=True, fmt='d', cmap='Greens',
                xticklabels=["Pred 0", "Pred 1"], yticklabels=["True 0", "True 1"])
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title("Confusion Matrix - Cell Detection")
    plt.tight_layout()
    plt.show()



train_dir = "/kaggle/input/celldetection/ds1/train"
detector_model_path = "/kaggle/working/smaller_unet.pth"
#epochs = 25

# === LOAD DATASET ===
dataset = CellDataset(train_dir)


test_detector(dataset, detector_model_path)

In [None]:
class GrayScaleSmallerUnet(nn.Module):
    def __init__(self):
        super().__init__()

        # Encoder
        self.e1 = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(),
            nn.Conv2d(32, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU()
        )
        self.max_pooling_e1 = nn.MaxPool2d(2)

        self.e2 = nn.Sequential(
            nn.Conv2d(32, 48, 3, padding=1), nn.BatchNorm2d(48), nn.ReLU(),
            nn.Conv2d(48, 48, 3, padding=1), nn.BatchNorm2d(48), nn.ReLU()
        )
        self.max_pooling_e2 = nn.MaxPool2d(2)

        self.e3 = nn.Sequential(
            nn.Conv2d(48, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU()
        )
        self.max_pooling_e3 = nn.MaxPool2d(2)

        # Bottleneck
        self.bottle_neck = nn.Sequential(
            nn.Conv2d(64, 96, 3, padding=1), nn.BatchNorm2d(96), nn.ReLU(),
            nn.Conv2d(96, 96, 3, padding=1), nn.BatchNorm2d(96), nn.ReLU(),
            nn.ConvTranspose2d(96, 64, 2, stride=2)
        )

        # Decoder
        self.d3 = nn.Sequential(
            nn.Conv2d(64 + 64, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(),
            nn.ConvTranspose2d(64, 48, 2, stride=2)
        )

        self.d2 = nn.Sequential(
            nn.Conv2d(48 + 48, 48, 3, padding=1), nn.BatchNorm2d(48), nn.ReLU(),
            nn.Conv2d(48, 48, 3, padding=1), nn.BatchNorm2d(48), nn.ReLU(),
            nn.ConvTranspose2d(48, 32, 2, stride=2)
        )

        self.d1 = nn.Sequential(
            nn.Conv2d(32 + 32, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(),
            nn.Conv2d(32, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(),
            nn.Conv2d(32, 1, 1)
        )

    def forward(self, x):
        e1 = self.e1(x)
        e2 = self.e2(self.max_pooling_e1(e1))
        e3 = self.e3(self.max_pooling_e2(e2))
        bottle = self.bottle_neck(self.max_pooling_e3(e3))

        d3 = self.d3(torch.cat([e3, bottle], dim=1))
        d2 = self.d2(torch.cat([e2, d3], dim=1))
        d1 = self.d1(torch.cat([e1, d2], dim=1))

        return d1

In [None]:
dataset = CellDataSet()
model = GrayScaleSmallerUnet()
train_model(model, dataset, num_epochs=40, grayscale=True, save_path = "smaller_unet.pth")


In [None]:
test_detector(dataset, model_class=GrayScaleSmallerUnet, model_path="smaller_unet.pth", grayscale=True)
show_test(dataset, model_class=GrayScaleSmallerUnet, model_path="smaller_unet.pth", grayscale=True)

In [None]:
dataset = CellDataSet()
test_detector(dataset, model_class=UnetArhitecture, model_path="best_model.pth")


In [None]:
test_detector(dataset, model_class=GrayScaleSmallerUnet, model_path="smaller_unet.pth", grayscale=True)

In [None]:
class GrayScaleMediumUnet(nn.Module):
    def __init__(self):
        super().__init__()

        # Encoder
        self.e1 = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(),
            nn.Conv2d(32, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU()
        )
        self.max_pooling_e1 = nn.MaxPool2d(2)

        self.e2 = nn.Sequential(
            nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU()
        )
        self.max_pooling_e2 = nn.MaxPool2d(2)

        self.e3 = nn.Sequential(
            nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(),
            nn.Conv2d(128, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU()
        )
        self.max_pooling_e3 = nn.MaxPool2d(2)

        # Bottleneck
        self.bottle_neck = nn.Sequential(
            nn.Conv2d(128, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(),
            nn.Conv2d(256, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(),
            nn.ConvTranspose2d(256, 128, 2, stride=2)
        )

        # Decoder
        self.d3 = nn.Sequential(
            nn.Conv2d(128 + 128, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(),
            nn.Conv2d(128, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 2, stride=2)
        )

        self.d2 = nn.Sequential(
            nn.Conv2d(64 + 64, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 2, stride=2)
        )

        self.d1 = nn.Sequential(
            nn.Conv2d(32 + 32, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(),
            nn.Conv2d(32, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(),
            nn.Conv2d(32, 1, 1)
        )

    def forward(self, x):
        e1 = self.e1(x)
        e2 = self.e2(self.max_pooling_e1(e1))
        e3 = self.e3(self.max_pooling_e2(e2))
        bottle = self.bottle_neck(self.max_pooling_e3(e3))

        d3 = self.d3(torch.cat([e3, bottle], dim=1))
        d2 = self.d2(torch.cat([e2, d3], dim=1))
        d1 = self.d1(torch.cat([e1, d2], dim=1))

        return d1

In [None]:
dataset = CellDataSet()
model = GrayScaleMediumUnet()
train_model(model, dataset, num_epochs=80, grayscale=True, save_path = "dice_medium.pth")

In [None]:
test_detector(dataset, model_class=GrayScaleMediumUnet, model_path="dice_medium.pth", grayscale=True)
show_test(dataset, model_class=GrayScaleMediumUnet, model_path="dice_medium.pth", grayscale=True)

In [None]:
dataset = CellDataSet()
model = UnetArhitecture()
train_model(model, dataset, num_epochs=80, save_path = "dice_unet.pth")

In [None]:
dataset = CellDataSet()
test_detector(dataset, model_class=UnetArhitecture, model_path="dice_unet.pth")
show_test(dataset, model_class=UnetArhitecture, model_path="dice_unet.pth")

In [None]:
dataset = CellDataSet()
model = GrayScaleMediumUnet()
train_model(model, dataset, num_epochs=80, grayscale=True, save_path = "dice_HDmedium.pth")

In [None]:
dataset = CellDataSet()
test_detector(dataset, model_class=GrayScaleMediumUnet, model_path="dice_HDmedium.pth",grayscale=True)
show_test(dataset, model_class=GrayScaleMediumUnet, model_path="dice_HDmedium.pth",grayscale=True)

In [None]:
model = GrayScaleMediumUnet()
model.load_state_dict(torch.load("dice_HDmedium.pth"))  # Load weights

train_model(model, dataset, num_epochs=40, grayscale=True, save_path="dice_HDmedium2.pth")

In [None]:
dataset = CellDataSet()
test_detector(dataset, model_class=GrayScaleMediumUnet, model_path="dice_HDmedium.pth",grayscale=True)
show_test(dataset, model_class=GrayScaleMediumUnet, model_path="dice_HDmedium.pth",grayscale=True)

In [None]:
dataset = CellDataSet()
model = GrayScaleMediumUnet()
#0,3 cross 0.6 dice 0,001 haus
train_model(model, dataset, num_epochs=80, grayscale=True, save_path = "dice_MultiLossMedium.pth")

In [None]:
dataset = CellDataSet()
test_detector(dataset, model_class=GrayScaleMediumUnet, model_path="dice_MultiLossMedium.pth",grayscale=True)
show_test(dataset, model_class=GrayScaleMediumUnet, model_path="dice_MultiLossMedium.pth",grayscale=True)

In [None]:
model = GrayScaleMediumUnet()
model.load_state_dict(torch.load("dice_MultiLossMedium.pth"))  # Load weights

train_model(model, dataset, num_epochs=80, grayscale=True, save_path="dice_MultiLossMedium.pth")
test_detector(dataset, model_class=GrayScaleMediumUnet, model_path="dice_MultiLossMedium.pth",grayscale=True)
show_test(dataset, model_class=GrayScaleMediumUnet, model_path="dice_MultiLossMedium.pth",grayscale=True)

In [None]:
dataset = CellDataSet()
model = GrayScaleMediumUnet()
#0,7 cross 0.3 dice
train_model(model, dataset, num_epochs=80, grayscale=True, save_path = "GrayCr07Dice03Medium.pth")

In [None]:
dataset = CellDataSet()
test_detector(dataset, model_class=GrayScaleMediumUnet, model_path="GrayCr07Dice03Medium.pth",grayscale=True)
show_test(dataset, model_class=GrayScaleMediumUnet, model_path="GrayCr07Dice03Medium.pth",grayscale=True)

In [None]:
test_detector(dataset, model_class=GrayScaleSmallerUnet, model_path="smaller_unet.pth",grayscale=True)
show_test(dataset, model_class=GrayScaleSmallerUnet, model_path="smaller_unet.pth",grayscale=True)

In [None]:
model = GrayScaleSmallerUnet()
#0,7 cross 0.3 dice
train_model(model, dataset, num_epochs=80, grayscale=True, save_path = "GrayCr07Dice03Small.pth")

In [None]:
test_detector(dataset, model_class=GrayScaleSmallerUnet, model_path="GrayCr07Dice03Small.pth",grayscale=True)
show_test(dataset, model_class=GrayScaleSmallerUnet, model_path="GrayCr07Dice03Small.pth",grayscale=True)

In [None]:
model = GrayScaleSmallerUnet()
dataset = CellDataSet()
train_model(model, dataset, num_epochs=80, grayscale=True, save_path = "GrayCr03Dice06H0001Small.pth")

In [None]:
dataset = CellDataSet()
test_detector(dataset, model_class=GrayScaleSmallerUnet, model_path="GrayCr03Dice06H0001Small.pth",grayscale=True)
show_test(dataset, model_class=GrayScaleSmallerUnet, model_path="GrayCr03Dice06H0001Small.pth",grayscale=True)

In [None]:
model = SmallerUnet()
#0,7 cross 0.3 dice
train_model(model, dataset, num_epochs=80, save_path = "Cr07Dice03Small.pth")

In [None]:
dataset = CellDataSet()
test_detector(dataset, model_class=SmallerUnet, model_path="Cr07Dice03Small.pth")
show_test(dataset, model_class=SmallerUnet, model_path="Cr07Dice03Small.pth")

In [None]:
class Layer4GrayScaleSmallerUnet(nn.Module):
    def __init__(self):
        super().__init__()

        # Encoder
        self.e1 = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(),
            nn.Conv2d(32, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU()
        )
        self.max_pooling_e1 = nn.MaxPool2d(2)

        self.e2 = nn.Sequential(
            nn.Conv2d(32, 48, 3, padding=1), nn.BatchNorm2d(48), nn.ReLU(),
            nn.Conv2d(48, 48, 3, padding=1), nn.BatchNorm2d(48), nn.ReLU()
        )
        self.max_pooling_e2 = nn.MaxPool2d(2)

        self.e3 = nn.Sequential(
            nn.Conv2d(48, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU()
        )
        self.max_pooling_e3 = nn.MaxPool2d(2)

        self.e4 = nn.Sequential(
            nn.Conv2d(64, 80, 3, padding=1), nn.BatchNorm2d(80), nn.ReLU(),
            nn.Conv2d(80, 80, 3, padding=1), nn.BatchNorm2d(80), nn.ReLU()
        )
        self.max_pooling_e4 = nn.MaxPool2d(2)

        # Bottleneck
        self.bottle_neck = nn.Sequential(
            nn.Conv2d(80, 96, 3, padding=1), nn.BatchNorm2d(96), nn.ReLU(),
            nn.Conv2d(96, 96, 3, padding=1), nn.BatchNorm2d(96), nn.ReLU(),
            nn.ConvTranspose2d(96, 80, 2, stride=2)
        )

        # Decoder
        self.d4 = nn.Sequential(
            nn.Conv2d(80 + 80, 80, 3, padding=1), nn.BatchNorm2d(80), nn.ReLU(),
            nn.Conv2d(80, 80, 3, padding=1), nn.BatchNorm2d(80), nn.ReLU(),
            nn.ConvTranspose2d(80, 64, 2, stride=2)
        )

        self.d3 = nn.Sequential(
            nn.Conv2d(64 + 64, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(),
            nn.ConvTranspose2d(64, 48, 2, stride=2)
        )

        self.d2 = nn.Sequential(
            nn.Conv2d(48 + 48, 48, 3, padding=1), nn.BatchNorm2d(48), nn.ReLU(),
            nn.Conv2d(48, 48, 3, padding=1), nn.BatchNorm2d(48), nn.ReLU(),
            nn.ConvTranspose2d(48, 32, 2, stride=2)
        )

        self.d1 = nn.Sequential(
            nn.Conv2d(32 + 32, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(),
            nn.Conv2d(32, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(),
            nn.Conv2d(32, 1, 1)
        )

    def forward(self, x):
        e1 = self.e1(x)
        e2 = self.e2(self.max_pooling_e1(e1))
        e3 = self.e3(self.max_pooling_e2(e2))
        e4 = self.e4(self.max_pooling_e3(e3))
        bottle = self.bottle_neck(self.max_pooling_e4(e4))

        d4 = self.d4(torch.cat([e4, bottle], dim=1))
        d3 = self.d3(torch.cat([e3, d4], dim=1))
        d2 = self.d2(torch.cat([e2, d3], dim=1))
        d1 = self.d1(torch.cat([e1, d2], dim=1))

        return d1

In [None]:
model = Layer4GrayScaleSmallerUnet()
dataset = CellDataSet()
#0,7 cross 0.3 dice
train_model(model, dataset, num_epochs=80, save_path = "Cr07Dice03L4Small.pth",grayscale=True)

In [None]:
dataset = CellDataSet()
model = Layer4GrayScaleSmallerUnet()
model.load_state_dict(torch.load("Cr07Dice03L4Small.pth"))  # Load weights

train_model(model, dataset, num_epochs=80, grayscale=True, save_path="Cr07Dice03L4Small.pth")

In [None]:
dataset = CellDataSet()
test_detector(dataset, model_class=Layer4GrayScaleSmallerUnet, model_path="Cr07Dice03L4Small.pth", grayscale=True)
show_test(dataset, model_class=Layer4GrayScaleSmallerUnet, model_path="Cr07Dice03L4Small.pth", grayscale=True)

In [None]:
dataset = CellDataSet()
model = Layer4GrayScaleSmallerUnet()
model.load_state_dict(torch.load("Cr07Dice03L4Small2.pth")) 

train_model(model, dataset, num_epochs=80, grayscale=True, save_path="Cr07Dice03L4Small2.pth")
test_detector(dataset, model_class=Layer4GrayScaleSmallerUnet, model_path="Cr07Dice03L4Small2.pth", grayscale=True)
show_test(dataset, model_class=Layer4GrayScaleSmallerUnet, model_path="Cr07Dice03L4Small2.pth", grayscale=True)

In [None]:
dataset = CellDataSet()
model = Layer4GrayScaleSmallerUnet()
model.load_state_dict(torch.load("CDH05_0005.pth"))
test_detector(dataset, model_class=Layer4GrayScaleSmallerUnet, model_path="CDH05_0005.pth", grayscale=True)
show_test(dataset, model_class=Layer4GrayScaleSmallerUnet, model_path="CDH05_0005.pth", grayscale=True)
train_model(model, dataset, num_epochs=80, grayscale=True, save_path="CDH05_0005.pth")

In [None]:
dataset = CellDataSet()
model = GrayScaleSmallerUnet()

train_model(model, dataset, num_epochs=80, grayscale=True, save_path="F05Dice05H0002L4Small2.pth")
test_detector(dataset, model_class=GrayScaleSmallerUnet, model_path="F05Dice05H0002L4Small2.pth", grayscale=True)
show_test(dataset, model_class=GrayScaleSmallerUnet, model_path="F05Dice05H0002L4Small2.pth", grayscale=True)

In [None]:
dataset = CellDataSet()
model = GrayScaleSmallerUnet()

train_model(model, dataset, num_epochs=80, grayscale=True, save_path="F07Dice03H0002L4Small2.pth")
test_detector(dataset, model_class=GrayScaleSmallerUnet, model_path="F07Dice03H0002L4Small2.pth", grayscale=True)
show_test(dataset, model_class=GrayScaleSmallerUnet, model_path="F07Dice03H0002L4Small2.pth", grayscale=True)

In [None]:
dataset = CellDataSet()
model = GrayScaleSmallerUnet()
model.load_state_dict(torch.load("F07Dice03H0002L4Small2.pth"))
test_detector(dataset, model_class=GrayScaleSmallerUnet, model_path="F07Dice03H0002L4Small2.pth", grayscale=True)
show_test(dataset, model_class=GrayScaleSmallerUnet, model_path="F07Dice03H0002L4Small2.pth", grayscale=True)
train_model(model, dataset, num_epochs=80, grayscale=True, save_path="F07Dice03H0002L4Small2.pth")
test_detector(dataset, model_class=GrayScaleSmallerUnet, model_path="F07Dice03H0002L4Small2.pth", grayscale=True)
show_test(dataset, model_class=GrayScaleSmallerUnet, model_path="F07Dice03H0002L4Small2.pth", grayscale=True)

In [None]:
dataset = CellDataSet()
model = GrayScaleSmallerUnet()

train_model(model, dataset, num_epochs=80, grayscale=True, save_path="F05Dice05BC052.pth")
test_detector(dataset, model_class=GrayScaleSmallerUnet, model_path="F05Dice05BC052.pth", grayscale=True)
show_test(dataset, model_class=GrayScaleSmallerUnet, model_path="F05Dice05BC052.pth", grayscale=True)

In [None]:
dataset = CellDataSet()
model = GrayScaleSmallerUnet()

# Antrenare
train_model(model, dataset, num_epochs=80, grayscale=True, scale_factor=2, save_path="SC2F05Dice05BC052.pth")




In [None]:
# Testare
test_detector(dataset, model_class=GrayScaleSmallerUnet, model_path="SC2F05Dice05BC052.pth", grayscale=True, scale_factor=2)
show_test(dataset, model_class=GrayScaleSmallerUnet, model_path="SC2F05Dice05BC052.pth", grayscale=True, scale_factor=2)

In [None]:
dataset = CellDataSet()
model = GrayScaleSmallerUnet()

train_model(model, dataset, num_epochs=80, grayscale=True, save_path="F07Dice015BC015.pth")
test_detector(dataset, model_class=GrayScaleSmallerUnet, model_path="F07Dice015BC015.pth", grayscale=True)
show_test(dataset, model_class=GrayScaleSmallerUnet, model_path="F07Dice015BC015.pth", grayscale=True)

In [None]:
class GrUnetArhitecture(nn.Module):
    def __init__(self):
        super().__init__()
        
        #Encoder
        self.e1 = nn.Sequential(
            nn.Conv2d(1,64,3,padding=1),nn.BatchNorm2d(64),nn.ReLU(),
            nn.Conv2d(64,64,3,padding=1),nn.BatchNorm2d(64),nn.ReLU()
        )
        
        self.max_pooling_e1 = nn.MaxPool2d(2) 
        
        self.e2 = nn.Sequential(
            nn.Conv2d(64,128,3,padding=1),nn.BatchNorm2d(128),nn.ReLU(),
            nn.Conv2d(128,128,3,padding=1),nn.BatchNorm2d(128),nn.ReLU()
        )
        
        self.max_pooling_e2 = nn.MaxPool2d(2) 
        
        self.e3 = nn.Sequential(
            nn.Conv2d(128,256,3,padding=1),nn.BatchNorm2d(256),nn.ReLU(),
            nn.Conv2d(256,256,3,padding=1),nn.BatchNorm2d(256),nn.ReLU()
        )
        
        #BottleNeck
        self.max_pooling_e3 = nn.MaxPool2d(2)
        self.bottle_neck = nn.Sequential(
            nn.Conv2d(256,512,3,padding=1),nn.BatchNorm2d(512),nn.ReLU(),
            nn.Conv2d(512,512,3,padding=1),nn.BatchNorm2d(512),nn.ReLU(),
            nn.ConvTranspose2d(512, 256, 2, stride=2)
        )
    
        self.d3 = nn.Sequential(
            nn.Conv2d(512,256,3,padding=1),nn.BatchNorm2d(256),nn.ReLU(), #at this part concat with e3 (256 from e3 + 256 from bottle_neck)
            nn.Conv2d(256,256,3,padding=1),nn.BatchNorm2d(256),nn.ReLU(),
            nn.ConvTranspose2d(256, 128, 2, stride=2)
        )
        
        self.d2 = nn.Sequential(
            nn.Conv2d(256,128,3,padding=1),nn.BatchNorm2d(128),nn.ReLU(), #at this part concat with e2 (128 from e2 + 128 from d3)
            nn.Conv2d(128,128,3,padding=1),nn.BatchNorm2d(128),nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 2, stride=2)
        )
        
        self.d1 = nn.Sequential(
            nn.Conv2d(128,64,3,padding=1),nn.BatchNorm2d(64),nn.ReLU(), #at this part concat with e1 (64 from e1 + 64 from d2)
            nn.Conv2d(64,64,3,padding=1),nn.BatchNorm2d(64),nn.ReLU(),
            nn.Conv2d(64, 1, 1)
        )
    
    def forward(self, x):
        #print("Input:", x.shape)
        e1 = self.e1(x)
        #print("e1:", e1.shape)
        e2 = self.e2(self.max_pooling_e1(e1))
        #print("e2:", e2.shape)
        e3 = self.e3(self.max_pooling_e2(e2))
        #print("e3:", e3.shape)
        bottle_neck = self.bottle_neck(self.max_pooling_e3(e3))
        #print("bottleneck:", bottle_neck.shape)
        d3 = self.d3(torch.cat([e3, bottle_neck], dim=1))
        #print("d3:", d3.shape)
        d2 = self.d2(torch.cat([e2, d3], dim=1))
        #print("d2:", d2.shape)
        d1 = self.d1(torch.cat([e1, d2], dim=1))
        #print("d1 / Output:", d1.shape)
        return d1


In [None]:
dataset = CellDataSet()
model = GrUnetArhitecture()

train_model(model, dataset, num_epochs=80, grayscale=True, save_path="BUGF07Dice015BC015.pth")
test_detector(dataset, model_class=GrUnetArhitecture, model_path="BUGF07Dice015BC015.pth", grayscale=True)
show_test(dataset, model_class=GrUnetArhitecture, model_path="BUGF07Dice015BC015.pth", grayscale=True)

In [None]:
dataset = CellDataSet()
model = GrUnetArhitecture()

train_model(model, dataset, num_epochs=80, grayscale=True, save_path="BUGF05Dice05BC05.pth")
test_detector(dataset, model_class=GrUnetArhitecture, model_path="BUGF05Dice05BC05.pth", grayscale=True)
show_test(dataset, model_class=GrUnetArhitecture, model_path="BUGF05Dice05BC05.pth", grayscale=True)

In [None]:
dataset = CellDataSet()
model = GrUnetArhitecture()

train_model(model, dataset, num_epochs=80, grayscale=True, scale_factor=2,save_path="SC2BUGF05Dice05BC05.pth")
test_detector(dataset, model_class=GrUnetArhitecture,scale_factor=2, model_path="SC2BUGF05Dice05BC05.pth", grayscale=True)
show_test(dataset, model_class=GrUnetArhitecture,scale_factor=2, model_path="SC2BUGF05Dice05BC05.pth", grayscale=True)

In [None]:
dataset = CellDataSet()
model = GrUnetArhitecture()

train_model(model, dataset, num_epochs=80, grayscale=True, scale_factor=2,save_path="SC2BUGF015Dice07BC015.pth")
test_detector(dataset, model_class=GrUnetArhitecture,scale_factor=2, model_path="SC2BUGF015Dice07BC015.pth", grayscale=True)
show_test(dataset, model_class=GrUnetArhitecture,scale_factor=2, model_path="SC2BUGF015Dice07BC015.pth", grayscale=True)

In [None]:
from monai.metrics import DiceMetric, HausdorffDistanceMetric, SurfaceDistanceMetric
from monai.transforms import AsDiscrete
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
import torch.nn.functional as F

def test_detector(dataset, model_class, model_path, grayscale=False, scharr=False, scale_factor=1, device="cuda"):
    model = model_class().to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()
    
    # MONAI metric utilities
    dice_metric = DiceMetric(include_background=False, reduction="none")  # Changed to "none" to get per-image scores
    hausdorff = HausdorffDistanceMetric(include_background=False, percentile=95, reduction="none")
    surface_dice = SurfaceDistanceMetric(include_background=False, symmetric=True, reduction="none")
    post_pred = AsDiscrete(threshold=0.5)
    post_label = AsDiscrete(threshold=0.5)
    
    all_preds = []
    all_labels = []
    all_multi_masks = []  # Store multi-class masks for confusion matrix
    
    # Lists to store per-image metrics for visualization
    dice_scores = []
    hausdorff_distances = []
    surface_dice_scores = []
    
    print("🧪 Running detection on test dataset...")
    with torch.no_grad():
        for i in range(len(dataset.test)):
            sample = dataset.get_test_index(i, grayscale=grayscale, scharr=scharr)
            if sample[0] is None:
                continue
            
            image, bin_mask, multi_mask = sample
            orig_size = image.shape[-2:]  # Save original size
            
            if scale_factor != 1:
                image = F.interpolate(image.unsqueeze(0), scale_factor=1/scale_factor, mode='bilinear', align_corners=False)
            else:
                image = image.unsqueeze(0)
            
            bin_mask = bin_mask.unsqueeze(0).to(device)
            image = image.to(device)
            
            pred = model(image)
            if scale_factor != 1:
                pred = F.interpolate(pred, size=orig_size, mode='bilinear', align_corners=False)
            
            pred_bin = post_pred(torch.sigmoid(pred))
            label_bin = post_label(bin_mask)
            
            # Calculate per-image metrics
            dice_val = dice_metric(pred_bin, label_bin).item()
            hausdorff_val = hausdorff(pred_bin, label_bin).item()
            surface_dice_val = surface_dice(pred_bin, label_bin).item()
            
            # Store metrics
            dice_scores.append(dice_val)
            hausdorff_distances.append(hausdorff_val)
            surface_dice_scores.append(surface_dice_val)
            
            all_preds.append(pred_bin)
            all_labels.append(label_bin)
            all_multi_masks.append(multi_mask.unsqueeze(0))  # Keep multi-class mask
            
            del image, bin_mask, pred
            torch.cuda.empty_cache()
    
    preds_tensor = torch.cat(all_preds, dim=0)
    labels_tensor = torch.cat(all_labels, dim=0)
    multi_masks_tensor = torch.cat(all_multi_masks, dim=0)
    
    # Compute overall metrics (mean of per-image metrics)
    mean_dice = np.mean(dice_scores)
    mean_hausdorff = np.mean([h for h in hausdorff_distances if not np.isinf(h)])
    mean_surface_dice = np.mean(surface_dice_scores)
    
    print(f"\n📊 Standard Metrics:")
    print(f"🎯 Dice Score: {mean_dice:.4f} ± {np.std(dice_scores):.4f}")
    print(f"📏 Hausdorff Distance (95th percentile): {mean_hausdorff:.4f} ± {np.std([h for h in hausdorff_distances if not np.isinf(h)]):.4f}")
    print(f"🌊 Normalized Surface Dice: {mean_surface_dice:.4f} ± {np.std(surface_dice_scores):.4f}")
    
    # Create metrics distribution plots
    create_metrics_plots(dice_scores, hausdorff_distances, surface_dice_scores)
    
    # Create confusion matrix analysis
    analyze_cell_type_confusion(preds_tensor, multi_masks_tensor)
    
    return mean_dice, mean_hausdorff, mean_surface_dice

def create_metrics_plots(dice_scores, hausdorff_distances, surface_dice_scores):
    """
    Create visualization plots for the three main metrics.
    """
    
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    
    # Dice Score distribution
    if dice_scores:
        axes[0].hist(dice_scores, bins=20, alpha=0.7, color='skyblue', edgecolor='black')
        axes[0].set_xlabel('Dice Score')
        axes[0].set_ylabel('Number of Images')
        axes[0].set_title('Distribution of Dice Scores')
        axes[0].axvline(np.mean(dice_scores), color='red', linestyle='--', 
                       label=f'Mean: {np.mean(dice_scores):.3f}')
        axes[0].legend()
        axes[0].grid(True, alpha=0.3)
    
    # Hausdorff Distance distribution
    if hausdorff_distances:
        finite_hausdorff = [h for h in hausdorff_distances if not np.isinf(h)]
        if finite_hausdorff:
            axes[1].hist(finite_hausdorff, bins=20, alpha=0.7, color='lightcoral', edgecolor='black')
            axes[1].set_xlabel('Hausdorff Distance')
            axes[1].set_ylabel('Number of Images')
            axes[1].set_title('Distribution of Hausdorff Distances')
            axes[1].axvline(np.mean(finite_hausdorff), color='red', linestyle='--', 
                           label=f'Mean: {np.mean(finite_hausdorff):.3f}')
            axes[1].legend()
            axes[1].grid(True, alpha=0.3)
        else:
            axes[1].text(0.5, 0.5, 'No finite Hausdorff distances', 
                        transform=axes[1].transAxes, ha='center', va='center')
            axes[1].set_title('Distribution of Hausdorff Distances')
    
    # Surface Dice distribution
    if surface_dice_scores:
        axes[2].hist(surface_dice_scores, bins=20, alpha=0.7, color='lightgreen', edgecolor='black')
        axes[2].set_xlabel('Surface Dice Score')
        axes[2].set_ylabel('Number of Images')
        axes[2].set_title('Distribution of Surface Dice Scores')
        axes[2].axvline(np.mean(surface_dice_scores), color='red', linestyle='--', 
                       label=f'Mean: {np.mean(surface_dice_scores):.3f}')
        axes[2].legend()
        axes[2].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Print additional statistics
    print(f"\n📈 Detailed Metrics Statistics:")
    print("=" * 50)
    if dice_scores:
        print(f"Dice Score - Min: {np.min(dice_scores):.3f}, Max: {np.max(dice_scores):.3f}, Median: {np.median(dice_scores):.3f}")
    
    finite_hausdorff = [h for h in hausdorff_distances if not np.isinf(h)]
    if finite_hausdorff:
        print(f"Hausdorff Distance - Min: {np.min(finite_hausdorff):.3f}, Max: {np.max(finite_hausdorff):.3f}, Median: {np.median(finite_hausdorff):.3f}")
    
    if surface_dice_scores:
        print(f"Surface Dice - Min: {np.min(surface_dice_scores):.3f}, Max: {np.max(surface_dice_scores):.3f}, Median: {np.median(surface_dice_scores):.3f}")

def analyze_cell_type_confusion(preds_tensor, multi_masks_tensor):
    """
    Analyze which cell types are being missed by the binary segmentation model
    """
    print(f"\n🔍 Cell Type Detection Analysis:")
    
    # Convert tensors to numpy and flatten
    preds_np = preds_tensor.cpu().numpy().flatten()
    multi_masks_np = multi_masks_tensor.cpu().numpy().flatten()
    
    # Cell type names (adjust according to your dataset)
    cell_types = ['Background', 'Type 1', 'Type 2', 'Type 3', 'Type 4', 
                  'Type 5', 'Type 6', 'Type 7']
    
    # Create binary ground truth from multi-class (any non-zero = cell)
    gt_binary = (multi_masks_np > 0).astype(int)
    
    # Statistics for each cell type
    print("\n📈 Detection Statistics by Cell Type:")
    print("=" * 60)
    
    overall_stats = {
        'cell_type': [],
        'total_pixels': [],
        'detected_pixels': [],
        'detection_rate': [],
        'false_positive_on_bg': 0,
        'total_bg_pixels': 0
    }
    
    for cell_type_id in range(len(cell_types)):
        # Mask for current cell type
        cell_mask = (multi_masks_np == cell_type_id)
        total_pixels = cell_mask.sum()
        
        if cell_type_id == 0:  # Background
            # For background, count false positives
            false_positives = ((cell_mask) & (preds_np == 1)).sum()
            overall_stats['false_positive_on_bg'] = false_positives
            overall_stats['total_bg_pixels'] = total_pixels
            print(f"{cell_types[cell_type_id]:>10}: {total_pixels:>8} pixels | "
                  f"False Positives: {false_positives:>6} ({100*false_positives/total_pixels:.2f}%)")
        else:  # Cell types
            # For cell types, count true positives (correctly detected)
            detected_pixels = ((cell_mask) & (preds_np == 1)).sum()
            detection_rate = detected_pixels / total_pixels if total_pixels > 0 else 0
            
            overall_stats['cell_type'].append(cell_types[cell_type_id])
            overall_stats['total_pixels'].append(total_pixels)
            overall_stats['detected_pixels'].append(detected_pixels)
            overall_stats['detection_rate'].append(detection_rate)
            
            print(f"{cell_types[cell_type_id]:>10}: {total_pixels:>8} pixels | "
                  f"Detected: {detected_pixels:>6} ({100*detection_rate:.2f}%)")
    
    # Create confusion matrix visualization
    create_confusion_matrix_plot(preds_np, multi_masks_np, cell_types)
    
    # Summary of poorly detected cell types
    print(f"\n⚠️  Detection Performance Summary:")
    print("=" * 50)
    
    detection_rates = overall_stats['detection_rate']
    if detection_rates:
        avg_detection_rate = np.mean(detection_rates)
        print(f"Average Detection Rate: {100*avg_detection_rate:.2f}%")
        
        # Find poorly detected cell types (below average)
        poor_detection_threshold = avg_detection_rate * 0.8  # 80% of average
        poorly_detected = []
        
        for i, (cell_type, rate) in enumerate(zip(overall_stats['cell_type'], detection_rates)):
            if rate < poor_detection_threshold:
                poorly_detected.append((cell_type, rate))
        
        if poorly_detected:
            print(f"\n🚨 Cell types with poor detection (< {100*poor_detection_threshold:.1f}%):")
            for cell_type, rate in poorly_detected:
                print(f"   • {cell_type}: {100*rate:.2f}%")
        else:
            print("✅ All cell types have reasonable detection rates!")
    
    # Background false positive analysis
    bg_fp_rate = overall_stats['false_positive_on_bg'] / overall_stats['total_bg_pixels']
    print(f"\n🎭 Background False Positive Rate: {100*bg_fp_rate:.2f}%")
    
    return overall_stats

def create_confusion_matrix_plot(preds_np, multi_masks_np, cell_types):
    """
    Create a confusion matrix plot showing binary prediction vs multi-class ground truth
    """
    # Create a modified ground truth for visualization
    # 0 = background, 1-8 = cell types, but we'll group cell types
    gt_for_confusion = multi_masks_np.copy()
    
    # For confusion matrix, we want to see:
    # - Background correctly classified as background (TN)
    # - Background incorrectly classified as cell (FP)  
    # - Each cell type correctly classified as cell (TP)
    # - Each cell type incorrectly classified as background (FN)
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
    
    # Plot 1: Detailed confusion matrix (background vs each cell type)
    confusion_detailed = np.zeros((len(cell_types), 2))  # [cell_types x 2] (pred_bg, pred_cell)
    
    for cell_type_id in range(len(cell_types)):
        cell_mask = (multi_masks_np == cell_type_id)
        
        # Count predictions for this cell type
        pred_bg_count = ((cell_mask) & (preds_np == 0)).sum()
        pred_cell_count = ((cell_mask) & (preds_np == 1)).sum()
        
        confusion_detailed[cell_type_id, 0] = pred_bg_count
        confusion_detailed[cell_type_id, 1] = pred_cell_count
    
    # Normalize by row to show percentages
    confusion_detailed_pct = confusion_detailed / (confusion_detailed.sum(axis=1, keepdims=True) + 1e-8)
    
    sns.heatmap(confusion_detailed_pct, 
                xticklabels=['Pred: Background', 'Pred: Cell'],
                yticklabels=cell_types,
                annot=True, 
                fmt='.3f',
                cmap='Blues',
                ax=ax1)
    ax1.set_title('Detection Rate by Cell Type\n(Row-normalized)')
    ax1.set_ylabel('True Cell Type')
    
    # Plot 2: Simple 2x2 confusion matrix (Binary: Cell vs Background)
    gt_binary = (multi_masks_np > 0).astype(int)
    cm_binary = confusion_matrix(gt_binary, preds_np)
    
    sns.heatmap(cm_binary, 
                xticklabels=['Pred: Background', 'Pred: Cell'],
                yticklabels=['True: Background', 'True: Cell'],
                annot=True, 
                fmt='d',
                cmap='Blues',
                ax=ax2)
    ax2.set_title('Binary Confusion Matrix\n(Cell vs Background)')
    
    plt.tight_layout()
    plt.show()
    
    # Print binary confusion matrix metrics
    tn, fp, fn, tp = cm_binary.ravel()
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
    
    print(f"\n📋 Binary Classification Metrics:")
    print(f"   Precision (Cell): {precision:.4f}")
    print(f"   Recall (Cell):    {recall:.4f}")
    print(f"   F1-Score:         {f1:.4f}")
    print(f"   True Positives:   {tp:,}")
    print(f"   False Positives:  {fp:,}")
    print(f"   False Negatives:  {fn:,}")
    print(f"   True Negatives:   {tn:,}")

In [None]:
dataset = CellDataSet()

In [None]:
test_detector(dataset, model_class=GrUnetArhitecture,scale_factor=2, model_path="SC2BUGF015Dice07BC015.pth", grayscale=True)

In [None]:
test_detector(dataset, model_class=GrUnetArhitecture, model_path="BUGF05Dice05BC05.pth", grayscale=True)

In [None]:
test_detector(dataset, model_class=GrayScaleSmallerUnet, model_path="smaller_unet.pth", grayscale=True)

In [None]:
test_detector(dataset, model_class=GrayScaleMediumUnet, model_path="dice_medium.pth", grayscale=True)

In [None]:
test_detector(dataset, model_class=GrayScaleSmallerUnet, model_path="F05Dice05H0002L4Small2.pth", grayscale=True)

In [None]:
test_detector(dataset, model_class=UnetArhitecture, model_path="dice_unet.pth")

In [None]:
test_detector(dataset, model_class=GrayScaleMediumUnet, model_path="dice_HDmedium.pth",grayscale=True)

In [None]:
test_detector(dataset, model_class=GrayScaleMediumUnet, model_path="dice_MultiLossMedium.pth",grayscale=True)

In [None]:
test_detector(dataset, model_class=GrayScaleSmallerUnet, model_path="GrayCr07Dice03Small.pth",grayscale=True)
test_detector(dataset, model_class=GrayScaleSmallerUnet, model_path="GrayCr03Dice06H0001Small.pth",grayscale=True)


In [None]:
import gc
gc.collect()
torch.cuda.empty_cache()

In [None]:
test_detector(dataset, model_class=SmallerUnet, model_path="Cr07Dice03Small.pth")
test_detector(dataset, model_class=Layer4GrayScaleSmallerUnet, model_path="Cr07Dice03L4Small.pth", grayscale=True)
test_detector(dataset, model_class=Layer4GrayScaleSmallerUnet, model_path="Cr07Dice03L4Small2.pth", grayscale=True)
test_detector(dataset, model_class=Layer4GrayScaleSmallerUnet, model_path="CDH05_0005.pth", grayscale=True)
test_detector(dataset, model_class=GrayScaleSmallerUnet, model_path="F05Dice05H0002L4Small2.pth", grayscale=True)
test_detector(dataset, model_class=GrayScaleSmallerUnet, model_path="F07Dice03H0002L4Small2.pth", grayscale=True)

In [None]:
test_detector(dataset, model_class=GrayScaleSmallerUnet, model_path="F07Dice03H0002L4Small2.pth", grayscale=True)
test_detector(dataset, model_class=GrayScaleSmallerUnet, model_path="F05Dice05BC052.pth", grayscale=True)

In [None]:
dataset = CellDataSet()
model = Layer4GrayScaleSmallerUnet()
model.load_state_dict(torch.load("L4F04D06H0003.pth")) 
train_model(model, dataset, num_epochs=80, grayscale=True, save_path="L4F04D06H0003.pth")
test_detector(dataset, model_class=Layer4GrayScaleSmallerUnet, model_path="L4F04D06H0003.pth", grayscale=True)

In [None]:

gc.collect()
dataset = CellDataSet()
try:
    model = GrayScaleSmallerUnet()
    train_model(model, dataset, num_epochs=80, grayscale=True, save_path="SF04D06H0003.pth")
    test_detector(dataset, model_class=GrayScaleSmallerUnet, model_path="SF04D06H0003.pth", grayscale=True)
except:
    print("An exception occurred")
    gc.collect()
    torch.cuda.empty_cache()
    model = GrayScaleSmallerUnet()
    model.load_state_dict(torch.load("SF04D06H0003.pth")) 
    train_model(model, dataset, num_epochs=80, grayscale=True, save_path="SF04D06H0003.pth")
    test_detector(dataset, model_class=GrayScaleSmallerUnet, model_path="SF04D06H0003.pth", grayscale=True)


In [None]:

dataset = CellDataSet()
model = GrayScaleSmallerUnet()
model.load_state_dict(torch.load("SF04D06H0003.pth")) 
train_model(model, dataset, num_epochs=80, grayscale=True, save_path="SF04D06H0003.pth")
test_detector(dataset, model_class=GrayScaleSmallerUnet, model_path="SF04D06H0003.pth", grayscale=True)


In [None]:

test_detector(dataset, model_class=GrayScaleMediumUnet, model_path="MF05D05H0002v2.pth", grayscale=True)
model = GrayScaleMediumUnet()
model.load_state_dict(torch.load("MF05D05H0002v2.pth")) 
train_model(model, dataset, num_epochs=80, grayscale=True, save_path="MF05D05H0002v3.pth")
test_detector(dataset, model_class=GrayScaleMediumUnet, model_path="MF05D05H0002v3.pth", grayscale=True)

In [None]:
gc.collect()
model = GrayScaleMediumUnet()
train_model(model, dataset, num_epochs=80, grayscale=True, save_path="MF05D05H0002NoAUG.pth")
test_detector(dataset, model_class=GrayScaleMediumUnet, model_path="MF05D05H0002NoAUG.pth", grayscale=True)

In [None]:
import os
import cv2
import gc
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.cluster import DBSCAN
from sklearn.metrics import confusion_matrix
from pathlib import Path
import glob
from tqdm import tqdm
import torch
from monai.metrics import DiceMetric, HausdorffDistanceMetric, SurfaceDistanceMetric
from monai.transforms import AsDiscrete

def dbscan_cell_detection(image_path, color_eps=0.1, color_min_samples=3, 
                         spatial_eps=0.005, spatial_min_samples=80, visualize=False):
    """
    Apply DBSCAN-based cell detection to a single image.
    
    Args:
        image_path: Path to the input image
        color_eps: DBSCAN epsilon for color clustering
        color_min_samples: DBSCAN min_samples for color clustering
        spatial_eps: DBSCAN epsilon for spatial clustering
        spatial_min_samples: DBSCAN min_samples for spatial clustering
        visualize: Whether to show intermediate visualizations
    
    Returns:
        mask_grayscale: Binary mask of detected cells
        stats: Dictionary with detection statistics
    """
    # Read image
    picture = np.array(cv2.imread(image_path, cv2.IMREAD_COLOR))
    if picture is None:
        return None, None
    
    original_picture = picture.copy()
    original_shape = picture.shape
    
    # Step 1: Color-based clustering to remove background
    picture_reshaped = picture.reshape((picture.shape[0] * picture.shape[1], picture.shape[2]))
    
    # Apply DBSCAN on color space
    dbscan_color = DBSCAN(eps=color_eps, min_samples=color_min_samples)
    color_labels = dbscan_color.fit_predict(picture_reshaped)
    
    # Create mask - set background (noise) to black
    color_mask = np.array(picture_reshaped)
    color_mask[color_labels != -1] = np.array([0, 0, 0])
    color_mask = color_mask.reshape(original_shape)
    color_mask = color_mask.astype(np.uint8)
    
    if visualize:
        plt.figure(figsize=(15, 5))
        plt.subplot(1, 3, 1)
        plt.imshow(cv2.cvtColor(original_picture, cv2.COLOR_BGR2RGB))
        plt.title('Original Image')
        plt.axis('off')
        
        plt.subplot(1, 3, 2)
        plt.imshow(cv2.cvtColor(color_mask, cv2.COLOR_BGR2RGB))
        plt.title('After Color Clustering')
        plt.axis('off')
    
    # Step 2: Spatial clustering on remaining pixels
    cells = np.argwhere(np.any(color_mask != [0, 0, 0], axis=-1))
    
    if len(cells) == 0:
        # No cells detected
        mask_grayscale = np.zeros(original_shape[:2], dtype=np.uint8)
        stats = {
            'total_pixels': original_shape[0] * original_shape[1],
            'background_pixels': original_shape[0] * original_shape[1],
            'cell_pixels': 0,
            'detected_clusters': 0,
            'noise_pixels': 0
        }
        return mask_grayscale, stats
    
    # Normalize pixel coordinates
    normalized_pixels = cells / np.array(original_shape[0:2])
    
    # Apply DBSCAN on spatial coordinates
    dbscan_spatial = DBSCAN(eps=spatial_eps, min_samples=spatial_min_samples)
    spatial_labels = dbscan_spatial.fit_predict(normalized_pixels)
    
    # Create final mask
    cells_valid = cells[spatial_labels != -1]
    spatial_labels_valid = spatial_labels[spatial_labels != -1]
    
    mask_grayscale = np.zeros(original_shape[:2], dtype=np.uint8)
    if len(cells_valid) > 0:
        mask_grayscale[cells_valid[:, 0], cells_valid[:, 1]] = 255
    
    if visualize:
        plt.subplot(1, 3, 3)
        plt.imshow(mask_grayscale, cmap='gray')
        plt.title('Final Cell Detection')
        plt.axis('off')
        plt.tight_layout()
        plt.show()
    
    # Calculate statistics
    total_pixels = original_shape[0] * original_shape[1]
    cell_pixels = len(cells_valid)
    noise_pixels = len(cells) - len(cells_valid) if len(cells) > 0 else 0
    background_pixels = total_pixels - len(cells)
    detected_clusters = len(np.unique(spatial_labels_valid)) if len(spatial_labels_valid) > 0 else 0
    
    stats = {
        'total_pixels': total_pixels,
        'background_pixels': background_pixels,
        'cell_pixels': cell_pixels,
        'detected_clusters': detected_clusters,
        'noise_pixels': noise_pixels,
        'detection_rate': cell_pixels / total_pixels,
        'cluster_efficiency': cell_pixels / len(cells) if len(cells) > 0 else 0
    }
    
    return mask_grayscale, stats

def compute_advanced_metrics(pred_mask, gt_mask):
    """
    Compute Dice score, Hausdorff distance, and Surface Dice using MONAI metrics.
    
    Args:
        pred_mask: Predicted binary mask (H, W)
        gt_mask: Ground truth binary mask (H, W)
    
    Returns:
        dict: Dictionary containing the three metrics
    """
    # Convert to torch tensors and add batch and channel dimensions
    pred_tensor = torch.tensor(pred_mask, dtype=torch.float32).unsqueeze(0).unsqueeze(0)  # (1, 1, H, W)
    gt_tensor = torch.tensor(gt_mask, dtype=torch.float32).unsqueeze(0).unsqueeze(0)      # (1, 1, H, W)
    
    # Ensure binary masks (0 or 1)
    pred_tensor = (pred_tensor > 0.5).float()
    gt_tensor = (gt_tensor > 0.5).float()
    
    # Initialize MONAI metrics
    dice_metric = DiceMetric(include_background=False, reduction="mean")
    hausdorff_metric = HausdorffDistanceMetric(include_background=False, percentile=95)
    surface_dice_metric = SurfaceDistanceMetric(include_background=False, symmetric=True)
    
    try:
        # Compute metrics
        dice_score = dice_metric(pred_tensor, gt_tensor).item()
        hausdorff_dist = hausdorff_metric(pred_tensor, gt_tensor).item()
        surface_dice = surface_dice_metric(pred_tensor, gt_tensor).item()
        
        # Handle edge cases
        if np.isnan(dice_score) or np.isinf(dice_score):
            dice_score = 0.0
        if np.isnan(hausdorff_dist) or np.isinf(hausdorff_dist):
            hausdorff_dist = float('inf')
        if np.isnan(surface_dice) or np.isinf(surface_dice):
            surface_dice = 0.0
            
    except Exception as e:
        print(f"Warning: Error computing metrics: {e}")
        dice_score = 0.0
        hausdorff_dist = float('inf')
        surface_dice = 0.0
    gc.collect()
    return {
        'dice_score': dice_score,
        'hausdorff_distance': hausdorff_dist,
        'surface_dice': surface_dice
    }

def analyze_cell_type_detection(pred_mask, multi_mask, cell_types=None):
    """
    Analyze detection performance for each cell type using multi-class ground truth.
    
    Args:
        pred_mask: Binary prediction mask (H, W)
        multi_mask: Multi-class ground truth mask (H, W)
        cell_types: List of cell type names
    
    Returns:
        dict: Detection statistics for each cell type
    """
    if cell_types is None:
        unique_labels = np.unique(multi_mask)
        cell_types = [f'Type_{i}' if i > 0 else 'Background' for i in unique_labels]
    
    # Ensure binary prediction mask
    pred_binary = (pred_mask > 0).astype(int)
    
    cell_type_stats = {}
    
    for cell_type_id in np.unique(multi_mask):
        cell_name = cell_types[cell_type_id] if cell_type_id < len(cell_types) else f'Type_{cell_type_id}'
        
        # Mask for current cell type
        cell_mask = (multi_mask == cell_type_id)
        total_pixels = cell_mask.sum()
        
        if cell_type_id == 0:  # Background
            # For background, count false positives
            detected_pixels = ((cell_mask) & (pred_binary == 1)).sum()
            detection_rate = detected_pixels / total_pixels if total_pixels > 0 else 0
            cell_type_stats[cell_name] = {
                'total_pixels': total_pixels,
                'false_positives': detected_pixels,
                'false_positive_rate': detection_rate
            }
        else:  # Cell types
            # For cell types, count true positives
            detected_pixels = ((cell_mask) & (pred_binary == 1)).sum()
            detection_rate = detected_pixels / total_pixels if total_pixels > 0 else 0
            cell_type_stats[cell_name] = {
                'total_pixels': total_pixels,
                'detected_pixels': detected_pixels,
                'detection_rate': detection_rate
            }
    
    return cell_type_stats

def test_dbscan_detector(image_folder, bin_mask_folder, mult_mask_folder,
                        color_eps=0.1, color_min_samples=3,
                        spatial_eps=0.005, spatial_min_samples=80,
                        file_pattern="*.tif", visualize_samples=3,
                        cell_types=None):
    """
    Test DBSCAN detector with advanced metrics similar to the UNet test function.
    
    Args:
        image_folder: Path to folder containing test images
        bin_mask_folder: Path to binary ground truth masks
        mult_mask_folder: Path to multi-class ground truth masks
        color_eps, color_min_samples: DBSCAN parameters for color clustering
        spatial_eps, spatial_min_samples: DBSCAN parameters for spatial clustering
        file_pattern: File pattern to match
        visualize_samples: Number of samples to visualize
        cell_types: List of cell type names for analysis
    
    Returns:
        dict: Comprehensive test results
    """
    
    image_folder = Path(image_folder)
    bin_mask_folder = Path(bin_mask_folder) if bin_mask_folder else None
    mult_mask_folder = Path(mult_mask_folder) if mult_mask_folder else None
    
    if not image_folder.exists():
        print(f"❌ Image folder {image_folder} does not exist!")
        return None
    
    # Get all image files
    image_files = list(image_folder.glob(file_pattern))
    if not image_files:
        print(f"❌ No files matching pattern '{file_pattern}' found in {image_folder}")
        return None
    
    print(f"🧪 Testing DBSCAN detector on {len(image_files)} images...")
    print(f"📁 Images: {image_folder}")
    print(f"📁 Binary masks: {bin_mask_folder}")
    print(f"📁 Multi-class masks: {mult_mask_folder}")
    print(f"⚙️  Parameters: color_eps={color_eps}, spatial_eps={spatial_eps}")
    print("=" * 80)
    
    # Initialize storage for results
    all_dice_scores = []
    all_hausdorff_distances = []
    all_surface_dice_scores = []
    all_cell_type_stats = []
    all_image_stats = []
    
    # Process each image
    for i, image_path in enumerate(tqdm(image_files, desc="Processing images")):
        
        # Show visualization for first few samples
        show_viz = i < visualize_samples
        
        # Run DBSCAN detection
        pred_mask, basic_stats = dbscan_cell_detection(
            str(image_path), 
            color_eps=color_eps,
            color_min_samples=color_min_samples,
            spatial_eps=spatial_eps,
            spatial_min_samples=spatial_min_samples,
            visualize=show_viz
        )
        
        if pred_mask is None:
            print(f"⚠️  Failed to process {image_path.name}")
            continue
        
        image_results = {
            'filename': image_path.name,
            'basic_stats': basic_stats
        }
        
        # Load and process ground truth masks
        if bin_mask_folder:
            bin_gt_path = bin_mask_folder / image_path.name
            if bin_gt_path.exists():
                bin_gt_mask = cv2.imread(str(bin_gt_path), cv2.IMREAD_GRAYSCALE)
                if bin_gt_mask is not None:
                    # Convert to binary (0 or 1)
                    bin_gt_mask = (bin_gt_mask > 0).astype(np.uint8)
                    
                    # Compute advanced metrics
                    metrics = compute_advanced_metrics(pred_mask, bin_gt_mask)
                    image_results['metrics'] = metrics
                    
                    all_dice_scores.append(metrics['dice_score'])
                    all_hausdorff_distances.append(metrics['hausdorff_distance'])
                    all_surface_dice_scores.append(metrics['surface_dice'])
                    
                    if show_viz:
                        # Show comparison
                        plt.figure(figsize=(15, 5))
                        plt.subplot(1, 3, 1)
                        plt.imshow(pred_mask, cmap='gray')
                        plt.title(f'DBSCAN Prediction\n{image_path.name}')
                        plt.axis('off')
                        
                        plt.subplot(1, 3, 2)
                        plt.imshow(bin_gt_mask, cmap='gray')
                        plt.title('Ground Truth')
                        plt.axis('off')
                        
                        plt.subplot(1, 3, 3)
                        plt.imshow(pred_mask, cmap='Reds', alpha=0.7)
                        plt.imshow(bin_gt_mask, cmap='Blues', alpha=0.3)
                        plt.title(f'Overlay\nDice: {metrics["dice_score"]:.3f}')
                        plt.axis('off')
                        plt.tight_layout()
                        plt.show()
        
        # Analyze cell type detection
        if mult_mask_folder:
            mult_gt_path = mult_mask_folder / image_path.name
            if mult_gt_path.exists():
                mult_gt_mask = cv2.imread(str(mult_gt_path), cv2.IMREAD_GRAYSCALE)
                if mult_gt_mask is not None:
                    cell_type_stats = analyze_cell_type_detection(pred_mask, mult_gt_mask, cell_types)
                    image_results['cell_type_stats'] = cell_type_stats
                    all_cell_type_stats.append(cell_type_stats)
        
        all_image_stats.append(image_results)
    
    # Print comprehensive results
    print_comprehensive_results(all_dice_scores, all_hausdorff_distances, all_surface_dice_scores, 
                               all_cell_type_stats, all_image_stats)
    
    return {
        'dice_scores': all_dice_scores,
        'hausdorff_distances': all_hausdorff_distances,
        'surface_dice_scores': all_surface_dice_scores,
        'cell_type_stats': all_cell_type_stats,
        'image_stats': all_image_stats
    }

def print_comprehensive_results(dice_scores, hausdorff_distances, surface_dice_scores, 
                               cell_type_stats, image_stats):
    """
    Print comprehensive test results similar to UNet test function.
    """
    
    print(f"\n📊 DBSCAN Cell Detection Test Results")
    print("=" * 80)
    
    # Standard metrics
    if dice_scores:
        avg_dice = np.mean(dice_scores)
        std_dice = np.std(dice_scores)
        print(f"🎯 Dice Score: {avg_dice:.4f} ± {std_dice:.4f}")
        print(f"   Range: {min(dice_scores):.4f} - {max(dice_scores):.4f}")
    
    if hausdorff_distances:
        # Filter out infinite values for statistics
        finite_hausdorff = [h for h in hausdorff_distances if not np.isinf(h)]
        if finite_hausdorff:
            avg_hausdorff = np.mean(finite_hausdorff)
            std_hausdorff = np.std(finite_hausdorff)
            print(f"📏 Hausdorff Distance (95th percentile): {avg_hausdorff:.4f} ± {std_hausdorff:.4f}")
            print(f"   Range: {min(finite_hausdorff):.4f} - {max(finite_hausdorff):.4f}")
            if len(finite_hausdorff) < len(hausdorff_distances):
                print(f"   Note: {len(hausdorff_distances) - len(finite_hausdorff)} images had infinite Hausdorff distance")
    
    if surface_dice_scores:
        avg_surface_dice = np.mean(surface_dice_scores)
        std_surface_dice = np.std(surface_dice_scores)
        print(f"🌊 Normalized Surface Dice: {avg_surface_dice:.4f} ± {std_surface_dice:.4f}")
        print(f"   Range: {min(surface_dice_scores):.4f} - {max(surface_dice_scores):.4f}")
    
    # Cell type analysis
    if cell_type_stats:
        print(f"\n🔍 Cell Type Detection Analysis:")
        print("=" * 60)
        
        # Aggregate cell type statistics
        cell_type_summary = {}
        for stats in cell_type_stats:
            for cell_type, type_stats in stats.items():
                if cell_type not in cell_type_summary:
                    cell_type_summary[cell_type] = {
                        'total_pixels': [],
                        'detected_pixels': [] if 'detected_pixels' in type_stats else [],
                        'detection_rates': [] if 'detection_rate' in type_stats else [],
                        'false_positive_rates': [] if 'false_positive_rate' in type_stats else []
                    }
                
                cell_type_summary[cell_type]['total_pixels'].append(type_stats['total_pixels'])
                if 'detected_pixels' in type_stats:
                    cell_type_summary[cell_type]['detected_pixels'].append(type_stats['detected_pixels'])
                    cell_type_summary[cell_type]['detection_rates'].append(type_stats['detection_rate'])
                if 'false_positive_rate' in type_stats:
                    cell_type_summary[cell_type]['false_positive_rates'].append(type_stats['false_positive_rate'])
        
        # Print summary for each cell type
        for cell_type, summary in cell_type_summary.items():
            total_pixels = sum(summary['total_pixels'])
            
            if cell_type == 'Background':
                if summary['false_positive_rates']:
                    avg_fp_rate = np.mean(summary['false_positive_rates'])
                    print(f"{cell_type:>12}: {total_pixels:>10,} pixels | "
                          f"False Positive Rate: {100*avg_fp_rate:.2f}%")
            else:
                if summary['detection_rates']:
                    total_detected = sum(summary['detected_pixels'])
                    avg_detection_rate = np.mean(summary['detection_rates'])
                    print(f"{cell_type:>12}: {total_pixels:>10,} pixels | "
                          f"Detected: {total_detected:>8,} ({100*avg_detection_rate:.2f}%)")
    
    # Create performance visualizations
    if dice_scores or hausdorff_distances or surface_dice_scores:
        create_metrics_plots(dice_scores, hausdorff_distances, surface_dice_scores)
    
    # Binary confusion matrix analysis
    if image_stats:
        create_overall_confusion_analysis(image_stats)

def create_metrics_plots(dice_scores, hausdorff_distances, surface_dice_scores):
    """
    Create visualization plots for the three main metrics.
    """
    
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    
    # Dice Score distribution
    if dice_scores:
        axes[0].hist(dice_scores, bins=20, alpha=0.7, color='skyblue', edgecolor='black')
        axes[0].set_xlabel('Dice Score')
        axes[0].set_ylabel('Number of Images')
        axes[0].set_title('Distribution of Dice Scores')
        axes[0].axvline(np.mean(dice_scores), color='red', linestyle='--', 
                       label=f'Mean: {np.mean(dice_scores):.3f}')
        axes[0].legend()
    
    # Hausdorff Distance distribution
    if hausdorff_distances:
        finite_hausdorff = [h for h in hausdorff_distances if not np.isinf(h)]
        if finite_hausdorff:
            axes[1].hist(finite_hausdorff, bins=20, alpha=0.7, color='lightcoral', edgecolor='black')
            axes[1].set_xlabel('Hausdorff Distance')
            axes[1].set_ylabel('Number of Images')
            axes[1].set_title('Distribution of Hausdorff Distances')
            axes[1].axvline(np.mean(finite_hausdorff), color='red', linestyle='--', 
                           label=f'Mean: {np.mean(finite_hausdorff):.3f}')
            axes[1].legend()
    
    # Surface Dice distribution
    if surface_dice_scores:
        axes[2].hist(surface_dice_scores, bins=20, alpha=0.7, color='lightgreen', edgecolor='black')
        axes[2].set_xlabel('Surface Dice Score')
        axes[2].set_ylabel('Number of Images')
        axes[2].set_title('Distribution of Surface Dice Scores')
        axes[2].axvline(np.mean(surface_dice_scores), color='red', linestyle='--', 
                       label=f'Mean: {np.mean(surface_dice_scores):.3f}')
        axes[2].legend()
    
    plt.tight_layout()
    plt.show()

def create_overall_confusion_analysis(image_stats):
    """
    Create overall confusion matrix analysis from all processed images.
    """
    all_predictions = []
    all_ground_truths = []
    
    for img_stats in image_stats:
        if 'metrics' in img_stats:
            # This is a simplified approach - in practice you'd need to store the actual masks
            # For now, we'll use the detection rate as a proxy
            detection_rate = img_stats['basic_stats']['detection_rate']
            # Approximate binary classification based on detection rate
            # This is a simplification - ideally you'd store the actual pixel-wise predictions
            pass
    
    print(f"\n📋 Overall Performance Summary:")
    print(f"   Total Images Processed: {len(image_stats)}")
    
    # Find best and worst performing images
    if any('metrics' in stats for stats in image_stats):
        dice_results = [(stats['filename'], stats['metrics']['dice_score']) 
                       for stats in image_stats if 'metrics' in stats]
        
        if dice_results:
            dice_results.sort(key=lambda x: x[1], reverse=True)
            
            print(f"\n🏆 Best Performing Images (Dice Score):")
            for filename, dice in dice_results[:3]:
                print(f"   • {filename}: {dice:.4f}")
            
            print(f"\n⚠️  Worst Performing Images (Dice Score):")
            for filename, dice in dice_results[-3:]:
                print(f"   • {filename}: {dice:.4f}")

# Example usage function
def run_dbscan_test_example():
    """
    Example of how to run the DBSCAN test with all metrics.
    """
    
    # Define cell types (adjust according to your dataset)
    cell_types = ['Background', 'Type 1', 'Type 2', 'Type 3', 'Type 4', 
                  'Type 5', 'Type 6', 'Type 7']
    
    # Run the test
    results = test_dbscan_detector(
        image_folder="/kaggle/input/celldetection/ds1/test/img/cls",
        bin_mask_folder="/kaggle/input/celldetection/ds1/test/bin_mask/cls",
        mult_mask_folder="/kaggle/input/celldetection/ds1/test/mult_mask/cls",
        color_eps=0.1,
        color_min_samples=3,
        spatial_eps=0.005,
        spatial_min_samples=80,
        file_pattern="*.tif",
        visualize_samples=3,
        cell_types=cell_types
    )
    
    return results

if __name__ == "__main__":
    # Example usage
    print("DBSCAN Cell Detection with Advanced Metrics")
    run_dbscan_test_example()

In [None]:
dataset = CellDataSet()
model = GrUnetArhitecture()
train_model(model, dataset, num_epochs=80,patience=5, grayscale=True,scharr=True, save_path = "dice_Scharr.pth")
test_detector(dataset, model_class=GrUnetArhitecture, model_path="dice_Scharr.pth", grayscale=True,scharr=True)

In [None]:
dataset = CellDataSet()
test_detector(dataset, model_class=GrayScaleMediumUnet, model_path="dice_mediumScharr.pth", grayscale=True,scharr=True)

In [None]:
test_detector(dataset, model_class=GrUnetArhitecture, model_path="dice_Scharr.pth", grayscale=True,scharr=True)