# 0. Setup

In [None]:
import kagglehub
anasmohammedtahir_covidqu_path = kagglehub.dataset_download('anasmohammedtahir/covidqu')

Using Colab cache for faster access to the 'covidqu' dataset.


In [None]:
import shutil
import os

src = anasmohammedtahir_covidqu_path
dst = "/content/covidqu"

if not os.path.exists(dst):
    shutil.copytree(src, dst)

In [None]:
!pip install -q segmentation-models-pytorch albumentations opencv-python scikit-image

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/154.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m154.8/154.8 kB[0m [31m14.0 MB/s[0m eta [36m0:00:00[0m
[?25h

# 1. Requirement

In [None]:
import os
import numpy as np
import pandas as pd
import cv2
import matplotlib.pyplot as plt
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import KFold
import albumentations as A
from albumentations.pytorch import ToTensorV2
import segmentation_models_pytorch as smp
from tqdm import tqdm
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score
from scipy.spatial.distance import directed_hausdorff
import warnings
warnings.filterwarnings('ignore')

torch.manual_seed(3)
np.random.seed(3)

config = {
    'data_root': '/content/covidqu/Infection Segmentation Data/Infection Segmentation Data',

    'image_size': 256, #384,
    'batch_size': 8,
    'num_epochs': 30,
    'learning_rate': 3e-4,
    'weight_decay': 1e-5,
    'patience': 5,

    'encoder_name': 'efficientnet-b3',
    'encoder_weights': 'imagenet',
    'architecture': 'UnetPlusPlus',
    'loss_function': 'bce_dice',
}

#2. Dataset

In [None]:
def load_all_categories(data_root, split='Train'):

    categories = ['COVID-19', 'Non-COVID', 'Normal']
    all_data = []

    for category in categories:
        category_path = os.path.join(data_root, split, category)

        images_path = os.path.join(category_path, 'images')
        infection_masks_path = os.path.join(category_path, 'infection masks')
        lung_masks_path = os.path.join(category_path, 'lung masks')

        if not os.path.exists(images_path):
            print(f"  Warning: {images_path} does not exist, skipping...")
            continue

        image_files = sorted([f for f in os.listdir(images_path) if f.endswith('.png')])

        print(f"  {category}: Found {len(image_files)} images")

        for img_file in image_files:
            img_path = os.path.join(images_path, img_file)
            infection_mask_path = os.path.join(infection_masks_path, img_file)
            lung_mask_path = os.path.join(lung_masks_path, img_file)

            if os.path.exists(infection_mask_path):
                test_mask = cv2.imread(infection_mask_path, cv2.IMREAD_GRAYSCALE)
                if test_mask is not None:
                    all_data.append({
                        'image_path': img_path,
                        'infection_mask_path': infection_mask_path,
                        'lung_mask_path': lung_mask_path if os.path.exists(lung_mask_path) else None,
                        'filename': img_file,
                        'category': category,
                        'split': split
                    })

    return pd.DataFrame(all_data)

print("\nLoading Train split:")
train_df = load_all_categories(config['data_root'], 'Train')

print("\nLoading Val split:")
val_df = load_all_categories(config['data_root'], 'Val')

print("\nLoading Test split:")
test_df = load_all_categories(config['data_root'], 'Test')

print(f"Train samples: {len(train_df)}")
print(f"Val samples: {len(val_df)}")
print(f"Test samples: {len(test_df)}")
print(f"Total samples: {len(train_df) + len(val_df) + len(test_df)}")


Loading Train split:
  COVID-19: Found 1864 images
  Non-COVID: Found 932 images
  Normal: Found 932 images

Loading Val split:
  COVID-19: Found 466 images
  Non-COVID: Found 233 images
  Normal: Found 233 images

Loading Test split:
  COVID-19: Found 583 images
  Non-COVID: Found 292 images
  Normal: Found 291 images
Train samples: 3728
Val samples: 932
Test samples: 1166
Total samples: 5826


#3. Weights

In [None]:
def calculate_class_weights(df):
    total_positive = 0
    total_negative = 0

    for idx, row in tqdm(df.iterrows(), total=len(df), desc="Calculating class weights"):
        mask = cv2.imread(row['infection_mask_path'], cv2.IMREAD_GRAYSCALE)
        binary_mask = (mask > 127).astype(np.uint8)

        total_positive += np.sum(binary_mask)
        total_negative += np.sum(1 - binary_mask)

    pos_weight = total_negative / (total_positive + 1e-6)

    print(f"\nClass distribution:")
    print(f"  Positive pixels: {total_positive}")
    print(f"  Negative pixels: {total_negative}")
    print(f"  Positive weight: {pos_weight:.4f}")

    return pos_weight

pos_weight = calculate_class_weights(train_df)

Calculating class weights: 100%|██████████| 3728/3728 [00:03<00:00, 1066.67it/s]


Class distribution:
  Positive pixels: 17043669
  Negative pixels: 227274539
  Positive weight: 13.3348





#4.Loss

In [None]:
class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-6):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, pred, target):
        pred = pred.contiguous().view(-1)
        target = target.contiguous().view(-1)

        intersection = (pred * target).sum()
        dice = (2. * intersection + self.smooth) / (pred.sum() + target.sum() + self.smooth)

        return 1 - dice

class BCEDiceLoss(nn.Module):
    def __init__(self, bce_weight=0.5, dice_weight=0.5):
        super(BCEDiceLoss, self).__init__()
        self.bce = nn.BCEWithLogitsLoss()
        self.dice = DiceLoss()
        self.bce_weight = bce_weight
        self.dice_weight = dice_weight

    def forward(self, pred, target):
        bce_loss = self.bce(pred, target)
        pred_sigmoid = torch.sigmoid(pred)
        dice_loss = self.dice(pred_sigmoid, target)

        return self.bce_weight * bce_loss + self.dice_weight * dice_loss

class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, pred, target):
        bce_loss = F.binary_cross_entropy_with_logits(pred, target, reduction='none')
        pt = torch.exp(-bce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * bce_loss

        return focal_loss.mean()

class WeightedBCEDiceLoss(nn.Module):
    def __init__(self, bce_weight=0.5, dice_weight=0.5, pos_weight=1.0):
        super(WeightedBCEDiceLoss, self).__init__()
        self.bce = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight]))
        self.dice = DiceLoss()
        self.bce_weight = bce_weight
        self.dice_weight = dice_weight

    def forward(self, pred, target):
        bce_loss = self.bce(pred, target)
        pred_sigmoid = torch.sigmoid(pred)
        dice_loss = self.dice(pred_sigmoid, target)

        return self.bce_weight * bce_loss + self.dice_weight * dice_loss


def get_loss_function(loss_name, pos_weight = 1.0):
    if loss_name == 'dice':
        return DiceLoss()
    elif loss_name == 'bce_dice':
        return WeightedBCEDiceLoss(pos_weight=pos_weight)
    elif loss_name == 'focal':
        return FocalLoss()
    elif loss_name == 'bce':
        return nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight]))
    else:
        raise ValueError(f"Unknown loss function: {loss_name}")

#5. Metrics

In [None]:
def dice_coefficient(pred, target, smooth=1e-6):
    pred = pred.contiguous().view(-1)
    target = target.contiguous().view(-1)

    intersection = (pred * target).sum()
    dice = (2. * intersection + smooth) / (pred.sum() + target.sum() + smooth)

    return dice.item()

def iou_score(pred, target, smooth=1e-6):
    pred = pred.contiguous().view(-1)
    target = target.contiguous().view(-1)

    intersection = (pred * target).sum()
    union = pred.sum() + target.sum() - intersection
    iou = (intersection + smooth) / (union + smooth)

    return iou.item()

def pixel_accuracy(pred, target):
    pred = pred.contiguous().view(-1)
    target = target.contiguous().view(-1)

    correct = (pred == target).sum()
    total = target.numel()

    return (correct.float() / total).item()

def calculate_metrics(pred, target, threshold=0.5):
    pred_binary = (pred > threshold).float()
    target_binary = target.float()

    dice = dice_coefficient(pred_binary, target_binary)
    iou = iou_score(pred_binary, target_binary)
    accuracy = pixel_accuracy(pred_binary, target_binary)

    pred_np = pred_binary.cpu().numpy().flatten()
    target_np = target_binary.cpu().numpy().flatten()

    precision = precision_score(target_np, pred_np, zero_division=0)
    recall = recall_score(target_np, pred_np, zero_division=0)
    f1 = f1_score(target_np, pred_np, zero_division=0)

    return {
        'dice': dice,
        'iou': iou,
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1
    }

# 6. Architechture

In [None]:
def get_model(architecture, encoder_name, encoder_weights='imagenet'):
    if architecture == 'Unet':
        model = smp.Unet(
            encoder_name=encoder_name,
            encoder_weights=encoder_weights,
            in_channels=1,
            classes=1,
            activation=None
        )
    elif architecture == 'UnetPlusPlus':
        model = smp.UnetPlusPlus(
            encoder_name=encoder_name,
            encoder_weights=encoder_weights,
            in_channels=1,
            classes=1,
            activation=None
        )
    elif architecture == 'FPN':
        model = smp.FPN(
            encoder_name=encoder_name,
            encoder_weights=encoder_weights,
            in_channels=1,
            classes=1,
            activation=None
        )
    elif architecture == 'DeepLabV3Plus':
        model = smp.DeepLabV3Plus(
            encoder_name=encoder_name,
            encoder_weights=encoder_weights,
            in_channels=1,
            classes=1,
            activation=None
        )
    else:
        raise ValueError(f"Unknown architecture: {architecture}")

    return model

# 7. Augmentation

In [None]:
def get_train_transforms(image_size):
    return A.Compose([
        A.Resize(image_size, image_size),
        A.HorizontalFlip(p=0.5),
        A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=10, p=0.3),
        A.RandomBrightnessContrast(brightness_limit=0.15, contrast_limit=0.15, p=0.3),
        A.GaussNoise(var_limit=(5.0, 15.0), p=0.2),
        A.Normalize(mean=[0.485], std=[0.229]),
        ToTensorV2(),
    ])

def get_val_transforms(image_size):
    return A.Compose([
        A.Resize(image_size, image_size),
        A.Normalize(mean=[0.485], std=[0.229]),
        ToTensorV2(),
    ])

In [None]:
class COVID19Dataset(Dataset):
    def __init__(self, dataframe, transform=None, debug=False):
        self.df = dataframe.reset_index(drop=True)
        self.transform = transform
        self.debug = debug

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        image = cv2.imread(row['image_path'], cv2.IMREAD_GRAYSCALE)

        mask = cv2.imread(row['infection_mask_path'], cv2.IMREAD_GRAYSCALE)

        if self.debug and idx == 0:
            print(f"Debug - Before processing:")
            print(f"  Image shape: {image.shape}, range: [{image.min()}, {image.max()}]")
            print(f"  Mask shape: {mask.shape}, range: [{mask.min()}, {mask.max()}]")
            print(f"  Infection pixels: {np.sum(mask > 127)}")

        image = np.expand_dims(image, axis=-1) # Expand dimensions for grayscale

        # Binarize mask: 0 or 1
        mask = (mask > 127).astype(np.float32)

        if self.debug and idx == 0:
            print(f"Debug - After binarization:")
            print(f"  Mask unique values: {np.unique(mask)}")
            print(f"  Positive pixels: {np.sum(mask > 0.5)}")

        if self.transform:
            transformed = self.transform(image=image, mask=mask)
            image = transformed['image']
            mask = transformed['mask']

        # Ensure mask has correct shape [1, H, W]
        if len(mask.shape) == 2:
            mask = mask.unsqueeze(0)

        if self.debug and idx == 0:
            print(f"Debug - After transform:")
            print(f"  Image shape: {image.shape}")
            print(f"  Mask shape: {mask.shape}")
            print(f"  Mask unique values: {torch.unique(mask)}")

        return {
            'image': image,
            'mask': mask,
            'filename': row['filename']
        }

test_dataset = COVID19Dataset(train_df[:1], transform=get_val_transforms(256), debug=True)
sample = test_dataset[0]

Debug - Before processing:
  Image shape: (256, 256), range: [0, 255]
  Mask shape: (256, 256), range: [0, 255]
  Infection pixels: 9040
Debug - After binarization:
  Mask unique values: [0. 1.]
  Positive pixels: 9040
Debug - After transform:
  Image shape: torch.Size([1, 256, 256])
  Mask shape: torch.Size([1, 256, 256])
  Mask unique values: tensor([0., 1.])


In [None]:
train_dataset = COVID19Dataset(train_df, transform=get_train_transforms(config['image_size']))
val_dataset = COVID19Dataset(val_df, transform=get_val_transforms(config['image_size']))
test_dataset = COVID19Dataset(test_df, transform=get_val_transforms(config['image_size']))

train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=config['batch_size'], shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False, num_workers=2)

print(f"\nTrain batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print(f"Test batches: {len(test_loader)}")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"\nUsing device: {device}")

model = get_model(config['architecture'], config['encoder_name'], config['encoder_weights'])
model = model.to(device)

criterion = get_loss_function(config['loss_function'], pos_weight=pos_weight)
criterion = criterion.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config['num_epochs'])

Train batches: 466
Val batches: 117
Test batches: 146


In [None]:
def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    all_metrics = {
        'dice': [],
        'iou': [],
        'accuracy': [],
        'precision': [],
        'recall': [],
        'f1': []
    }

    pbar = tqdm(dataloader, desc='Training')
    for batch in pbar:
        images = batch['image'].to(device)
        masks = batch['mask'].to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)

        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        with torch.no_grad():
            pred_sigmoid = torch.sigmoid(outputs)
            metrics = calculate_metrics(pred_sigmoid, masks)
            for key in all_metrics.keys():
                all_metrics[key].append(metrics[key])

        pbar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'dice': f'{metrics["dice"]:.4f}'
        })

    avg_loss = total_loss / len(dataloader)
    avg_metrics = {key: np.mean(values) for key, values in all_metrics.items()}
    avg_metrics['loss'] = avg_loss

    return avg_metrics

In [None]:
from scipy import ndimage

def post_process_mask(pred_mask, min_size=100):
    labeled_mask, num_features = ndimage.label(pred_mask)

    if num_features == 0:
        return pred_mask

    component_sizes = ndimage.sum(pred_mask, labeled_mask, range(num_features + 1))

    mask_size_filter = component_sizes >= min_size
    mask_size_filter[0] = 0  # Background

    cleaned_mask = mask_size_filter[labeled_mask]

    # Morphological closing to fill small holes
    kernel = np.ones((5, 5), np.uint8)
    cleaned_mask = cv2.morphologyEx(cleaned_mask.astype(np.uint8), cv2.MORPH_CLOSE, kernel)

    return cleaned_mask

def validate_epoch(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    all_metrics = {
        'dice': [],
        'dice_postprocess': [],
        'iou': [],
        'accuracy': [],
        'precision': [],
        'recall': [],
        'f1': []
    }

    pbar = tqdm(dataloader, desc='Validation')
    with torch.no_grad():
        for batch in pbar:
            images = batch['image'].to(device)
            masks = batch['mask'].to(device)

            outputs = model(images)
            loss = criterion(outputs, masks)

            total_loss += loss.item()

            pred_sigmoid = torch.sigmoid(outputs)

            # Metrics without post-processing
            metrics = calculate_metrics(pred_sigmoid, masks)
            for key in ['dice', 'iou', 'accuracy', 'precision', 'recall', 'f1']:
                all_metrics[key].append(metrics[key])

            # Metrics with post-processing
            pred_binary = (pred_sigmoid > 0.5).cpu().numpy()
            for i in range(pred_binary.shape[0]):
                pred_postprocess = post_process_mask(pred_binary[i, 0])
                pred_postprocess_tensor = torch.from_numpy(pred_postprocess).float().unsqueeze(0).to(device)

                dice_pp = dice_coefficient(pred_postprocess_tensor, masks[i])
                all_metrics['dice_postprocess'].append(dice_pp)

            pbar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'dice': f'{metrics["dice"]:.4f}'
            })

    avg_loss = total_loss / len(dataloader)
    avg_metrics = {key: np.mean(values) for key, values in all_metrics.items()}
    avg_metrics['loss'] = avg_loss

    return avg_metrics

# 8. Train + Validation

In [None]:
train_df = load_all_categories(config['data_root'], 'Train')
val_df = load_all_categories(config['data_root'], 'Val')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"\nUsing device: {device}")

print(f"Train samples: {len(train_df)}")
print(f"Val samples: {len(val_df)}")

pos_weight = calculate_class_weights(train_df)

train_dataset = COVID19Dataset(train_df, transform=get_train_transforms(config['image_size']))
val_dataset = COVID19Dataset(val_df, transform=get_val_transforms(config['image_size']))

train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=config['batch_size'], shuffle=False, num_workers=2)

model = get_model(config['architecture'], config['encoder_name'], config['encoder_weights'])
model = model.to(device)

criterion = get_loss_function(config['loss_function'], pos_weight=pos_weight)
criterion = criterion.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config['num_epochs'])

if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path, weights_only=False)

    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

    print(f"Resuming from epoch {start_epoch + 1}")

COVID-19: Found 1864 images
Non-COVID: Found 932 images
Normal: Found 932 images
COVID-19: Found 466 images
Non-COVID: Found 233 images
Normal: Found 233 images
Using device: cuda
Train samples: 3728
Val samples: 932
Calculating class weights: 100%|██████████| 3728/3728 [00:02<00:00, 1263.23it/s]
Class distribution:
Positive pixels: 17043669
Negative pixels: 227274539
Positive weight: 13.3348


In [None]:
patience_counter = 0
best_epoch = start_epoch


for epoch in range(start_epoch, config['num_epochs']):
    print(f"\nEpoch {epoch + 1}/{config['num_epochs']}")
    print("-" * 70)

    # Train
    train_metrics = train_epoch(model, train_loader, criterion, optimizer, device)

    # Validate
    val_metrics = validate_epoch(model, val_loader, criterion, device)

    # Step scheduler
    scheduler.step()

    # Print metrics
    print(f"\nTrain - Loss: {train_metrics['loss']:.4f}, Dice: {train_metrics['dice']:.4f}, IoU: {train_metrics['iou']:.4f}")
    print(f"Val   - Loss: {val_metrics['loss']:.4f}, Dice: {val_metrics['dice']:.4f}, IoU: {val_metrics['iou']:.4f}")


    # Save best model
    if val_metrics['dice'] > best_dice:
        best_dice = val_metrics['dice']
        best_epoch = epoch + 1

        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'best_dice': best_dice,
            'config': config
        }, '/content/best_model.pth')

        print(f"\n✓ Saved best model with Dice: {best_dice:.4f}")
        patience_counter = 0
    else:
        patience_counter += 1

    # Early stopping
    if patience_counter >= config['patience']:
        print(f"\n⚠ Early stopping triggered after {epoch + 1} epochs")
        print(f"Best Dice: {best_dice:.4f} at epoch {best_epoch}")
        break
print(f"Best Validation Dice: {best_dice:.4f} at epoch {best_epoch}")

Epoch 30/30
----------------------------------------------------------------------
Training: 100%|██████████| 466/466 [02:55<00:00,  2.65it/s, loss=0.3025, dice=0.6788]
Validation: 100%|██████████| 117/117 [00:21<00:00,  5.45it/s, loss=0.5000, dice=1.0000]
Train - Loss: 0.1779, Dice: 0.8183, IoU: 0.7050
Val   - Loss: 0.3982, Dice: 0.8952, IoU: 0.8390
TRAINING COMPLETE
Best Validation Dice: 0.9032 at epoch 26
Run summary:
epoch	30
learning_rate	0
train_accuracy	0.97321
train_dice	0.8183
train_f1	0.8183
train_iou	0.70495
train_loss	0.17789
train_precision	0.72724
train_recall	0.95723
val_accuracy	0.97686


# 9. Testing

In [None]:
test_loader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False, num_workers=2)
checkpoint = torch.load('/content/best_model.pth', weights_only=False)
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(device)
model.eval()

all_test_metrics = { # Evaluate on test set
    'dice': [],
    'iou': [],
    'accuracy': [],
    'precision': [],
    'recall': [],
    'f1': []
}

print("\nRunning inference on test set")
with torch.no_grad():
    for batch in tqdm(test_loader, desc='Test Inference'):
        images = batch['image'].to(device)
        masks = batch['mask'].to(device)

        outputs = model(images)
        pred_sigmoid = torch.sigmoid(outputs)

        metrics = calculate_metrics(pred_sigmoid, masks)
        for key in all_test_metrics.keys():
            all_test_metrics[key].append(metrics[key])

avg_test_metrics = {key: np.mean(values) for key, values in all_test_metrics.items()}

print("TEST SET RESULTS")
print(f"Dice Coefficient: {avg_test_metrics['dice']:.4f}")
print(f"IoU Score:        {avg_test_metrics['iou']:.4f}")
print(f"Pixel Accuracy:   {avg_test_metrics['accuracy']:.4f}")
print(f"Precision:        {avg_test_metrics['precision']:.4f}")
print(f"Recall:           {avg_test_metrics['recall']:.4f}")
print(f"F1 Score:         {avg_test_metrics['f1']:.4f}")

Running inference on test set
Test Inference:   0%|          | 0/1 [00:00<?, ?it/s]Debug - Before processing:
 Image shape: (256, 256), range: [0, 255]
 Mask shape: (256, 256), range: [0, 255]
 Infection pixels: 9040
Debug - After binarization:
 Mask unique values: [0. 1.]
 Positive pixels: 9040
Debug - After transform:
 Image shape: torch.Size([1, 256, 256])
 Mask shape: torch.Size([1, 256, 256])
 Mask unique values: tensor([0., 1.])
Test Inference: 100%|██████████| 1/1 [00:00<00:00,  1.23it/s]TEST SET RESULTS
Dice Coefficient: 0.9218
IoU Score:        0.8550
Pixel Accuracy:   0.9768
Precision:        0.8610
Recall:           0.9918
F1 Score:         0.9218


# 10. Visualize

In [None]:
# Visualize some test predictions
num_samples = min(8, len(test_dataset))
fig, axes = plt.subplots(num_samples, 3, figsize=(12, num_samples * 4))

model.eval()
with torch.no_grad():
    for i in range(num_samples):
        sample = test_dataset[i]
        image = sample['image'].unsqueeze(0).to(device)
        true_mask = sample['mask'].squeeze().cpu().numpy()

        output = model(image)
        pred_mask = torch.sigmoid(output).squeeze().cpu().numpy()
        pred_mask_binary = (pred_mask > 0.5).astype(np.uint8)

        img_np = sample['image'].squeeze().cpu().numpy()

        # Denormalize image for visualization
        img_denorm = img_np * 0.229 + 0.485

        axes[i, 0].imshow(img_denorm, cmap='gray')
        axes[i, 0].set_title('Input Image')
        axes[i, 0].axis('off')

        axes[i, 1].imshow(true_mask, cmap='gray')
        axes[i, 1].set_title('Ground Truth')
        axes[i, 1].axis('off')

        axes[i, 2].imshow(img_denorm, cmap='gray')
        axes[i, 2].imshow(pred_mask_binary, cmap='Reds', alpha=0.5)

        # Calculate Dice for this sample
        dice = dice_coefficient(
            torch.from_numpy(pred_mask_binary).float(),
            torch.from_numpy(true_mask).float()
        )
        axes[i, 2].set_title(f'Prediction (Dice: {dice:.3f})')
        axes[i, 2].axis('off')

plt.tight_layout()
plt.savefig('/content/test_predictions.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
summary = {
    'best_val_dice': best_dice,
    'best_epoch': best_epoch,
    'test_dice': avg_test_metrics['dice'],
    'test_iou': avg_test_metrics['iou'],
    'test_f1': avg_test_metrics['f1'],
    'config': config
}
print(f"Best Val Dice:  {best_dice:.4f}")
print(f"Test Dice:      {avg_test_metrics['dice']:.4f}")
print(f"Test IoU:       {avg_test_metrics['iou']:.4f}")