In [None]:
# =============================================================== #
#                     2D Cell Segmenatation                       #
#                      Mustansir Verdawala                        #
# =============================================================== #

In [None]:
#%% Installing libraries

!pip install monai[itk] itk matplotlib torch torchvision torchaudio nibabel pydicom

In [None]:
#%% Download Dataset

!mkdir -p data/train data/masks/train
!mkdir -p data/val data/masks/val
!mkdir -p data/test data/masks/test

!wget -O data/train.zip "https://huggingface.co/datasets/alkzar90/cell_benchmark/resolve/main/data/train.zip"
!wget -O data/masks/train.zip "https://huggingface.co/datasets/alkzar90/cell_benchmark/resolve/main/data/masks/train.zip"
!wget -O data/val.zip "https://huggingface.co/datasets/alkzar90/cell_benchmark/resolve/main/data/val.zip"
!wget -O data/masks/val.zip "https://huggingface.co/datasets/alkzar90/cell_benchmark/resolve/main/data/masks/val.zip"
!wget -O data/test.zip "https://huggingface.co/datasets/alkzar90/cell_benchmark/resolve/main/data/test.zip"
!wget -O data/masks/test.zip "https://huggingface.co/datasets/alkzar90/cell_benchmark/resolve/main/data/masks/test.zip"

!unzip -q data/train.zip -d data/train
!unzip -q data/masks/train.zip -d data/masks/train
!unzip -q data/val.zip -d data/val
!unzip -q data/masks/val.zip -d data/masks/val
!unzip -q data/test.zip -d data/test
!unzip -q data/masks/test.zip -d data/masks/test

In [None]:
#%% Viewing dataset

import os

train_images_dir = "data/train/train"
train_masks_dir = "data/masks/train"
val_images_dir = "data/val/val"
val_masks_dir = "data/masks/val"
test_images_dir = "data/test/test"
test_masks_dir = "data/masks/test"

print("Train Images:", os.listdir(train_images_dir))
print("Train Masks:", os.listdir(train_masks_dir))
print("Validation Images:", os.listdir(val_images_dir))
print("Validation Masks:", os.listdir(val_masks_dir))
print("Test Images:", os.listdir(test_images_dir))
print("Test Masks:", os.listdir(test_masks_dir))

In [None]:
#%% Previewing image

import itk
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

img_dir = Path("data/train/train")
mask_dir = Path("data/masks/train")

img_files = sorted(img_dir.glob("*.jpg"))
mask_files = sorted(mask_dir.glob("*.png"))

img_itk = itk.imread(str(img_files[0]), itk.RGBPixel[itk.UC])
mask_itk = itk.imread(str(mask_files[0]), itk.UC)

img_np = itk.array_view_from_image(img_itk)
mask_np = itk.array_view_from_image(mask_itk)

print(img_files[0])
print("Image shape:", img_np.shape)
print("Mask shape:", mask_np.shape)


fig, axes = plt.subplots(1, 2, figsize=(10, 5))

axes[0].imshow(img_np)
axes[0].set_title("Original Image")
axes[0].axis("off")

axes[1].imshow(mask_np, cmap="gray")
axes[1].set_title("Image + Mask Overlay")
axes[1].axis("off")

plt.tight_layout()
plt.show()

In [None]:
#%% Image Analysis

colors = ['red', 'green', 'blue']

for i, color in enumerate(colors):
    plt.figure(figsize=(8, 5))
    plt.hist(img_np[:, :, i].flatten(), bins=256, color=color, alpha=0.5, label=f"{color} channel")
    plt.title("Histogram per Channel")
    plt.xlabel("Pixel intensity")
    plt.ylabel("Frequency")
    plt.legend()
    plt.show()

plt.figure(figsize=(6, 4))
plt.hist(mask_np.flatten(), bins=256, color='gray')
plt.title("Histogram of Mask")
plt.xlabel("Pixel intensity")
plt.ylabel("Frequency")
plt.show()


for i, color in enumerate(colors):
    channel_vals = img_np[:, :, i][mask_np > 127]
    plt.figure(figsize=(10, 6))
    plt.hist(channel_vals.flatten(), bins=256, color=color, alpha=0.5, label=f"{color} channel")
    plt.title("Histogram of Image Pixels (within mask)")
    plt.xlabel("Pixel intensity")
    plt.ylabel("Frequency")
    plt.legend()
    plt.show()


for i, color in enumerate(colors):
    channel_vals = img_np[:, :, i][mask_np < 128]
    plt.figure(figsize=(10, 6))
    plt.hist(channel_vals.flatten(), bins=256, color=color, alpha=0.5, label=f"{color} channel")
    plt.title("Histogram of Image Pixels (outside mask)")
    plt.xlabel("Pixel intensity")
    plt.ylabel("Frequency")
    plt.legend()
    plt.show()

In [None]:
#%% Training Loop

import os
from pathlib import Path
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import matplotlib.pyplot as plt
from monai.networks.nets import UNet
from monai.losses import DiceLoss


# Model parameters
IMG_SIZE = 512
BATCH_SIZE = 5
EPOCHS = 1000
LR = 1e-2
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

img_dir = Path("data/train/train")
mask_dir = Path("data/masks/train")
img_files = sorted(list(img_dir.glob("*.jpg")))
mask_files = sorted(list(mask_dir.glob("*.png")))

val_img_dir = Path("data/val/val")
val_mask_dir = Path("data/masks/val")
val_img_files = sorted(list(val_img_dir.glob("*.jpg")))
val_mask_files = sorted(list(val_mask_dir.glob("*.png")))

# Dataset
class CellDataset(Dataset):
    def __init__(self, img_files, mask_files, img_size=IMG_SIZE):
        self.img_files = img_files
        self.mask_files = mask_files
        self.img_size = img_size

    def __len__(self):
        return len(self.img_files)

    def __getitem__(self, idx):
        img = cv2.imread(str(self.img_files[idx]))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(str(self.mask_files[idx]), cv2.IMREAD_GRAYSCALE)

        img = cv2.resize(img, (self.img_size, self.img_size)).astype(np.float32)/255.0
        mask = cv2.resize(mask, (self.img_size, self.img_size), interpolation=cv2.INTER_NEAREST)
        mask = (mask > 127).astype(np.float32)

        img = np.transpose(img, (2,0,1))
        mask = mask[np.newaxis, ...]

        mask_encode = np.zeros((2, *mask.shape[1:]), dtype=np.float32)
        mask_encode[0] = 1 - mask
        mask_encode[1] = mask

        return torch.tensor(img), torch.tensor(mask_encode)

train_dataset = CellDataset(img_files, mask_files)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

val_dataset = CellDataset(val_img_files, val_mask_files)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

# Model
model = UNet(
    spatial_dims=2,
    in_channels=3,
    out_channels=2,
    channels=(16,32,64,128,256),
    strides=(2,2,2,2),
    num_res_units=2,
    act='leakyrelu',
    dropout=0.5,
).to(DEVICE)

loss_fn = DiceLoss(softmax=True)
optimizer = optim.Adam(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=(1/(10**0.5)))

best_val_dice = 1.0
patience = 50
trigger_times = 0

# Train loop
for epoch in range(EPOCHS):
    model.train()
    epoch_loss = 0
    for imgs, masks in train_loader:
        imgs, masks = imgs.to(DEVICE), masks.to(DEVICE)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = loss_fn(outputs, masks)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    epoch_loss = epoch_loss / len(train_loader)

    # Validation
    model.eval()
    val_dice_total = 0
    with torch.no_grad():
        for imgs, masks in val_loader:
            imgs, masks = imgs.to(DEVICE), masks.to(DEVICE)
            outputs = model(imgs)
            loss = loss_fn(outputs, masks)
            val_dice_total += loss.item()
    val_dice = val_dice_total / len(val_loader)

    # Early stopping
    if val_dice < best_val_dice:
        best_val_dice = val_dice
        trigger_times = 0
        torch.save(model.state_dict(), "best_unet_rgb_cells.pth")
    else:
        trigger_times += 1
        if trigger_times >= patience:
            print(f"Early stopping at epoch {epoch+1}")
            break

    if (epoch+1)%50==0:
        scheduler.step()

    # Print
    print(f"Epoch {epoch+1}/{EPOCHS} | Train Loss: {epoch_loss:.4f} | Val Loss: {val_dice:.4f} | LR: {scheduler.get_last_lr()[0]:.6f}")

    if (epoch+1) % 10 == 0:
        pred_bin = torch.argmax(torch.softmax(outputs, dim=1), dim=1)[0]
        plt.figure(figsize=(12,6))
        plt.subplot(1,2,1)
        plt.title("Mask")
        plt.imshow(masks[0,0].cpu(), cmap='gray')
        plt.subplot(1,2,2)
        plt.title("Predicted")
        plt.imshow(pred_bin, cmap='gray')
        plt.show(block=False)
        plt.pause(0.001)
        plt.close()

In [None]:
#%% Testing loop

def dice_score(pred, target, eps=1e-6):
    pred = pred.float()
    intersection = (pred * target).sum(dim=(1,2,3))
    union = pred.sum(dim=(1,2,3)) + target.sum(dim=(1,2,3))
    return ((2 * intersection + eps) / (union + eps)).mean().item()

img_dir = Path("data/train/train")
mask_dir = Path("data/masks/train")
img_files = sorted(list(img_dir.glob("*.jpg")))
mask_files = sorted(list(mask_dir.glob("*.png")))

train_dataset = CellDataset(img_files, mask_files)
train_loader = DataLoader(train_dataset, batch_size=3, shuffle=False)

all_dice = []

with torch.no_grad():
    for idx, (imgs, masks) in enumerate(train_loader):
        imgs, masks = imgs.to(DEVICE), masks.to(DEVICE)
        outputs = torch.sigmoid(model(imgs))
        pred_bin = (outputs > 0.5).float()
        dice = dice_score(pred_bin, masks)
        all_dice.append(dice)


# Report
mean_dice = np.mean(all_dice)
print(f"Train set Dice score: {mean_dice:.4f}")


img_dir = Path("data/val/val")
mask_dir = Path("data/masks/val")
img_files = sorted(list(img_dir.glob("*.jpg")))
mask_files = sorted(list(mask_dir.glob("*.png")))

val_dataset = CellDataset(img_files, mask_files)
val_loader = DataLoader(val_dataset, batch_size=3, shuffle=False)

all_dice = []

with torch.no_grad():
    for idx, (imgs, masks) in enumerate(val_loader):
        imgs, masks = imgs.to(DEVICE), masks.to(DEVICE)
        outputs = torch.sigmoid(model(imgs))
        pred_bin = (outputs > 0.5).float()
        dice = dice_score(pred_bin, masks)
        all_dice.append(dice)


# Report
mean_dice = np.mean(all_dice)
print(f"Validation set Dice score: {mean_dice:.4f}")


img_dir = Path("data/test/test")
mask_dir = Path("data/masks/test")
img_files = sorted(list(img_dir.glob("*.jpg")))
mask_files = sorted(list(mask_dir.glob("*.png")))

test_dataset = CellDataset(img_files, mask_files)
test_loader = DataLoader(test_dataset, batch_size=3, shuffle=False)

all_dice = []

with torch.no_grad():
    for idx, (imgs, masks) in enumerate(test_loader):
        imgs, masks = imgs.to(DEVICE), masks.to(DEVICE)
        outputs = torch.sigmoid(model(imgs))
        pred_bin = (outputs > 0.5).float()
        dice = dice_score(pred_bin, masks)
        all_dice.append(dice)

# Report
mean_dice = np.mean(all_dice)
print(f"Test set Dice score: {mean_dice:.4f}")