In [1]:
## mount drive
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
!pip install snntorch



In [3]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torch.cuda.amp import GradScaler, autocast
from pathlib import Path
from tqdm import tqdm
import albumentations as A
from albumentations.pytorch import ToTensorV2
import matplotlib.pyplot as plt
from sklearn.metrics import matthews_corrcoef
import snntorch as snn
from snntorch import surrogate
from snntorch import utils

# ==========================================
# 1. CONFIGURATION
# ==========================================
CONFIG = {
    "base_dir": "/content/drive/MyDrive/GlacierHack_practice/Train",
    "project_dir": "/content/drive/MyDrive/Glacier_SNN_Project",

    # TOGGLE THIS: "CNN" or "SNN"
    "model_type": "CNN",

    # SNN Parameters
    "time_steps": 4,       # T = 4
    "beta": 0.9,           # Decay rate

    "epochs": 30,
    "batch_size": 8,
    "lr": 1e-3,
    "num_workers": 2,
    "seed": 42,
    "device": torch.device("cuda" if torch.cuda.is_available() else "cpu")
}

os.makedirs(CONFIG['project_dir'], exist_ok=True)

def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)

set_seed(CONFIG['seed'])

# ==========================================
# 2. DATASET
# ==========================================
class GlacierDataset(Dataset):
    def __init__(self, base_dir, transform=None):
        self.base_dir = Path(base_dir)
        self.band_dirs = [self.base_dir / f"Band{i}" for i in range(1, 6)]
        self.label_dir = self.base_dir / "labels"
        self.ids = sorted([p.stem for p in self.band_dirs[0].glob("*.tif")])
        self.transform = transform

    def __len__(self): return len(self.ids)

    def __getitem__(self, idx):
        img_id = self.ids[idx]
        bands = [cv2.imread(str(d / f"{img_id}.tif"), cv2.IMREAD_UNCHANGED).astype(np.float32) for d in self.band_dirs]
        image = np.stack(bands, axis=-1)

        label = cv2.imread(str(self.label_dir / f"{img_id}.tif"), cv2.IMREAD_UNCHANGED)
        if label.ndim == 3: label = cv2.cvtColor(label, cv2.COLOR_BGR2GRAY)
        mask = np.zeros_like(label, dtype=np.uint8)
        mask[label == 85] = 1; mask[label == 170] = 2; mask[label == 255] = 3

        p02, p98 = np.percentile(image, 2), np.percentile(image, 98)
        image = np.clip(image, p02, p98)
        image = (image - image.min()) / (image.max() - image.min() + 1e-6)

        if self.transform:
            aug = self.transform(image=image, mask=mask)
            return aug["image"].float(), aug["mask"].long()
        return torch.tensor(image.transpose(2,0,1)).float(), torch.tensor(mask).long()

class Wrapper(Dataset):
    def __init__(self, ds, t): self.ds, self.t = ds, t
    def __len__(self): return len(self.ds)
    def __getitem__(self, i):
        img, mask = self.ds[i]
        img = img.numpy().transpose(1,2,0); mask = mask.numpy()
        res = self.t(image=img, mask=mask)
        return res['image'], res['mask'].long()

train_transform = A.Compose([
    A.HorizontalFlip(p=0.5), A.VerticalFlip(p=0.5), A.RandomRotate90(p=0.5),
    A.GridDistortion(p=0.3),
    ToTensorV2(),
])
val_transform = A.Compose([ToTensorV2()])

full_dataset = GlacierDataset(CONFIG['base_dir'], transform=None)
val_len = int(len(full_dataset)*0.2)
train_ds, val_ds = random_split(full_dataset, [len(full_dataset)-val_len, val_len])

train_loader = DataLoader(Wrapper(train_ds, train_transform), batch_size=CONFIG['batch_size'], shuffle=True, num_workers=2)
val_loader = DataLoader(Wrapper(val_ds, val_transform), batch_size=CONFIG['batch_size'], shuffle=False, num_workers=2)

# ==========================================
# 3. ARCHITECTURES
# ==========================================

# A. CNN Block
class CNNBlock(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_c, out_c, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_c),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_c, out_c, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_c),
            nn.ReLU(inplace=True)
        )
    def forward(self, x): return self.net(x)

# B. SNN Block (snnTorch)
class SNNBlock(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        spike_grad = surrogate.fast_sigmoid()
        # init_hidden=True lets snnTorch handle the states
        self.net = nn.Sequential(
            nn.Conv2d(in_c, out_c, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_c),
            snn.Leaky(beta=CONFIG['beta'], spike_grad=spike_grad, init_hidden=True),

            nn.Conv2d(out_c, out_c, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_c),
            snn.Leaky(beta=CONFIG['beta'], spike_grad=spike_grad, init_hidden=True)
        )
    def forward(self, x): return self.net(x)

# C. The U-Net
class ProjectUNet(nn.Module):
    def __init__(self, in_ch=5, n_classes=4, mode="CNN"):
        super().__init__()
        self.mode = mode
        Block = SNNBlock if mode == "SNN" else CNNBlock

        # Encoder
        self.inc = Block(in_ch, 32)
        self.down1 = nn.Sequential(nn.MaxPool2d(2), Block(32, 64))
        self.down2 = nn.Sequential(nn.MaxPool2d(2), Block(64, 128))
        self.down3 = nn.Sequential(nn.MaxPool2d(2), Block(128, 256))

        # Decoder
        self.up1 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.conv1 = Block(256, 128)
        self.up2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.conv2 = Block(128, 64)
        self.up3 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
        self.conv3 = Block(64, 32)

        self.outc = nn.Conv2d(32, n_classes, 1)

    def forward(self, x):
        if self.mode == "SNN":
            # SNN Step: Reset hidden states at start of batch!
            utils.reset(self)

            spk_rec = []
            # Run T times
            for step in range(CONFIG['time_steps']):
                # Encoder
                x1 = self.inc(x)
                x2 = self.down1(x1)
                x3 = self.down2(x2)
                x4 = self.down3(x3)

                # Decoder
                x_up1 = self.up1(x4)
                if x_up1.shape != x3.shape: x_up1 = F.interpolate(x_up1, size=x3.shape[2:])
                x5 = self.conv1(torch.cat([x3, x_up1], dim=1))

                x_up2 = self.up2(x5)
                if x_up2.shape != x2.shape: x_up2 = F.interpolate(x_up2, size=x2.shape[2:])
                x6 = self.conv2(torch.cat([x2, x_up2], dim=1))

                x_up3 = self.up3(x6)
                if x_up3.shape != x1.shape: x_up3 = F.interpolate(x_up3, size=x1.shape[2:])
                x7 = self.conv3(torch.cat([x1, x_up3], dim=1))

                out = self.outc(x7)
                spk_rec.append(out)

            # Average spikes over time
            return torch.stack(spk_rec).mean(0)
        else:
            # CNN Pass
            x1 = self.inc(x)
            x2 = self.down1(x1)
            x3 = self.down2(x2)
            x4 = self.down3(x3)

            x = self.up1(x4)
            if x.shape != x3.shape: x = F.interpolate(x, size=x3.shape[2:])
            x = self.conv1(torch.cat([x3, x], dim=1))

            x = self.up2(x)
            if x.shape != x2.shape: x = F.interpolate(x, size=x2.shape[2:])
            x = self.conv2(torch.cat([x2, x], dim=1))

            x = self.up3(x)
            if x.shape != x1.shape: x = F.interpolate(x, size=x1.shape[2:])
            x = self.conv3(torch.cat([x1, x], dim=1))

            return self.outc(x)

# ==========================================
# 4. VISUALIZATION
# ==========================================
def save_artifacts(history, sample_vis, epoch, mode):
    # Plot
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(history['loss'], label='Train Loss')
    plt.title(f"{mode} Loss")
    plt.legend()
    plt.subplot(1, 2, 2)
    plt.plot(history['mcc'], label='Val MCC', color='green')
    plt.title(f"{mode} MCC")
    plt.legend()
    plt.savefig(f"{CONFIG['project_dir']}/{mode}_history.png")
    plt.close()

    # Sample
    img, gt, pred = sample_vis
    rgb = img[[3,2,1]].transpose(1,2,0)
    rgb = (rgb - rgb.min()) / (rgb.max() - rgb.min() + 1e-6)

    plt.figure(figsize=(12, 4))
    plt.subplot(1, 3, 1); plt.imshow(rgb); plt.title("Input")
    plt.subplot(1, 3, 2); plt.imshow(gt, cmap='nipy_spectral'); plt.title("Ground Truth")
    plt.subplot(1, 3, 3); plt.imshow(pred, cmap='nipy_spectral'); plt.title(f"Pred (Ep {epoch})")
    plt.savefig(f"{CONFIG['project_dir']}/{mode}_sample.png")
    plt.close()

# ==========================================
# 5. TRAINING LOOP
# ==========================================
print(f"üöÄ Initializing {CONFIG['model_type']} (snnTorch)...")
# FIX IS HERE: using 'in_ch' to match class definition
model = ProjectUNet(in_ch=5, n_classes=4, mode=CONFIG['model_type']).to(CONFIG['device'])

optimizer = optim.AdamW(model.parameters(), lr=CONFIG['lr'])
criterion = nn.CrossEntropyLoss()
scaler = GradScaler()

best_mcc = -1.0
history = {'loss': [], 'mcc': []}

print(f"üî• Starting {CONFIG['model_type']} Training...")

for epoch in range(CONFIG['epochs']):
    model.train()
    running_loss = 0

    for imgs, masks in tqdm(train_loader, desc=f"Ep {epoch+1}"):
        imgs, masks = imgs.to(CONFIG['device']), masks.to(CONFIG['device']).long()
        optimizer.zero_grad()

        with autocast():
            out = model(imgs)
            loss = criterion(out, masks)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        running_loss += loss.item()

    avg_loss = running_loss / len(train_loader)
    history['loss'].append(avg_loss)

    # Validation
    model.eval()
    preds, targets = [], []
    sample_vis = None

    with torch.no_grad():
        for i, (imgs, masks) in enumerate(val_loader):
            imgs = imgs.to(CONFIG['device'])
            with autocast():
                out = model(imgs)
            preds.append(out.argmax(1).cpu())
            targets.append(masks.cpu())

            if i == 0:
                sample_vis = (imgs[0].cpu().numpy(), masks[0].cpu().numpy(), preds[-1][0].numpy())

    mcc = matthews_corrcoef(torch.cat(targets).numpy().flatten(), torch.cat(preds).numpy().flatten())
    history['mcc'].append(mcc)

    print(f"Ep {epoch+1} | Loss: {avg_loss:.4f} | Val MCC: {mcc:.4f}")

    save_artifacts(history, sample_vis, epoch+1, CONFIG['model_type'])

    if mcc > best_mcc:
        best_mcc = mcc
        torch.save(model.state_dict(), f"{CONFIG['project_dir']}/best_{CONFIG['model_type']}.pth")
        print("‚úÖ Saved Best Model!")

print(f"üèÅ Training Complete. Best MCC: {best_mcc:.4f}")

üöÄ Initializing CNN (snnTorch)...


  scaler = GradScaler()


üî• Starting CNN Training...


  with autocast():
Ep 1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3/3 [00:03<00:00,  1.08s/it]
  with autocast():


Ep 1 | Loss: 1.2649 | Val MCC: 0.0000
‚úÖ Saved Best Model!


  with autocast():
Ep 2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3/3 [00:03<00:00,  1.27s/it]
  with autocast():


Ep 2 | Loss: 1.0950 | Val MCC: 0.0000


  with autocast():
Ep 3: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3/3 [00:03<00:00,  1.13s/it]
  with autocast():


Ep 3 | Loss: 1.0166 | Val MCC: 0.0984
‚úÖ Saved Best Model!


  with autocast():
Ep 4: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3/3 [00:02<00:00,  1.26it/s]
  with autocast():


Ep 4 | Loss: 0.9407 | Val MCC: 0.0005


  with autocast():
Ep 5: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3/3 [00:02<00:00,  1.14it/s]
  with autocast():


Ep 5 | Loss: 0.8875 | Val MCC: 0.0513


  with autocast():
Ep 6: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3/3 [00:01<00:00,  1.56it/s]
  with autocast():


Ep 6 | Loss: 0.8510 | Val MCC: 0.2449
‚úÖ Saved Best Model!


  with autocast():
Ep 7: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3/3 [00:01<00:00,  1.61it/s]
  with autocast():


Ep 7 | Loss: 0.7990 | Val MCC: 0.2362


  with autocast():
Ep 8: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3/3 [00:04<00:00,  1.38s/it]
  with autocast():


Ep 8 | Loss: 0.7608 | Val MCC: 0.2026


  with autocast():
Ep 9: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3/3 [00:01<00:00,  1.58it/s]
  with autocast():


Ep 9 | Loss: 0.7314 | Val MCC: 0.4035
‚úÖ Saved Best Model!


  with autocast():
Ep 10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3/3 [00:01<00:00,  1.58it/s]
  with autocast():


Ep 10 | Loss: 0.7261 | Val MCC: 0.4333
‚úÖ Saved Best Model!


  with autocast():
Ep 11: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3/3 [00:01<00:00,  1.59it/s]
  with autocast():


Ep 11 | Loss: 0.6969 | Val MCC: 0.3939


  with autocast():
Ep 12: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3/3 [00:02<00:00,  1.20it/s]
  with autocast():


Ep 12 | Loss: 0.6777 | Val MCC: 0.4482
‚úÖ Saved Best Model!


  with autocast():
Ep 13: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3/3 [00:01<00:00,  1.59it/s]
  with autocast():


Ep 13 | Loss: 0.6676 | Val MCC: 0.4831
‚úÖ Saved Best Model!


  with autocast():
Ep 14: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3/3 [00:01<00:00,  1.64it/s]
  with autocast():


Ep 14 | Loss: 0.6485 | Val MCC: 0.4942
‚úÖ Saved Best Model!


  with autocast():
Ep 15: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3/3 [00:01<00:00,  1.55it/s]
  with autocast():


Ep 15 | Loss: 0.6389 | Val MCC: 0.5010
‚úÖ Saved Best Model!


  with autocast():
Ep 16: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3/3 [00:02<00:00,  1.30it/s]
  with autocast():


Ep 16 | Loss: 0.6200 | Val MCC: 0.4215


  with autocast():
Ep 17: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3/3 [00:01<00:00,  1.61it/s]
  with autocast():


Ep 17 | Loss: 0.6255 | Val MCC: 0.4461


  with autocast():
Ep 18: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3/3 [00:01<00:00,  1.60it/s]
  with autocast():


Ep 18 | Loss: 0.5892 | Val MCC: 0.5262
‚úÖ Saved Best Model!


  with autocast():
Ep 19: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3/3 [00:01<00:00,  1.56it/s]
  with autocast():


Ep 19 | Loss: 0.5913 | Val MCC: 0.5067


  with autocast():
Ep 20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3/3 [00:02<00:00,  1.25it/s]
  with autocast():


Ep 20 | Loss: 0.5819 | Val MCC: 0.4936


  with autocast():
Ep 21: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3/3 [00:01<00:00,  1.58it/s]
  with autocast():


Ep 21 | Loss: 0.5652 | Val MCC: 0.5557
‚úÖ Saved Best Model!


  with autocast():
Ep 22: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3/3 [00:01<00:00,  1.57it/s]
  with autocast():


Ep 22 | Loss: 0.5861 | Val MCC: 0.5240


  with autocast():
Ep 23: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3/3 [00:02<00:00,  1.50it/s]
  with autocast():


Ep 23 | Loss: 0.5585 | Val MCC: 0.5419


  with autocast():
Ep 24: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3/3 [00:02<00:00,  1.38it/s]
  with autocast():


Ep 24 | Loss: 0.5418 | Val MCC: 0.5395


  with autocast():
Ep 25: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3/3 [00:01<00:00,  1.58it/s]
  with autocast():


Ep 25 | Loss: 0.5312 | Val MCC: 0.4963


  with autocast():
Ep 26: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3/3 [00:01<00:00,  1.55it/s]
  with autocast():


Ep 26 | Loss: 0.5225 | Val MCC: 0.5440


  with autocast():
Ep 27: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3/3 [00:02<00:00,  1.50it/s]
  with autocast():


Ep 27 | Loss: 0.5311 | Val MCC: 0.5030


  with autocast():
Ep 28: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3/3 [00:02<00:00,  1.34it/s]
  with autocast():


Ep 28 | Loss: 0.5296 | Val MCC: 0.4090


  with autocast():
Ep 29: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3/3 [00:02<00:00,  1.23it/s]
  with autocast():


Ep 29 | Loss: 0.4970 | Val MCC: 0.5326


  with autocast():
Ep 30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3/3 [00:01<00:00,  1.58it/s]
  with autocast():


Ep 30 | Loss: 0.5302 | Val MCC: 0.5742
‚úÖ Saved Best Model!
üèÅ Training Complete. Best MCC: 0.5742
