# SegFormer Change Detection Training Notebook
This notebook trains a SegFormer model for change detection using A/B/label folders for train, val, and test. It implements SegFormer from scratch with the same aggressive dropout strategy, training approach, and evaluation metrics as the U-Net++ implementation.

## Assignment Compliance (Segmentation)
- Problem: Change detection (binary segmentation of change mask)
- Model: SegFormer (Vision Transformer-based segmentation model from scratch)
- Epochs: Min 50 with early stopping (patience 10)
- Data: Using existing train / val / test folders exactly as provided (no re-splitting enforced).
- Metrics tracked: IoU, Dice, Precision, Recall, F1, Accuracy, Loss + confusion matrix (pixel-wise)
- Outputs: Metric plots, sample predictions, parameter count, saved best weights.
- Saved artifacts: best_segformer.pth, training_history_segformer.csv, test_metrics_segformer.csv, confusion_matrix_segformer.txt, prediction PNGs.
- Dropout: Very aggressive dropout strategy matching U-Net++ configuration

In [8]:
# Install all required packages
!pip install torch torchvision timm albumentations scikit-learn pandas tqdm matplotlib seaborn einops --quiet

In [9]:
# Imports & Setup for custom SegFormer training (from scratch implementation)
import os, random, math
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from einops import rearrange

# Device & Reproducibility
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
print(f"Using device: {DEVICE}")

# Loss components (Dice + BCE) - Same as U-Net++
class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-6):
        super().__init__()
        self.smooth = smooth
    def forward(self, preds, targets):
        # preds: probabilities after sigmoid, targets: binary
        preds = preds.contiguous()
        targets = targets.contiguous()
        intersection = (preds * targets).sum(dim=(2,3))
        denom = preds.sum(dim=(2,3)) + targets.sum(dim=(2,3))
        dice = (2 * intersection + self.smooth) / (denom + self.smooth)
        return 1 - dice.mean()

def combined_loss(logits, targets, bce_w=0.6, dice_w=0.4):
    bce = nn.BCEWithLogitsLoss()(logits, targets)
    probs = torch.sigmoid(logits)
    dloss = DiceLoss()(probs, targets)
    return bce_w * bce + dice_w * dloss

@torch.no_grad()
def batch_metrics(logits, targets, thresh=0.3):  # Same 0.3 threshold as U-Net++
    probs = torch.sigmoid(logits)
    preds = (probs >= thresh).float()
    p = preds.view(-1).cpu().numpy()
    t = targets.view(-1).cpu().numpy()
    # Confusion components
    cm = confusion_matrix(t, p, labels=[0,1]) if (t.sum()>0 or p.sum()>0) else np.array([[len(t),0],[0,0]])
    if cm.shape == (2,2):
        tn, fp, fn, tp = cm.ravel()
    else:  # degenerate
        tn = fp = fn = tp = 0
    eps = 1e-8
    iou = tp / (tp + fp + fn + eps)
    dice = (2*tp) / (2*tp + fp + fn + eps)
    precision = tp / (tp + fp + eps) if (tp+fp)>0 else 0.0
    recall = tp / (tp + fn + eps) if (tp+fn)>0 else 0.0
    f1 = 2*precision*recall/(precision+recall+eps) if (precision+recall)>0 else 0.0
    acc = (tp + tn) / (tp + tn + fp + fn + eps)
    return dict(tp=int(tp), fp=int(fp), fn=int(fn), tn=int(tn), iou=float(iou), dice=float(dice), precision=float(precision), recall=float(recall), f1=float(f1), acc=float(acc))

class EarlyStopping:
    def __init__(self, patience=10, min_delta=1e-4, restore_best=True, min_epochs=10):  # Same patience as U-Net++
        self.patience = patience
        self.min_delta = min_delta
        self.restore_best = restore_best
        self.best_loss = None
        self.counter = 0
        self.best_state = None
        self.min_epochs = min_epochs
    def __call__(self, epoch, current_loss, model):
        if self.best_loss is None or (self.best_loss - current_loss) > self.min_delta:
            self.best_loss = current_loss
            self.counter = 0
            if self.restore_best:
                self.best_state = {k: v.detach().cpu().clone() for k,v in model.state_dict().items()}
        else:
            self.counter += 1
        if epoch+1 < self.min_epochs:
            return False
        if self.counter >= self.patience:
            if self.restore_best and self.best_state is not None:
                model.load_state_dict(self.best_state)
            return True
        return False

Using device: cuda


In [10]:
# Dataset (same as U-Net++ implementation)
DATA_ROOT = '/kaggle/input/earthquakedatasetnew/earthquakeDataset'  # Adjust to local path as needed
IMG_SIZE = (256, 256)
TRAIN_BATCH = 6  # Same batch sizes as U-Net++
VAL_BATCH = 2
TEST_BATCH = 1

transform_img = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])

transform_mask = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.ToTensor()
])

class ChangeDataset(Dataset):
    def __init__(self, root, split='train'):
        if split=='train':
            a_dir = os.path.join(root,'train','A_train_aug')
            b_dir = os.path.join(root,'train','B_train_aug')
            m_dir = os.path.join(root,'train','label_train_aug')
        elif split=='val':
            a_dir = os.path.join(root,'val','A_val')
            b_dir = os.path.join(root,'val','B_val')
            m_dir = os.path.join(root,'val','label_val')
        else:
            a_dir = os.path.join(root,'test','A_test')
            b_dir = os.path.join(root,'test','B_test')
            m_dir = os.path.join(root,'test','label_test')
        self.a_files = sorted([f for f in os.listdir(a_dir) if f.endswith('.png')])
        self.a_dir, self.b_dir, self.m_dir = a_dir, b_dir, m_dir
    def __len__(self): return len(self.a_files)
    def __getitem__(self, idx):
        name = self.a_files[idx]
        a = Image.open(os.path.join(self.a_dir,name)).convert('RGB')
        b = Image.open(os.path.join(self.b_dir,name)).convert('RGB')
        m = Image.open(os.path.join(self.m_dir,name)).convert('L')
        a = transform_img(a)
        b = transform_img(b)
        m = transform_mask(m)
        m = (m>0.5).float()
        x = torch.cat([a,b], dim=0)  # 6 channels (A+B concatenated)
        return x, m

train_ds = ChangeDataset(DATA_ROOT,'train')
val_ds = ChangeDataset(DATA_ROOT,'val')
test_ds = ChangeDataset(DATA_ROOT,'test')

train_loader = DataLoader(train_ds, batch_size=TRAIN_BATCH, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=VAL_BATCH, shuffle=False, num_workers=1)
test_loader = DataLoader(test_ds, batch_size=TEST_BATCH, shuffle=False, num_workers=1)

print(f"Train {len(train_ds)} | Val {len(val_ds)} | Test {len(test_ds)}")

Train 2268 | Val 189 | Test 189


In [11]:
# Custom SegFormer implementation from scratch with aggressive dropout
class PatchEmbed(nn.Module):
    """Overlap Patch Embedding"""
    def __init__(self, img_size=256, patch_size=7, stride=4, in_chans=6, embed_dim=64, dropout_rate=0.0):
        super().__init__()
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, padding=patch_size//2)
        self.norm = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout_rate) if dropout_rate > 0 else nn.Identity()
        
    def forward(self, x):
        x = self.proj(x)  # (B, C, H, W)
        B, C, H, W = x.shape
        x = x.flatten(2).transpose(1, 2)  # (B, H*W, C)
        x = self.norm(x)
        x = self.dropout(x)
        return x, H, W

class EfficientSelfAttention(nn.Module):
    """Efficient Self-Attention with reduction ratio"""
    def __init__(self, dim, num_heads=8, qkv_bias=False, sr_ratio=1, dropout_rate=0.0, attention_dropout=0.0):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5
        self.sr_ratio = sr_ratio
        
        self.q = nn.Linear(dim, dim, bias=qkv_bias)
        self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
        
        if sr_ratio > 1:
            self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
            self.norm = nn.LayerNorm(dim)
        
        self.attn_drop = nn.Dropout(attention_dropout)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(dropout_rate)
        
    def forward(self, x, H, W):
        B, N, C = x.shape
        q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
        
        if self.sr_ratio > 1:
            x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
            x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
            x_ = self.norm(x_)
            kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        else:
            kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        k, v = kv[0], kv[1]
        
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class MixFFN(nn.Module):
    """Mix-FFN with depthwise convolution"""
    def __init__(self, dim, hidden_dim, dropout_rate=0.0):
        super().__init__()
        self.fc1 = nn.Linear(dim, hidden_dim)
        self.dwconv = nn.Conv2d(hidden_dim, hidden_dim, 3, 1, 1, groups=hidden_dim)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_dim, dim)
        self.drop = nn.Dropout(dropout_rate)
        
    def forward(self, x, H, W):
        x = self.fc1(x)
        B, N, C = x.shape
        x = x.permute(0, 2, 1).reshape(B, C, H, W)
        x = self.dwconv(x)
        x = x.flatten(2).permute(0, 2, 1)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

class TransformerBlock(nn.Module):
    """Transformer Block with efficient attention and Mix-FFN"""
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, sr_ratio=1, 
                 dropout_rate=0.0, attention_dropout=0.0, layer_dropout=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = EfficientSelfAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, 
                                         sr_ratio=sr_ratio, dropout_rate=dropout_rate, 
                                         attention_dropout=attention_dropout)
        self.layer_drop1 = nn.Dropout(layer_dropout) if layer_dropout > 0 else nn.Identity()
        
        self.norm2 = nn.LayerNorm(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = MixFFN(dim, mlp_hidden_dim, dropout_rate=dropout_rate)
        self.layer_drop2 = nn.Dropout(layer_dropout) if layer_dropout > 0 else nn.Identity()
        
    def forward(self, x, H, W):
        x = x + self.layer_drop1(self.attn(self.norm1(x), H, W))
        x = x + self.layer_drop2(self.mlp(self.norm2(x), H, W))
        return x

class MixTransformerStage(nn.Module):
    """Mix Transformer Stage"""
    def __init__(self, img_size, patch_size, stride, in_chans, embed_dim, depth, num_heads, 
                 mlp_ratio, qkv_bias, sr_ratio, dropout_rates):
        super().__init__()
        
        self.patch_embed = PatchEmbed(img_size, patch_size, stride, in_chans, embed_dim, 
                                    dropout_rates.get('patch_embed', 0.0))
        
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, mlp_ratio, qkv_bias, sr_ratio, 
                           dropout_rates.get('transformer', 0.0),
                           dropout_rates.get('attention', 0.0),
                           dropout_rates.get('layer', 0.0))
            for _ in range(depth)
        ])
        
        self.norm = nn.LayerNorm(embed_dim)
        
    def forward(self, x):
        x, H, W = self.patch_embed(x)
        for blk in self.blocks:
            x = blk(x, H, W)
        x = self.norm(x)
        x = x.reshape(-1, H, W, x.size(-1)).permute(0, 3, 1, 2)
        return x

class MLPDecoder(nn.Module):
    """Lightweight MLP Decoder"""
    def __init__(self, in_channels, num_classes=1, dropout_rate=0.0):
        super().__init__()
        self.decoder = nn.Sequential(
            nn.Conv2d(sum(in_channels), 256, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Dropout2d(dropout_rate),
            nn.Conv2d(256, num_classes, 1)
        )
        
    def forward(self, features):
        # Upsample all features to the same size (1/4 of input)
        target_size = features[0].shape[2:]
        upsampled = []
        for feat in features:
            if feat.shape[2:] != target_size:
                feat = F.interpolate(feat, size=target_size, mode='bilinear', align_corners=False)
            upsampled.append(feat)
        
        # Concatenate and decode
        x = torch.cat(upsampled, dim=1)
        x = self.decoder(x)
        return x

class SegFormer(nn.Module):
    """SegFormer model from scratch with aggressive dropout"""
    def __init__(self, img_size=256, in_chans=6, num_classes=1, 
                 embed_dims=[64, 128, 256, 512], depths=[3, 4, 6, 3], num_heads=[1, 2, 4, 8],
                 mlp_ratios=[4, 4, 4, 4], qkv_bias=True, sr_ratios=[8, 4, 2, 1], dropout_rates=None):
        super().__init__()
        
        # Same aggressive dropout as U-Net++
        if dropout_rates is None:
            dropout_rates = {
                'patch_embed': [0.0, 0.15, 0.25, 0.4],     # Increasing dropout
                'transformer': [0.0, 0.2, 0.35, 0.5],     # Very high dropout in deeper stages
                'attention': [0.0, 0.15, 0.3, 0.45],       # Attention dropout
                'layer': [0.0, 0.1, 0.2, 0.35],           # Layer dropout (stochastic depth)
                'decoder': 0.4                             # High dropout in decoder
            }
        
        # Multi-stage encoder
        self.stages = nn.ModuleList()
        patch_sizes = [7, 3, 3, 3]
        strides = [4, 2, 2, 2]
        
        for i in range(len(embed_dims)):
            stage_dropout = {
                'patch_embed': dropout_rates['patch_embed'][i],
                'transformer': dropout_rates['transformer'][i],
                'attention': dropout_rates['attention'][i],
                'layer': dropout_rates['layer'][i]
            }
            
            stage = MixTransformerStage(
                img_size=img_size // (4 * 2**i) if i > 0 else img_size,
                patch_size=patch_sizes[i],
                stride=strides[i],
                in_chans=in_chans if i == 0 else embed_dims[i-1],
                embed_dim=embed_dims[i],
                depth=depths[i],
                num_heads=num_heads[i],
                mlp_ratio=mlp_ratios[i],
                qkv_bias=qkv_bias,
                sr_ratio=sr_ratios[i],
                dropout_rates=stage_dropout
            )
            self.stages.append(stage)
        
        # Lightweight MLP decoder with dropout
        self.decoder = MLPDecoder(embed_dims, num_classes, dropout_rates['decoder'])
        
    def forward(self, x):
        # Multi-scale feature extraction
        features = []
        for stage in self.stages:
            x = stage(x)
            features.append(x)
        
        # Decode features
        x = self.decoder(features)
        
        # Upsample to input resolution
        x = F.interpolate(x, scale_factor=4, mode='bilinear', align_corners=False)
        
        return x

# Very aggressive dropout configuration (matching U-Net++)
very_aggressive_dropout = {
    'patch_embed': [0.0, 0.15, 0.25, 0.4],     # Increasing dropout in deeper stages
    'transformer': [0.0, 0.2, 0.4, 0.6],       # Very high dropout
    'attention': [0.0, 0.15, 0.3, 0.5],        # High attention dropout
    'layer': [0.0, 0.1, 0.25, 0.4],           # High stochastic depth
    'decoder': 0.4                             # High decoder dropout
}

# Create SegFormer model with aggressive dropout
model = SegFormer(
    img_size=256, 
    in_chans=6, 
    num_classes=1,
    embed_dims=[64, 128, 256, 512],  # Similar capacity to U-Net++
    depths=[3, 4, 6, 3],             # Reasonable depth
    num_heads=[1, 2, 4, 8],
    dropout_rates=very_aggressive_dropout
).to(DEVICE)

print(f"Model params: {sum(p.numel() for p in model.parameters()):,}")

# Same optimizer settings as U-Net++
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True, min_lr=1e-6)
early_stop = EarlyStopping(patience=10, min_delta=1e-4, min_epochs=10)  # Same as U-Net++

print("SegFormer model created with aggressive dropout strategy!")

Model params: 20,522,817
SegFormer model created with aggressive dropout strategy!


In [None]:
# Training Loop - Same configuration as U-Net++
EPOCHS = 200
history = []

for epoch in range(EPOCHS):
    model.train()  # Important: enables dropout
    train_loss = 0.0
    for xb, yb in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS} Train", leave=False):
        xb, yb = xb.to(DEVICE), yb.to(DEVICE)
        optimizer.zero_grad()
        logits = model(xb)
        loss = combined_loss(logits, yb)
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)  # Gradient clipping
        optimizer.step()
        train_loss += loss.item() * xb.size(0)
    train_loss /= len(train_loader.dataset)

    # Validation
    model.eval()  # Important: disables dropout
    val_loss = 0.0
    agg = dict(tp=0,fp=0,fn=0,tn=0)
    with torch.no_grad():
        for xb, yb in tqdm(val_loader, desc=f"Epoch {epoch+1}/{EPOCHS} Val", leave=False):
            xb, yb = xb.to(DEVICE), yb.to(DEVICE)
            logits = model(xb)
            loss = combined_loss(logits, yb)
            val_loss += loss.item() * xb.size(0)
            mets = batch_metrics(logits, yb, thresh=0.3)  # Same 0.3 threshold
            for k in agg: agg[k] += mets[k]
    
    val_loss /= len(val_loader.dataset)
    eps=1e-8
    tp,fp,fn,tn = agg['tp'],agg['fp'],agg['fn'],agg['tn']
    iou = tp / (tp+fp+fn+eps)
    dice = (2*tp)/(2*tp+fp+fn+eps)
    precision = tp/(tp+fp+eps) if (tp+fp)>0 else 0
    recall = tp/(tp+fn+eps) if (tp+fn)>0 else 0
    f1 = 2*precision*recall/(precision+recall+eps) if (precision+recall)>0 else 0
    acc = (tp+tn)/(tp+tn+fp+fn+eps)
    history.append(dict(epoch=epoch+1, train_loss=train_loss, val_loss=val_loss, IoU=iou, Dice=dice, Precision=precision, Recall=recall, F1=f1, Accuracy=acc))

    scheduler.step(val_loss)
    print(f"Epoch {epoch+1}: TL {train_loss:.4f} VL {val_loss:.4f} IoU {iou:.4f} Dice {dice:.4f} F1 {f1:.4f} LR {optimizer.param_groups[0]['lr']:.2e}")

    # Save best model
    if epoch==0 or val_loss == min(h['val_loss'] for h in history):
        torch.save(model.state_dict(), 'best_segformer.pth')

    if early_stop(epoch, val_loss, model):
        print(f"Early stopping at epoch {epoch+1}")
        break

# Save training history
pd.DataFrame(history).to_csv('training_history_segformer.csv', index=False)
print('SegFormer training complete.')

                                                                    

Epoch 1: TL 0.5719 VL 0.4734 IoU 0.2640 Dice 0.4178 F1 0.4178 LR 1.00e-04


Epoch 2/200 Train:  24%|██▍       | 90/378 [00:50<02:42,  1.77it/s]

In [None]:
# Test evaluation - Same as U-Net++
model = SegFormer(
    img_size=256, 
    in_chans=6, 
    num_classes=1,
    embed_dims=[64, 128, 256, 512],
    depths=[3, 4, 6, 3],
    num_heads=[1, 2, 4, 8],
    dropout_rates=very_aggressive_dropout
).to(DEVICE)

# Load the trained weights
model.load_state_dict(torch.load('best_segformer.pth', map_location=DEVICE))
model.eval()  # IMPORTANT: This disables dropout for inference

agg = dict(tp=0,fp=0,fn=0,tn=0)
all_preds = []

with torch.no_grad():
    for xb, yb in tqdm(test_loader, desc="Test", leave=False):
        xb, yb = xb.to(DEVICE), yb.to(DEVICE)
        logits = model(xb)
        mets = batch_metrics(logits, yb, thresh=0.3)  # Same 0.3 threshold
        for k in agg: agg[k] += mets[k]
        probs = torch.sigmoid(logits)
        preds = (probs>=0.3).float().cpu()  # Same 0.3 threshold
        all_preds.append(preds)

all_preds = torch.cat(all_preds, dim=0)
eps=1e-8
tp,fp,fn,tn = agg['tp'],agg['fp'],agg['fn'],agg['tn']
iou = tp/(tp+fp+fn+eps)
dice = (2*tp)/(2*tp+fp+fn+eps)
precision = tp/(tp+fp+eps) if (tp+fp)>0 else 0
recall = tp/(tp+fn+eps) if (tp+fn)>0 else 0
f1 = 2*precision*recall/(precision+recall+eps) if (precision+recall)>0 else 0
acc = (tp+tn)/(tp+tn+fp+fn+eps)

cm = np.array([[tn, fp],[fn, tp]])
metrics = dict(IoU=iou, Dice=dice, Precision=precision, Recall=recall, F1=f1, Accuracy=acc, TP=tp, FP=fp, FN=fn, TN=tn)

print('\nTest Metrics (SegFormer with Aggressive Dropout):')
print(f'IoU: {iou:.4f}')
print(f'Dice: {dice:.4f}')
print(f'Precision: {precision:.4f}')
print(f'Recall: {recall:.4f}')
print(f'F1: {f1:.4f}')
print(f'Accuracy: {acc:.4f}')

# Save results
pd.DataFrame([metrics]).to_csv('test_metrics_segformer.csv', index=False)
np.savetxt('confusion_matrix_segformer.txt', cm, fmt='%d')

# Save first 10 prediction masks
os.makedirs('test_predictions_segformer', exist_ok=True)
for i in range(min(10, all_preds.shape[0])):
    img = (all_preds[i,0].numpy()*255).astype('uint8')
    Image.fromarray(img).save(f'test_predictions_segformer/pred_{i}.png')
print('Saved SegFormer prediction samples.')
print("SegFormer testing complete with 0.3 threshold!")

In [None]:
# Visualization - Same as U-Net++
hist_df = pd.read_csv('training_history_segformer.csv')
print('SegFormer Training History:')
print(hist_df.head())

# Create training plots
fig, ((ax1, ax2, ax3),(ax4, ax5, ax6)) = plt.subplots(2,3, figsize=(16,8))

# Plot 1: Loss curves
ax1.plot(hist_df['epoch'], hist_df['train_loss'], label='Train Loss', color='blue')
ax1.plot(hist_df['epoch'], hist_df['val_loss'], label='Val Loss', color='red')
ax1.set_title('SegFormer Loss Curves')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Plot 2: IoU
ax2.plot(hist_df['epoch'], hist_df['IoU'], label='IoU', color='green')
ax2.set_title('SegFormer Validation IoU')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('IoU')
ax2.grid(True, alpha=0.3)

# Plot 3: Dice
ax3.plot(hist_df['epoch'], hist_df['Dice'], label='Dice', color='orange')
ax3.set_title('SegFormer Validation Dice')
ax3.set_xlabel('Epoch')
ax3.set_ylabel('Dice')
ax3.grid(True, alpha=0.3)

# Plot 4: Precision & Recall
ax4.plot(hist_df['epoch'], hist_df['Precision'], label='Precision', color='purple')
ax4.plot(hist_df['epoch'], hist_df['Recall'], label='Recall', color='brown')
ax4.set_title('SegFormer Precision & Recall')
ax4.set_xlabel('Epoch')
ax4.set_ylabel('Score')
ax4.legend()
ax4.grid(True, alpha=0.3)

# Plot 5: F1 & Accuracy
ax5.plot(hist_df['epoch'], hist_df['F1'], label='F1', color='red')
ax5.plot(hist_df['epoch'], hist_df['Accuracy'], label='Accuracy', color='blue')
ax5.set_title('SegFormer F1 & Accuracy')
ax5.set_xlabel('Epoch')
ax5.set_ylabel('Score')
ax5.legend()
ax5.grid(True, alpha=0.3)

# Plot 6: Summary statistics
ax6.axis('off')
if len(hist_df) > 0:
    best_epoch = hist_df.loc[hist_df['val_loss'].idxmin(), 'epoch']
    best_val_loss = hist_df['val_loss'].min()
    best_iou = hist_df['IoU'].max()
    best_dice = hist_df['Dice'].max()
    best_f1 = hist_df['F1'].max()
    
    summary_text = f"""
    SEGFORMER TRAINING SUMMARY
    =========================
    Total Epochs: {len(hist_df)}
    Best Epoch: {best_epoch}
    
    Best Metrics:
    Val Loss: {best_val_loss:.4f}
    IoU: {best_iou:.4f}
    Dice: {best_dice:.4f}
    F1: {best_f1:.4f}
    
    Final Train/Val Gap:
    {hist_df['val_loss'].iloc[-1] - hist_df['train_loss'].iloc[-1]:.4f}
    
    Model: SegFormer w/ Aggressive Dropout
    Threshold: 0.3
    """
    ax6.text(0.1, 0.5, summary_text, fontsize=9, fontfamily='monospace',
             verticalalignment='center', bbox=dict(boxstyle="round,pad=0.3", facecolor="lightblue"))

plt.tight_layout()
plt.savefig('training_curves_segformer.png', dpi=150, bbox_inches='tight')
plt.show()

print("\n" + "="*50)
print("SEGFORMER IMPLEMENTATION COMPLETE!")
print("="*50)
print(f"Model Parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Training Epochs: {len(hist_df)}")
print(f"Best Validation IoU: {best_iou:.4f}")
print(f"Best Validation Dice: {best_dice:.4f}")
print(f"Best Validation F1: {best_f1:.4f}")
print("\nFiles generated:")
print("- best_segformer.pth")
print("- training_history_segformer.csv")
print("- test_metrics_segformer.csv")
print("- confusion_matrix_segformer.txt")
print("- training_curves_segformer.png")
print("- test_predictions_segformer/ (folder with predictions)")
print("\nSegFormer training and evaluation completed successfully!")