In [None]:
from pathlib import Path
import re, random, os, math
import numpy as np
from PIL import Image

import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from Vit import CloudSegSpectralViT

import multiprocessing

import pickle

In [None]:
torch.set_num_threads(multiprocessing.cpu_count())  # e.g., 16
torch.set_num_interop_threads(4)
torch.set_float32_matmul_precision('medium')

In [None]:
with open("train_val_split.pkl", "rb") as f:
    data = pickle.load(f)
train_samples, val_samples = data["train"], data["val"]

In [None]:
print(len(train_samples), len(val_samples))

In [None]:
class Cloud95Dataset(Dataset):
    
    def __init__(self, items, tilesize=None, augment=False):
        self.items = items
        self.tilesize = tilesize
        self.augment = augment

    def __len__(self): 
        return len(self.items)

    def _read_band(self, p: Path):
        arr = np.array(Image.open(p), dtype=np.uint8).astype(np.float32)
        return arr / 255.0

    def __getitem__(self, i):
        rec = self.items[i]
        R = self._read_band(rec["r"])
        G = self._read_band(rec["g"])
        B = self._read_band(rec["b"])
        N = self._read_band(rec["n"])
        M = np.array(Image.open(rec["m"]).convert("L")).astype(np.uint8)
        M = (M > 0).astype(np.float32)

        # Optional random 512 crop (most are 512 already; keep robust)
        H, W = R.shape
        if self.tilesize and (H >= self.tilesize and W >= self.tilesize):
            s = self.tilesize
            # x = 0 if W == s else random.randint(0, W - s)
            # y = 0 if H == s else random.randint(0, H - s)
            # R, G, B, N, M = R[y:y+s, x:x+s], G[y:y+s, x:x+s], B[y:y+s, x:x+s], N[y:y+s, x:x+s], M[y:y+s, x:x+s]

            if self.augment:
                x = 0 if W == s else random.randint(0, W - s)
                y = 0 if H == s else random.randint(0, H - s)
            else:
                x = (W - s) // 2
                y = (H - s) // 2
                
            R, G, B, N, M = R[y:y+s, x:x+s], G[y:y+s, x:x+s], B[y:y+s, x:x+s], N[y:y+s, x:x+s], M[y:y+s, x:x+s]
        
        
        # Simple flips as light augmentation
        if self.augment and random.random() < 0.5:
            R, G, B, N, M = np.fliplr(R), np.fliplr(G), np.fliplr(B), np.fliplr(N), np.fliplr(M)
        if self.augment and random.random() < 0.5:
            R, G, B, N, M = np.flipud(R), np.flipud(G), np.flipud(B), np.flipud(N), np.flipud(M)

        img = np.stack([R,G,B,N], axis=0).astype(np.float32)        # (4,H,W)
        msk = M[None, ...].astype(np.float32)                        # (1,H,W)

        
        return torch.from_numpy(img), torch.from_numpy(msk)

In [None]:
class DiceLoss(nn.Module):
    def __init__(self, eps=1e-6): super().__init__(); self.eps = eps
    def forward(self, logits, targets):
        probs = torch.sigmoid(logits)
        num = 2 * (probs*targets).sum(dim=(1,2,3))
        den = probs.sum(dim=(1,2,3)) + targets.sum(dim=(1,2,3)) + self.eps
        return (1 - (num + self.eps)/(den + self.eps)).mean()

def iou_binary(logits, targets, thr=0.5):
    p = (torch.sigmoid(logits) > thr).float()
    inter = (p*targets).sum(dim=(1,2,3))
    union = (p + targets - p*targets).sum(dim=(1,2,3))
    return ((inter+1e-6)/(union+1e-6)).mean().item()

In [None]:
tilesize = 128 
train_ds = Cloud95Dataset(
    train_samples, 
    tilesize=tilesize, 
    augment=True
)
train_dl = DataLoader(
    train_ds,
    batch_size=8,            # smaller batch to start
    shuffle=True,
    num_workers=0,           # <-- key
    pin_memory=False,        # pinning helps GPU, can slow CPU-only
    persistent_workers=False # <-- key
)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = CloudSegSpectralViT(
    in_ch=4,            # RGB+NIR as in Cloud95
    num_classes=1,      # or 2 if you prefer CE
    patch_size=(8,8),   # matches paper
    spec_group=3,       # 3D token grouping; works with 4ch (pads 1 ch internally)
    depth=8, embed_dim=384, num_heads=6  # lighter starting point
).to(device)

bce, dice = nn.BCEWithLogitsLoss(), DiceLoss()

opt = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)

sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=20)

best_iou, best_path = 0.0, "unet_cloud95.pt"

In [None]:
model.train()
for imgs, masks in train_loader:         # imgs: (B,4,H,W), masks: (B,1,H,W) or (B,H,W)
    imgs, masks = imgs.to(device), masks.to(device)

    # (ViT only) pad to 8x
    imgs, ph, pw = pad_to_multiple(imgs, 8)
    if model_out_is_binary:  # num_classes=1
        masks, _, _ = pad_to_multiple(masks.float(), 8)  # keep as float for BCE
    else:
        masks, _, _ = pad_to_multiple(masks, 8)

    logits = model(imgs)
    if model_out_is_binary:
        loss = criterion(logits, masks)  # BCEWithLogitsLoss expects same shape
    else:
        loss = criterion(logits, masks.squeeze(1).long())  # CE expects (B,H,W) long

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
