In [1]:
# train_ViT_train.py
"""
Training script for CloudSegSpectralViT.
Place this file in same folder as Vit.py and data_loader.py and run:
    python train_ViT_train.py
"""

'\nTraining script for CloudSegSpectralViT.\nPlace this file in same folder as Vit.py and data_loader.py and run:\n    python train_ViT_train.py\n'

In [2]:
import os
from pathlib import Path
import pickle
import random
import time
from typing import List, Dict

In [3]:
import numpy as np
from PIL import Image

In [4]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim

In [5]:
# ---- adjust these if needed ----
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 8           # change to fit GPU memory
NUM_EPOCHS = 2
LR = 1e-4
WEIGHT_DECAY = 1e-5
NUM_WORKERS = 0
PATCH_DIV = 8            # model patch size in Vit.py default (8) -> ensure image dims divisible by this
CHECKPOINT_DIR = Path("checkpoints")
CHECKPOINT_DIR.mkdir(exist_ok=True)
# --------------------------------

In [6]:
# import your model (do not modify Vit.py)
from Vit import CloudSegSpectralViT

In [7]:
# ---------------------------------------------------------
# Dataset: reads the dict produced by your data_loader.py
# ---------------------------------------------------------
class Cloud95Dataset(Dataset):
    def __init__(self, samples: List[Dict], transforms=None, crop_size=None):
        """
        samples: list of dicts with keys "r","g","b","n","m" where values are Path objects / strings.
        transforms: optional callable(image_np, mask_np) -> (image_np, mask_np)
        crop_size: optional tuple (H, W) to randomly crop tiles (must be divisible by PATCH_DIV)
        """
        self.samples = samples
        self.transforms = transforms
        self.crop_size = crop_size

    def __len__(self):
        return len(self.samples)

    def _open_band(self, p):
        if isinstance(p, str):
            p = Path(p)
        im = Image.open(p)
        arr = np.array(im).astype(np.float32)
        return arr

    def __getitem__(self, idx):
        s = self.samples[idx]
        r = self._open_band(s["r"])
        g = self._open_band(s["g"])
        b = self._open_band(s["b"])
        n = self._open_band(s["n"])
        m = self._open_band(s["m"])  # mask

        # stack channels -> (H, W, 4)
        # If single-channel images are HxW arrays, that's fine.
        img = np.stack([r, g, b, n], axis=-1)  # H,W,4

        # Normalize to [0,1] per-image (simple)
        # To avoid division by zero handle constant images
        img = img.astype(np.float32)
        # per-band normalization (min-max)
        for c in range(img.shape[-1]):
            band = img[..., c]
            mn, mx = band.min(), band.max()
            if mx > mn:
                img[..., c] = (band - mn) / (mx - mn)
            else:
                img[..., c] = band - mn  # zeros

        # mask -> ensure single channel and 0/1 labels
        if m.ndim == 3:
            m = m[..., 0]
        mask = (m > 0).astype(np.uint8)  # cloud=1, background=0

        H, W, _ = img.shape

        # Optional random crop if crop_size provided
        if self.crop_size is not None:
            ch, cw = self.crop_size
            # ensure image is at least crop size
            if H < ch or W < cw:
                # pad if necessary (simple pad with zeros)
                pad_h = max(0, ch - H)
                pad_w = max(0, cw - W)
                img = np.pad(img, ((0, pad_h), (0, pad_w), (0, 0)), mode="constant", constant_values=0)
                mask = np.pad(mask, ((0, pad_h), (0, pad_w)), mode="constant", constant_values=0)
                H, W = img.shape[:2]

            top = random.randint(0, H - ch)
            left = random.randint(0, W - cw)
            img = img[top:top + ch, left:left + cw]
            mask = mask[top:top + ch, left:left + cw]

        # final safety: ensure divisibility by PATCH_DIV
        H, W = img.shape[:2]
        rem_h = H % PATCH_DIV
        rem_w = W % PATCH_DIV
        if rem_h != 0 or rem_w != 0:
            # crop bottom/right to make divisible
            H2 = H - rem_h
            W2 = W - rem_w
            img = img[:H2, :W2]
            mask = mask[:H2, :W2]

        # transpose to C,H,W and to torch tensors
        img = img.transpose(2, 0, 1).astype(np.float32)  # 4,H,W
        img_t = torch.from_numpy(img)
        mask_t = torch.from_numpy(mask.astype(np.int64))  # for CrossEntropyLoss expect long dtype

        return img_t, mask_t

In [8]:
# ---------------------------------------------------------
# simple IoU metric
# ---------------------------------------------------------
def batch_iou(pred_logits, target, threshold=0.5, ignore_background=False):
    """
    pred_logits: (B, C, H, W) logits
    target: (B, H, W) ints 0..C-1
    returns mean IoU over batch for class=1 (cloud) and overall mean
    """
    with torch.no_grad():
        probs = torch.softmax(pred_logits, dim=1)
        preds = probs.argmax(dim=1)  # (B,H,W)
        B = target.shape[0]
        ious = []
        for b in range(B):
            t = target[b].view(-1)
            p = preds[b].view(-1)
            # class 1 IoU
            inter = ((p == 1) & (t == 1)).sum().item()
            union = ((p == 1) | (t == 1)).sum().item()
            if union == 0:
                iou1 = 1.0  # no positive region in both -> perfect for class1
            else:
                iou1 = inter / union
            # overall (mean of class0 & class1)
            inter0 = ((p == 0) & (t == 0)).sum().item()
            union0 = ((p == 0) | (t == 0)).sum().item()
            if union0 == 0:
                iou0 = 1.0
            else:
                iou0 = inter0 / union0
            ious.append((iou0 + iou1) / 2.0)
        return float(np.mean(ious))

In [9]:
# ---------------------------------------------------------
# load train/val split produced by data_loader.py
# ---------------------------------------------------------
SPLIT_PKL = Path("train_val_split.pkl")
if not SPLIT_PKL.exists():
    raise FileNotFoundError("train_val_split.pkl not found. Run data_loader.py first to create it.")

with open(SPLIT_PKL, "rb") as f:
    splits = pickle.load(f)
train_samples = splits["train"]
val_samples = splits["val"]

print(f"Loaded splits -> train: {len(train_samples)}, val: {len(val_samples)}")

# Optional: small debug subset
# train_samples = train_samples[:200]
# val_samples = val_samples[:100]

# Create datasets / dataloaders
# If your tiles are already target-size (e.g., 256x256), set crop_size=None.
# If you want random crops of size 256, set crop_size=(256,256)
CROP_SIZE = None  # or (256,256) or (128,128) as per your tile size
train_ds = Cloud95Dataset(train_samples, crop_size=CROP_SIZE)
val_ds = Cloud95Dataset(val_samples, crop_size=None)  # keep full tiles at val

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)


Loaded splits -> train: 21041, val: 5260


In [11]:
# ---------------------------------------------------------
# Build model, loss, optimizer
# ---------------------------------------------------------
# automatically infer in_ch from dataset first sample
sample_img, _ = train_ds[0]
in_ch = sample_img.shape[0]
print("Detected input channels:", in_ch)

model = CloudSegSpectralViT(in_ch=in_ch, num_classes=2, embed_dim=768, depth=8, num_heads=8, patch_size=(8,8), spec_group=3)
model = model.to(DEVICE)

criterion = nn.CrossEntropyLoss()  # expects logits and target (B,H,W) with values 0..C-1
optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
# optional LR scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", factor=0.5, patience=3)

# optionally use mixed precision
scaler = torch.cuda.amp.GradScaler() if DEVICE.startswith("cuda") else None

Detected input channels: 4


In [None]:
# ---------------------------------------------------------
# training loop
# ---------------------------------------------------------
best_val_iou = 0.0
start_time = time.time()
print('begin training')
for epoch in range(1, NUM_EPOCHS + 1):
    model.train()
    running_loss = 0.0
    t0 = time.time()
    for batch_i, (imgs, masks) in enumerate(train_loader, start=1):
        print(f"\rbatch {batch_i}", end='', flush=True)
        imgs = imgs.to(DEVICE, non_blocking=True)
        masks = masks.to(DEVICE, non_blocking=True)  # shape (B,H,W)
        optimizer.zero_grad()

        if scaler is not None:
            with torch.cuda.amp.autocast():
                logits = model(imgs)  # (B,2,H,W)
                loss = criterion(logits, masks)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            logits = model(imgs)
            loss = criterion(logits, masks)
            loss.backward()
            optimizer.step()

        running_loss += loss.item()
        if batch_i % 20 == 0:
            print(f"Epoch {epoch} | Batch {batch_i}/{len(train_loader)} | loss: {running_loss / batch_i:.4f}", end="\r")

    epoch_loss = running_loss / len(train_loader)
    t1 = time.time()

    # Validation
    model.eval()
    val_losses = []
    val_ious = []
    with torch.no_grad():
        for imgs, masks in val_loader:
            imgs = imgs.to(DEVICE, non_blocking=True)
            masks = masks.to(DEVICE, non_blocking=True)
            logits = model(imgs)
            loss = criterion(logits, masks)
            val_losses.append(loss.item())
            val_ious.append(batch_iou(logits, masks))

    mean_val_loss = float(np.mean(val_losses))
    mean_val_iou = float(np.mean(val_ious))

    print(f"\nEpoch {epoch} finished in {(t1-t0):.1f}s | train_loss={epoch_loss:.4f} val_loss={mean_val_loss:.4f} val_iou={mean_val_iou:.4f}")

    # scheduler step (ReduceLROnPlateau)
    if scheduler is not None:
        scheduler.step(mean_val_iou)

    # save checkpoint
    ckpt = {
        "epoch": epoch,
        "model_state": model.state_dict(),
        "optim_state": optimizer.state_dict(),
        "val_iou": mean_val_iou
    }
    ckpt_path = CHECKPOINT_DIR / f"ckpt_epoch{epoch:03d}_iou{mean_val_iou:.4f}.pth"
    torch.save(ckpt, ckpt_path)

    # keep best
    if mean_val_iou > best_val_iou:
        best_val_iou = mean_val_iou
        best_path = CHECKPOINT_DIR / "best_model.pth"
        torch.save(ckpt, best_path)
        print(f"  -> New best model saved (val_iou={best_val_iou:.4f})")

total_time = time.time() - start_time
print(f"Training completed in {total_time/60:.1f} minutes. Best val IoU: {best_val_iou:.4f}")

begin training




batch 1