# U-Net++ Change Detection Training Notebook
This notebook trains a U-Net++ model for change detection using A/B/label folders for train, val, and test. It includes installation, data loading, training with early stopping, saving metrics, and displaying results.

## Assignment Compliance (Segmentation)
- Problem: Change detection (binary segmentation of change mask)
- Model: U-Net++ (recent variant, deeper variant over base U-Net)
- Epochs: Min 50 with early stopping (patience 10)
- Data: Using existing train / val / test folders exactly as provided (no re-splitting enforced).
- Metrics tracked: IoU, Dice, Precision, Recall, F1, Accuracy, Loss + confusion matrix (pixel-wise)
- Outputs: Metric plots, sample predictions, parameter count, GFLOPs, saved best weights.
- Saved artifacts: best_model.pth, training_history.csv, test_metrics.csv, confusion_matrix.txt, prediction PNGs.


In [1]:
# Install all required packages
!pip install segmentation-models-pytorch torch torchvision albumentations scikit-learn pandas tqdm thop torchinfo matplotlib seaborn --quiet

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m154.8/154.8 kB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m83.2 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m64.8 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m37.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m7.7 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m 

In [2]:
# Import Required Libraries
import os
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import segmentation_models_pytorch as smp
from sklearn.metrics import f1_score, accuracy_score, jaccard_score

In [3]:
# Data Loading and Dataset Definition
DATA_ROOT = '/kaggle/input/earthquakedataset/earthquakeDataset'  # Change to your dataset path
TRAIN_DIR = os.path.join(DATA_ROOT, 'train')
VAL_DIR = os.path.join(DATA_ROOT, 'val')
TEST_DIR = os.path.join(DATA_ROOT, 'test')

A_TRAIN = os.path.join(TRAIN_DIR, 'A_train_aug')
B_TRAIN = os.path.join(TRAIN_DIR, 'B_train_aug')
LABEL_TRAIN = os.path.join(TRAIN_DIR, 'label_train_aug')

A_VAL = os.path.join(VAL_DIR, 'A_val')
B_VAL = os.path.join(VAL_DIR, 'B_val')
LABEL_VAL = os.path.join(VAL_DIR, 'label_val')

A_TEST = os.path.join(TEST_DIR, 'A_test')
B_TEST = os.path.join(TEST_DIR, 'B_test')
LABEL_TEST = os.path.join(TEST_DIR, 'label_test')

IMG_SIZE = (256, 256)
BATCH_SIZE = 16
MIN_EPOCHS = 100
PATIENCE = 10

class ChangeDataset(Dataset):
    def __init__(self, a_dir, b_dir, label_dir):
        self.a_files = sorted([os.path.join(a_dir, f) for f in os.listdir(a_dir) if f.endswith('.png')])
        self.b_files = sorted([os.path.join(b_dir, f) for f in os.listdir(b_dir) if f.endswith('.png')])
        self.label_files = sorted([os.path.join(label_dir, f) for f in os.listdir(label_dir) if f.endswith('.png')])
        assert len(self.a_files) == len(self.b_files) == len(self.label_files)
        
        # Add ImageNet normalization constants
        self.imagenet_mean = np.array([0.485, 0.456, 0.406])
        self.imagenet_std = np.array([0.229, 0.224, 0.225])

    def __len__(self):
        return len(self.a_files)

    def __getitem__(self, idx):
        a_img = np.array(Image.open(self.a_files[idx]).convert('RGB')).astype('float32') / 255.0
        b_img = np.array(Image.open(self.b_files[idx]).convert('RGB')).astype('float32') / 255.0
        
        # Apply ImageNet normalization (same as your augmentation script)
        a_img = (a_img - self.imagenet_mean) / self.imagenet_std
        b_img = (b_img - self.imagenet_mean) / self.imagenet_std
        
        label = np.array(Image.open(self.label_files[idx]).convert('L')).astype('float32') / 255.0
        label = (label > 0.5).astype('float32')  # Fixed HTML entity
        x = np.concatenate([a_img, b_img], axis=2)
        x = np.transpose(x, (2, 0, 1))
        y = label[np.newaxis, ...]
        return torch.tensor(x), torch.tensor(y)


train_ds = ChangeDataset(A_TRAIN, B_TRAIN, LABEL_TRAIN)
val_ds = ChangeDataset(A_VAL, B_VAL, LABEL_VAL)
test_ds = ChangeDataset(A_TEST, B_TEST, LABEL_TEST)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

In [4]:
# Model Setup and Training Loop with extended metrics
model = smp.UnetPlusPlus(
    encoder_name='resnet34',
    in_channels=6,
    classes=1,
    activation=None
).cuda()

loss_fn = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

from sklearn.metrics import precision_score, recall_score, f1_score

def compute_seg_metrics(y_true, y_pred):
    # y_true, y_pred: (N,1,H,W) float tensors
    y_true = y_true.detach().cpu().numpy().astype('uint8')
    y_pred = (y_pred.detach().cpu().numpy() > 0.5).astype('uint8')
    flat_t = y_true.reshape(-1)
    flat_p = y_pred.reshape(-1)
    intersection = (flat_t & flat_p).sum()
    union = (flat_t | flat_p).sum() + 1e-7
    iou = intersection / union
    dice = (2*intersection) / (flat_t.sum() + flat_p.sum() + 1e-7)
    prec = precision_score(flat_t, flat_p, zero_division=0)
    rec = recall_score(flat_t, flat_p, zero_division=0)
    f1 = f1_score(flat_t, flat_p, zero_division=0)
    acc = (flat_t == flat_p).mean()
    return dict(IoU=iou, Dice=dice, Precision=prec, Recall=rec, F1=f1, Accuracy=acc)

best_val_loss = float('inf')
epochs_no_improve = 0
history = []
for epoch in range(1, 1000):
    model.train()
    train_loss = 0
    for x, y in train_loader:
        x, y = x.cuda(), y.cuda()
        optimizer.zero_grad()
        out = model(x)
        loss = loss_fn(out, y)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * x.size(0)
    train_loss /= len(train_loader.dataset)

    model.eval()
    val_loss = 0
    val_preds = []
    val_trues = []
    with torch.no_grad():
        for x, y in val_loader:
            x, y = x.cuda(), y.cuda()
            out = model(x)
            loss = loss_fn(out, y)
            val_loss += loss.item() * x.size(0)
            val_preds.append(torch.sigmoid(out))
            val_trues.append(y)
    val_loss /= len(val_loader.dataset)

    # Metrics on concatenated tensors
    val_preds_cat = torch.cat(val_preds, dim=0)
    val_trues_cat = torch.cat(val_trues, dim=0)
    metrics = compute_seg_metrics(val_trues_cat, val_preds_cat)
    history.append({'epoch':epoch,'train_loss':train_loss,'val_loss':val_loss, **metrics})

    print(f"Epoch {epoch}: TL {train_loss:.4f} VL {val_loss:.4f} IoU {metrics['IoU']:.4f} Dice {metrics['Dice']:.4f} F1 {metrics['F1']:.4f}")

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        epochs_no_improve = 0
        torch.save(model.state_dict(), 'best_model.pth')
    else:
        epochs_no_improve += 1

    if epochs_no_improve >= 10:
        print(f'Early stopping at epoch {epoch}')
        break

import pandas as pd
hist_df = pd.DataFrame(history)
hist_df.to_csv('training_history.csv', index=False)
print('Training complete. Best val loss:', best_val_loss)

config.json:   0%|          | 0.00/156 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/87.3M [00:00<?, ?B/s]

Epoch 1: TL 0.3278 VL 0.2847 IoU 0.0973 Dice 0.1774 F1 0.1774
Epoch 2: TL 0.3066 VL 0.3018 IoU 0.0000 Dice 0.0000 F1 0.0000
Epoch 3: TL 0.3004 VL 0.2831 IoU 0.0010 Dice 0.0020 F1 0.0020
Epoch 4: TL 0.2961 VL 0.2628 IoU 0.0228 Dice 0.0445 F1 0.0445
Epoch 5: TL 0.2896 VL 0.4019 IoU 0.0005 Dice 0.0009 F1 0.0009
Epoch 6: TL 0.2865 VL 0.2515 IoU 0.0941 Dice 0.1720 F1 0.1720
Epoch 7: TL 0.2805 VL 0.2498 IoU 0.3444 Dice 0.5124 F1 0.5124
Epoch 8: TL 0.2744 VL 0.2291 IoU 0.3562 Dice 0.5253 F1 0.5253
Epoch 9: TL 0.2670 VL 0.2502 IoU 0.2010 Dice 0.3347 F1 0.3347
Epoch 10: TL 0.2579 VL 0.2532 IoU 0.2951 Dice 0.4557 F1 0.4557
Epoch 11: TL 0.2463 VL 0.2450 IoU 0.2943 Dice 0.4548 F1 0.4548
Epoch 12: TL 0.2343 VL 0.2479 IoU 0.3784 Dice 0.5490 F1 0.5490
Epoch 13: TL 0.2171 VL 0.2478 IoU 0.3435 Dice 0.5114 F1 0.5114


KeyboardInterrupt: 

In [None]:
# Load Best Model and Evaluate on Test Set with full metrics
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score
from torchinfo import summary
from thop import profile

model.load_state_dict(torch.load('best_model.pth'))
model.eval()

# Parameter count & GFLOPs
sample_input = torch.randn(1,6,256,256).cuda()
macs, params = profile(model, inputs=(sample_input,), verbose=False)
GFLOPs = macs/1e9
param_millions = params/1e6
print(f'GFLOPs: {GFLOPs:.3f}, Params (M): {param_millions:.3f}')

all_preds = []
all_logits = []
all_labels = []
with torch.no_grad():
    for x, y in test_loader:
        x = x.cuda(); y = y.cuda()
        out = model(x)
        probs = torch.sigmoid(out)
        preds = (probs > 0.5).float()
        all_preds.append(preds.cpu())
        all_logits.append(probs.cpu())
        all_labels.append(y.cpu())

all_preds = torch.cat(all_preds, dim=0)
all_labels = torch.cat(all_labels, dim=0)

# Metrics
flat_p = all_preds.numpy().reshape(-1).astype('uint8')
flat_t = all_labels.numpy().reshape(-1).astype('uint8')
intersection = (flat_p & flat_t).sum()
union = (flat_p | flat_t).sum() + 1e-7
iou = intersection/union
dice = (2*intersection)/(flat_p.sum()+flat_t.sum()+1e-7)
prec = precision_score(flat_t, flat_p, zero_division=0)
rec = recall_score(flat_t, flat_p, zero_division=0)
f1 = f1_score(flat_t, flat_p, zero_division=0)
acc = (flat_p==flat_t).mean()
cm = confusion_matrix(flat_t, flat_p).astype(int)
print('Confusion Matrix:\n', cm)
metrics = {'IoU':iou,'Dice':dice,'Precision':prec,'Recall':rec,'F1':f1,'Accuracy':acc,'GFLOPs':GFLOPs,'Params_M':param_millions}
print('Test Metrics:', metrics)

pd.DataFrame([metrics]).to_csv('test_metrics.csv', index=False)
np.savetxt('confusion_matrix.txt', cm, fmt='%d')

# Save predictions as images (first 10)
import os
os.makedirs('test_predictions', exist_ok=True)
import numpy as np
from PIL import Image
for i in range(min(10, all_preds.shape[0])):
    pred_img = (all_preds[i,0].numpy()*255).astype('uint8')
    Image.fromarray(pred_img).save(f'test_predictions/pred_{i}.png')


In [None]:
# Display Final Results & Plots
import matplotlib.pyplot as plt
import seaborn as sns
hist_df = pd.read_csv('training_history.csv')
print('History head:')
print(hist_df.head())

fig, axes = plt.subplots(2,3, figsize=(16,8))
axes = axes.ravel()
plot_cols = ['train_loss','val_loss','IoU','Dice','Precision','Recall']
for ax,col in zip(axes, plot_cols):
    ax.plot(hist_df['epoch'], hist_df[col], label=col)
    ax.set_title(col)
    ax.set_xlabel('Epoch')
    ax.legend()
plt.tight_layout()
plt.show()

# F1 & Accuracy separate
plt.figure(figsize=(6,4))
plt.plot(hist_df['epoch'], hist_df['F1'], label='F1')
plt.plot(hist_df['epoch'], hist_df['Accuracy'], label='Accuracy')
plt.legend(); plt.title('F1 & Accuracy'); plt.xlabel('Epoch'); plt.show()

# Load test metrics
metrics_df = pd.read_csv('test_metrics.csv')
print('Test Metrics:')
print(metrics_df)

# Confusion matrix heatmap
import numpy as np
cm = np.loadtxt('confusion_matrix.txt', dtype=int)
plt.figure(figsize=(4,4))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.title('Pixel Confusion Matrix')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.show()

# Show sample predictions
import os
from PIL import Image
for i in range(3):
    p_path = f'test_predictions/pred_{i}.png'
    if os.path.exists(p_path):
        plt.figure()
        plt.imshow(Image.open(p_path), cmap='gray')
        plt.title(f'Pred {i}')
        plt.axis('off')
plt.show()