In [1]:
## mount drive
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
!pip install snntorch segmentation-models-pytorch



In [None]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torch.cuda.amp import GradScaler, autocast
from pathlib import Path
from tqdm import tqdm
import albumentations as A
from albumentations.pytorch import ToTensorV2
import matplotlib.pyplot as plt
from sklearn.metrics import matthews_corrcoef
from torchvision import models
import snntorch as snn
from snntorch import surrogate
from snntorch import utils
import gc

# ==========================================
# 1. CONFIGURATION
# ==========================================
CONFIG = {
    "base_dir": "/content/drive/MyDrive/GlacierHack_practice/Train",
    "project_dir": "/content/drive/MyDrive/Glacier_SNN_ResNet",

    "model_type": "SNN",

    "time_steps": 8,       # T=8 for SNN
    "beta": 0.9,
    "epochs": 30,
    "batch_size": 4,
    "lr": 1e-4,
    "num_workers": 2,
    "device": torch.device("cuda" if torch.cuda.is_available() else "cpu")
}

os.makedirs(CONFIG['project_dir'], exist_ok=True)
torch.cuda.empty_cache()
gc.collect()

def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)

set_seed(42)

# ==========================================
# 2. DATASET
# ==========================================
class GlacierDataset(Dataset):
    def __init__(self, base_dir, transform=None):
        self.base_dir = Path(base_dir)
        self.band_dirs = [self.base_dir / f"Band{i}" for i in range(1, 6)]
        self.label_dir = self.base_dir / "labels"
        self.ids = sorted([p.stem for p in self.band_dirs[0].glob("*.tif")])
        self.transform = transform

    def __len__(self): return len(self.ids)

    def __getitem__(self, idx):
        img_id = self.ids[idx]
        bands = [cv2.imread(str(d / f"{img_id}.tif"), cv2.IMREAD_UNCHANGED).astype(np.float32) for d in self.band_dirs]
        image = np.stack(bands, axis=-1)
        label = cv2.imread(str(self.label_dir / f"{img_id}.tif"), cv2.IMREAD_UNCHANGED)
        if label.ndim == 3: label = cv2.cvtColor(label, cv2.COLOR_BGR2GRAY)
        mask = np.zeros_like(label, dtype=np.uint8)
        mask[label == 85] = 1; mask[label == 170] = 2; mask[label == 255] = 3

        p02, p98 = np.percentile(image, 2), np.percentile(image, 98)
        image = np.clip(image, p02, p98)
        image = (image - image.min()) / (image.max() - image.min() + 1e-6)

        if self.transform:
            aug = self.transform(image=image, mask=mask)
            return aug["image"].float(), torch.tensor(aug["mask"]).long()
        return torch.tensor(image.transpose(2,0,1)).float(), torch.tensor(mask).long()

class Wrapper(Dataset):
    def __init__(self, ds, t): self.ds, self.t = ds, t
    def __len__(self): return len(self.ds)
    def __getitem__(self, i):
        img, mask = self.ds[i]
        img = img.numpy().transpose(1,2,0); mask = mask.numpy()
        res = self.t(image=img, mask=mask)
        return res['image'], res['mask'].long()

train_transform = A.Compose([
    A.HorizontalFlip(p=0.5), A.VerticalFlip(p=0.5), A.RandomRotate90(p=0.5),
    A.GridDistortion(p=0.3),
    ToTensorV2(),
])
val_transform = A.Compose([ToTensorV2()])

full_ds = GlacierDataset(CONFIG['base_dir'], transform=train_transform)
val_len = int(len(full_ds)*0.2)
train_ds, val_ds = random_split(full_ds, [len(full_ds)-val_len, val_len], generator=torch.Generator().manual_seed(42))
val_ds.dataset.transform = val_transform

train_loader = DataLoader(Wrapper(train_ds, train_transform), batch_size=CONFIG['batch_size'], shuffle=True, num_workers=2)
val_loader = DataLoader(Wrapper(val_ds, val_transform), batch_size=CONFIG['batch_size'], shuffle=False, num_workers=2)

# ==========================================
# 3. FIXED RESNET ARCHITECTURE
# ==========================================
class UnifiedResNetEncoder(nn.Module):
    def __init__(self, mode="CNN"):
        super().__init__()
        resnet = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)

        self.conv1 = nn.Conv2d(5, 64, kernel_size=7, stride=2, padding=3, bias=False)
        with torch.no_grad():
            self.conv1.weight[:, :3] = resnet.conv1.weight
            self.conv1.weight[:, 3:] = resnet.conv1.weight[:, :2]

        self.bn1 = resnet.bn1
        self.relu = snn.Leaky(beta=CONFIG['beta'], spike_grad=surrogate.atan(), init_hidden=True) if mode == "SNN" else resnet.relu
        self.maxpool = resnet.maxpool

        self.layer1 = self._convert(resnet.layer1, mode)
        self.layer2 = self._convert(resnet.layer2, mode)
        self.layer3 = self._convert(resnet.layer3, mode)
        self.layer4 = self._convert(resnet.layer4, mode)

    def _convert(self, block, mode):
        if mode == "CNN": return block
        layers = []
        for b in block:
            b.relu = snn.Leaky(beta=CONFIG['beta'], spike_grad=surrogate.atan(), init_hidden=True)
            layers.append(b)
        return nn.Sequential(*layers)

    def forward(self, x):
        feats = []

        # Stem
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        feats.append(x) # [B, 64, 256, 256] -> Feat 0 (Important for 512x512 restoration)

        x = self.maxpool(x) # 128x128

        # Layers
        x = self.layer1(x)
        feats.append(x) # [B, 64, 128, 128] -> Feat 1

        x = self.layer2(x)
        feats.append(x) # [B, 128, 64, 64] -> Feat 2

        x = self.layer3(x)
        feats.append(x) # [B, 256, 32, 32] -> Feat 3

        x = self.layer4(x)
        feats.append(x) # [B, 512, 16, 16] -> Feat 4 (Bottleneck)

        return feats

class UnifiedDecoder(nn.Module):
    def __init__(self, mode="CNN"):
        super().__init__()
        spike_grad = surrogate.atan()

        def block(in_c, out_c):
            act = snn.Leaky(beta=CONFIG['beta'], spike_grad=spike_grad, init_hidden=True) if mode == "SNN" else nn.ReLU(inplace=True)
            return nn.Sequential(
                nn.Conv2d(in_c, out_c, 3, padding=1),
                nn.BatchNorm2d(out_c),
                act
            )

        # Decoder Steps (Reverse of Encoder)

        # 1. 16 -> 32 (Connect with Layer3 [256])
        self.up4 = nn.ConvTranspose2d(512, 256, 2, 2)
        self.dec4 = block(256+256, 256)

        # 2. 32 -> 64 (Connect with Layer2 [128])
        self.up3 = nn.ConvTranspose2d(256, 128, 2, 2)
        self.dec3 = block(128+128, 128)

        # 3. 64 -> 128 (Connect with Layer1 [64])
        self.up2 = nn.ConvTranspose2d(128, 64, 2, 2)
        self.dec2 = block(64+64, 64)

        # 4. 128 -> 256 (Connect with Stem [64])
        self.up1 = nn.ConvTranspose2d(64, 64, 2, 2)
        self.dec1 = block(64+64, 64)

        # 5. 256 -> 512 (Final Upsample)
        self.up_final = nn.ConvTranspose2d(64, 32, 2, 2)
        self.dec_final = block(32, 32)

        self.final = nn.Conv2d(32, 4, 1)

    def forward(self, enc_feats):
        f0, f1, f2, f3, f4 = enc_feats[0], enc_feats[1], enc_feats[2], enc_feats[3], enc_feats[4]

        u4 = self.up4(f4)
        if u4.shape != f3.shape: u4 = F.interpolate(u4, size=f3.shape[2:])
        d4 = self.dec4(torch.cat([f3, u4], 1))

        u3 = self.up3(d4)
        if u3.shape != f2.shape: u3 = F.interpolate(u3, size=f2.shape[2:])
        d3 = self.dec3(torch.cat([f2, u3], 1))

        u2 = self.up2(d3)
        if u2.shape != f1.shape: u2 = F.interpolate(u2, size=f1.shape[2:])
        d2 = self.dec2(torch.cat([f1, u2], 1))

        u1 = self.up1(d2)
        if u1.shape != f0.shape: u1 = F.interpolate(u1, size=f0.shape[2:])
        d1 = self.dec1(torch.cat([f0, u1], 1))

        out = self.up_final(d1)
        out = self.dec_final(out)

        return self.final(out)

class UnifiedUNet(nn.Module):
    def __init__(self, mode="CNN"):
        super().__init__()
        self.mode = mode
        print(f"‚è≥ Initializing {mode} ResNet18 U-Net...")
        self.encoder = UnifiedResNetEncoder(mode)
        self.decoder = UnifiedDecoder(mode)

    def forward(self, x):
        if self.mode == "SNN":
            utils.reset(self)
            spk_rec = []
            for step in range(CONFIG['time_steps']):
                enc_feats = self.encoder(x)
                out = self.decoder(enc_feats)
                spk_rec.append(out)
            return torch.stack(spk_rec).mean(0)
        else:
            enc_feats = self.encoder(x)
            return self.decoder(enc_feats)

# ==========================================
# 4. TRAINING & UTILS
# ==========================================
def manual_reset(model):
    for m in model.modules():
        if hasattr(m, "reset_mem"): m.reset_mem()

def save_vis(history, sample_vis, epoch, mode):
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1); plt.plot(history['loss']); plt.title(f"{mode} Loss")
    plt.subplot(1, 2, 2); plt.plot(history['mcc']); plt.title(f"{mode} MCC")
    plt.savefig(f"{CONFIG['project_dir']}/{mode}_history.png"); plt.close()

    img, gt, pred = sample_vis
    rgb = img[[3,2,1]].transpose(1,2,0)
    rgb = (rgb - rgb.min()) / (rgb.max() - rgb.min() + 1e-6)
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 3, 1); plt.imshow(rgb); plt.title("Input")
    plt.subplot(1, 3, 2); plt.imshow(gt, cmap='nipy_spectral'); plt.title("Ground Truth")
    plt.subplot(1, 3, 3); plt.imshow(pred, cmap='nipy_spectral'); plt.title(f"{mode} Pred")
    plt.savefig(f"{CONFIG['project_dir']}/{mode}_sample.png"); plt.close()

model = UnifiedUNet(mode=CONFIG['model_type']).to(CONFIG['device'])
optimizer = optim.AdamW(model.parameters(), lr=CONFIG['lr'])
weights = torch.tensor([0.2, 1.0, 1.0, 3.0]).to(CONFIG['device'])
criterion = nn.CrossEntropyLoss(weight=weights)
scaler = GradScaler() # Optional for SNN, good for CNN

best_mcc = -1.0
history = {'loss': [], 'mcc': []}

print(f"üî• Starting {CONFIG['model_type']} Training...")

for epoch in range(CONFIG['epochs']):
    model.train()
    running_loss = 0

    for imgs, masks in tqdm(train_loader, desc=f"Ep {epoch+1}"):
        imgs, masks = imgs.to(CONFIG['device']), masks.to(CONFIG['device']).long()

        if CONFIG['model_type'] == "SNN": manual_reset(model)

        optimizer.zero_grad()

        if CONFIG['model_type'] == "CNN":
            with autocast():
                out = model(imgs)
                loss = criterion(out, masks)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            out = model(imgs)
            loss = criterion(out, masks)
            loss.backward()
            optimizer.step()

        running_loss += loss.item()

    # Val
    model.eval()
    preds, targets = [], []
    with torch.no_grad():
        for i, (imgs, masks) in enumerate(val_loader):
            imgs = imgs.to(CONFIG['device'])
            if CONFIG['model_type'] == "SNN": manual_reset(model)

            with autocast(): out = model(imgs)
            preds.append(out.argmax(1).cpu())
            targets.append(masks.cpu())

            if i == 0: sample_vis = (imgs[0].cpu().numpy(), masks[0].cpu().numpy(), preds[-1][0].numpy())

    mcc = matthews_corrcoef(torch.cat(targets).numpy().flatten(), torch.cat(preds).numpy().flatten())
    history['mcc'].append(mcc)
    history['loss'].append(running_loss/len(train_loader))

    print(f"Ep {epoch+1} | Loss: {history['loss'][-1]:.4f} | Val MCC: {mcc:.4f}")
    save_vis(history, sample_vis, epoch+1, CONFIG['model_type'])

    if mcc > best_mcc:
        best_mcc = mcc
        torch.save(model.state_dict(), f"{CONFIG['project_dir']}/best_{CONFIG['model_type']}_ResNet.pth")
        print("‚úÖ Saved Best Model!")

print(f"üèÅ Finished. Best MCC: {best_mcc:.4f}")

‚è≥ Initializing SNN ResNet18 U-Net...


  scaler = GradScaler() # Optional for SNN, good for CNN


üî• Starting SNN Training...


  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
Ep 1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 5/5 [00:12<00:00,  2.42s/it]
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  with autocast(): out = model(imgs)


Ep 1 | Loss: 1.2651 | Val MCC: 0.1101
‚úÖ Saved Best Model!


  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
Ep 2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 5/5 [00:09<00:00,  1.83s/it]
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  with autocast(): out = model(imgs)


Ep 2 | Loss: 1.2233 | Val MCC: 0.2421
‚úÖ Saved Best Model!


  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
Ep 3: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 5/5 [00:09<00:00,  1.92s/it]
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  with autocast(): out = model(imgs)


Ep 3 | Loss: 1.1931 | Val MCC: 0.2804
‚úÖ Saved Best Model!


  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
Ep 4: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 5/5 [00:09<00:00,  1.89s/it]
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  with autocast(): out = model(imgs)


Ep 4 | Loss: 1.1675 | Val MCC: 0.2963
‚úÖ Saved Best Model!


  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
Ep 5: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 5/5 [00:09<00:00,  1.90s/it]
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  with autocast(): out = model(imgs)


Ep 5 | Loss: 1.1343 | Val MCC: 0.3069
‚úÖ Saved Best Model!


  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
Ep 6: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 5/5 [00:09<00:00,  1.88s/it]
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  with autocast(): out = model(imgs)


Ep 6 | Loss: 1.1143 | Val MCC: 0.3114
‚úÖ Saved Best Model!


  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
Ep 7: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 5/5 [00:09<00:00,  1.88s/it]
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  with autocast(): out = model(imgs)


Ep 7 | Loss: 1.0988 | Val MCC: 0.3153
‚úÖ Saved Best Model!


  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
Ep 8: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 5/5 [00:09<00:00,  1.89s/it]
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  with autocast(): out = model(imgs)


Ep 8 | Loss: 1.0794 | Val MCC: 0.3178
‚úÖ Saved Best Model!


  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
Ep 9: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 5/5 [00:10<00:00,  2.02s/it]
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  with autocast(): out = model(imgs)


Ep 9 | Loss: 1.0626 | Val MCC: 0.3198
‚úÖ Saved Best Model!


  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
Ep 10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 5/5 [00:09<00:00,  1.91s/it]
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  with autocast(): out = model(imgs)


Ep 10 | Loss: 1.0581 | Val MCC: 0.3194


  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
Ep 11: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 5/5 [00:09<00:00,  1.91s/it]
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  with autocast(): out = model(imgs)


Ep 11 | Loss: 1.0369 | Val MCC: 0.3223
‚úÖ Saved Best Model!


  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
Ep 12: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 5/5 [00:09<00:00,  1.92s/it]
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  with autocast(): out = model(imgs)


Ep 12 | Loss: 1.0354 | Val MCC: 0.3238
‚úÖ Saved Best Model!


  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
Ep 13: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 5/5 [00:09<00:00,  1.93s/it]
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  with autocast(): out = model(imgs)


Ep 13 | Loss: 1.0133 | Val MCC: 0.3269
‚úÖ Saved Best Model!


  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
Ep 14: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 5/5 [00:09<00:00,  1.94s/it]
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  with autocast(): out = model(imgs)


Ep 14 | Loss: 1.0210 | Val MCC: 0.3280
‚úÖ Saved Best Model!


  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
Ep 15: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 5/5 [00:09<00:00,  1.93s/it]
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  with autocast(): out = model(imgs)


Ep 15 | Loss: 0.9977 | Val MCC: 0.3122


  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
Ep 16: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 5/5 [00:09<00:00,  1.90s/it]
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  with autocast(): out = model(imgs)


Ep 16 | Loss: 0.9859 | Val MCC: 0.3159


  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
Ep 17: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 5/5 [00:09<00:00,  1.90s/it]
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  with autocast(): out = model(imgs)


Ep 17 | Loss: 0.9860 | Val MCC: 0.3190


  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
Ep 18: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 5/5 [00:09<00:00,  1.98s/it]
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  with autocast(): out = model(imgs)


Ep 18 | Loss: 0.9606 | Val MCC: 0.3142


  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
Ep 19: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 5/5 [00:09<00:00,  1.97s/it]
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  with autocast(): out = model(imgs)


Ep 19 | Loss: 0.9454 | Val MCC: 0.3263


  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
Ep 20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 5/5 [00:09<00:00,  1.90s/it]
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  with autocast(): out = model(imgs)


Ep 20 | Loss: 0.9646 | Val MCC: 0.3294
‚úÖ Saved Best Model!


  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
Ep 21: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 5/5 [00:09<00:00,  1.95s/it]
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  with autocast(): out = model(imgs)


Ep 21 | Loss: 0.9294 | Val MCC: 0.3301
‚úÖ Saved Best Model!


  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
Ep 22: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 5/5 [00:09<00:00,  1.94s/it]
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  with autocast(): out = model(imgs)


Ep 22 | Loss: 0.9486 | Val MCC: 0.3398
‚úÖ Saved Best Model!


  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
Ep 23: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 5/5 [00:09<00:00,  1.91s/it]
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  with autocast(): out = model(imgs)


Ep 23 | Loss: 0.9284 | Val MCC: 0.3354


  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
Ep 24: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 5/5 [00:09<00:00,  1.99s/it]
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  with autocast(): out = model(imgs)


Ep 24 | Loss: 0.9124 | Val MCC: 0.3293


  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
Ep 25: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 5/5 [00:09<00:00,  1.91s/it]
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  with autocast(): out = model(imgs)


Ep 25 | Loss: 0.9242 | Val MCC: 0.3250


  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
Ep 26: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 5/5 [00:09<00:00,  1.90s/it]
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  with autocast(): out = model(imgs)


Ep 26 | Loss: 0.9073 | Val MCC: 0.3346


  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
Ep 27: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 5/5 [00:09<00:00,  1.98s/it]
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  with autocast(): out = model(imgs)


Ep 27 | Loss: 0.9139 | Val MCC: 0.3397


  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
Ep 28: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 5/5 [00:10<00:00,  2.05s/it]
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  with autocast(): out = model(imgs)


Ep 28 | Loss: 0.9097 | Val MCC: 0.3463
‚úÖ Saved Best Model!


  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
Ep 29: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 5/5 [00:09<00:00,  1.94s/it]
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  with autocast(): out = model(imgs)


Ep 29 | Loss: 0.8952 | Val MCC: 0.3556
‚úÖ Saved Best Model!


  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
Ep 30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 5/5 [00:09<00:00,  1.94s/it]
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  return aug["image"].float(), torch.tensor(aug["mask"]).long()
  with autocast(): out = model(imgs)


Ep 30 | Loss: 0.8946 | Val MCC: 0.3265
üèÅ Finished. Best MCC: 0.3556
