In [1]:
from pathlib import Path
import re, random, os, math
import numpy as np
from PIL import Image

import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from models import UNetSmall

import multiprocessing

import pickle

In [2]:
torch.set_num_threads(multiprocessing.cpu_count())  # e.g., 16
torch.set_num_interop_threads(4)
torch.set_float32_matmul_precision('medium')

  _C._set_float32_matmul_precision(precision)


In [3]:
with open("train_val_split.pkl", "rb") as f:
    data = pickle.load(f)
train_samples, val_samples = data["train"], data["val"]

In [4]:
print(len(train_samples), len(val_samples))

21041 5260


In [5]:
class Cloud95Dataset(Dataset):
    
    def __init__(self, items, tilesize=None, augment=False):
        self.items = items
        self.tilesize = tilesize
        self.augment = augment

    def __len__(self): 
        return len(self.items)

    def _read_band(self, p: Path):
        arr = np.array(Image.open(p), dtype=np.uint8).astype(np.float32)
        return arr / 255.0

    def __getitem__(self, i):
        rec = self.items[i]
        R = self._read_band(rec["r"])
        G = self._read_band(rec["g"])
        B = self._read_band(rec["b"])
        N = self._read_band(rec["n"])
        M = np.array(Image.open(rec["m"]).convert("L")).astype(np.uint8)
        M = (M > 0).astype(np.float32)

        # Optional random 512 crop (most are 512 already; keep robust)
        H, W = R.shape
        if self.tilesize and (H >= self.tilesize and W >= self.tilesize):
            s = self.tilesize
            # x = 0 if W == s else random.randint(0, W - s)
            # y = 0 if H == s else random.randint(0, H - s)
            # R, G, B, N, M = R[y:y+s, x:x+s], G[y:y+s, x:x+s], B[y:y+s, x:x+s], N[y:y+s, x:x+s], M[y:y+s, x:x+s]

            if self.augment:
                x = 0 if W == s else random.randint(0, W - s)
                y = 0 if H == s else random.randint(0, H - s)
            else:
                x = (W - s) // 2
                y = (H - s) // 2
                
            R, G, B, N, M = R[y:y+s, x:x+s], G[y:y+s, x:x+s], B[y:y+s, x:x+s], N[y:y+s, x:x+s], M[y:y+s, x:x+s]
        
        
        # Simple flips as light augmentation
        if self.augment and random.random() < 0.5:
            R, G, B, N, M = np.fliplr(R), np.fliplr(G), np.fliplr(B), np.fliplr(N), np.fliplr(M)
        if self.augment and random.random() < 0.5:
            R, G, B, N, M = np.flipud(R), np.flipud(G), np.flipud(B), np.flipud(N), np.flipud(M)

        img = np.stack([R,G,B,N], axis=0).astype(np.float32)        # (4,H,W)
        msk = M[None, ...].astype(np.float32)                        # (1,H,W)

        
        return torch.from_numpy(img), torch.from_numpy(msk)

In [7]:
class DiceLoss(nn.Module):
    def __init__(self, eps=1e-6): super().__init__(); self.eps = eps
    def forward(self, logits, targets):
        probs = torch.sigmoid(logits)
        num = 2 * (probs*targets).sum(dim=(1,2,3))
        den = probs.sum(dim=(1,2,3)) + targets.sum(dim=(1,2,3)) + self.eps
        return (1 - (num + self.eps)/(den + self.eps)).mean()

def iou_binary(logits, targets, thr=0.5):
    p = (torch.sigmoid(logits) > thr).float()
    inter = (p*targets).sum(dim=(1,2,3))
    union = (p + targets - p*targets).sum(dim=(1,2,3))
    return ((inter+1e-6)/(union+1e-6)).mean().item()

In [8]:
tilesize = 128 
val_ds = Cloud95Dataset(
    val_samples, 
    tilesize=tilesize, 
    augment=False
)
val_dl = DataLoader(
    val_ds,
    batch_size=8,            # smaller batch to start
    shuffle=False,
    num_workers=0,           # <-- key
    pin_memory=False,        # pinning helps GPU, can slow CPU-only
    persistent_workers=False # <-- key
)

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = UNetSmall(in_ch=4, out_ch=1).to(device)

bce, dice = nn.BCEWithLogitsLoss(), DiceLoss()

opt = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)

sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=20)

best_iou, best_path = 0.0, "unet_cloud95.pt"

In [10]:
ckpt = torch.load(best_path, map_location=device)

state_dict = ckpt["model_state"] if isinstance(ckpt, dict) and "model_state" in ckpt else ckpt

model.load_state_dict(state_dict, strict=True)

<All keys matched successfully>

In [12]:

model.eval()
v_loss = 0.0; v_iou = 0.0; n = 0

with torch.no_grad():
    i=1
    for imgs, masks in val_dl:
        print(f"\rbatch {i}", end='', flush=True)
        imgs, masks = imgs.to(device), masks.to(device)
        logits = model(imgs)
        v_loss += (0.5*bce(logits, masks) + 0.5*dice(logits, masks)).item()
        v_iou  += iou_binary(logits, masks)
        n += 1
        i+=1
v_loss /= n; v_iou /= n
sched.step()

print()
print(f"val_loss={v_loss:.4f}  val_IoU={v_iou:.4f}")

batch 658
val_loss=0.5943  val_IoU=0.6784


