<a href="https://colab.research.google.com/github/aarnavg54/Deep-Learning-Radiomic-Stability/blob/main/U_Net%2B%2B_with_EfficientNet_b7.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -U segmentation-models-pytorch
!pip install -U git+https://github.com/qubvel-org/segmentation_models.pytorch
!pip install lightning albumentations

In [None]:
# Creating a pytorch dataset
import torch
from torch.utils.data import Dataset

class HistologyDataset(Dataset):
    def __init__(self, images, masks):
        self.images = images
        self.masks = masks

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx].astype("float32")
        mask = self.masks[idx].astype("float32")
        image = torch.from_numpy(image).permute(2, 0, 1)
        mask = torch.from_numpy(mask).permute(2, 0, 1)
        return image, mask


In [None]:
from google.colab import drive
drive.mount('/content/drive')

import numpy as np
X_train = np.load('/content/drive/MyDrive/X_train_ultrasound_images_256_2.npy')
y_train = np.load('/content/drive/MyDrive/y_train_ultrasound_images_256_2.npy')
X_val = np.load('/content/drive/MyDrive/X_val_ultrasound_images_256_2.npy')
y_val = np.load('/content/drive/MyDrive/y_val_ultrasound_images_256_2.npy')
X_test = np.load('/content/drive/MyDrive/X_test_ultrasound_images_256_2.npy')
y_test = np.load('/content/drive/MyDrive/y_test_ultrasound_images_256_2.npy')

train_dataset = HistologyDataset(X_train, y_train)
val_dataset = HistologyDataset(X_val, y_val)
test_dataset = HistologyDataset(X_test, y_test)

In [None]:
print(X_train.shape)
print(y_train.shape)
print(X_val.shape)
print(y_val.shape)
print(X_test.shape)
print(y_test.shape)

In [None]:
import segmentation_models_pytorch as smp

model = smp.UnetPlusPlus(
    encoder_name="efficientnet-b7",   # UPDATED
    encoder_weights="imagenet",
    in_channels=3,
    classes=1,
    activation=None
)

In [None]:
from segmentation_models_pytorch.losses import DiceLoss
import torch.optim as optim

dice_loss = DiceLoss(mode='binary', from_logits=True, smooth=1e-5)
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# --- SETUP TRAINING LOOP ---
from torch.utils.data import DataLoader
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

In [None]:
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)  # UPDATED
val_loader = DataLoader(val_dataset, batch_size=4)
test_loader = DataLoader(test_dataset)

num_epochs = 500
patience = 10
best_val_loss = float('inf')
epochs_no_improve = 0

checkpoint_dir = "./checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint_path = os.path.join(checkpoint_dir, "best_model.pth")

In [None]:
for epoch in range(1, num_epochs + 1):
    model.train()
    train_losses = []
    for images, masks in train_loader:
        images, masks = images.to(device), masks.to(device)

        optimizer.zero_grad()
        outputs = model(images)

        loss = dice_loss(outputs, masks)
        loss.backward()
        optimizer.step()

        train_losses.append(loss.item())

    avg_train_loss = sum(train_losses) / len(train_losses)

    model.eval()
    val_losses = []
    with torch.no_grad():
        for images, masks in val_loader:
            images, masks = images.to(device), masks.to(device)
            outputs = model(images)

            loss = dice_loss(outputs, masks)
            val_losses.append(loss.item())

    avg_val_loss = sum(val_losses) / len(val_losses)

    print(f"Epoch {epoch:03d}: Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")

    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        epochs_no_improve = 0
        torch.save(model.state_dict(), checkpoint_path)
        print(f"Best model saved with Val Loss: {best_val_loss:.4f}")
    else:
        epochs_no_improve += 1
        if epochs_no_improve >= patience:
            print("Early stopping triggered.")
            break

In [None]:
import torch
import segmentation_models_pytorch as smp

def dice_score(preds, targets, smooth=1e-6):
    preds = preds.flatten(1)  # (B, H*W)
    targets = targets.flatten(1)

    intersection = (preds * targets).sum(1)
    union = preds.sum(1) + targets.sum(1)
    dice = (2 * intersection + smooth) / (union + smooth)
    return dice.mean()

def iou_score(tp, fp, fn, tn, reduction="micro"):
    return smp.metrics.iou_score(tp, fp, fn, tn, reduction=reduction)

model.eval()
dice_scores = []
iou_scores = []

with torch.no_grad():
    for images, masks in test_loader:
        images, masks = images.to(device), masks.to(device)
        outputs = model(images)

        # Apply sigmoid and threshold at 0.5 for binary masks
        probs = torch.sigmoid(outputs)
        preds = (probs > 0.5).float()

        # Dice
        dice = dice_score(preds, masks).cpu().item()
        dice_scores.append(dice)

        # Calculate tp, fp, fn, tn for IoU
        tp = (preds * masks).sum(dim=(1, 2, 3))
        fp = (preds * (1 - masks)).sum(dim=(1, 2, 3))
        fn = ((1 - preds) * masks).sum(dim=(1, 2, 3))
        tn = ((1 - preds) * (1 - masks)).sum(dim=(1, 2, 3))

        iou = iou_score(tp, fp, fn, tn, reduction="micro").cpu().item()
        iou_scores.append(iou)

avg_dice = sum(dice_scores) / len(dice_scores)
avg_iou = sum(iou_scores) / len(iou_scores)

print(f"Test Dice Score: {avg_dice:.4f}")
print(f"Test IoU Score: {avg_iou:.4f}")

In [None]:
import torch
import segmentation_models_pytorch as smp

# Initialize metrics from SMP (they handle thresholding internally)
dice_metric = smp.metrics.Fscore(threshold=0.5)  # Dice is a special case of F-score
iou_metric = smp.metrics.IoU(threshold=0.5)

model.eval()
dice_scores = []
iou_scores = []

with torch.no_grad():
    for images, masks in test_loader:
        images, masks = images.to(device), masks.to(device)

        outputs = model(images)
        probs = torch.sigmoid(outputs)  # convert logits to probabilities

        # SMP metrics expect (B, C, H, W) tensors with binary masks
        dice = dice_metric(probs, masks).cpu().item()
        iou = iou_metric(probs, masks).cpu().item()

        dice_scores.append(dice)
        iou_scores.append(iou)

avg_dice = sum(dice_scores) / len(dice_scores)
avg_iou = sum(iou_scores) / len(iou_scores)

print(f"Test Dice Score: {avg_dice:.4f}")
print(f"Test IoU Score: {avg_iou:.4f}")

In [None]:
import matplotlib.pyplot as plt
import torch

def visualize_batch(images, true_masks, pred_masks, batch_size=4):
    """
    images: tensor, shape (B, C, H, W)
    true_masks: tensor, shape (B, 1, H, W)
    pred_masks: tensor, shape (B, 1, H, W)
    """
    images = images.cpu().permute(0, 2, 3, 1).numpy()  # (B, H, W, C)
    true_masks = true_masks.cpu().squeeze(1).numpy()   # (B, H, W)
    pred_masks = pred_masks.cpu().squeeze(1).numpy()   # (B, H, W)

    plt.figure(figsize=(12, batch_size * 3))
    for i in range(batch_size):
        # Input image
        plt.subplot(batch_size, 3, i * 3 + 1)
        plt.title("Input Image")
        plt.imshow(images[i])
        plt.axis("off")

        # Ground truth mask
        plt.subplot(batch_size, 3, i * 3 + 2)
        plt.title("Ground Truth Mask")
        plt.imshow(true_masks[i], cmap="gray")
        plt.axis("off")

        # Predicted mask (threshold at 0.5)
        plt.subplot(batch_size, 3, i * 3 + 3)
        plt.title("Predicted Mask")
        plt.imshow(pred_masks[i] > 0.5, cmap="gray")
        plt.axis("off")

    plt.tight_layout()
    plt.show()

In [None]:
img_idx = 200  # Image number you want to view

image, mask = test_dataset[img_idx]

model.eval()
with torch.no_grad():
    input_img = image.unsqueeze(0).to(device)  # add batch dim
    output = model(input_img)
    prob = torch.sigmoid(output)

visualize_batch(input_img, mask.unsqueeze(0), prob, batch_size=1)


In [None]:
checkpoint_path = '/content/drive/MyDrive/U-Net_0.9311.pth'

In [None]:
torch.save(model.state_dict(), checkpoint_path)
print(f"Model saved to {checkpoint_path}")

In [None]:
def calculate_metrics(preds, targets, smooth=1e-6):
    # Flatten tensors and calculate TP, FP, FN, TN
    preds_flat = preds.flatten(1)
    targets_flat = targets.flatten(1)

    tp = (preds_flat * targets_flat).sum(1)
    fp = (preds_flat * (1 - targets_flat)).sum(1)
    fn = ((1 - preds_flat) * targets_flat).sum(1)
    tn = ((1 - preds_flat) * (1 - targets_flat)).sum(1)

    # Standard metrics
    metrics = {
        'Dice': ((2 * tp + smooth) / (tp + fp + tp + fn + smooth)).mean().item(),
        'IoU': ((tp + smooth) / (tp + fp + fn + smooth)).mean().item(),
        'Precision': ((tp + smooth) / (tp + fp + smooth)).mean().item(),
        'Recall': ((tp + smooth) / (tp + fn + smooth)).mean().item(),
        'Specificity': ((tn + smooth) / (tn + fp + smooth)).mean().item(),
        'Accuracy': ((tp + tn + smooth) / (tp + tn + fp + fn + smooth)).mean().item(),
    }

    # Hausdorff Distance (handle batch)
    hd_values = []
    for p, t in zip(preds, targets):
        hd = hausdorff_distance(p, t)
        if not np.isnan(hd):
            hd_values.append(hd)
    metrics['HD95'] = np.percentile(hd_values, 95) if hd_values else np.nan

    # Tumor size error
    pred_area = preds_flat.sum(1)
    target_area = targets_flat.sum(1)
    metrics['Size_Error'] = ((pred_area - target_area).abs() / (target_area + smooth)).mean().item()

    return metrics

In [None]:
from scipy.spatial.distance import directed_hausdorff

def hausdorff_distance(pred, target):
    """Compute 95% Hausdorff Distance with edge case handling"""
    pred = pred.squeeze().cpu().numpy()  # Remove batch/channel dims
    target = target.squeeze().cpu().numpy()

    # Get coordinates of boundary pixels
    pred_coords = np.argwhere(pred > 0.5)
    target_coords = np.argwhere(target > 0.5)

    # Handle empty masks
    if len(pred_coords) == 0 or len(target_coords) == 0:
        return np.nan  # Return NaN if either mask is empty

    # Compute both directions
    hd1 = directed_hausdorff(pred_coords, target_coords)[0]
    hd2 = directed_hausdorff(target_coords, pred_coords)[0]

    return max(hd1, hd2)