In [22]:
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from models import UNetSmall
import multiprocessing
import pickle

In [23]:
torch.set_num_threads(multiprocessing.cpu_count())  # e.g., 16
torch.set_num_interop_threads(4)
torch.set_float32_matmul_precision('medium')

RuntimeError: Error: cannot set number of interop threads after parallel work has started or set_num_interop_threads called

In [None]:

with open("train_val_split.pkl", "rb") as f:
    data = pickle.load(f)
train_samples, val_samples = data["train"], data["val"]


In [12]:
# ===== 3) Dataset =====
class Cloud95Dataset(Dataset):
    def __init__(self, items, tilesize=None, augment=False):
        self.items = items
        self.tilesize = tilesize
        self.augment = augment

    def __len__(self): return len(self.items)

    # def _read_band(self, p: Path):
    #     arr = np.array(Image.open(p))
    #     arr = arr.astype(np.float32)
    #     # Normalize: if looks like 8-bit, /255; else use /10000 (typical S2 scaling)
    #     maxv = arr.max() if arr.size else 255
    #     if maxv <= 255: arr = arr / 255.0
    #     else: arr = np.clip(arr / 10000.0, 0.0, 1.5)
    #     return arr

    def _read_band(self, p: Path):
        arr = np.array(Image.open(p), dtype=np.uint8).astype(np.float32)
        return arr / 255.0

    

    def __getitem__(self, i):
        rec = self.items[i]
        R = self._read_band(rec["r"])
        G = self._read_band(rec["g"])
        B = self._read_band(rec["b"])
        N = self._read_band(rec["n"])
        M = np.array(Image.open(rec["m"]).convert("L")).astype(np.uint8)
        M = (M > 0).astype(np.float32)

        # Optional random 512 crop (most are 512 already; keep robust)
        H, W = R.shape
        if self.tilesize and (H >= self.tilesize and W >= self.tilesize):
            s = self.tilesize
            # x = 0 if W == s else random.randint(0, W - s)
            # y = 0 if H == s else random.randint(0, H - s)
            # R, G, B, N, M = R[y:y+s, x:x+s], G[y:y+s, x:x+s], B[y:y+s, x:x+s], N[y:y+s, x:x+s], M[y:y+s, x:x+s]

            if self.augment:
                x = 0 if W == s else random.randint(0, W - s)
                y = 0 if H == s else random.randint(0, H - s)
            else:
                x = (W - s) // 2
                y = (H - s) // 2
                
            R, G, B, N, M = R[y:y+s, x:x+s], G[y:y+s, x:x+s], B[y:y+s, x:x+s], N[y:y+s, x:x+s], M[y:y+s, x:x+s]
        # Simple flips as light augmentation
        if self.augment and random.random() < 0.5:
            R, G, B, N, M = np.fliplr(R), np.fliplr(G), np.fliplr(B), np.fliplr(N), np.fliplr(M)
        if self.augment and random.random() < 0.5:
            R, G, B, N, M = np.flipud(R), np.flipud(G), np.flipud(B), np.flipud(N), np.flipud(M)

        img = np.stack([R,G,B,N], axis=0).astype(np.float32)        # (4,H,W)
        msk = M[None, ...].astype(np.float32)                        # (1,H,W)
        return torch.from_numpy(img), torch.from_numpy(msk)
        

In [13]:
# ===== 5) Loss, metrics, training =====
class DiceLoss(nn.Module):
    def __init__(self, eps=1e-6): super().__init__(); self.eps = eps
    def forward(self, logits, targets):
        probs = torch.sigmoid(logits)
        num = 2 * (probs*targets).sum(dim=(1,2,3))
        den = probs.sum(dim=(1,2,3)) + targets.sum(dim=(1,2,3)) + self.eps
        return (1 - (num + self.eps)/(den + self.eps)).mean()

def iou_binary(logits, targets, thr=0.5):
    p = (torch.sigmoid(logits) > thr).float()
    inter = (p*targets).sum(dim=(1,2,3))
    union = (p + targets - p*targets).sum(dim=(1,2,3))
    return ((inter+1e-6)/(union+1e-6)).mean().item()

In [14]:
# Datasets & loaders
tile = 128  # chips are typically 512; keep same
train_ds = Cloud95Dataset(train_samples, tilesize=tile, augment=True)
val_ds   = Cloud95Dataset(val_samples,   tilesize=tile, augment=False)
# train_ds.tilesize = 256
# val_ds.tilesize   = 256

# === Debug subset mode ===
DEBUG_FRACTION = 0.1  # load only 1/10th of dataset
if DEBUG_FRACTION < 1.0:
    from torch.utils.data import Subset
    import random
    n_train = int(len(train_ds) * DEBUG_FRACTION)
    n_val   = int(len(val_ds) * DEBUG_FRACTION)
    subset_train_idx = random.sample(range(len(train_ds)), n_train)
    subset_val_idx   = random.sample(range(len(val_ds)), n_val)
    train_ds = Subset(train_ds, subset_train_idx)
    val_ds   = Subset(val_ds, subset_val_idx)
    print(f"Debug mode: using {n_train}/{len(train_samples)} train and {n_val}/{len(val_samples)} val samples")

train_dl = DataLoader(train_ds, batch_size=4, shuffle=True,  num_workers=2, pin_memory=True)
val_dl   = DataLoader(val_ds,   batch_size=4, shuffle=False, num_workers=2, pin_memory=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNetSmall(in_ch=4, out_ch=1).to(device)
bce, dice = nn.BCEWithLogitsLoss(), DiceLoss()
criterion = nn.BCEWithLogitsLoss() #comment
opt = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=20)

best_iou, best_path = 0.0, "unet_cloud95.pt"

Debug mode: using 2104/21041 train and 526/5260 val samples


In [15]:
import time
t0 = time.time()
print("len(train_ds) =", len(train_ds))
img, msk = train_ds[0]   # direct call bypasses DataLoader workers
print("one sample shapes:", img.shape, msk.shape, "took", round(time.time()-t0, 2), "s")

len(train_ds) = 2104
one sample shapes: torch.Size([4, 128, 128]) torch.Size([1, 128, 128]) took 0.3 s


In [16]:
from torch.utils.data import DataLoader

train_dl = DataLoader(
    train_ds,
    batch_size=8,            # smaller batch to start
    shuffle=True,
    num_workers=0,           # <-- key
    pin_memory=False,        # pinning helps GPU, can slow CPU-only
    persistent_workers=False # <-- key
)

In [25]:
print('begin training')
model.train()
t0 = time.time()

for batch_i, (imgs, masks) in enumerate(train_dl, 1):
    print(f"\rbatch {batch_i}", end='', flush=True)  # prints as soon as first batch arrives
    imgs, masks = imgs.to(device), masks.to(device)
    logits = model(imgs)
    loss = 0.5*bce(logits, masks) + 0.5*dice(logits, masks)
    opt.zero_grad(); loss.backward(); opt.step()

print("\nend of epoch in", round(time.time()-t0, 1), "s")

begin training
batch 263
end of epoch in 284.4 s


In [26]:
# Save model weights (and optional optimizer/metadata)
epoch=1
torch.save({
    "epoch": epoch,
    "best_iou": best_iou,
    "model_state": model.state_dict(),
    "optimizer_state": opt.state_dict(),
}, best_path)

print(f"Model saved as {best_path}")


Model saved as unet_cloud95.pt
