In [2]:
# ===== 0) Config =====
from pathlib import Path
import re, random, os, math
import numpy as np
from PIL import Image
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
print('finished imports')

finished imports


In [3]:
import multiprocessing
torch.set_num_threads(multiprocessing.cpu_count())  # e.g., 16
torch.set_num_interop_threads(4)
torch.set_float32_matmul_precision('medium')

  _C._set_float32_matmul_precision(precision)


In [4]:
root = Path(r"C:\Users\garvi\.cache\kagglehub\datasets\sorour\95cloud-cloud-segmentation-on-satellite-images\versions\3\95-cloud_training_only_additional_to38-cloud")
print('path saved to root')

path saved to root


In [5]:
# exact subfolders as seen in your output
RED   = root / "train_red_additional_to38cloud"
GREEN = root / "train_green_additional_to38cloud"
BLUE  = root / "train_blue_additional_to38cloud"
NIR   = root / "train_nir_additional_to38cloud"
GT    = root / "train_gt_additional_to38cloud"
print('data loaded')

data loaded


In [6]:
assert RED.exists() and GREEN.exists() and BLUE.exists() and NIR.exists() and GT.exists(), "One or more band/GT folders missing."

In [7]:
# ===== 1) Pair files by base key (strip '_red/_green/_nir/_gt' tokens) =====
IMG_EXT = (".tif",".tiff",".png",".jpg",".jpeg")
MSK_EXT = (".tif",".tiff",".png")

In [8]:
_token = re.compile(r"(?:^|[_\-])(red|green|blue|nir|gt|mask|label)(?:$|[_\-])", re.I)
def norm_stem(p: Path) -> str:
    s = p.stem
    s = _token.sub("_", s)
    s = re.sub(r"[_\-]+", "_", s).strip("_-")
    return s.lower()

def list_files(folder, exts):
    return [p for p in folder.iterdir() if p.is_file() and p.suffix.lower() in exts]

R = list_files(RED, IMG_EXT)
G = list_files(GREEN, IMG_EXT)
B = list_files(BLUE, IMG_EXT)
N = list_files(NIR, IMG_EXT)
M = list_files(GT,  MSK_EXT)

def index(files):
    d = {}
    for p in files:
        d.setdefault(norm_stem(p), []).append(p)
    return d

In [9]:
R_idx, G_idx, B_idx, N_idx, M_idx = map(index, (R,G,B,N,M))
keys = set(R_idx) & set(G_idx) & set(B_idx) & set(N_idx) & set(M_idx)
keys = sorted(list(keys))
print(f"Found {len(keys)} paired chips.")

Found 26301 paired chips.


In [10]:
pairs = []
for k in keys:
    pairs.append({
        "r": R_idx[k][0], "g": G_idx[k][0], "b": B_idx[k][0], "n": N_idx[k][0], "m": M_idx[k][0]
    })

In [11]:
# ===== 2) Train/val split =====
random.seed(42)
random.shuffle(pairs)
val_frac = 0.2
n_val = int(len(pairs)*val_frac)
val_samples = pairs[:n_val]
train_samples = pairs[n_val:]
print(f"Train: {len(train_samples)}  Val: {len(val_samples)}")

Train: 21041  Val: 5260


In [12]:
# ===== 3) Dataset =====
class Cloud95Dataset(Dataset):
    def __init__(self, items, tilesize=None, augment=False):
        self.items = items
        self.tilesize = tilesize
        self.augment = augment

    def __len__(self): return len(self.items)

    # def _read_band(self, p: Path):
    #     arr = np.array(Image.open(p))
    #     arr = arr.astype(np.float32)
    #     # Normalize: if looks like 8-bit, /255; else use /10000 (typical S2 scaling)
    #     maxv = arr.max() if arr.size else 255
    #     if maxv <= 255: arr = arr / 255.0
    #     else: arr = np.clip(arr / 10000.0, 0.0, 1.5)
    #     return arr

    def _read_band(self, p: Path):
        arr = np.array(Image.open(p), dtype=np.uint8).astype(np.float32)
        return arr / 255.0

    

    def __getitem__(self, i):
        rec = self.items[i]
        R = self._read_band(rec["r"])
        G = self._read_band(rec["g"])
        B = self._read_band(rec["b"])
        N = self._read_band(rec["n"])
        M = np.array(Image.open(rec["m"]).convert("L")).astype(np.uint8)
        M = (M > 0).astype(np.float32)

        # Optional random 512 crop (most are 512 already; keep robust)
        H, W = R.shape
        if self.tilesize and (H >= self.tilesize and W >= self.tilesize):
            s = self.tilesize
            # x = 0 if W == s else random.randint(0, W - s)
            # y = 0 if H == s else random.randint(0, H - s)
            # R, G, B, N, M = R[y:y+s, x:x+s], G[y:y+s, x:x+s], B[y:y+s, x:x+s], N[y:y+s, x:x+s], M[y:y+s, x:x+s]

            if self.augment:
                x = 0 if W == s else random.randint(0, W - s)
                y = 0 if H == s else random.randint(0, H - s)
            else:
                x = (W - s) // 2
                y = (H - s) // 2
                
            R, G, B, N, M = R[y:y+s, x:x+s], G[y:y+s, x:x+s], B[y:y+s, x:x+s], N[y:y+s, x:x+s], M[y:y+s, x:x+s]
        # Simple flips as light augmentation
        if self.augment and random.random() < 0.5:
            R, G, B, N, M = np.fliplr(R), np.fliplr(G), np.fliplr(B), np.fliplr(N), np.fliplr(M)
        if self.augment and random.random() < 0.5:
            R, G, B, N, M = np.flipud(R), np.flipud(G), np.flipud(B), np.flipud(N), np.flipud(M)

        img = np.stack([R,G,B,N], axis=0).astype(np.float32)        # (4,H,W)
        msk = M[None, ...].astype(np.float32)                        # (1,H,W)
        return torch.from_numpy(img), torch.from_numpy(msk)
        

In [13]:
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch, groups=8):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
            nn.GroupNorm(groups, out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
            nn.GroupNorm(groups, out_ch),
            nn.ReLU(inplace=True),
        )
    def forward(self, x): return self.net(x)

class UNetSmall(nn.Module):
    def __init__(self, in_ch=4, out_ch=1):
        super().__init__()
        # slimmer channels help a lot on CPU
        ch = [24, 48, 96, 192]
        self.down1 = DoubleConv(in_ch, ch[0])
        self.pool1 = nn.MaxPool2d(2)
        self.down2 = DoubleConv(ch[0], ch[1])
        self.pool2 = nn.MaxPool2d(2)
        self.down3 = DoubleConv(ch[1], ch[2])
        self.pool3 = nn.MaxPool2d(2)
        self.bottom = DoubleConv(ch[2], ch[3])

        def up(in_c, skip_c, out_c):
            return nn.ModuleDict({
                "up": nn.Sequential(nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
                                    nn.Conv2d(in_c, out_c, 1, bias=False)),
                "dec": DoubleConv(out_c + skip_c, out_c)
            })
        self.up3 = up(ch[3], ch[2], ch[2])
        self.up2 = up(ch[2], ch[1], ch[1])
        self.up1 = up(ch[1], ch[0], ch[0])
        self.head = nn.Conv2d(ch[0], out_ch, 1)

    def forward(self, x):
        c1 = self.down1(x); p1 = self.pool1(c1)
        c2 = self.down2(p1); p2 = self.pool2(c2)
        c3 = self.down3(p2); p3 = self.pool3(c3)
        cb = self.bottom(p3)

        u3 = self.up3["up"](cb); d3 = self.up3["dec"](torch.cat([u3, c3], 1))
        u2 = self.up2["up"](d3); d2 = self.up2["dec"](torch.cat([u2, c2], 1))
        u1 = self.up1["up"](d2); d1 = self.up1["dec"](torch.cat([u1, c1], 1))
        return self.head(d1)

In [14]:
# ===== 4) U-Net (small) =====
# class DoubleConv(nn.Module):
#     def __init__(self, in_ch, out_ch):
#         super().__init__()
#         self.net = nn.Sequential(
#             nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True),
#             nn.Conv2d(out_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True),
#         )
#     def forward(self, x): return self.net(x)

# class UNetSmall(nn.Module):
#     def __init__(self, in_ch=4, out_ch=1):
#         super().__init__()
#         self.down1 = DoubleConv(in_ch, 32)
#         self.pool1 = nn.MaxPool2d(2)
#         self.down2 = DoubleConv(32, 64)
#         self.pool2 = nn.MaxPool2d(2)
#         self.down3 = DoubleConv(64, 128)
#         self.pool3 = nn.MaxPool2d(2)
#         self.bottom = DoubleConv(128, 256)
#         self.up3 = nn.ConvTranspose2d(256, 128, 2, 2)
#         self.dec3 = DoubleConv(256, 128)
#         self.up2 = nn.ConvTranspose2d(128, 64, 2, 2)
#         self.dec2 = DoubleConv(128, 64)
#         self.up1 = nn.ConvTranspose2d(64, 32, 2, 2)
#         self.dec1 = DoubleConv(64, 32)
#         self.head = nn.Conv2d(32, out_ch, 1)

    # def forward(self, x):
    #     c1 = self.down1(x)         # (B,32,H,W)
    #     p1 = self.pool1(c1)        # (B,32,H/2,W/2)
    #     c2 = self.down2(p1)        # (B,64, ...)
    #     p2 = self.pool2(c2)
    #     c3 = self.down3(p2)        # (B,128,...)
    #     p3 = self.pool3(c3)
    #     cb = self.bottom(p3)       # (B,256,...)
    #     u3 = self.up3(cb)          # -> match c3
    #     d3 = self.dec3(torch.cat([u3, c3], dim=1))
    #     u2 = self.up2(d3)
    #     d2 = self.dec2(torch.cat([u2, c2], dim=1))
    #     u1 = self.up1(d2)
    #     d1 = self.dec1(torch.cat([u1, c1], dim=1))
    #     return self.head(d1)       # logits (B,1,H,W)

In [15]:
# ===== 5) Loss, metrics, training =====
class DiceLoss(nn.Module):
    def __init__(self, eps=1e-6): super().__init__(); self.eps = eps
    def forward(self, logits, targets):
        probs = torch.sigmoid(logits)
        num = 2 * (probs*targets).sum(dim=(1,2,3))
        den = probs.sum(dim=(1,2,3)) + targets.sum(dim=(1,2,3)) + self.eps
        return (1 - (num + self.eps)/(den + self.eps)).mean()

def iou_binary(logits, targets, thr=0.5):
    p = (torch.sigmoid(logits) > thr).float()
    inter = (p*targets).sum(dim=(1,2,3))
    union = (p + targets - p*targets).sum(dim=(1,2,3))
    return ((inter+1e-6)/(union+1e-6)).mean().item()

In [16]:
# Datasets & loaders
tile = 128  # chips are typically 512; keep same
train_ds = Cloud95Dataset(train_samples, tilesize=tile, augment=True)
val_ds   = Cloud95Dataset(val_samples,   tilesize=tile, augment=False)
# train_ds.tilesize = 256
# val_ds.tilesize   = 256

train_dl = DataLoader(train_ds, batch_size=4, shuffle=True,  num_workers=2, pin_memory=True)
val_dl   = DataLoader(val_ds,   batch_size=4, shuffle=False, num_workers=2, pin_memory=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNetSmall(in_ch=4, out_ch=1).to(device)
bce, dice = nn.BCEWithLogitsLoss(), DiceLoss()
criterion = nn.BCEWithLogitsLoss() #comment
opt = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=20)

best_iou, best_path = 0.0, "unet_cloud95.pt"

In [17]:
import time
t0 = time.time()
print("len(train_ds) =", len(train_ds))
img, msk = train_ds[0]   # direct call bypasses DataLoader workers
print("one sample shapes:", img.shape, msk.shape, "took", round(time.time()-t0, 2), "s")

len(train_ds) = 21041
one sample shapes: torch.Size([4, 128, 128]) torch.Size([1, 128, 128]) took 0.29 s


In [18]:
from torch.utils.data import DataLoader

train_dl = DataLoader(
    train_ds,
    batch_size=8,            # smaller batch to start
    shuffle=True,
    num_workers=0,           # <-- key
    pin_memory=False,        # pinning helps GPU, can slow CPU-only
    persistent_workers=False # <-- key
)
val_dl = DataLoader(
    val_ds,                  # same dataset as before
    batch_size=8,           # same batch size
    shuffle=False,
    num_workers=0,           # <-- disable multiprocessing
    pin_memory=False,        # <-- you're likely on CPU
    persistent_workers=False # <-- avoid stale worker pool
)

In [19]:
print('begin training')
model.train()
t0 = time.time()

for batch_i, (imgs, masks) in enumerate(train_dl, 1):
    print(f"\rbatch {batch_i}", end='', flush=True)  # prints as soon as first batch arrives
    imgs, masks = imgs.to(device), masks.to(device)
    logits = model(imgs)
    loss = 0.5*bce(logits, masks) + 0.5*dice(logits, masks)
    opt.zero_grad(); loss.backward(); opt.step()

print("\nend of epoch in", round(time.time()-t0, 1), "s")
torch.save({"model": model.state_dict()})

begin training
batch 2631
end of epoch in 4862.3 s


In [None]:
print('r')

In [None]:
print('begin evaluation')
model.eval()
t0 = time.time()
v_loss = 0.0; v_iou = 0.0; n = 0
with torch.no_grad():
    i=1
    for imgs, masks in val_dl:
        print(f"\rbatch {i}", end='', flush=True)
        imgs, masks = imgs.to(device), masks.to(device)
        logits = model(imgs)
        v_loss += (0.5*bce(logits, masks) + 0.5*dice(logits, masks)).item()
        v_iou  += iou_binary(logits, masks)
        n += 1
        i+=1
v_loss /= n; v_iou /= n
sched.step()
print(f"Epoch {ep:02d}  val_loss={v_loss:.4f}  val_IoU={v_iou:.4f}")
print("\nend of epoch in", round(time.time()-t0, 1), "s")

begin evaluation




In [None]:
epochs = 20
for ep in range(1, epochs+1):
    # --- train ---
    model.train()
    for imgs, masks in train_dl:
        imgs, masks = imgs.to(device), masks.to(device)
        logits = model(imgs)
        loss = 0.5*bce(logits, masks) + 0.5*dice(logits, masks)
        opt.zero_grad(); loss.backward(); opt.step()

    # --- validate ---
    model.eval()
    v_loss = 0.0; v_iou = 0.0; n = 0
    with torch.no_grad():
        for imgs, masks in val_dl:
            imgs, masks = imgs.to(device), masks.to(device)
            logits = model(imgs)
            v_loss += (0.5*bce(logits, masks) + 0.5*dice(logits, masks)).item()
            v_iou  += iou_binary(logits, masks)
            n += 1
    v_loss /= n; v_iou /= n
    sched.step()
    print(f"Epoch {ep:02d}  val_loss={v_loss:.4f}  val_IoU={v_iou:.4f}")

    if v_iou > best_iou:
        best_iou = v_iou
        torch.save({"model": model.state_dict()}, best_path)

print(f"Best IoU: {best_iou:.4f}  -> saved {best_path}")

In [None]:
# ===== 6) Load best model and evaluate on validation set =====
print("\nLoading best model for final evaluation...")
checkpoint = torch.load(best_path, map_location=device)
model.load_state_dict(checkpoint["model"])
model.eval()

ious, dices = [], []
with torch.no_grad():
    for imgs, masks in val_dl:
        imgs, masks = imgs.to(device), masks.to(device)
        logits = model(imgs)
        probs = torch.sigmoid(logits)
        preds = (probs > 0.5).float()
        inter = (preds * masks).sum(dim=(1,2,3))
        union = (preds + masks - preds*masks).sum(dim=(1,2,3))
        dice  = (2*inter / (preds.sum(dim=(1,2,3)) + masks.sum(dim=(1,2,3)) + 1e-6))
        iou   = (inter + 1e-6) / (union + 1e-6)
        ious.extend(iou.cpu().numpy())
        dices.extend(dice.cpu().numpy())

final_iou = np.mean(ious)
final_dice = np.mean(dices)
print(f"✅ Final Validation IoU:  {final_iou:.4f}")
print(f"✅ Final Validation Dice: {final_dice:.4f}")


In [17]:
load_t = fwd_t = bwd_t = 0.0
N = 50  # measure first 50 batches

it = iter(train_dl)
for i in range(1, N+1):
    t0 = time.time()
    imgs, masks = next(it)     # load
    load_t += time.time() - t0

    imgs, masks = imgs.to(device), masks.to(device)
    t1 = time.time()
    logits = model(imgs)       # forward
    loss = 0.5*bce(logits, masks) + 0.5*dice(logits, masks)
    fwd_t += time.time() - t1

    t2 = time.time()
    opt.zero_grad(); loss.backward(); opt.step()  # backward
    bwd_t += time.time() - t2

print(f"avg load  : {load_t/N:.3f}s  | avg fwd : {fwd_t/N:.3f}s  | avg bwd : {bwd_t/N:.3f}s")

avg load  : 0.214s  | avg fwd : 0.913s  | avg bwd : 1.851s


In [None]:
model.train()
load_t=fwd_t=bwd_t=0.0
it = iter(train_dl)
N = min(50, len(train_dl))
for i in range(1, N+1):
    t0=time.time(); imgs, masks = next(it); load_t += time.time()-t0
    t1=time.time(); logits = model(imgs); loss = criterion(logits, masks); fwd_t += time.time()-t1
    t2=time.time(); opt.zero_grad(); loss.backward(); opt.step(); bwd_t += time.time()-t2
print(f"avg load: {load_t/N:.3f}s | avg fwd: {fwd_t/N:.3f}s | avg bwd: {bwd_t/N:.3f}s")

