In [1]:
import os
import random
import numpy as np
import pandas as pd
from collections import defaultdict
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.amp import autocast, GradScaler

import albumentations as A
from albumentations.pytorch import ToTensorV2
import segmentation_models_pytorch as smp
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from tqdm import tqdm

# Optional: for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
set_seed(42)


In [2]:
# --------------------
# Config class
# --------------------
class Config:
    IMG_SIZE = (256, 256)
    NUM_CLASSES = 4
    BATCH_SIZE = 4
    NUM_WORKERS = 0
    EPOCHS = 40
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    LEARNING_RATE = 1e-4
    WEIGHT_DECAY = 1e-5
    OUTPUT_DIR = "outputs"
    CHECKPOINT_DIR = os.path.join(OUTPUT_DIR, "checkpoints")
    TRAIN_IMG_DIR = "data/train_images"
    TRAIN_CSV = "data/train.csv"
    AMP = True
    SEED = 42
    T_MAX = 32
    ETA_MIN = 1e-6
    TARGET_DICE = 0.95

config = Config()
os.makedirs(config.OUTPUT_DIR, exist_ok=True)
os.makedirs(config.CHECKPOINT_DIR, exist_ok=True)

# Utility: RLE decoding, augments
def rle_decode(mask_rle, shape=(256,1600)):
    if pd.isna(mask_rle):
        return np.zeros(shape, dtype=np.uint8)
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0::2], s[1::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape((shape[1], shape[0])).T

def get_train_augmentations():
    return A.Compose([
        A.HorizontalFlip(p=0.5), A.VerticalFlip(p=0.5), A.RandomRotate90(p=0.5),
        A.Transpose(p=0.5), A.ElasticTransform(p=0.5), A.GridDistortion(p=0.2),
        A.HueSaturationValue(p=0.7), A.RandomBrightnessContrast(p=0.5),
        A.Blur(p=0.2), A.CLAHE(p=0.2),
        A.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225)), ToTensorV2()
    ])

def get_val_augmentations():
    return A.Compose([
        A.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225)), ToTensorV2()
    ])


In [3]:
# --------------------
# Dataset
# --------------------
class SteelDefectDataset(Dataset):
    def __init__(self, df, img_dir, img_size, mode='train'):
        self.df = df.copy()
        self.img_dir = img_dir
        self.img_size = img_size
        self.mode = mode
        self.image_ids = self.df['ImageId'].unique()
        self.transform = get_train_augmentations() if mode=='train' else get_val_augmentations()
    def __len__(self): return len(self.image_ids)
    def __getitem__(self, idx):
        image_id = self.image_ids[idx]
        img_path = os.path.join(self.img_dir, image_id)
        image = np.array(Image.open(img_path).convert("RGB").resize(self.img_size))
        mask = np.zeros((*self.img_size, config.NUM_CLASSES), dtype=np.uint8)
        records = self.df[self.df['ImageId'] == image_id]
        for _, row in records.iterrows():
            class_idx = int(row['ClassId']) - 1
            single_mask = rle_decode(row['EncodedPixels'], shape=(256,1600))
            single_mask_resized = np.array(Image.fromarray(single_mask).resize(self.img_size, resample=Image.NEAREST))
            mask[..., class_idx] = single_mask_resized
        auged = self.transform(image=image, mask=mask)
        image = auged['image']
        mask = auged['mask'].permute(2,0,1).float()
        return image, mask

# --------------------
# Model: Improved UNet
# --------------------
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )
    def forward(self, x):
        return self.net(x)

class ImprovedUNet(nn.Module):
    def __init__(self, in_ch=3, out_ch=4, base=64):
        super().__init__()
        self.inc = DoubleConv(in_ch, base)
        self.down1 = nn.Sequential(nn.MaxPool2d(2), DoubleConv(base, base*2))
        self.down2 = nn.Sequential(nn.MaxPool2d(2), DoubleConv(base*2, base*4))
        self.down3 = nn.Sequential(nn.MaxPool2d(2), DoubleConv(base*4, base*8))
        self.down4 = nn.Sequential(nn.MaxPool2d(2), DoubleConv(base*8, base*16))

        self.up1 = nn.ConvTranspose2d(base*16, base*8, 2, stride=2)
        self.conv1 = DoubleConv(base*16, base*8)
        self.up2 = nn.ConvTranspose2d(base*8, base*4, 2, stride=2)
        self.conv2 = DoubleConv(base*8, base*4)
        self.up3 = nn.ConvTranspose2d(base*4, base*2, 2, stride=2)
        self.conv3 = DoubleConv(base*4, base*2)
        self.up4 = nn.ConvTranspose2d(base*2, base, 2, stride=2)
        self.conv4 = DoubleConv(base*2, base)
        self.outc = nn.Conv2d(base, out_ch, 1)
    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5); x = torch.cat([x, x4], dim=1); x = self.conv1(x)
        x = self.up2(x);  x = torch.cat([x, x3], dim=1); x = self.conv2(x)
        x = self.up3(x);  x = torch.cat([x, x2], dim=1); x = self.conv3(x)
        x = self.up4(x);  x = torch.cat([x, x1], dim=1); x = self.conv4(x)
        return self.outc(x)


In [4]:
class TverskyLoss(nn.Module):
    def __init__(self, alpha=0.7, beta=0.3, smooth=1e-6):
        super().__init__()
        self.alpha = alpha
        self.beta = beta
        self.smooth = smooth
    def forward(self, preds, targets):
        preds = torch.sigmoid(preds)
        TP = (preds * targets).sum(dim=(2,3))
        FN = ((1-preds) * targets).sum(dim=(2,3))
        FP = (preds * (1-targets)).sum(dim=(2,3))
        t = (TP + self.smooth) / (TP + self.alpha*FP + self.beta*FN + self.smooth)
        return 1.0 - t.mean()

class CombinedLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.dice = smp.losses.DiceLoss(mode='multilabel', from_logits=True)
        self.tversky = TverskyLoss(alpha=0.7, beta=0.3)
    def forward(self, preds, targets):
        return 0.5*self.dice(preds, targets) + 0.5*self.tversky(preds, targets)

def dice_score(pred, target, smooth=1e-6):
    pred_bin = (torch.sigmoid(pred) > 0.5).float()
    intersection = (pred_bin * target).sum()
    return (2. * intersection + smooth) / (pred_bin.sum() + target.sum() + smooth)

def safe_partial_load(model, ckpt_path, device, load_optimizer=False, optimizer=None):
    ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
    src = ckpt.get('model_state_dict', ckpt)
    dst = model.state_dict()
    compatible = {k: v for k, v in src.items() if k in dst and v.shape == dst[k].shape}
    dst.update(compatible)
    model.load_state_dict(dst, strict=False)
    print(f"[safe_partial_load] loaded tensors: {len(compatible)} | skipped: {len(src) - len(compatible)}")
    if load_optimizer and optimizer is not None and 'optimizer_state_dict' in ckpt:
        try:
            optimizer.load_state_dict(ckpt['optimizer_state_dict'])
        except Exception as e:
            print(f"[safe_partial_load] skip optimizer state: {e}")
    histories = {
        'train_history': ckpt.get('train_history', {}),
        'val_history': ckpt.get('val_history', {}),
        'metrics': ckpt.get('metrics', {}),
        'epoch': ckpt.get('epoch', 0),
    }
    return histories


In [5]:
class Trainer:
    def __init__(self, model, train_loader, val_loader, criterion, optimizer, scheduler, device, cfg):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.criterion = criterion
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.device = device
        self.cfg = cfg
        self.scaler = GradScaler('cuda', enabled=cfg.AMP)
        self.train_history = defaultdict(list)
        self.val_history = defaultdict(list)
        self.best_val_dice = 0.0
        self.best_val_loss = float('inf')

    def train_epoch(self, epoch):
        print(f"\nEpoch {epoch+1}/{self.cfg.EPOCHS} [Train]", flush=True)
        self.model.train()
        losses, dices = [], []
        pbar = tqdm(self.train_loader, desc=f"Train {epoch+1}", unit="batch", leave=False)
        for images, masks in pbar:
            images = images.to(self.device, non_blocking=True)
            masks = masks.to(self.device, non_blocking=True)
            self.optimizer.zero_grad(set_to_none=True)
            with autocast('cuda', enabled=self.cfg.AMP):
                outputs = self.model(images)
                loss = self.criterion(outputs, masks)
            self.scaler.scale(loss).backward()
            self.scaler.step(self.optimizer)
            self.scaler.update()
            losses.append(loss.item())
            with torch.no_grad():
                d = dice_score(outputs, masks).item()
                dices.append(d)
            if len(losses) % 10 == 0:
                pbar.set_postfix(loss=f"{np.mean(losses[-10:]):.4f}", dice=f"{np.mean(dices[-10:]):.4f}")
        mean_loss, mean_dice = float(np.mean(losses)), float(np.mean(dices))
        self.train_history['loss'].append(mean_loss)
        self.train_history['dice'].append(mean_dice)
        print(f"Train  -> loss {mean_loss:.4f} | dice {mean_dice:.4f}", flush=True)
        return mean_loss, {'dice': mean_dice}

    def validate_epoch(self, epoch):
        print(f"Epoch {epoch+1}/{self.cfg.EPOCHS} [Val]", flush=True)
        self.model.eval()
        losses, dices = [], []
        with torch.no_grad():
            pbar = tqdm(self.val_loader, desc=f"Val {epoch+1}", unit="batch", leave=False)
            for images, masks in pbar:
                images = images.to(self.device, non_blocking=True)
                masks = masks.to(self.device, non_blocking=True)
                outputs = self.model(images)
                loss = self.criterion(outputs, masks)
                losses.append(loss.item())
                dices.append(dice_score(outputs, masks).item())
                if len(losses) % 10 == 0:
                    pbar.set_postfix(loss=f"{np.mean(losses[-10:]):.4f}", dice=f"{np.mean(dices[-10:]):.4f}")
        mean_loss, mean_dice = float(np.mean(losses)), float(np.mean(dices))
        self.val_history['loss'].append(mean_loss)
        self.val_history['dice'].append(mean_dice)
        print(f"Val    -> loss {mean_loss:.4f} | dice {mean_dice:.4f}\n", flush=True)
        return mean_loss, {'dice': mean_dice}

    def save_checkpoint(self, epoch, metrics, filename):
        path = os.path.join(self.cfg.CHECKPOINT_DIR, filename)
        payload = {
            'epoch': epoch + 1,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'train_history': dict(self.train_history),
            'val_history': dict(self.val_history),
            'metrics': metrics,
        }
        torch.save(payload, path)

class Visualizer:
    def __init__(self, cfg):
        self.cfg = cfg
    def plot_training_history(self, train_hist, val_hist):
        plt.figure(figsize=(12,5))
        plt.subplot(1,2,1)
        plt.plot(train_hist.get('loss', []), label='Train Loss')
        plt.plot(val_hist.get('loss', []), label='Val Loss')
        plt.legend(); plt.title('Loss'); plt.grid()
        plt.subplot(1,2,2)
        plt.plot(train_hist.get('dice', []), label='Train Dice')
        plt.plot(val_hist.get('dice', []), label='Val Dice')
        plt.legend(); plt.title('Dice'); plt.grid()
        plt.tight_layout()
        plt.savefig(os.path.join(self.cfg.OUTPUT_DIR, "history.png"))
        plt.close()
    def visualize_predictions(self, model, val_dataset, device, num_samples=6):
        model.eval()
        os.makedirs(os.path.join(self.cfg.OUTPUT_DIR, "viz"), exist_ok=True)
        n = min(num_samples, len(val_dataset))
        if n == 0: return
        idxs = np.random.choice(len(val_dataset), size=n, replace=False)
        with torch.no_grad():
            for i, idx in enumerate(idxs):
                image, mask = val_dataset[idx]
                image_t = image.unsqueeze(0).to(device)
                logits = model(image_t)
                probs = torch.sigmoid(logits)[0].cpu().numpy()
                pred_mask = (probs > 0.5).astype(np.uint8)
                fig, axes = plt.subplots(1, 1 + config.NUM_CLASSES, figsize=(15,4))
                axes[0].imshow(image.permute(1,2,0).cpu().numpy())
                axes[0].set_title("Image"); axes[0].axis('off')
                for c in range(config.NUM_CLASSES):
                    axes[c+1].imshow(mask[c].cpu().numpy(), alpha=0.5, cmap='Reds')
                    axes[c+1].imshow(pred_mask[c], alpha=0.5, cmap='Greens')
                    axes[c+1].set_title(f"Class {c+1}"); axes[c+1].axis('off')
                plt.tight_layout()
                plt.savefig(os.path.join(self.cfg.OUTPUT_DIR, "viz", f"sample_{i}.png"))
                plt.close()


In [None]:
def main():
    train_df = pd.read_csv(config.TRAIN_CSV)
    image_ids = train_df['ImageId'].unique()
    train_ids, val_ids = train_test_split(image_ids, test_size=0.15, random_state=42, shuffle=True)
    train_data = train_df[train_df['ImageId'].isin(train_ids)]
    val_data = train_df[train_df['ImageId'].isin(val_ids)]
    train_dataset = SteelDefectDataset(train_data, config.TRAIN_IMG_DIR, config.IMG_SIZE, mode='train')
    val_dataset = SteelDefectDataset(val_data, config.TRAIN_IMG_DIR, config.IMG_SIZE, mode='val')
    train_loader = DataLoader(train_dataset, batch_size=config.BATCH_SIZE, shuffle=True, num_workers=config.NUM_WORKERS, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=config.BATCH_SIZE, shuffle=False, num_workers=config.NUM_WORKERS, pin_memory=True)
    model = ImprovedUNet(in_ch=3, out_ch=config.NUM_CLASSES).to(config.DEVICE)
    criterion = CombinedLoss()
    optimizer = optim.AdamW(model.parameters(), lr=config.LEARNING_RATE, weight_decay=config.WEIGHT_DECAY)
    scheduler = CosineAnnealingLR(optimizer, T_max=config.T_MAX, eta_min=config.ETA_MIN)
    checkpoint_path = os.path.join(config.CHECKPOINT_DIR, 'best_model_dice.pth')
    start_epoch = 0
    trainer = Trainer(model, train_loader, val_loader, criterion, optimizer, scheduler, config.DEVICE, config)
    if os.path.exists(checkpoint_path):
        print("Loading checkpoint from:", checkpoint_path)
        hist = safe_partial_load(model, checkpoint_path, config.DEVICE, load_optimizer=True, optimizer=optimizer)
        trainer.train_history = defaultdict(list, hist['train_history'])
        trainer.val_history = defaultdict(list, hist['val_history'])
        metrics = hist['metrics']
        trainer.best_val_dice = metrics.get('dice', 0.0)
        trainer.best_val_loss = metrics.get('loss', float('inf'))
        start_epoch = 8
    max_epochs = config.EPOCHS
    for epoch in range(start_epoch, max_epochs):
        tr_loss, tr_metrics = trainer.train_epoch(epoch)
        val_loss, val_metrics = trainer.validate_epoch(epoch)
        scheduler.step()
        print(f"Epoch {epoch+1}: Train Dice {tr_metrics['dice']:.4f}, Val Dice {val_metrics['dice']:.4f}", flush=True)
        if val_metrics['dice'] > trainer.best_val_dice:
            trainer.best_val_dice = val_metrics['dice']
            trainer.save_checkpoint(epoch, val_metrics, 'best_model_dice.pth')
        if val_loss < trainer.best_val_loss:
            trainer.best_val_loss = val_loss
            trainer.save_checkpoint(epoch, val_metrics, 'best_model_loss.pth')
        if val_metrics['dice'] >= config.TARGET_DICE:
            print(f"✓ Early stopping! Target Dice ({config.TARGET_DICE*100:.1f}%) reached at epoch {epoch+1}!")
            break
    vis = Visualizer(config)
    vis.plot_training_history(trainer.train_history, trainer.val_history)
    vis.visualize_predictions(model, val_dataset, config.DEVICE, num_samples=6)
    print("✓ Full resumed pipeline completed, results in:", config.OUTPUT_DIR)
main()


Loading checkpoint from: outputs\checkpoints\best_model_dice.pth
[safe_partial_load] loaded tensors: 0 | skipped: 220
[safe_partial_load] skip optimizer state: loaded state dict contains a parameter group that doesn't match the size of optimizer's group

Epoch 9/40 [Train]


                                                                                                                       

Train  -> loss 0.6898 | dice 0.1149
Epoch 9/40 [Val]


                                                                                                                       

Val    -> loss 0.6447 | dice 0.1981

Epoch 9: Train Dice 0.1149, Val Dice 0.1981





Epoch 10/40 [Train]


                                                                                                                       

Train  -> loss 0.6435 | dice 0.1853
Epoch 10/40 [Val]


                                                                                                                       

Val    -> loss 0.6081 | dice 0.2404

Epoch 10: Train Dice 0.1853, Val Dice 0.2404





Epoch 11/40 [Train]


                                                                                                                       

Train  -> loss 0.6215 | dice 0.2292
Epoch 11/40 [Val]


                                                                                                                       

Val    -> loss 0.5905 | dice 0.2869

Epoch 11: Train Dice 0.2292, Val Dice 0.2869





Epoch 12/40 [Train]


                                                                                                                       

Train  -> loss 0.6153 | dice 0.2558
Epoch 12/40 [Val]


                                                                                                                       

Val    -> loss 0.5931 | dice 0.2931

Epoch 12: Train Dice 0.2558, Val Dice 0.2931

Epoch 13/40 [Train]


                                                                                                                       

Train  -> loss 0.6096 | dice 0.2945
Epoch 13/40 [Val]


                                                                                                                       

Val    -> loss 0.5864 | dice 0.3372

Epoch 13: Train Dice 0.2945, Val Dice 0.3372





Epoch 14/40 [Train]


                                                                                                                       

Train  -> loss 0.6012 | dice 0.3409
Epoch 14/40 [Val]


                                                                                                                       

Val    -> loss 0.5823 | dice 0.4092

Epoch 14: Train Dice 0.3409, Val Dice 0.4092





Epoch 15/40 [Train]


                                                                                                                       

Train  -> loss 0.5954 | dice 0.3682
Epoch 15/40 [Val]


                                                                                                                       

Val    -> loss 0.5711 | dice 0.4233

Epoch 15: Train Dice 0.3682, Val Dice 0.4233





Epoch 16/40 [Train]


                                                                                                                       

Train  -> loss 0.5921 | dice 0.3840
Epoch 16/40 [Val]


                                                                                                                       

Val    -> loss 0.5703 | dice 0.4169

Epoch 16: Train Dice 0.3840, Val Dice 0.4169





Epoch 17/40 [Train]


                                                                                                                       

Train  -> loss 0.5881 | dice 0.3939
Epoch 17/40 [Val]


                                                                                                                       

Val    -> loss 0.5723 | dice 0.4222

Epoch 17: Train Dice 0.3939, Val Dice 0.4222

Epoch 18/40 [Train]


                                                                                                                       

Train  -> loss 0.5850 | dice 0.3971
Epoch 18/40 [Val]


                                                                                                                       

Val    -> loss 0.5671 | dice 0.4391

Epoch 18: Train Dice 0.3971, Val Dice 0.4391





Epoch 19/40 [Train]


                                                                                                                       

Train  -> loss 0.5816 | dice 0.4093
Epoch 19/40 [Val]


                                                                                                                       

Val    -> loss 0.5624 | dice 0.4773

Epoch 19: Train Dice 0.4093, Val Dice 0.4773





Epoch 20/40 [Train]


                                                                                                                       

Train  -> loss 0.5776 | dice 0.4218
Epoch 20/40 [Val]


                                                                                                                       

Val    -> loss 0.5629 | dice 0.4597

Epoch 20: Train Dice 0.4218, Val Dice 0.4597

Epoch 21/40 [Train]


                                                                                                                       

Train  -> loss 0.5750 | dice 0.4269
Epoch 21/40 [Val]


                                                                                                                       

Val    -> loss 0.5631 | dice 0.4655

Epoch 21: Train Dice 0.4269, Val Dice 0.4655

Epoch 22/40 [Train]


                                                                                                                       

Train  -> loss 0.5741 | dice 0.4319
Epoch 22/40 [Val]


                                                                                                                       

Val    -> loss 0.5595 | dice 0.4700

Epoch 22: Train Dice 0.4319, Val Dice 0.4700





Epoch 23/40 [Train]


                                                                                                                       

Train  -> loss 0.5712 | dice 0.4435
Epoch 23/40 [Val]


                                                                                                                       

Val    -> loss 0.5502 | dice 0.4930

Epoch 23: Train Dice 0.4435, Val Dice 0.4930





Epoch 24/40 [Train]


                                                                                                                       

Train  -> loss 0.5683 | dice 0.4487
Epoch 24/40 [Val]


                                                                                                                       

Val    -> loss 0.5503 | dice 0.5169

Epoch 24: Train Dice 0.4487, Val Dice 0.5169

Epoch 25/40 [Train]


                                                                                                                       

Train  -> loss 0.5668 | dice 0.4527
Epoch 25/40 [Val]


                                                                                                                       

Val    -> loss 0.5499 | dice 0.5110

Epoch 25: Train Dice 0.4527, Val Dice 0.5110





Epoch 26/40 [Train]


                                                                                                                       

Train  -> loss 0.5644 | dice 0.4610
Epoch 26/40 [Val]


                                                                                                                       

Val    -> loss 0.5482 | dice 0.5075

Epoch 26: Train Dice 0.4610, Val Dice 0.5075





Epoch 27/40 [Train]


                                                                                                                       

Train  -> loss 0.5627 | dice 0.4677
Epoch 27/40 [Val]


                                                                                                                       

Val    -> loss 0.5490 | dice 0.5101

Epoch 27: Train Dice 0.4677, Val Dice 0.5101

Epoch 28/40 [Train]


                                                                                                                       

Train  -> loss 0.5592 | dice 0.4726
Epoch 28/40 [Val]


                                                                                                                       

Val    -> loss 0.5446 | dice 0.5415

Epoch 28: Train Dice 0.4726, Val Dice 0.5415





Epoch 29/40 [Train]


                                                                                                                       

Train  -> loss 0.5592 | dice 0.4806
Epoch 29/40 [Val]


                                                                                                                       

Val    -> loss 0.5431 | dice 0.5313

Epoch 29: Train Dice 0.4806, Val Dice 0.5313





Epoch 30/40 [Train]


                                                                                                                       

Train  -> loss 0.5553 | dice 0.4851
Epoch 30/40 [Val]


                                                                                                                       

Val    -> loss 0.5425 | dice 0.5339

Epoch 30: Train Dice 0.4851, Val Dice 0.5339





Epoch 31/40 [Train]


                                                                                                                       

Train  -> loss 0.5532 | dice 0.4889
Epoch 31/40 [Val]


                                                                                                                       

Val    -> loss 0.5445 | dice 0.5142

Epoch 31: Train Dice 0.4889, Val Dice 0.5142

Epoch 32/40 [Train]


                                                                                                                       

Train  -> loss 0.5530 | dice 0.4966
Epoch 32/40 [Val]


                                                                                                                       

Val    -> loss 0.5406 | dice 0.5384

Epoch 32: Train Dice 0.4966, Val Dice 0.5384





Epoch 33/40 [Train]


                                                                                                                       

Train  -> loss 0.5504 | dice 0.5016
Epoch 33/40 [Val]


                                                                                                                       

Val    -> loss 0.5390 | dice 0.5442

Epoch 33: Train Dice 0.5016, Val Dice 0.5442





Epoch 34/40 [Train]


                                                                                                                       

Train  -> loss 0.5494 | dice 0.5108
Epoch 34/40 [Val]


                                                                                                                       

Val    -> loss 0.5398 | dice 0.5381

Epoch 34: Train Dice 0.5108, Val Dice 0.5381

Epoch 35/40 [Train]


                                                                                                                       

Train  -> loss 0.5479 | dice 0.5093
Epoch 35/40 [Val]


                                                                                                                       

Val    -> loss 0.5362 | dice 0.5551

Epoch 35: Train Dice 0.5093, Val Dice 0.5551





Epoch 36/40 [Train]


                                                                                                                       

Train  -> loss 0.5453 | dice 0.5178
Epoch 36/40 [Val]


                                                                                                                       

Val    -> loss 0.5363 | dice 0.5489

Epoch 36: Train Dice 0.5178, Val Dice 0.5489

Epoch 37/40 [Train]


                                                                                                                       

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

# --- Save training curves ---
def plot_and_save_history(train_hist, val_hist, out_dir):
    plt.figure(figsize=(12,5))
    plt.subplot(1,2,1)
    plt.plot(train_hist.get('loss', []), label='Train Loss')
    plt.plot(val_hist.get('loss', []), label='Val Loss')
    plt.legend(); plt.title('Loss'); plt.grid()
    plt.subplot(1,2,2)
    plt.plot(train_hist.get('dice', []), label='Train Dice')
    plt.plot(val_hist.get('dice', []), label='Val Dice')
    plt.legend(); plt.title('Dice'); plt.grid()
    plt.tight_layout()
    out_path = os.path.join(out_dir, "history.png")
    plt.savefig(out_path)
    plt.show()

    return out_path

# --- Save final metrics as CSV table ---
def save_metrics_table(trainer, out_dir):
    last_tr = {'loss': trainer.train_history['loss'][-1], 'dice': trainer.train_history['dice'][-1]}
    last_val = {'loss': trainer.val_history['loss'][-1], 'dice': trainer.val_history['dice'][-1]}
    df = pd.DataFrame([
        {'Split': 'Train', **last_tr},
        {'Split': 'Validation', **last_val},
        {'Split': 'Best Val', 'loss': trainer.best_val_loss, 'dice': trainer.best_val_dice}
    ])
    metrics_path = os.path.join(out_dir, "metrics_table.csv")
    df.to_csv(metrics_path, index=False)
    print(f"Metrics table saved to {metrics_path}")
    display(df)
    return metrics_path

# --- Call these at end of training ---
plot_and_save_history(trainer.train_history, trainer.val_history, config.OUTPUT_DIR)
save_metrics_table(trainer, config.OUTPUT_DIR)


In [None]:
def visualize_sample_predictions(model, val_dataset, out_dir, device, num_samples=6):
    import matplotlib.pyplot as plt
    model.eval()
    os.makedirs(out_dir, exist_ok=True)
    idxs = np.random.choice(len(val_dataset), size=min(num_samples, len(val_dataset)), replace=False)
    with torch.no_grad():
        for i, idx in enumerate(idxs):
            image, mask = val_dataset[idx]
            image_t = image.unsqueeze(0).to(device)
            logits = model(image_t)
            probs = torch.sigmoid(logits)[0].cpu().numpy()
            pred_mask = (probs > 0.5).astype(np.uint8)
            fig, axes = plt.subplots(1, 1 + config.NUM_CLASSES, figsize=(15,4))
            axes[0].imshow(image.permute(1,2,0).cpu().numpy())
            axes[0].set_title("Image"); axes[0].axis('off')
            for c in range(config.NUM_CLASSES):
                axes[c+1].imshow(mask[c].cpu().numpy(), alpha=0.5, cmap='Reds')
                axes[c+1].imshow(pred_mask[c], alpha=0.5, cmap='Greens')
                axes[c+1].set_title(f"Class {c+1}"); axes[c+1].axis('off')
            plt.tight_layout()
            img_path = os.path.join(out_dir, f"sample_{i}.png")
            plt.savefig(img_path)
            plt.close()
    print(f"{len(idxs)} validation prediction overlays saved to {out_dir}")

# Usage:
visualize_sample_predictions(model, val_dataset, os.path.join(config.OUTPUT_DIR, "viz"), config.DEVICE)
