In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
import os
import sys
import random
import time
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import transforms

print(f"Python: {sys.version.split()[0]}")
print(f"PyTorch: {torch.__version__}")
print(f"Torchvision: {torchvision.__version__}")

# Set random seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
def find_voc_dataset():
    """Automatically find VOC dataset path in Kaggle input directory"""
    input_base = Path('/kaggle/input')
    
    # Based on your screenshot, the structure is:
    # pascal-voc-2012-dataset/VOC2012_train_val/VOC2012_train_val/
    potential_paths = [
        input_base / 'pascal-voc-2012-dataset' / 'VOC2012_train_val' / 'VOC2012_train_val',
        input_base / 'pascal-voc-2012-dataset' / 'VOC2012_test' / 'VOC2012_test',
        input_base / 'pascal-voc-2012-dataset' / 'VOCdevkit' / 'VOC2012',
        input_base / 'pascal-voc-2012-dataset' / 'VOC2012',
        input_base / 'pascal-voc-2012-dataset',
    ]
    
    for p in potential_paths:
        if p.exists() and (p / 'JPEGImages').exists():
            print(f"✓ Found dataset at: {p}")
            return p
    
    # Fallback: search recursively
    print("Searching for JPEGImages directory...")
    for root, dirs, files in os.walk(input_base):
        if 'JPEGImages' in dirs and 'SegmentationClass' in dirs:
            found_path = Path(root)
            print(f"✓ Found dataset at: {found_path}")
            return found_path
    
    raise FileNotFoundError("Could not find VOC dataset. Please check the input directory.")

ROOT = find_voc_dataset()
print(f"\nDataset root: {ROOT}")
print(f"\nDataset structure:")
for item in sorted(ROOT.iterdir()):
    if item.is_dir():
        file_count = len(list(item.glob('*')))
        print(f"  📁 {item.name}/ ({file_count} items)")
    else:
        print(f"  📄 {item.name}")

# Verify critical directories
assert (ROOT / 'JPEGImages').exists(), "JPEGImages directory not found"
assert (ROOT / 'SegmentationClass').exists(), "SegmentationClass directory not found"

jpg_count = len(list((ROOT / 'JPEGImages').glob('*.jpg')))
png_count = len(list((ROOT / 'SegmentationClass').glob('*.png')))
print(f"\n✓ Found {jpg_count} images and {png_count} masks")


In [None]:
def create_pascal_label_colormap():
    """Creates a label colormap used in PASCAL VOC segmentation benchmark."""
    colormap = np.zeros((256, 3), dtype=np.uint8)
    
    for i in range(256):
        r = g = b = 0
        cid = i
        for j in range(8):
            r |= ((cid >> 0) & 1) << (7 - j)
            g |= ((cid >> 1) & 1) << (7 - j)
            b |= ((cid >> 2) & 1) << (7 - j)
            cid >>= 3
        colormap[i] = [r, g, b]
    
    return colormap

VOC_COLORMAP = create_pascal_label_colormap()
VOC_CLASSES = [
    'background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle',
    'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog',
    'horse', 'motorbike', 'person', 'pottedplant', 'sheep',
    'sofa', 'train', 'tvmonitor'
]

def mask_to_class_index(mask_pil):
    """Convert RGB mask to class indices (0-20) with 255 for ignore."""
    arr = np.array(mask_pil.convert('RGB'), dtype=np.uint8)
    h, w, _ = arr.shape
    
    # Create mapping from color to class index (only 21 classes used)
    color_to_idx = {tuple(VOC_COLORMAP[i]): i for i in range(21)}
    
    # Initialize with ignore index
    result = np.full((h, w), 255, dtype=np.uint8)
    
    # Map each pixel
    arr_flat = arr.reshape(-1, 3)
    result_flat = np.full(arr_flat.shape[0], 255, dtype=np.uint8)
    
    for color, idx in color_to_idx.items():
        matches = np.all(arr_flat == color, axis=1)
        result_flat[matches] = idx
    
    return result_flat.reshape(h, w)

def class_index_to_rgb(mask_arr):
    """Convert class indices back to RGB for visualization."""
    h, w = mask_arr.shape
    rgb = np.zeros((h, w, 3), dtype=np.uint8)
    
    for cls_idx in range(21):
        rgb[mask_arr == cls_idx] = VOC_COLORMAP[cls_idx]
    
    return rgb

print(f"✓ VOC colormap created with {len(VOC_CLASSES)} classes")

In [None]:
class DoubleConv(nn.Module):
    """(Conv => BN => ReLU) * 2"""
    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""
    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)


class UNet(nn.Module):
    def __init__(self, n_channels=3, n_classes=21, bilinear=True):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

print("✓ U-Net architecture defined")

In [None]:
class PascalVOCDataset(Dataset):
    """Pascal VOC 2012 Segmentation Dataset"""
    
    def __init__(self, root, image_ids, img_size=(256, 256), augment=False):
        self.root = Path(root)
        self.image_ids = image_ids
        self.img_size = img_size
        self.augment = augment
        
        self.img_dir = self.root / 'JPEGImages'
        self.mask_dir = self.root / 'SegmentationClass'
        
        # Data augmentation transforms
        self.color_jitter = transforms.ColorJitter(
            brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1
        )
    
    def __len__(self):
        return len(self.image_ids)
    
    def __getitem__(self, idx):
        img_id = self.image_ids[idx]
        
        # Load image and mask
        img_path = self.img_dir / f"{img_id}.jpg"
        mask_path = self.mask_dir / f"{img_id}.png"
        
        if not img_path.exists() or not mask_path.exists():
            # Return next valid item if this one is missing
            return self.__getitem__((idx + 1) % len(self))
        
        img = Image.open(img_path).convert('RGB')
        mask = Image.open(mask_path)
        
        # Resize
        img = img.resize(self.img_size, Image.BILINEAR)
        mask = mask.resize(self.img_size, Image.NEAREST)
        
        # Apply augmentations
        if self.augment:
            # Random horizontal flip
            if random.random() > 0.5:
                img = img.transpose(Image.FLIP_LEFT_RIGHT)
                mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
            
            # Random vertical flip
            if random.random() > 0.5:
                img = img.transpose(Image.FLIP_TOP_BOTTOM)
                mask = mask.transpose(Image.FLIP_TOP_BOTTOM)
            
            # Color jittering (only on image)
            img = self.color_jitter(img)
        
        # Convert mask to class indices
        mask_arr = mask_to_class_index(mask)
        
        # Convert to tensors
        img_tensor = transforms.ToTensor()(img)
        img_tensor = transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )(img_tensor)
        
        mask_tensor = torch.from_numpy(mask_arr).long()
        
        return img_tensor, mask_tensor

In [None]:
# Get all valid image IDs (those with both image and mask)
all_img_ids = [p.stem for p in sorted((ROOT / 'JPEGImages').glob('*.jpg'))]

valid_ids = []
for img_id in all_img_ids:
    img_path = ROOT / 'JPEGImages' / f'{img_id}.jpg'
    mask_path = ROOT / 'SegmentationClass' / f'{img_id}.png'
    if img_path.exists() and mask_path.exists():
        valid_ids.append(img_id)

print(f"Total valid samples: {len(valid_ids)}")

# Split into train and validation (85/15 split)
train_ids, val_ids = train_test_split(
    valid_ids, test_size=0.15, random_state=SEED
)

print(f"Train samples: {len(train_ids)}")
print(f"Val samples: {len(val_ids)}")

# Create datasets
IMG_SIZE = (256, 256)  # U-Net typically uses 256x256 or 512x512

train_dataset = PascalVOCDataset(
    ROOT, train_ids, img_size=IMG_SIZE, augment=True
)
val_dataset = PascalVOCDataset(
    ROOT, val_ids, img_size=IMG_SIZE, augment=False
)

# Create dataloaders
BATCH_SIZE = 8  # Adjust based on GPU memory
NUM_WORKERS = 2

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    drop_last=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True
)

print(f"\n✓ DataLoaders created")
print(f"  Image size: {IMG_SIZE}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Train batches: {len(train_loader)}")
print(f"  Val batches: {len(val_loader)}")

In [None]:
def denormalize(img_tensor):
    """Denormalize image tensor for visualization"""
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    return img_tensor * std + mean

def visualize_batch(images, masks, predictions=None, num_samples=4):
    """Visualize a batch of images with masks and predictions"""
    num_samples = min(num_samples, images.shape[0])
    
    cols = 3 if predictions is not None else 2
    fig, axes = plt.subplots(num_samples, cols, figsize=(cols*5, num_samples*5))
    
    if num_samples == 1:
        axes = axes.reshape(1, -1)
    
    for i in range(num_samples):
        # Denormalize and convert image
        img = denormalize(images[i]).cpu().permute(1, 2, 0).numpy()
        img = np.clip(img, 0, 1)
        
        # Convert mask to RGB
        mask = masks[i].cpu().numpy()
        mask_rgb = class_index_to_rgb(mask)
        
        # Plot image and mask
        axes[i, 0].imshow(img)
        axes[i, 0].set_title('Image', fontsize=14, fontweight='bold')
        axes[i, 0].axis('off')
        
        axes[i, 1].imshow(mask_rgb)
        axes[i, 1].set_title('Ground Truth', fontsize=14, fontweight='bold')
        axes[i, 1].axis('off')
        
        # Plot prediction if available
        if predictions is not None:
            pred = predictions[i].cpu().numpy()
            pred_rgb = class_index_to_rgb(pred)
            axes[i, 2].imshow(pred_rgb)
            axes[i, 2].set_title('Prediction', fontsize=14, fontweight='bold')
            axes[i, 2].axis('off')
    
    plt.tight_layout()
    plt.show()

# Show sample batch
sample_images, sample_masks = next(iter(train_loader))
print(f"Sample batch shape: images={sample_images.shape}, masks={sample_masks.shape}")
visualize_batch(sample_images, sample_masks, num_samples=3)

In [None]:
NUM_CLASSES = 21

# Initialize U-Net model
model = UNet(n_channels=3, n_classes=NUM_CLASSES, bilinear=True)
model = model.to(device)

# Loss function (ignore index 255)
criterion = nn.CrossEntropyLoss(ignore_index=255)

# Optimizer - Adam with weight decay
optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)

# Learning rate scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=5, verbose=True, min_lr=1e-7
)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\n{'='*60}")
print(f"U-Net Model Initialized")
print(f"{'='*60}")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Model size: {total_params * 4 / 1024 / 1024:.2f} MB (float32)")
print(f"{'='*60}\n")

In [None]:
def compute_iou(pred, target, num_classes=21, ignore_index=255):
    """Compute Intersection over Union (IoU) for each class"""
    pred = pred.flatten()
    target = target.flatten()
    
    # Remove ignore index
    mask = target != ignore_index
    pred = pred[mask]
    target = target[mask]
    
    ious = []
    for cls in range(num_classes):
        pred_cls = (pred == cls)
        target_cls = (target == cls)
        
        intersection = (pred_cls & target_cls).sum().item()
        union = (pred_cls | target_cls).sum().item()
        
        if union == 0:
            ious.append(float('nan'))  # Class not present
        else:
            ious.append(intersection / union)
    
    return ious

def compute_pixel_accuracy(pred, target, ignore_index=255):
    """Compute pixel-wise accuracy"""
    mask = target != ignore_index
    correct = (pred[mask] == target[mask]).sum().item()
    total = mask.sum().item()
    return correct / total if total > 0 else 0.0

def compute_dice_score(pred, target, num_classes=21, ignore_index=255):
    """Compute Dice coefficient for each class"""
    pred = pred.flatten()
    target = target.flatten()
    
    # Remove ignore index
    mask = target != ignore_index
    pred = pred[mask]
    target = target[mask]
    
    dice_scores = []
    for cls in range(num_classes):
        pred_cls = (pred == cls)
        target_cls = (target == cls)
        
        intersection = (pred_cls & target_cls).sum().item()
        pred_sum = pred_cls.sum().item()
        target_sum = target_cls.sum().item()
        
        if pred_sum + target_sum == 0:
            dice_scores.append(float('nan'))
        else:
            dice_scores.append(2 * intersection / (pred_sum + target_sum))
    
    return dice_scores

print("✓ Metrics functions defined (IoU, Pixel Accuracy, Dice Score)")

In [None]:
def train_one_epoch(model, loader, criterion, optimizer, device, epoch):
    """Train for one epoch"""
    model.train()
    
    running_loss = 0.0
    running_acc = 0.0
    num_batches = len(loader)
    
    for batch_idx, (images, masks) in enumerate(loader):
        images = images.to(device)
        masks = masks.to(device)
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, masks)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Metrics
        preds = outputs.argmax(dim=1)
        acc = compute_pixel_accuracy(preds, masks)
        
        running_loss += loss.item()
        running_acc += acc
        
        # Print progress
        if (batch_idx + 1) % 20 == 0 or (batch_idx + 1) == num_batches:
            print(f"  [{batch_idx+1:3d}/{num_batches}] "
                  f"Loss: {loss.item():.4f} | Acc: {acc:.4f}")
    
    epoch_loss = running_loss / num_batches
    epoch_acc = running_acc / num_batches
    
    return epoch_loss, epoch_acc

def validate(model, loader, criterion, device):
    """Validate the model"""
    model.eval()
    
    running_loss = 0.0
    running_acc = 0.0
    all_ious = []
    all_dice = []
    
    with torch.no_grad():
        for images, masks in loader:
            images = images.to(device)
            masks = masks.to(device)
            
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, masks)
            
            # Metrics
            preds = outputs.argmax(dim=1)
            acc = compute_pixel_accuracy(preds, masks)
            
            # Compute IoU and Dice for each sample in batch
            for pred, mask in zip(preds, masks):
                ious = compute_iou(pred.cpu().numpy(), mask.cpu().numpy())
                dice = compute_dice_score(pred.cpu().numpy(), mask.cpu().numpy())
                all_ious.append(ious)
                all_dice.append(dice)
            
            running_loss += loss.item()
            running_acc += acc
    
    epoch_loss = running_loss / len(loader)
    epoch_acc = running_acc / len(loader)
    
    # Compute mean IoU and Dice (excluding NaN values)
    all_ious = np.array(all_ious)
    all_dice = np.array(all_dice)
    mean_ious = np.nanmean(all_ious, axis=0)
    mean_dice = np.nanmean(all_dice, axis=0)
    mean_iou = np.nanmean(mean_ious)
    mean_dice_score = np.nanmean(mean_dice)
    
    return epoch_loss, epoch_acc, mean_iou, mean_dice_score, mean_ious

print("✓ Training and validation functions defined")

In [None]:
NUM_EPOCHS = 30
best_miou = 0.0

history = {
    'train_loss': [], 'train_acc': [],
    'val_loss': [], 'val_acc': [], 
    'val_miou': [], 'val_dice': []
}

print(f"\n{'='*70}")
print(f"Starting U-Net Training for {NUM_EPOCHS} epochs")
print(f"{'='*70}\n")

start_time = time.time()

for epoch in range(NUM_EPOCHS):
    epoch_start = time.time()
    
    print(f"\nEpoch [{epoch+1}/{NUM_EPOCHS}]")
    print("-" * 70)
    
    # Train
    train_loss, train_acc = train_one_epoch(
        model, train_loader, criterion, optimizer, device, epoch
    )
    
    # Validate
    val_loss, val_acc, val_miou, val_dice, class_ious = validate(
        model, val_loader, criterion, device
    )
    
    # Update learning rate
    scheduler.step(val_loss)
    
    # Save history
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    history['val_miou'].append(val_miou)
    history['val_dice'].append(val_dice)
    
    # Print epoch summary
    epoch_time = time.time() - epoch_start
    print(f"\n{'='*70}")
    print(f"Epoch {epoch+1} Summary:")
    print(f"  Train - Loss: {train_loss:.4f} | Acc: {train_acc:.4f}")
    print(f"  Val   - Loss: {val_loss:.4f} | Acc: {val_acc:.4f}")
    print(f"         mIoU: {val_miou:.4f} | Dice: {val_dice:.4f}")
    print(f"  Time: {epoch_time:.2f}s | LR: {optimizer.param_groups[0]['lr']:.2e}")
    
    # Save best model
    if val_miou > best_miou:
        best_miou = val_miou
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_miou': val_miou,
            'val_dice': val_dice,
        }, '/kaggle/working/best_unet_model.pth')
        print(f"  ✓ Saved best model! (mIoU: {best_miou:.4f})")
    
    print(f"{'='*70}")

total_time = time.time() - start_time
print(f"\n{'='*70}")
print(f"Training Completed!")
print(f"{'='*70}")
print(f"Total time: {total_time/60:.2f} minutes")
print(f"Best validation mIoU: {best_miou:.4f}")
print(f"{'='*70}\n")

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(16, 12))

# Loss plot
axes[0, 0].plot(history['train_loss'], label='Train Loss', marker='o', linewidth=2)
axes[0, 0].plot(history['val_loss'], label='Val Loss', marker='s', linewidth=2)
axes[0, 0].set_xlabel('Epoch', fontsize=12)
axes[0, 0].set_ylabel('Loss', fontsize=12)
axes[0, 0].set_title('Training and Validation Loss', fontsize=14, fontweight='bold')
axes[0, 0].legend(fontsize=11)
axes[0, 0].grid(True, alpha=0.3)

# Accuracy plot
axes[0, 1].plot(history['train_acc'], label='Train Acc', marker='o', linewidth=2, color='green')
axes[0, 1].plot(history['val_acc'], label='Val Acc', marker='s', linewidth=2, color='orange')
axes[0, 1].set_xlabel('Epoch', fontsize=12)
axes[0, 1].set_ylabel('Accuracy', fontsize=12)
axes[0, 1].set_title('Training and Validation Accuracy', fontsize=14, fontweight='bold')
axes[0, 1].legend(fontsize=11)
axes[0, 1].grid(True, alpha=0.3)

# mIoU plot
axes[1, 0].plot(history['val_miou'], label='Val mIoU', marker='s', linewidth=2, color='purple')
axes[1, 0].set_xlabel('Epoch', fontsize=12)
axes[1, 0].set_ylabel('mIoU', fontsize=12)
axes[1, 0].set_title('Validation Mean IoU', fontsize=14, fontweight='bold')
axes[1, 0].legend(fontsize=11)
axes[1, 0].grid(True, alpha=0.3)

# Dice score plot
axes[1, 1].plot(history['val_dice'], label='Val Dice', marker='s', linewidth=2, color='red')
axes[1, 1].set_xlabel('Epoch', fontsize=12)
axes[1, 1].set_ylabel('Dice Score', fontsize=12)
axes[1, 1].set_title('Validation Dice Score', fontsize=14, fontweight='bold')
axes[1, 1].legend(fontsize=11)
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('/kaggle/working/unet_training_history.png', dpi=150, bbox_inches='tight')
plt.show()

print("✓ Training history plots saved")

In [None]:
print("\n" + "="*70)
print("Per-Class Performance Analysis on Validation Set")
print("="*70)

# Compute per-class metrics on full validation set
model.eval()
all_class_ious = []
all_class_dice = []

with torch.no_grad():
    for images, masks in val_loader:
        images = images.to(device)
        masks = masks.to(device)
        
        outputs = model(images)
        preds = outputs.argmax(dim=1)
        
        for pred, mask in zip(preds, masks):
            ious = compute_iou(pred.cpu().numpy(), mask.cpu().numpy())
            dice = compute_dice_score(pred.cpu().numpy(), mask.cpu().numpy())
            all_class_ious.append(ious)
            all_class_dice.append(dice)

# Calculate mean per class
all_class_ious = np.array(all_class_ious)
all_class_dice = np.array(all_class_dice)
mean_ious = np.nanmean(all_class_ious, axis=0)
mean_dice = np.nanmean(all_class_dice, axis=0)

# Create results DataFrame
results = []
for idx, (cls_name, iou, dice) in enumerate(zip(VOC_CLASSES, mean_ious, mean_dice)):
    if not np.isnan(iou):
        results.append({
            'Class': cls_name,
            'IoU': f"{iou:.4f}",
            'Dice': f"{dice:.4f}",
            'IoU_numeric': iou,
            'Dice_numeric': dice
        })

results_df = pd.DataFrame(results)
results_df = results_df.sort_values('IoU_numeric', ascending=False)
results_df_display = results_df[['Class', 'IoU', 'Dice']].reset_index(drop=True)

print("\n" + results_df_display.to_string(index=False))
print("\n" + "="*70)
print(f"Overall Mean IoU: {np.nanmean(mean_ious):.4f}")
print(f"Overall Mean Dice: {np.nanmean(mean_dice):.4f}")
print("="*70)

# Save results
results_df_display.to_csv('/kaggle/working/unet_class_metrics.csv', index=False)
print("\n✓ Per-class metrics saved to /kaggle/working/unet_class_metrics.csv")

In [None]:
# Plot per-class IoU
fig, ax = plt.subplots(figsize=(14, 8))

classes_with_data = [results_df.iloc[i]['Class'] for i in range(len(results_df))]
ious_with_data = [results_df.iloc[i]['IoU_numeric'] for i in range(len(results_df))]

colors = plt.cm.viridis(np.linspace(0, 1, len(classes_with_data)))
bars = ax.barh(classes_with_data, ious_with_data, color=colors)

ax.set_xlabel('IoU Score', fontsize=12, fontweight='bold')
ax.set_ylabel('Class', fontsize=12, fontweight='bold')
ax.set_title('Per-Class IoU Performance (U-Net)', fontsize=14, fontweight='bold')
ax.axvline(x=np.nanmean(mean_ious), color='red', linestyle='--', linewidth=2, label=f'Mean IoU: {np.nanmean(mean_ious):.4f}')
ax.legend(fontsize=11)
ax.grid(axis='x', alpha=0.3)

plt.tight_layout()
plt.savefig('/kaggle/working/unet_class_iou_barplot.png', dpi=150, bbox_inches='tight')
plt.show()

print("✓ Per-class IoU plot saved")

In [None]:
def predict_image(model, img_path, device, img_size=(256, 256), return_overlay=False):
    """Predict segmentation mask for a single image"""
    model.eval()
    
    # Load and preprocess image
    img = Image.open(img_path).convert('RGB')
    original_size = img.size
    
    img_resized = img.resize(img_size, Image.BILINEAR)
    img_tensor = transforms.ToTensor()(img_resized)
    img_tensor = transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )(img_tensor)
    img_tensor = img_tensor.unsqueeze(0).to(device)
    
    # Predict
    with torch.no_grad():
        output = model(img_tensor)
        pred = output.argmax(dim=1).squeeze(0).cpu().numpy()
    
    # Convert to RGB
    pred_rgb = class_index_to_rgb(pred)
    pred_rgb_pil = Image.fromarray(pred_rgb).resize(original_size, Image.NEAREST)
    
    # Create overlay if requested
    if return_overlay:
        img_np = np.array(img)
        pred_rgb_resized = np.array(Image.fromarray(pred_rgb).resize(original_size, Image.NEAREST))
        overlay = (img_np * 0.6 + pred_rgb_resized * 0.4).astype(np.uint8)
        return pred_rgb_pil, pred, Image.fromarray(overlay)
    
    return pred_rgb_pil, pred

# Test on random validation images
print("\nTesting inference on random validation samples:\n")

num_test_samples = 4
test_ids = random.sample(val_ids, min(num_test_samples, len(val_ids)))

fig, axes = plt.subplots(num_test_samples, 4, figsize=(20, num_test_samples*5))
if num_test_samples == 1:
    axes = axes.reshape(1, -1)

for idx, test_id in enumerate(test_ids):
    test_img_path = ROOT / 'JPEGImages' / f'{test_id}.jpg'
    test_mask_path = ROOT / 'SegmentationClass' / f'{test_id}.png'
    
    # Predict
    pred_rgb, pred_mask, overlay = predict_image(model, test_img_path, device, img_size=IMG_SIZE, return_overlay=True)
    
    # Load ground truth
    ground_truth = Image.open(test_mask_path).resize(IMG_SIZE, Image.NEAREST)
    gt_arr = mask_to_class_index(ground_truth)
    gt_rgb = class_index_to_rgb(gt_arr)
    
    # Calculate IoU for this sample
    sample_iou = compute_iou(pred_mask, gt_arr)
    mean_sample_iou = np.nanmean(sample_iou)
    
    # Plot
    axes[idx, 0].imshow(Image.open(test_img_path))
    axes[idx, 0].set_title(f'Original Image\n{test_id}', fontsize=11, fontweight='bold')
    axes[idx, 0].axis('off')
    
    axes[idx, 1].imshow(gt_rgb)
    axes[idx, 1].set_title('Ground Truth', fontsize=11, fontweight='bold')
    axes[idx, 1].axis('off')
    
    axes[idx, 2].imshow(class_index_to_rgb(pred_mask))
    axes[idx, 2].set_title(f'Prediction\nmIoU: {mean_sample_iou:.4f}', fontsize=11, fontweight='bold')
    axes[idx, 2].axis('off')
    
    axes[idx, 3].imshow(overlay)
    axes[idx, 3].set_title('Overlay (60% img + 40% pred)', fontsize=11, fontweight='bold')
    axes[idx, 3].axis('off')

plt.tight_layout()
plt.savefig('/kaggle/working/unet_inference_samples.png', dpi=150, bbox_inches='tight')
plt.show()

print("✓ Inference samples saved")

In [None]:
print("\nComputing confusion matrix (this may take a minute)...")

from sklearn.metrics import confusion_matrix
import seaborn as sns

all_preds = []
all_targets = []

model.eval()
with torch.no_grad():
    for images, masks in val_loader:
        images = images.to(device)
        outputs = model(images)
        preds = outputs.argmax(dim=1)
        
        # Flatten and filter out ignore index
        for pred, mask in zip(preds, masks):
            pred_np = pred.cpu().numpy().flatten()
            mask_np = mask.cpu().numpy().flatten()
            
            valid_mask = mask_np != 255
            all_preds.extend(pred_np[valid_mask])
            all_targets.extend(mask_np[valid_mask])

# Sample for faster computation (use all data for accurate matrix)
sample_size = min(100000, len(all_preds))
indices = np.random.choice(len(all_preds), sample_size, replace=False)
sampled_preds = np.array(all_preds)[indices]
sampled_targets = np.array(all_targets)[indices]

# Compute confusion matrix
cm = confusion_matrix(sampled_targets, sampled_preds, labels=list(range(21)))
cm_normalized = cm.astype('float') / (cm.sum(axis=1)[:, np.newaxis] + 1e-10)

# Plot
fig, ax = plt.subplots(figsize=(16, 14))
sns.heatmap(cm_normalized, annot=False, fmt='.2f', cmap='Blues', 
            xticklabels=VOC_CLASSES, yticklabels=VOC_CLASSES,
            cbar_kws={'label': 'Normalized Count'}, ax=ax)
ax.set_xlabel('Predicted Class', fontsize=12, fontweight='bold')
ax.set_ylabel('True Class', fontsize=12, fontweight='bold')
ax.set_title('Confusion Matrix (Normalized)', fontsize=14, fontweight='bold')
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()
plt.savefig('/kaggle/working/unet_confusion_matrix.png', dpi=150, bbox_inches='tight')
plt.show()

print("✓ Confusion matrix saved")

In [None]:
def export_predictions(model, val_loader, output_dir='/kaggle/working/predictions'):
    """Export all validation predictions as PNG files"""
    output_dir = Path(output_dir)
    output_dir.mkdir(exist_ok=True)
    
    model.eval()
    count = 0
    
    print(f"Exporting predictions to {output_dir}...")
    
    with torch.no_grad():
        for batch_idx, (images, masks) in enumerate(val_loader):
            images = images.to(device)
            outputs = model(images)
            preds = outputs.argmax(dim=1)
            
            for i in range(preds.shape[0]):
                pred_mask = preds[i].cpu().numpy()
                pred_rgb = class_index_to_rgb(pred_mask)
                pred_img = Image.fromarray(pred_rgb)
                pred_img.save(output_dir / f'pred_{count:04d}.png')
                count += 1
            
            if (batch_idx + 1) % 10 == 0:
                print(f"  Exported {count} predictions...")
    
    print(f"✓ Exported {count} predictions successfully!")
    return count

# Uncomment to export all predictions
# export_predictions(model, val_loader)

print("\n🎉 All cells executed successfully! Your U-Net model is ready to use.")