In [None]:
#############################################
# Cell 1: Imports and Global Configurations
#############################################

import os
import glob
import copy
import time
import datetime
import numpy as np
import matplotlib.pyplot as plt
import cv2

from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as T
import timm

# For reproducibility
torch.manual_seed(42)
np.random.seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


In [None]:
#############################################
# Cell 2: Custom Dataset & Transforms
#############################################

class ChestXRayDataset(Dataset):
    """
    Custom dataset for Chest X-Ray segmentation.
    Assumes each image has a corresponding mask with the same file name 
    in a 'masks' folder (all PNG).
    """
    def __init__(self, images_dir, masks_dir, transform=None):
        super().__init__()
        self.images_dir = images_dir
        self.masks_dir = masks_dir
        self.transform = transform
        
        self.image_paths = sorted(glob.glob(os.path.join(images_dir, "*.png")))
        self.mask_paths = sorted(glob.glob(os.path.join(masks_dir, "*.png")))
        
        assert len(self.image_paths) == len(self.mask_paths), \
            "Number of images and masks do not match."
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        mask_path = self.mask_paths[idx]
        
        image = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")  # grayscale mask
        
        if self.transform:
            sample = {"image": image, "mask": mask}
            sample = self.transform(sample)
            image, mask = sample["image"], sample["mask"]
        else:
            # Default: just convert to tensor (no resize, no normalization)
            image = T.ToTensor()(image)
            mask = T.ToTensor()(mask)
            mask = (mask > 0.5).float()
        
        return image, mask


class JointTransformWrapper:
    """
    A wrapper to apply transforms that require both image and mask 
    simultaneously (e.g. resize, random flips, etc.).
    """
    def __init__(self, augment=True, image_size=(224, 224)):
        self.augment = augment
        self.image_size = image_size
        
        # Common transformations (resize, etc.)
        self.common_transforms = T.Compose([
            T.Resize(self.image_size),
        ])
        
        # Augmentations (e.g. random horizontal flip)
        self.augment_transform = T.RandomHorizontalFlip(p=0.5)
        
        # Convert to tensor
        self.to_tensor_img = T.ToTensor()
        self.to_tensor_mask = T.ToTensor()

        # ImageNet normalization (recommended for pretrained Swin)
        self.normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
        
    def __call__(self, sample):
        image, mask = sample["image"], sample["mask"]
        
        # 1. Resize
        image = self.common_transforms(image)
        mask = self.common_transforms(mask)
        
        if self.augment:
            # Ensure the same augmentation (e.g. flip) is applied to both image & mask
            seed = np.random.randint(2147483647)
            torch.manual_seed(seed)
            image = self.augment_transform(image)
            torch.manual_seed(seed)
            mask = self.augment_transform(mask)
        
        # 2. To Tensor
        image = self.to_tensor_img(image)
        mask = self.to_tensor_mask(mask)
        
        # 3. Normalize image (mask is 0/1, so no normalization)
        image = self.normalize(image)
        
        # 4. Binarize mask
        mask = (mask > 0.5).float()
        
        return {"image": image, "mask": mask}


In [None]:
#############################################
# Cell 3: Create Dataset Instances & Splits
#############################################

# Dataset Path
dataset_path = r"C:\Users\offic\OneDrive\Masaüstü\datasets\Chest_XRay"
images_dir = os.path.join(dataset_path, "images")
masks_dir = os.path.join(dataset_path, "masks")

# Create the transformation pipeline
joint_transform = JointTransformWrapper(augment=True, image_size=(224, 224))

# Create the full dataset
full_dataset = ChestXRayDataset(images_dir, masks_dir, transform=joint_transform)
print("Total samples in dataset:", len(full_dataset))

# Split into train/val/test
dataset_len = len(full_dataset)
train_size = int(0.7 * dataset_len)  # 70%
val_size = int(0.15 * dataset_len)   # 15%
test_size = dataset_len - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(
    full_dataset, [train_size, val_size, test_size]
)

print("Train:", len(train_dataset), "Val:", len(val_dataset), "Test:", len(test_dataset))


In [None]:
#############################################
# Cell 4: Create DataLoaders
#############################################

def create_dataloaders(train_ds, val_ds, test_ds, batch_size=4):
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=0)
    val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False, num_workers=0)
    test_loader  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False, num_workers=0)
    return train_loader, val_loader, test_loader

batch_size = 4
train_loader, val_loader, test_loader = create_dataloaders(
    train_dataset, val_dataset, test_dataset, batch_size=batch_size
)


In [None]:
#############################################
# Cell 5: Define Simpler Decoder (No Attention)
#         + SwinTransformerSegModel
#############################################

class SimpleDecoder(nn.Module):
    """
    A simpler UNet-like decoder WITHOUT attention blocks to reduce computation.
    We'll keep skip connections, transposed convolutions, 
    and final upsample to get 224x224 output.
    """
    def __init__(self, encoder_channels, out_channels=1):
        super().__init__()
        
        # e.g. encoder_channels = [96, 192, 384, 768] for swin_tiny
        # We'll do top-down upsampling with skip connections.
        
        self.conv_f4 = nn.Conv2d(encoder_channels[3], 512, kernel_size=1)
        self.conv_f3 = nn.Conv2d(encoder_channels[2], 256, kernel_size=1)
        self.conv_f2 = nn.Conv2d(encoder_channels[1], 128, kernel_size=1)
        self.conv_f1 = nn.Conv2d(encoder_channels[0], 64,  kernel_size=1)
        
        # Decoder up stages
        self.up1 = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
        )
        # fuse with f3
        self.fuse1 = nn.Sequential(
            nn.Conv2d(256+256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        
        self.up2 = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
        )
        # fuse with f2
        self.fuse2 = nn.Sequential(
            nn.Conv2d(128+128, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        
        self.up3 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
        )
        # fuse with f1
        self.fuse3 = nn.Sequential(
            nn.Conv2d(64+64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        
        self.up4 = nn.Sequential(
            nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
        )
        
        # An optional final upsample step to ensure 224x224 
        # (depending on your input resolution & Swin specifics).
        self.up5 = nn.Sequential(
            nn.ConvTranspose2d(32, 16, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
        )
        
        self.out_conv = nn.Conv2d(16, out_channels, kernel_size=1)
        
    def forward(self, features):
        # features = [f1, f2, f3, f4]
        f1, f2, f3, f4 = features
        
        # 1x1 conv to unify channel sizes
        f4 = self.conv_f4(f4)
        f3 = self.conv_f3(f3)
        f2 = self.conv_f2(f2)
        f1 = self.conv_f1(f1)
        
        # Stage 1
        x = f4
        x = self.up1(x)                     # upsample from f4
        x = self.fuse1(torch.cat([x, f3], dim=1))
        
        # Stage 2
        x = self.up2(x)
        x = self.fuse2(torch.cat([x, f2], dim=1))
        
        # Stage 3
        x = self.up3(x)
        x = self.fuse3(torch.cat([x, f1], dim=1))
        
        # Stage 4
        x = self.up4(x)
        
        # Stage 5 (final up to 224x224, if needed)
        x = self.up5(x)
        
        # Final conv
        x = self.out_conv(x)  # (B, out_channels, H, W)
        return x


class SwinTransformerSegModel(nn.Module):
    """
    Full segmentation model with a Swin Transformer encoder
    and our simpler UNet-like decoder (no attention).
    """
    def __init__(self, backbone_name="swin_tiny_patch4_window7_224", out_channels=1):
        super().__init__()
        self.encoder = timm.create_model(backbone_name, pretrained=True, features_only=True)
        encoder_channels = self.encoder.feature_info.channels()  # e.g. [96, 192, 384, 768]
        self.decoder = SimpleDecoder(encoder_channels, out_channels)
    
    def forward(self, x):
        # Extract features from the Swin encoder
        features = self.encoder(x)
        
        # If channels-last, ensure channels-first
        permuted_features = []
        for f in features:
            if f.dim() == 4 and f.shape[1] < f.shape[-1]:
                f = f.permute(0, 3, 1, 2)
            permuted_features.append(f)
        
        # Pass to decoder
        seg_map = self.decoder(permuted_features)
        return seg_map


In [None]:
#############################################
# Cell 6: Loss Functions & Metrics
#############################################

def dice_loss(pred, target, smooth=1e-5):
    pred = torch.sigmoid(pred)
    intersection = (pred * target).sum(dim=(2, 3))
    union = pred.sum(dim=(2, 3)) + target.sum(dim=(2, 3))
    dice = 1 - (2. * intersection + smooth) / (union + smooth)
    return dice.mean()

class ComboLoss(nn.Module):
    """
    Weighted combination of BCEWithLogitsLoss and Dice Loss
    """
    def __init__(self, weight_bce=0.5):
        super().__init__()
        self.bce = nn.BCEWithLogitsLoss()
        self.weight_bce = weight_bce
    
    def forward(self, pred, target):
        loss_bce = self.bce(pred, target)
        loss_dice = dice_loss(pred, target)
        return self.weight_bce * loss_bce + (1 - self.weight_bce) * loss_dice

def dice_coefficient(pred, target, threshold=0.5, smooth=1e-5):
    pred = torch.sigmoid(pred)
    pred = (pred > threshold).float()
    intersection = (pred * target).sum(dim=(2, 3))
    union = pred.sum(dim=(2, 3)) + target.sum(dim=(2, 3))
    dice = (2. * intersection + smooth) / (union + smooth)
    return dice.mean().item()

def iou_coefficient(pred, target, threshold=0.5, smooth=1e-5):
    pred = torch.sigmoid(pred)
    pred = (pred > threshold).float()
    intersection = (pred * target).sum(dim=(2, 3))
    union = pred.sum(dim=(2, 3)) + target.sum(dim=(2, 3)) - intersection
    iou = (intersection + smooth) / (union + smooth)
    return iou.mean().item()

def compute_confusion_matrix(pred, target, threshold=0.5):
    """
    Computes pixel-level confusion matrix (TP, FP, TN, FN) for a batch.
    pred, target: B x 1 x H x W (tensors)
    Returns: TP, FP, TN, FN (scalars)
    """
    pred = torch.sigmoid(pred)
    pred = (pred > threshold).float()
    
    # Flatten
    pred_flat = pred.view(-1)
    target_flat = target.view(-1)
    
    tp = (pred_flat * target_flat).sum()
    fp = (pred_flat * (1 - target_flat)).sum()
    fn = ((1 - pred_flat) * target_flat).sum()
    tn = ((1 - pred_flat) * (1 - target_flat)).sum()
    
    return tp.item(), fp.item(), tn.item(), fn.item()

def compute_additional_metrics(tp, fp, tn, fn, eps=1e-7):
    """
    Given total TP, FP, TN, FN, compute additional metrics:
    accuracy, precision, recall, specificity, f1
    """
    accuracy = (tp + tn) / (tp + tn + fp + fn + eps)
    precision = tp / (tp + fp + eps)
    recall = tp / (tp + fn + eps)
    specificity = tn / (tn + fp + eps)
    f1 = 2 * (precision * recall) / (precision + recall + eps)
    
    return {
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
        "specificity": specificity,
        "f1_score": f1
    }


In [None]:
#############################################
# Cell 7: Training Loop with Early Stopping,
#         LR Scheduling + Extended Metrics,
#         Time Tracking (elapsed/remaining).
#############################################

def train_one_epoch(model, train_loader, criterion, optimizer):
    model.train()
    running_loss = 0.0
    
    for images, masks in train_loader:
        images = images.to(device)
        masks = masks.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * images.size(0)
    
    epoch_loss = running_loss / len(train_loader.dataset)
    return epoch_loss


def validate_one_epoch(model, val_loader, criterion):
    model.eval()
    running_loss = 0.0
    running_dice = 0.0
    running_iou  = 0.0
    
    # For confusion matrix
    total_tp, total_fp, total_tn, total_fn = 0, 0, 0, 0
    
    with torch.no_grad():
        for images, masks in val_loader:
            images = images.to(device)
            masks = masks.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, masks)
            running_loss += loss.item() * images.size(0)
            
            # dice & iou
            d = dice_coefficient(outputs, masks)
            i = iou_coefficient(outputs, masks)
            running_dice += d * images.size(0)
            running_iou  += i * images.size(0)
            
            # confusion matrix
            tp, fp, tn, fn = compute_confusion_matrix(outputs, masks)
            total_tp += tp
            total_fp += fp
            total_tn += tn
            total_fn += fn
    
    epoch_val_loss = running_loss / len(val_loader.dataset)
    epoch_val_dice = running_dice / len(val_loader.dataset)
    epoch_val_iou  = running_iou  / len(val_loader.dataset)
    
    # Additional metrics
    metrics = compute_additional_metrics(total_tp, total_fp, total_tn, total_fn)
    
    return epoch_val_loss, epoch_val_dice, epoch_val_iou, metrics


def format_time(seconds):
    """ Convert seconds to hh:mm:ss string. """
    import datetime
    return str(datetime.timedelta(seconds=int(seconds)))


def train_model(model, 
                train_loader, 
                val_loader, 
                device, 
                epochs=20, 
                patience=5, 
                lr=1e-4, 
                weight_bce=0.5,
                reduce_on_plateau=True):
    """
    Trains the model with:
    - ComboLoss (BCE + Dice)
    - Early stopping on validation loss
    - (Optionally) LR scheduling (ReduceLROnPlateau)
    - Time tracking: prints elapsed & estimated remaining time each epoch
    Returns the best model (by val loss) and a training history.
    """
    criterion = ComboLoss(weight_bce=weight_bce)
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    
    # Learning Rate Scheduler (Reduce LR on Plateau) - no verbose param
    if reduce_on_plateau:
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, 
            mode='min', 
            factor=0.5, 
            patience=2
        )
    else:
        scheduler = None
    
    train_losses = []
    val_losses   = []
    val_dices    = []
    val_ious     = []
    
    # We'll also store additional metrics
    val_accuracies   = []
    val_precisions   = []
    val_recalls      = []
    val_specificities= []
    val_f1s          = []
    
    best_val_loss = float('inf')
    best_model_wts = copy.deepcopy(model.state_dict())
    no_improve_count = 0
    
    import time
    start_time = time.time()
    
    for epoch in range(1, epochs+1):
        epoch_start = time.time()
        
        # Train
        train_loss = train_one_epoch(model, train_loader, criterion, optimizer)
        
        # Validate
        val_loss, val_dice, val_iou, additional_metrics = validate_one_epoch(model, val_loader, criterion)
        
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        val_dices.append(val_dice)
        val_ious.append(val_iou)
        val_accuracies.append(additional_metrics["accuracy"])
        val_precisions.append(additional_metrics["precision"])
        val_recalls.append(additional_metrics["recall"])
        val_specificities.append(additional_metrics["specificity"])
        val_f1s.append(additional_metrics["f1_score"])
        
        # Scheduler step (on val loss)
        if scheduler is not None:
            scheduler.step(val_loss)
            current_lr = scheduler.optimizer.param_groups[0]['lr']
        else:
            current_lr = lr
        
        # Time calculations
        epoch_end = time.time()
        epoch_time = epoch_end - epoch_start
        total_time = epoch_end - start_time
        avg_epoch_time = total_time / epoch
        remaining_time = avg_epoch_time * (epochs - epoch)
        
        print(f"Epoch [{epoch}/{epochs}] (LR: {current_lr:.6f})")
        print(f"  Train Loss: {train_loss:.4f}")
        print(f"  Val Loss:   {val_loss:.4f} | Dice: {val_dice:.4f} | IoU: {val_iou:.4f}")
        print(f"  Accuracy: {additional_metrics['accuracy']:.4f}, "
              f"Precision: {additional_metrics['precision']:.4f}, "
              f"Recall: {additional_metrics['recall']:.4f}, "
              f"Specificity: {additional_metrics['specificity']:.4f}, "
              f"F1: {additional_metrics['f1_score']:.4f}")
        
        print(f"  Time Elapsed: {format_time(total_time)} | "
              f"Epoch Time: {format_time(epoch_time)} | "
              f"Est. Remaining: {format_time(remaining_time)}\n")
        
        # Early stopping check
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model_wts = copy.deepcopy(model.state_dict())
            no_improve_count = 0
        else:
            no_improve_count += 1
        
        if no_improve_count >= patience:
            print("Early stopping triggered!")
            break
    
    # Load best weights
    model.load_state_dict(best_model_wts)
    
    history = {
        "train_loss": train_losses,
        "val_loss": val_losses,
        "val_dice": val_dices,
        "val_iou": val_ious,
        "val_accuracy": val_accuracies,
        "val_precision": val_precisions,
        "val_recall": val_recalls,
        "val_specificity": val_specificities,
        "val_f1": val_f1s
    }
    return model, history


In [None]:
#############################################
# Cell 8: Train the Model
#############################################

model = SwinTransformerSegModel(
    backbone_name="swin_tiny_patch4_window7_224",
    out_channels=1
).to(device)

# Hyperparameters
epochs = 20
patience = 5
lr = 1e-4
weight_bce = 0.5

model, history = train_model(
    model,
    train_loader,
    val_loader,
    device,
    epochs=epochs,
    patience=patience,
    lr=lr,
    weight_bce=weight_bce,
    reduce_on_plateau=True  # LR scheduling
)


In [None]:
#############################################
# Cell 9: Plot Training Curves
#############################################

def plot_training_curves(history):
    epochs_range = range(1, len(history["train_loss"]) + 1)
    
    plt.figure(figsize=(16, 10))
    
    plt.subplot(2, 3, 1)
    plt.plot(epochs_range, history["train_loss"], label="Train Loss")
    plt.plot(epochs_range, history["val_loss"], label="Val Loss")
    plt.title("Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    
    plt.subplot(2, 3, 2)
    plt.plot(epochs_range, history["val_dice"], label="Val Dice")
    plt.title("Dice")
    plt.xlabel("Epoch")
    plt.ylabel("Dice")
    plt.legend()
    
    plt.subplot(2, 3, 3)
    plt.plot(epochs_range, history["val_iou"], label="Val IoU")
    plt.title("IoU")
    plt.xlabel("Epoch")
    plt.ylabel("IoU")
    plt.legend()
    
    plt.subplot(2, 3, 4)
    plt.plot(epochs_range, history["val_accuracy"], label="Val Accuracy")
    plt.title("Accuracy")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.legend()
    
    plt.subplot(2, 3, 5)
    plt.plot(epochs_range, history["val_precision"], label="Val Precision")
    plt.plot(epochs_range, history["val_recall"], label="Val Recall")
    plt.title("Precision & Recall")
    plt.xlabel("Epoch")
    plt.ylabel("Value")
    plt.legend()
    
    plt.subplot(2, 3, 6)
    plt.plot(epochs_range, history["val_specificity"], label="Val Specificity")
    plt.plot(epochs_range, history["val_f1"], label="Val F1")
    plt.title("Specificity & F1")
    plt.xlabel("Epoch")
    plt.ylabel("Value")
    plt.legend()
    
    plt.tight_layout()
    plt.show()

plot_training_curves(history)


In [None]:
#############################################
# Cell 10: Final Evaluation on Test Set 
#          (Multiple metrics)
#############################################

def evaluate_model(model, test_loader):
    model.eval()
    criterion = ComboLoss()  # same combo used in training
    
    test_loss = 0.0
    test_dice = 0.0
    test_iou  = 0.0
    
    total_tp, total_fp, total_tn, total_fn = 0, 0, 0, 0
    
    with torch.no_grad():
        for images, masks in test_loader:
            images = images.to(device)
            masks = masks.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, masks)
            test_loss += loss.item() * images.size(0)
            
            d = dice_coefficient(outputs, masks)
            i = iou_coefficient(outputs, masks)
            test_dice += d * images.size(0)
            test_iou  += i * images.size(0)
            
            tp, fp, tn, fn = compute_confusion_matrix(outputs, masks)
            total_tp += tp
            total_fp += fp
            total_tn += tn
            total_fn += fn
    
    test_loss /= len(test_loader.dataset)
    test_dice /= len(test_loader.dataset)
    test_iou  /= len(test_loader.dataset)
    
    additional_metrics = compute_additional_metrics(total_tp, total_fp, total_tn, total_fn)
    
    print("=== Test Results ===")
    print(f"Loss: {test_loss:.4f}")
    print(f"Dice: {test_dice:.4f} | IoU: {test_iou:.4f}")
    print(f"Accuracy: {additional_metrics['accuracy']:.4f}")
    print(f"Precision: {additional_metrics['precision']:.4f}")
    print(f"Recall: {additional_metrics['recall']:.4f}")
    print(f"Specificity: {additional_metrics['specificity']:.4f}")
    print(f"F1: {additional_metrics['f1_score']:.4f}")


evaluate_model(model, test_loader)


In [None]:
#############################################
# Cell 11: Overlay Predictions for Visualization
#############################################

def overlay_mask_on_image(image, mask, alpha=0.5, color=(0, 255, 0)):
    """
    Overlays a binary mask on an image (both as NumPy arrays).
    image: (H x W x 3), RGB or BGR
    mask:  (H x W), 0/1
    alpha: blending factor
    color: color for the mask overlay
    """
    overlay = image.copy()
    overlay[mask > 0] = color
    return cv2.addWeighted(overlay, alpha, image, 1 - alpha, 0)


def visualize_predictions(model, loader, num_samples=4):
    model.eval()
    
    batch = next(iter(loader))  # get one batch
    images, masks = batch
    images = images.to(device)
    masks = masks.to(device)
    
    with torch.no_grad():
        outputs = model(images)
        preds = torch.sigmoid(outputs)
        preds = (preds > 0.5).float()
    
    for i in range(min(num_samples, images.size(0))):
        # Convert to CPU numpy
        img_np = images[i].cpu().numpy().transpose(1, 2, 0)
        mask_gt = masks[i].cpu().numpy().squeeze()
        mask_pred = preds[i].cpu().numpy().squeeze()
        
        # Denormalize image (if you used ImageNet stats)
        mean = np.array([0.485, 0.456, 0.406])
        std  = np.array([0.229, 0.224, 0.225])
        img_np = (img_np * std + mean)
        img_np = np.clip(img_np, 0, 1)
        img_np = (img_np * 255).astype(np.uint8)
        
        # Overlay ground truth (in RED for example)
        overlay_gt   = overlay_mask_on_image(img_np, mask_gt, alpha=0.5, color=(255, 0, 0))
        # Overlay prediction (in GREEN)
        overlay_pred = overlay_mask_on_image(img_np, mask_pred, alpha=0.5, color=(0, 255, 0))
        
        fig, axes = plt.subplots(1, 3, figsize=(12, 5))
        
        axes[0].imshow(img_np)
        axes[0].set_title("Original")
        axes[0].axis("off")
        
        axes[1].imshow(overlay_gt)
        axes[1].set_title("Ground Truth Overlay (Red)")
        axes[1].axis("off")
        
        axes[2].imshow(overlay_pred)
        axes[2].set_title("Prediction Overlay (Green)")
        axes[2].axis("off")
        
        plt.tight_layout()
        plt.show()

# Visualize a few predictions from the test set
visualize_predictions(model, test_loader, num_samples=4)
