# LEVIR-CD | Dense vs SP-in (30 epochs each)

This notebook trains **two** change-detection models (Dense baseline and **SP-in**) on a LEVIR-CD–style folder:

```
root/
  A/      # t1 images
  B/      # t2 images
  label/  # binary masks (0/255 or 0/1)
```
or nested as `root/train/{A,B,label}`, `root/val/{...}`, `root/test/{...}`.

It will automatically mount Google Drive on Colab, look for a likely dataset path, and save outputs (checkpoints, metrics, overlays) to your Drive at `MyDrive/LEVIR_CD_runs/run_YYYYmmdd_HHMMSS/`.

In [None]:
import os, math, json, time, argparse, sys
from typing import Tuple, List, Dict, Optional

import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

def in_colab() -> bool:
    try:
        import google.colab  # type: ignore
        return True
    except Exception:
        return False

def set_seed(seed: int = 42):
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def colab_mount_and_resolve_paths(user_root: Optional[str]) -> Tuple[Optional[str], str]:
    out_dir = "/content"
    if not in_colab():
        return user_root, out_dir
    try:
        from google.colab import drive  # type: ignore
        drive.mount("/content/drive", force_remount=False)
    except Exception:
        pass
    candidates: List[str] = []
    if user_root:
        candidates.append(user_root)
    candidates.extend([
        "/content/drive/MyDrive/LEVIR-CD",
        "/content/drive/MyDrive/datasets/LEVIR-CD",
        "/content/LEVIR-CD",
    ])
    resolved_root: Optional[str] = None
    for p in candidates:
        if p and os.path.isdir(p):
            if os.path.isdir(os.path.join(p, "A")) or os.path.isdir(os.path.join(p, "train", "A")):
                resolved_root = p
                break
    default_drive_out = "/content/drive/MyDrive/LEVIR_CD_runs"
    out_dir = default_drive_out if os.path.isdir("/content/drive/MyDrive") else "/content/outputs"
    os.makedirs(out_dir, exist_ok=True)
    if resolved_root is None:
        print("[Colab] Could not auto-locate LEVIR-CD. Please set ROOT manually in the next cell.")
    return resolved_root, out_dir

print("Torch:", torch.__version__, "CUDA:", torch.cuda.is_available())

## 1) Configure paths & hyperparameters
- If `ROOT` is None, the cell will try to auto-locate under your Drive.
- Change `ROOT` to your dataset folder if needed.
- `EPOCHS_DENSE` and `EPOCHS_SP` are both 30 by default.

In [None]:
ROOT = "/content/drive/MyDrive/LEVIR-CD"
RESIZE = 256
BATCH_SIZE = 8
LR = 1e-3
EPOCHS_DENSE = 30
EPOCHS_SP = 30
SEED = 42

set_seed(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
resolved_root, default_out_dir = colab_mount_and_resolve_paths(ROOT)
ROOT = resolved_root or ROOT
OUT_BASE = default_out_dir
RUN_DIR = os.path.join(OUT_BASE, time.strftime("run_%Y%m%d_%H%M%S"))
os.makedirs(RUN_DIR, exist_ok=True)
print("ROOT:", ROOT)
print("RUN_DIR:", RUN_DIR)
assert ROOT is not None, "Please set ROOT to the folder containing A/B/label (or nested train/val/test)."

## 2) Dataset (LEVIR-CD style)

In [None]:
class LevirCD(Dataset):
    def __init__(self, root: str, split: str = "train", resize: int = 256, aug: bool = True):
        self.root = root
        self.resize = resize
        self.aug = aug and (split == "train")
        cand = os.path.join(root, split)
        if os.path.isdir(cand) and os.path.isdir(os.path.join(cand, "A")):
            base = cand
        else:
            base = root
        a_dir = os.path.join(base, "A")
        b_dir = os.path.join(base, "B")
        l_dir = os.path.join(base, "label")
        names = set(os.listdir(a_dir)) & set(os.listdir(b_dir)) & set(os.listdir(l_dir))
        names = sorted([n for n in names if n.lower().endswith((".png", ".jpg", ".jpeg", ".tif", ".bmp"))])
        n = len(names)
        n_train = int(0.8 * n)
        n_val = int(0.1 * n)
        if split == "train":
            self.names = names[:n_train]
        elif split == "val":
            self.names = names[n_train:n_train+n_val]
        else:
            self.names = names[n_train+n_val:]
        self.a_dir, self.b_dir, self.l_dir = a_dir, b_dir, l_dir

    def __len__(self):
        return len(self.names)

    def _load_img(self, path: str) -> np.ndarray:
        img = Image.open(path).convert("RGB")
        if self.resize is not None:
            img = img.resize((self.resize, self.resize), Image.BILINEAR)
        return np.array(img, dtype=np.uint8)

    def _load_mask(self, path: str) -> np.ndarray:
        m = Image.open(path).convert("L")
        if self.resize is not None:
            m = m.resize((self.resize, self.resize), Image.NEAREST)
        arr = np.array(m, dtype=np.uint8)
        if arr.max() > 1:
            arr = (arr > 127).astype(np.uint8)
        return arr

    def __getitem__(self, idx):
        name = self.names[idx]
        a = self._load_img(os.path.join(self.a_dir, name))
        b = self._load_img(os.path.join(self.b_dir, name))
        y = self._load_mask(os.path.join(self.l_dir, name))
        if self.aug:
            if np.random.rand() < 0.5:
                a = np.flip(a, axis=1).copy(); b = np.flip(b, axis=1).copy(); y = np.flip(y, axis=1).copy()
            if np.random.rand() < 0.5:
                a = np.flip(a, axis=0).copy(); b = np.flip(b, axis=0).copy(); y = np.flip(y, axis=0).copy()
        a = torch.from_numpy(a).permute(2,0,1).float() / 255.0
        b = torch.from_numpy(b).permute(2,0,1).float() / 255.0
        y = torch.from_numpy(y).float().unsqueeze(0)
        return a, b, y, name

train_set = LevirCD(ROOT, split="train", resize=RESIZE, aug=True)
val_set   = LevirCD(ROOT, split="val",   resize=RESIZE, aug=False)
test_set  = LevirCD(ROOT, split="test",  resize=RESIZE, aug=False)
print("Train/Val/Test sizes:", len(train_set), len(val_set), len(test_set))

workers = int(os.getenv("CD_WORKERS", "0"))
pin_mem = os.getenv("CD_PIN_MEMORY", "0") == "1"
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True,  num_workers=workers, pin_memory=pin_mem)
val_loader   = DataLoader(val_set,   batch_size=BATCH_SIZE, shuffle=False, num_workers=workers, pin_memory=pin_mem)
test_loader  = DataLoader(test_set,  batch_size=BATCH_SIZE, shuffle=False, num_workers=workers, pin_memory=pin_mem)

## 3) SP-in Conv + FC-Siam-conc model

In [None]:
class SPInConv2d(nn.Module):
    def __init__(self, in_ch, out_ch, k=3, s=1, p=1, bias=True,
                 alpha=0.05, update_interval=50, warmup_steps=1000,
                 tau_high=0.7, tau_low=0.1, dens_freq=0.4, dens_rare=0.9,
                 min_fan=4, rescale_mode='sqrt', use_bn=True):
        super().__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, k, s, p, bias=bias)
        self.bn = nn.BatchNorm2d(out_ch) if use_bn else None
        self.alpha = float(alpha)
        self.update_interval = int(update_interval)
        self.warmup_steps = int(warmup_steps)
        self.tau_h = float(tau_high)
        self.tau_l = float(tau_low)
        self.d_f = float(dens_freq)
        self.d_r = float(dens_rare)
        self.min_fan = int(min_fan)
        self.rescale_mode = rescale_mode
        self.register_buffer("mask", torch.ones(out_ch, in_ch, 1, 1))
        self.register_buffer("activation_counts", torch.zeros(out_ch))
        self.register_buffer("step", torch.zeros((), dtype=torch.long), persistent=False)
        self._counter = 0

    def _rescale(self, pre):
        if self.rescale_mode == "none":
            return pre
        fan = self.mask.view(self.mask.size(0), -1).sum(dim=1).clamp_min(1.0)
        base = float(self.mask.size(1))
        if self.rescale_mode == "linear":
            s = base / fan
        else:
            s = (base / fan).sqrt()
        return pre * s.view(1, -1, 1, 1).to(pre.device, pre.dtype)

    @torch.no_grad()
    def _update_mask(self):
        eps = 1e-12
        avg = self.activation_counts.detach().clone()
        cur = self.mask.detach().clone()
        W = self.conv.weight.detach().abs()
        W2 = W.view(W.size(0), W.size(1), -1).mean(dim=2)
        freq = (avg > self.tau_h).nonzero(as_tuple=True)[0].tolist()
        rare = (avg < self.tau_l).nonzero(as_tuple=True)[0].tolist()
        out_ch, in_ch = cur.size(0), cur.size(1)
        def apply_row(i, density):
            alive = cur[i,:,0,0].bool()
            cand = W2[i][alive]
            k = max(self.min_fan, int(round(in_ch * float(density))))
            k = min(k, int(alive.sum().item()))
            if cand.numel() <= k:
                return
            thr = cand.topk(k, largest=True).values[-1] + eps
            row = cur[i,:,0,0].clone().zero_()
            keep = (alive & (W2[i] >= thr))
            row[keep] = 1.0
            cur[i,:,0,0].copy_(row)
        for i in freq:
            apply_row(i, self.d_f)
        for i in rare:
            apply_row(i, self.d_r)
        self.mask.data.copy_(cur)

    def forward(self, x):
        m = self.mask.to(dtype=self.conv.weight.dtype, device=self.conv.weight.device)
        w = self.conv.weight * m.expand_as(self.conv.weight)
        pre = F.conv2d(x, w, self.conv.bias, stride=self.conv.stride,
                       padding=self.conv.padding, dilation=self.conv.dilation, groups=self.conv.groups)
        pre = self._rescale(pre)
        h = F.relu(pre, inplace=False)
        if self.bn is not None:
            h = self.bn(h)
        if self.training:
            with torch.no_grad():
                act = (h > 0).to(h.dtype).mean(dim=(0,2,3))
                self.activation_counts.mul_(1 - self.alpha).add_(self.alpha * act)
                self.step += 1
                self._counter += 1
                if int(self.step.item()) >= self.warmup_steps and self.update_interval > 0 and (self._counter % self.update_interval == 0):
                    self._update_mask()
        return h

class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch, use_sp=False, sp_kwargs=None):
        super().__init__()
        sp_kwargs = sp_kwargs or {}
        if use_sp:
            self.c1 = SPInConv2d(in_ch, out_ch, k=3, p=1, use_bn=True, **sp_kwargs)
        else:
            self.c1 = nn.Sequential(
                nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
                nn.BatchNorm2d(out_ch),
                nn.ReLU(inplace=True),
            )
        self.c2 = nn.Sequential(
            nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )
    def forward(self, x):
        x = self.c1(x)
        x = self.c2(x)
        return x

class SiameseEncoder(nn.Module):
    def __init__(self, in_ch=3, feats=(32,64,128,256), use_sp=False, sp_kwargs=None):
        super().__init__()
        self.blocks = nn.ModuleList()
        self.pools = nn.ModuleList()
        ch = in_ch
        for f in feats:
            self.blocks.append(ConvBlock(ch, f, use_sp=use_sp, sp_kwargs=sp_kwargs))
            self.pools.append(nn.MaxPool2d(2))
            ch = f
    def forward_once(self, x):
        feats = []
        for block, pool in zip(self.blocks, self.pools):
            x = block(x); feats.append(x); x = pool(x)
        return feats, x
    def forward(self, a, b):
        fa, xa = self.forward_once(a)
        fb, xb = self.forward_once(b)
        return fa, xa, fb, xb

class Decoder(nn.Module):
    def __init__(self, feats=(32,64,128,256)):
        super().__init__()
        self.upconvs = nn.ModuleList()
        self.blocks = nn.ModuleList()
        ch = feats[-1]*2
        for f in reversed(feats):
            self.upconvs.append(nn.ConvTranspose2d(ch, f, 2, stride=2))
            self.blocks.append(ConvBlock(f*3, f, use_sp=False))
            ch = f
        self.final = nn.Conv2d(feats[0], 1, 1)
    def forward(self, fa, xa, fb, xb):
        x = torch.cat([xa, xb], dim=1)
        for i in range(len(self.upconvs)):
            x = self.upconvs[i](x)
            fa_i = fa[-(i+1)]
            fb_i = fb[-(i+1)]
            x = torch.cat([x, fa_i, fb_i], dim=1)
            x = self.blocks[i](x)
        return self.final(x)

class FCSiamConc(nn.Module):
    def __init__(self, in_ch=3, feats=(32,64,128,256), use_sp=False, sp_kwargs=None):
        super().__init__()
        self.encoder = SiameseEncoder(in_ch=in_ch, feats=feats, use_sp=use_sp, sp_kwargs=sp_kwargs)
        self.decoder = Decoder(feats=feats)
    def forward(self, a, b):
        fa, xa, fb, xb = self.encoder(a, b)
        logits = self.decoder(fa, xa, fb, xb)
        return logits


## 4) Metrics & helpers

In [None]:
def binarize(logits: torch.Tensor, thr: float = 0.5) -> torch.Tensor:
    return (torch.sigmoid(logits) >= thr).to(torch.uint8)

def iou_f1(pred: np.ndarray, gt: np.ndarray):
    inter = np.logical_and(pred==1, gt==1).sum()
    union = np.logical_or(pred==1, gt==1).sum()
    iou = inter / (union + 1e-6)
    tp = inter
    fp = np.logical_and(pred==1, gt==0).sum()
    fn = np.logical_and(pred==0, gt==1).sum()
    f1 = (2*tp) / (2*tp + fp + fn + 1e-6)
    return float(iou), float(f1)

@torch.no_grad()
def evaluate(model, loader, device, thr=0.5):
    model.eval()
    iou_all, f1_all = [], []
    for a,b,y,_ in loader:
        a,b,y = a.to(device), b.to(device), y.to(device)
        logits = model(a,b)
        pred = binarize(logits, thr=thr).cpu().numpy().astype(np.uint8)
        gt = (y.cpu().numpy() > 0.5).astype(np.uint8)
        for i in range(pred.shape[0]):
            iou, f1 = iou_f1(pred[i,0], gt[i,0])
            iou_all.append(iou); f1_all.append(f1)
    return {"mean_IoU": float(np.mean(iou_all)) if iou_all else 0.0,
            "mean_F1": float(np.mean(f1_all)) if f1_all else 0.0}

def overlay_example(img_rgb: np.ndarray, gt: np.ndarray, pred_dense: np.ndarray, pred_sp: np.ndarray):
    # img_rgb: HxWx3 uint8, gt/preds: HxW {0,1}
    base = img_rgb.copy()
    h,w,_ = base.shape
    canvas = base.astype(np.float32)
    both = (pred_dense==1) & (pred_sp==1) & (gt==1)   # yellow
    sp_only = (pred_sp==1) & (pred_dense==0) & (gt==1)  # green
    dense_only = (pred_dense==1) & (pred_sp==0) & (gt==1) # red
    # overlay with 50% alpha
    def paint(mask, color):
        c = np.array(color, dtype=np.float32)[None,None,:]
        m = mask.astype(np.float32)[:,:,None]
        canvas = base.astype(np.float32)
        return np.where(m>0, 0.5*canvas + 0.5*c, canvas)
    tmp = paint(both, (255,255,0))
    tmp = np.where(sp_only[:,:,None]>0, 0.5*tmp + 0.5*np.array([0,255,0],dtype=np.float32), tmp)
    tmp = np.where(dense_only[:,:,None]>0, 0.5*tmp + 0.5*np.array([255,0,0],dtype=np.float32), tmp)
    return tmp.clip(0,255).astype(np.uint8)


## 5) Train loops (Dense / SP-in)

In [None]:
from tqdm import tqdm

def train_one_epoch(model, loader, optimizer, device, epoch=None, total_epochs=None, method=None):
    model.train()
    loss_meter, n = 0.0, 0
    desc = f"Train[{method}] {epoch}/{total_epochs}" if method else "Train"
    pbar = tqdm(loader, desc=desc, dynamic_ncols=True, leave=True)
    for a,b,y,_ in pbar:
        a,b,y = a.to(device), b.to(device), y.to(device)
        logits = model(a,b)
        loss = F.binary_cross_entropy_with_logits(logits, y)
        optimizer.zero_grad(); loss.backward(); optimizer.step()

        bs = a.size(0)
        loss_meter += float(loss.item()) * bs
        n += bs

        # show rolling average & current lr on progress bar
        avg = loss_meter / max(1,n)
        lr  = optimizer.param_groups[0]['lr']
        pbar.set_postfix(loss=f"{avg:.4f}", lr=f"{lr:.4g}")
    pbar.close()
    return loss_meter / max(1,n)

@torch.no_grad()
def evaluate(model, loader, device, thr=0.5, desc=None):
    """Evaluate without bucketing (LEVIR-CD style metrics)."""
    model.eval()
    iou_all, f1_all = [], []
    all_preds, all_gts, all_names = [], [], []
    it = tqdm(loader, desc=desc, dynamic_ncols=True, leave=True) if desc else loader
    for a,b,y,names in it:
        a,b,y = a.to(device), b.to(device), y.to(device)
        logits = model(a,b)
        pred = (torch.sigmoid(logits) >= thr).to(torch.uint8).cpu().numpy()
        gt   = (y.cpu().numpy() > 0.5).astype(np.uint8)
        for i in range(pred.shape[0]):
            inter = np.logical_and(pred[i,0]==1, gt[i,0]==1).sum()
            union = np.logical_or (pred[i,0]==1, gt[i,0]==1).sum()
            iou = inter / (union + 1e-6)
            tp  = inter
            fp  = np.logical_and(pred[i,0]==1, gt[i,0]==0).sum()
            fn  = np.logical_and(pred[i,0]==0, gt[i,0]==1).sum()
            f1  = (2*tp) / (2*tp + fp + fn + 1e-6)
            iou_all.append(float(iou)); f1_all.append(float(f1))
        all_preds.append(pred[:,0])
        all_gts.append(gt[:,0])
        all_names.extend(list(names))
    return {
        "mean_IoU": float(np.mean(iou_all)) if iou_all else 0.0,
        "mean_F1":  float(np.mean(f1_all))  if f1_all  else 0.0,
        "preds": np.concatenate(all_preds, axis=0) if all_preds else None,
        "gts":   np.concatenate(all_gts,   axis=0) if all_gts   else None,
        "names": all_names,
    }
def train_and_eval(method: str, use_sp: bool, epochs: int, run_subdir: str, sp_kwargs=None):
    sp_kwargs = sp_kwargs or {}
    save_dir = os.path.join(RUN_DIR, run_subdir)
    os.makedirs(save_dir, exist_ok=True)
    model = FCSiamConc(in_ch=3, feats=(32,64,128,256), use_sp=use_sp, sp_kwargs=sp_kwargs).to(device)
    optim = torch.optim.Adam(model.parameters(), lr=LR)
    best_f1, best_path = -1.0, os.path.join(save_dir, "best.pth")
    last_path = os.path.join(save_dir, "last.pth")
    log = []
    for ep in range(1, epochs+1):
        tr_loss = train_one_epoch(model, train_loader, optim, device,
                                  epoch=ep, total_epochs=epochs, method=method)
        val_res = evaluate(model, val_loader, device, desc=f"Val[{method}] {ep}/{epochs}")
        log.append({"epoch": ep, "train_loss": tr_loss, **{k:v for k,v in val_res.items() if isinstance(v, float)}})
        from tqdm import tqdm as _tqdm
        _tqdm.write(f"[{method}] Epoch {ep:03d} | Train {tr_loss:.4f} | Val F1 {val_res['mean_F1']:.4f} | IoU {val_res['mean_IoU']:.4f}")

        torch.save(model.state_dict(), last_path)
        if val_res["mean_F1"] > best_f1:
            best_f1 = val_res["mean_F1"]
            torch.save(model.state_dict(), best_path)
            with open(os.path.join(save_dir, "val_best.json"), "w") as f:
                json.dump({"epoch": ep, **val_res}, f, indent=2)
        with open(os.path.join(save_dir, "epoch_log.jsonl"), "a") as f:
            f.write(json.dumps(log[-1]) + "\n")

    model.load_state_dict(torch.load(best_path, map_location=device))
    test_res = evaluate(model, test_loader, device, desc=f"Test[{method}]")
    with open(os.path.join(save_dir, "test_metrics.json"), "w") as f:
        json.dump(test_res, f, indent=2)
    print(f"[{method}] Test F1 {test_res['mean_F1']:.4f} | IoU {test_res['mean_IoU']:.4f}")
    return save_dir, model, test_res


# SP-in default hyperparams (safe starting point)
SP_KW = dict(
    alpha=0.05, update_interval=25, warmup_steps=40,
    tau_high=0.7, tau_low=0.1,
    dens_freq=0.4, dens_rare=0.9,
    min_fan=4, rescale_mode='sqrt'
)

dense_dir, dense_model, dense_test = train_and_eval("dense", use_sp=False, epochs=EPOCHS_DENSE, run_subdir="dense")
sp_dir, sp_model, sp_test = train_and_eval("sp_in", use_sp=True, epochs=EPOCHS_SP, run_subdir="sp_in", sp_kwargs=SP_KW)


## 6) Compare on test & produce a “best example” overlay

In [None]:
@torch.no_grad()
def pick_best_example_and_save(dense_model, sp_model, loader, save_dir):
    dense_model.eval(); sp_model.eval()
    best_score = -1e9
    best_path = None
    best_img = None
    for a,b,y,names in loader:
        a,b,y = a.to(device), b.to(device), y.to(device)
        logits_d = dense_model(a,b)
        logits_s = sp_model(a,b)
        pred_d = binarize(logits_d, 0.5).cpu().numpy().astype(np.uint8)
        pred_s = binarize(logits_s, 0.5).cpu().numpy().astype(np.uint8)
        gt = (y.cpu().numpy() > 0.5).astype(np.uint8)
        a_np = (a.cpu().numpy()*255).astype(np.uint8).transpose(0,2,3,1)
        b_np = (b.cpu().numpy()*255).astype(np.uint8).transpose(0,2,3,1)
        for i in range(pred_d.shape[0]):
            # score: sp-only hits minus dense-only hits
            sp_only = np.logical_and(pred_s[i,0]==1, np.logical_and(pred_d[i,0]==0, gt[i,0]==1)).sum()
            dense_only = np.logical_and(pred_d[i,0]==1, np.logical_and(pred_s[i,0]==0, gt[i,0]==1)).sum()
            score = sp_only - dense_only
            if score > best_score:
                best_score = score
                best_path = names[i]
                overlay = overlay_example(b_np[i], gt[i,0], pred_d[i,0], pred_s[i,0])
                best_img = overlay
    if best_img is not None:
        out_png = os.path.join(RUN_DIR, "best_example_overlay.png")
        Image.fromarray(best_img).save(out_png)
        return out_png, best_path, int(best_score)
    return None, None, 0

overlay_png, sample_name, score = pick_best_example_and_save(dense_model, sp_model, test_loader, RUN_DIR)
summary = {
    "dense_test": dense_test,
    "sp_in_test": sp_test,
    "best_example": {
        "file": overlay_png,
        "sample": sample_name,
        "score_sp_only_minus_dense_only": score
    }
}
with open(os.path.join(RUN_DIR, "comparison_summary.json"), "w") as f:
    json.dump(summary, f, indent=2)
print("Saved comparison summary to:", os.path.join(RUN_DIR, "comparison_summary.json"))
print("Best example overlay:", overlay_png)
summary

In [None]:
# @title LEVIR-CD | Dense vs SP-in (30 epochs each) – Colab-ready
# -*- coding: utf-8 -*-
import os, sys, math, time, random, json
from typing import Optional, Tuple, List, Dict

import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from tqdm import tqdm


# =========================
# Configs (edit here if needed)
# =========================
CFG = {
    "root": "/content/drive/MyDrive/LEVIR-CD",  # <- change to your LEVIR-CD root directory (containing train/val/test three subdirectories)
    "resize": 256,
    "batch_size": 8,
    "epochs_per_method": 30,
    "lr": 1e-3,
    "seed": 42,
    "out_dir": "/content/drive/MyDrive/LEVIR_CD_runs",  # save results
    # SP-in hyperparameters (conservative default)
    "sp": {
        "alpha": 0.05,
        "update_interval": 10,   # update mask every 50 batches
        "warmup_steps": 25,    # before 1000 steps, mask is not updated
        "tau_high": 0.7,         # frequent activation threshold (static threshold when percentile is difficult)
        "tau_low": 0.3,          # rare activation threshold
        "dens_freq": 0.5,        # frequent channels: more sparse
        "dens_rare": 0.9,        # rare channels: more dense
        "min_fan": 4,
        "rescale_mode": "sqrt",
    },
}


# =========================
# Colab helpers
# =========================
def in_colab() -> bool:
    try:
        import google.colab  # type: ignore
        return True
    except Exception:
        return False

def mount_drive_if_colab():
    if in_colab():
        from google.colab import drive  # type: ignore
        try:
            drive.mount("/content/drive", force_remount=False)
        except Exception:
            pass


# =========================
# Repro
# =========================
def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


# =========================
# Dataset
# =========================
class LevirCD(Dataset):
    """
    directory structure requirements:
      root/
        train/ A|B|label
        val/   A|B|label
        test/  A|B|label
    each subfolder has corresponding file names.
    """
    def __init__(self, root: str, split: str = "train", resize: int = 256, aug: bool = True):
        super().__init__()
        base = os.path.join(root, split)
        a_dir = os.path.join(base, "A")
        b_dir = os.path.join(base, "B")
        l_dir = os.path.join(base, "label")
        if not (os.path.isdir(a_dir) and os.path.isdir(b_dir) and os.path.isdir(l_dir)):
            raise FileNotFoundError(f"[{split}] expected folders A/B/label under: {base}")

        names = set(os.listdir(a_dir)) & set(os.listdir(b_dir)) & set(os.listdir(l_dir))
        self.names = sorted([n for n in names if n.lower().endswith((".png",".jpg",".jpeg",".tif",".bmp"))])
        if len(self.names) == 0:
            raise FileNotFoundError(f"No images found under {base}/A|B|label")
        self.a_dir, self.b_dir, self.l_dir = a_dir, b_dir, l_dir
        self.resize = resize
        self.aug = aug and (split == "train")

    def __len__(self): return len(self.names)

    def _load_img(self, path: str) -> np.ndarray:
        img = Image.open(path).convert("RGB")
        if self.resize is not None:
            img = img.resize((self.resize, self.resize), Image.BILINEAR)
        return np.array(img, dtype=np.uint8)

    def _load_mask(self, path: str) -> np.ndarray:
        m = Image.open(path).convert("L")
        if self.resize is not None:
            m = m.resize((self.resize, self.resize), Image.NEAREST)
        arr = np.array(m, dtype=np.uint8)
        if arr.max() > 1:
            arr = (arr > 127).astype(np.uint8)
        return arr

    def __getitem__(self, idx):
        name = self.names[idx]
        a = self._load_img(os.path.join(self.a_dir, name))
        b = self._load_img(os.path.join(self.b_dir, name))
        y = self._load_mask(os.path.join(self.l_dir, name))

        if self.aug:
            if np.random.rand() < 0.5:
                a = np.flip(a, axis=1).copy(); b = np.flip(b, axis=1).copy(); y = np.flip(y, axis=1).copy()
            if np.random.rand() < 0.5:
                a = np.flip(a, axis=0).copy(); b = np.flip(b, axis=0).copy(); y = np.flip(y, axis=0).copy()

        a = torch.from_numpy(a).permute(2,0,1).float() / 255.0
        b = torch.from_numpy(b).permute(2,0,1).float() / 255.0
        y = torch.from_numpy(y).float().unsqueeze(0)
        return a, b, y, name


# =========================
# SP-in for Conv2d (fan-in mask)
# =========================
class SPInConv2d(nn.Module):
    """
    dynamic sparse for fan-in (input channel dimension) of Conv2d:
      - count the ratio of activated channels after ReLU (EMA)
      - frequent channels use smaller density dens_freq, rare channels use larger density dens_rare
      - only do top-k on "current alive" fan-in (prune-only), not grow; more stable
      - mask shape: [out, in, 1, 1], broadcast to kernel
    """
    def __init__(self, in_ch, out_ch, k=3, s=1, p=1, bias=True,
                 alpha=0.05, update_interval=50, warmup_steps=1000,
                 tau_high=0.7, tau_low=0.1, dens_freq=0.4, dens_rare=0.9,
                 min_fan=4, rescale_mode='sqrt', use_bn=True):
        super().__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, k, s, p, bias=bias)
        self.use_bn = use_bn
        self.bn = nn.BatchNorm2d(out_ch) if use_bn else None

        self.alpha = float(alpha)
        self.update_interval = int(update_interval)
        self.warmup_steps = int(warmup_steps)
        self.tau_h = float(tau_high)
        self.tau_l = float(tau_low)
        self.d_f = float(dens_freq)
        self.d_r = float(dens_rare)
        self.min_fan = int(min_fan)
        self.rescale_mode = rescale_mode

        self.register_buffer("mask", torch.ones(out_ch, in_ch, 1, 1))
        self.register_buffer("act_ema", torch.zeros(out_ch))
        self.register_buffer("step", torch.zeros((), dtype=torch.long), persistent=False)
        self._tick = 0

    def _rescale(self, pre):
        if self.rescale_mode == "none":
            return pre
        # number of fan-in retained for each output channel
        fan = self.mask.view(self.mask.size(0), -1).sum(dim=1).clamp_min(1.0)
        base = float(self.mask.size(1))
        if self.rescale_mode == "linear":
            s = base / fan
        else:
            s = (base / fan).sqrt()
        return pre * s.view(1, -1, 1, 1).to(pre.device, pre.dtype)

    @torch.no_grad()
    def _update_mask(self):
        # only update after warmup and every update_interval steps
        if int(self.step.item()) < self.warmup_steps: return
        if self.update_interval <= 0: return
        if (self._tick % self.update_interval) != 0: return

        cur = self.mask.detach().clone()                     # [out,in,1,1]
        W = self.conv.weight.detach().abs()                  # [out,in,kh,kw]
        W2 = W.view(W.size(0), W.size(1), -1).mean(dim=2)    # [out,in]
        out_ch, in_ch = cur.size(0), cur.size(1)

        freq = (self.act_ema > self.tau_h).nonzero(as_tuple=True)[0].tolist()
        rare = (self.act_ema < self.tau_l).nonzero(as_tuple=True)[0].tolist()

        def apply_row(i, density):
            alive = cur[i,:,0,0].bool()
            k = max(self.min_fan, int(round(in_ch * float(density))))
            k = min(k, int(alive.sum().item()))
            if k <= 0: return
            cand = W2[i][alive]
            if cand.numel() <= k:
                return
            thr = cand.topk(k, largest=True).values[-1]
            row = cur[i,:,0,0].clone().zero_()
            keep = (alive & (W2[i] >= thr))
            row[keep] = 1.0
            cur[i,:,0,0].copy_(row)

        for i in freq: apply_row(i, self.d_f)
        for i in rare: apply_row(i, self.d_r)

        self.mask.data.copy_(cur)

    def forward(self, x):
        m = self.mask.to(dtype=self.conv.weight.dtype, device=self.conv.weight.device)
        w = self.conv.weight * m.expand_as(self.conv.weight)
        pre = F.conv2d(x, w, self.conv.bias, stride=self.conv.stride,
                       padding=self.conv.padding, dilation=self.conv.dilation, groups=self.conv.groups)
        pre = self._rescale(pre)
        h = F.relu(pre, inplace=False)
        if self.use_bn: h = self.bn(h)

        if self.training:
            with torch.no_grad():
                # count the ratio of activated channels after ReLU (EMA)
                act = (h > 0).to(h.dtype).mean(dim=(0,2,3))
                self.act_ema.mul_(1 - self.alpha).add_(self.alpha * act)
                self.step += 1
                self._tick += 1
                self._update_mask()
        return h


# =========================
# FC-Siam-conc
# =========================
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch, use_sp=False, sp_kwargs=None):
        super().__init__()
        sp_kwargs = sp_kwargs or {}
        if use_sp:
            self.c1 = SPInConv2d(in_ch, out_ch, k=3, p=1, use_bn=True, **sp_kwargs)
        else:
            self.c1 = nn.Sequential(
                nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
                nn.BatchNorm2d(out_ch),
                nn.ReLU(inplace=True),
            )
        self.c2 = nn.Sequential(
            nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.c2(self.c1(x))


class SiameseEncoder(nn.Module):
    def __init__(self, in_ch=3, feats=(32,64,128,256), use_sp=False, sp_kwargs=None):
        super().__init__()
        self.blocks = nn.ModuleList()
        self.pools  = nn.ModuleList()
        ch = in_ch
        for f in feats:
            self.blocks.append(ConvBlock(ch, f, use_sp=use_sp, sp_kwargs=sp_kwargs))
            self.pools.append(nn.MaxPool2d(2))
            ch = f

    def forward_once(self, x):
        feats = []
        for blk, pool in zip(self.blocks, self.pools):
            x = blk(x); feats.append(x); x = pool(x)
        return feats, x

    def forward(self, a, b):
        fa, xa = self.forward_once(a)
        fb, xb = self.forward_once(b)
        return fa, xa, fb, xb


class Decoder(nn.Module):
    def __init__(self, feats=(32,64,128,256)):
        super().__init__()
        self.up  = nn.ModuleList()
        self.blk = nn.ModuleList()
        ch = feats[-1]*2
        for f in reversed(feats):
            self.up.append(nn.ConvTranspose2d(ch, f, 2, stride=2))
            self.blk.append(ConvBlock(f*3, f, use_sp=False))
            ch = f
        self.final = nn.Conv2d(feats[0], 1, 1)

    def forward(self, fa, xa, fb, xb):
        x = torch.cat([xa, xb], dim=1)
        for i in range(len(self.up)):
            x = self.up[i](x)
            fa_i = fa[-(i+1)]; fb_i = fb[-(i+1)]
            x = torch.cat([x, fa_i, fb_i], dim=1)
            x = self.blk[i](x)
        return self.final(x)


class FCSiamConc(nn.Module):
    def __init__(self, in_ch=3, feats=(32,64,128,256), use_sp=False, sp_kwargs=None):
        super().__init__()
        self.encoder = SiameseEncoder(in_ch=in_ch, feats=feats, use_sp=use_sp, sp_kwargs=sp_kwargs)
        self.decoder = Decoder(feats=feats)

    def forward(self, a, b):
        fa, xa, fb, xb = self.encoder(a, b)
        return self.decoder(fa, xa, fb, xb)


# =========================
# Metrics & helpers
# =========================
def binarize(logits: torch.Tensor, thr: float = 0.5) -> torch.Tensor:
    return (torch.sigmoid(logits) >= thr).to(torch.uint8)

def iou_f1(pred: np.ndarray, gt: np.ndarray) -> Tuple[float, float]:
    inter = np.logical_and(pred==1, gt==1).sum()
    union = np.logical_or(pred==1, gt==1).sum()
    iou = inter / (union + 1e-6)
    tp = inter
    fp = np.logical_and(pred==1, gt==0).sum()
    fn = np.logical_and(pred==0, gt==1).sum()
    f1 = (2*tp) / (2*tp + fp + fn + 1e-6)
    return float(iou), float(f1)

@torch.no_grad()
def evaluate(model, loader, device, thr=0.5):
    model.eval()
    ious, f1s = [], []
    all_preds, all_gts, all_names = [], [], []
    for a,b,y,names in loader:
        a,b,y = a.to(device), b.to(device), y.to(device)
        logits = model(a,b)
        pred = binarize(logits, thr=thr).cpu().numpy().astype(np.uint8)
        gt   = (y.cpu().numpy() > 0.5).astype(np.uint8)
        for i in range(pred.shape[0]):
            iou, f1 = iou_f1(pred[i,0], gt[i,0])
            ious.append(iou); f1s.append(f1)
        all_preds.append(pred[:,0])
        all_gts.append(gt[:,0])
        all_names.extend(list(names))
    return {
        "mean_IoU": float(np.mean(ious)) if ious else 0.0,
        "mean_F1": float(np.mean(f1s)) if f1s else 0.0,
        "preds": np.concatenate(all_preds, axis=0) if all_preds else None,
        "gts":   np.concatenate(all_gts, axis=0) if all_gts else None,
        "names": all_names,
    }

def train_one_epoch(model, loader, optimizer, device):
    model.train()
    loss_meter, n = 0.0, 0
    for a,b,y,_ in tqdm(loader, desc="Train", leave=False):
        a,b,y = a.to(device), b.to(device), y.to(device)
        logits = model(a,b)
        loss = F.binary_cross_entropy_with_logits(logits, y)
        optimizer.zero_grad(); loss.backward(); optimizer.step()
        loss_meter += float(loss.item()) * a.size(0)
        n += a.size(0)
    return loss_meter / max(1,n)

def choose_perfect_example(a_dir, b_dir, name,
                           gt: np.ndarray, pd_dense: np.ndarray, pd_spin: np.ndarray) -> Image.Image:
    """
    color only in the region where GT==1:
      green: SP-in hit and dense miss
      red: dense hit and SP-in miss
      blue: both hit
      black: both miss (or GT==0)
    background use B image (second phase).
    """
    img = Image.open(os.path.join(b_dir, name)).convert("RGB").resize(gt.shape[::-1], Image.BILINEAR)
    base = np.array(img, dtype=np.uint8)

    H,W = gt.shape
    overlay = np.zeros((H,W,3), dtype=np.uint8)

    tp_dense = (pd_dense==1) & (gt==1)
    tp_spin  = (pd_spin==1)  & (gt==1)
    both = tp_dense & tp_spin
    only_dense = tp_dense & (~tp_spin)
    only_spin  = tp_spin  & (~tp_dense)

    # color
    overlay[both]       = (0, 255, 255)   # blue
    overlay[only_spin]  = (0, 255,   0)   # green
    overlay[only_dense] = (255, 0,   0)   # red

    # overlay
    alpha = 0.5
    vis = (base*(1-alpha) + overlay*alpha).astype(np.uint8)
    # draw a legend
    legend = np.ones((70, W, 3), dtype=np.uint8)*255
    def put_patch(x0, color, text):
        x1 = x0+160
        c = np.array(color, dtype=np.uint8)
        legend[10:40, x0:x0+40] = c
        # simple text (use color block instead)
        # to avoid dependency on PIL font, here use color block: c-legend
        # you can use cv2.putText (Colab can install opencv)
        return x1
    x = 10
    x = put_patch(x, (0,255,255), "Both")
    x = put_patch(x, (0,255,0),   "SP-in only")
    x = put_patch(x, (255,0,0),   "Dense only")
    vis_full = np.vstack([legend, vis])
    return Image.fromarray(vis_full)

from PIL import Image, ImageDraw, ImageFont

# load raw triplet (original resolution)
def load_raw_triplet(root: str, split: str, name: str):
    pa = os.path.join(root, split, "A", name)
    pb = os.path.join(root, split, "B", name)
    pg = os.path.join(root, split, "label", name)
    A = Image.open(pa).convert("RGB")
    B = Image.open(pb).convert("RGB")
    G = Image.open(pg).convert("L")
    A = np.array(A, dtype=np.uint8)
    B = np.array(B, dtype=np.uint8)
    GT = (np.array(G, dtype=np.uint8) > 127).astype(np.uint8)
    return A, B, GT  # A,B:[H,W,3]; GT:[H,W] in {0,1}

# upsample 0/1 mask to target size
def upsample_mask_to(mask_small: np.ndarray, target_hw: tuple[int,int]) -> np.ndarray:
    H, W = target_hw
    pil = Image.fromarray((mask_small.astype(np.uint8) * 255))
    pil = pil.resize((W, H), Image.NEAREST)
    return (np.array(pil, dtype=np.uint8) > 127).astype(np.uint8)

# use small prediction to build "original overlay", and get "colored GT" by unmixing, then use colored GT to overlay for consistency check
def make_overlay_and_colorized(B: np.ndarray, GT: np.ndarray,
                               pred_dense_small: np.ndarray,
                               pred_spin_small:  np.ndarray,
                               alpha: float = 0.5):
    H, W = B.shape[:2]
    PD = upsample_mask_to(pred_dense_small, (H, W))
    PS = upsample_mask_to(pred_spin_small,  (H, W))

    BOTH = (0,255,255); SPONLY = (0,255,0); DENSEONLY = (255,0,0)

    both = (PD==1) & (PS==1) & (GT==1)
    sp   = (PS==1) & (PD==0) & (GT==1)
    dn   = (PD==1) & (PS==0) & (GT==1)

    color_gt = np.zeros_like(B, dtype=np.uint8)
    color_gt[both] = BOTH
    color_gt[sp]   = SPONLY
    color_gt[dn]   = DENSEONLY

    overlay = (B*(1-alpha) + color_gt*alpha).astype(np.uint8)

    # —— unmix (for check)：COLOR ≈ clip(2*overlay - B, 0, 255)
    color_est = np.clip(2*overlay.astype(np.int16) - B.astype(np.int16), 0, 255).astype(np.uint8)
    target_colors = np.array([BOTH, SPONLY, DENSEONLY], dtype=np.uint8)  # 3×3
    diff = ((color_est[...,None,:].astype(np.int16) - target_colors.astype(np.int16))**2).sum(axis=-1) # H×W×3
    cls = diff.argmin(axis=-1)  # 0/1/2
    color_quant = np.zeros_like(B, dtype=np.uint8)
    color_quant[GT==1] = target_colors[cls[GT==1]]

    overlay_check = (B*(1-alpha) + color_quant*alpha).astype(np.uint8)

    return PD, PS, overlay, color_quant, overlay_check

# 2×3 check panel
def make_check_panel(root: str, name: str,
                     pred_dense_small: np.ndarray,
                     pred_spin_small:  np.ndarray,
                     alpha: float = 0.5) -> Image.Image:
    A, B, GT = load_raw_triplet(root, "test", name)
    _, _, overlay, color_gt, overlay_check = make_overlay_and_colorized(
        B, GT, pred_dense_small, pred_spin_small, alpha=alpha
    )

    def as_gray3(x01: np.ndarray) -> np.ndarray:
        g = (x01.astype(np.uint8) * 255)
        return np.stack([g,g,g], axis=-1)

    row1 = np.concatenate([A, B, as_gray3(GT)], axis=1)
    row2 = np.concatenate([overlay, color_gt, overlay_check], axis=1)
    panel = np.concatenate([row1, row2], axis=0)

    img = Image.fromarray(panel)
    draw = ImageDraw.Draw(img)
    try:
        font = ImageFont.load_default()
    except Exception:
        font = None

    H, W = B.shape[0], B.shape[1]
    titles = [
        ("A (t1)", 0,0), ("B (t2)", 1,0), ("Ground Truth", 2,0),
        ("Overlay (B + color)", 0,1),
        ("Colorized GT (recovered)", 1,1),
        ("Re-overlay check", 2,1),
    ]
    for text, cx, cy in titles:
        x = cx*W + 10; y = cy*H + 10
        draw.rectangle([x-6, y-6, x+6+7*len(text), y+18], fill=(255,255,255))
        draw.text((x, y), text, fill=(0,0,0), font=font)

    return img




# =========================
# Main training script (two methods)
# =========================
def main():
    mount_drive_if_colab()
    set_seed(CFG["seed"])
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    root = CFG["root"]
    out_dir = CFG["out_dir"]
    os.makedirs(out_dir, exist_ok=True)

    # Sanity check folders
    for sp in ("train","val","test"):
        for sub in ("A","B","label"):
            p = os.path.join(root, sp, sub)
            if not os.path.isdir(p):
                raise FileNotFoundError(f"Missing folder: {p}")
    print(f"[OK] Found LEVIR-CD at: {root}")

    # Datasets & loaders
    train_set = LevirCD(root, split="train", resize=CFG["resize"], aug=True)
    val_set   = LevirCD(root, split="val",   resize=CFG["resize"], aug=False)
    test_set  = LevirCD(root, split="test",  resize=CFG["resize"], aug=False)

    workers = 2 if torch.cuda.is_available() else 0
    pin_mem = True if torch.cuda.is_available() else False
    train_loader = DataLoader(train_set, batch_size=CFG["batch_size"], shuffle=True,
                              num_workers=workers, pin_memory=pin_mem)
    val_loader   = DataLoader(val_set,   batch_size=CFG["batch_size"], shuffle=False,
                              num_workers=workers, pin_memory=pin_mem)
    test_loader  = DataLoader(test_set,  batch_size=CFG["batch_size"], shuffle=False,
                              num_workers=workers, pin_memory=pin_mem)

    results = {}

    # ---------- Method A: Dense ----------
    dense_dir = os.path.join(out_dir, "dense"); os.makedirs(dense_dir, exist_ok=True)
    model = FCSiamConc(in_ch=3, feats=(32,64,128,256), use_sp=False).to(device)
    optim = torch.optim.Adam(model.parameters(), lr=CFG["lr"])
    best_f1, best_path = -1.0, os.path.join(dense_dir, "best.pth")
    print("\n==== Train: DENSE ====")
    for ep in range(1, CFG["epochs_per_method"]+1):
        tr = train_one_epoch(model, train_loader, optim, device)
        val = evaluate(model, val_loader, device)
        print(f"[Dense] Epoch {ep:02d}/{CFG['epochs_per_method']} | "
              f"TrainLoss {tr:.4f} | Val F1 {val['mean_F1']:.4f} | Val IoU {val['mean_IoU']:.4f}")
        if val["mean_F1"] > best_f1:
            best_f1 = val["mean_F1"]
            torch.save(model.state_dict(), best_path)
    # Test
    model.load_state_dict(torch.load(best_path, map_location=device))
    test_dense = evaluate(model, test_loader, device)
    print("==> Dense Test:", {k:round(v,4) for k,v in test_dense.items() if isinstance(v,float)})

    results["dense"] = {"test": test_dense, "ckpt": best_path}

    # ---------- Method B: SP-in ----------
    spin_dir = os.path.join(out_dir, "sp_in"); os.makedirs(spin_dir, exist_ok=True)
    model = FCSiamConc(in_ch=3, feats=(32,64,128,256), use_sp=True, sp_kwargs=CFG["sp"]).to(device)
    optim = torch.optim.Adam(model.parameters(), lr=CFG["lr"])
    best_f1, best_path = -1.0, os.path.join(spin_dir, "best.pth")
    print("\n==== Train: SP-in ====")
    for ep in range(1, CFG["epochs_per_method"]+1):
        tr = train_one_epoch(model, train_loader, optim, device)
        val = evaluate(model, val_loader, device)
        print(f"[SP-in] Epoch {ep:02d}/{CFG['epochs_per_method']} | "
              f"TrainLoss {tr:.4f} | Val F1 {val['mean_F1']:.4f} | Val IoU {val['mean_IoU']:.4f}")
        if val["mean_F1"] > best_f1:
            best_f1 = val["mean_F1"]
            torch.save(model.state_dict(), best_path)
    # Test
    model.load_state_dict(torch.load(best_path, map_location=device))
    test_spin = evaluate(model, test_loader, device)
    print("==> SP-in Test:", {k:round(v,4) for k,v in test_spin.items() if isinstance(v,float)})

    results["sp_in"] = {"test": test_spin, "ckpt": best_path}

    # ---------- Top-10 panels by ΔTP(SP-in − Dense) ----------
    dense_preds, spin_preds = test_dense["preds"], test_spin["preds"]   # [N,Hs,Ws]
    gts_small, names = test_spin["gts"], test_spin["names"]
    assert dense_preds.shape == spin_preds.shape == gts_small.shape

    # calculate ΔTP based on TP in GT
    deltas = []
    for i in range(len(names)):
        tp_dense = int(((dense_preds[i]==1) & (gts_small[i]==1)).sum())
        tp_spin  = int(((spin_preds[i]==1)  & (gts_small[i]==1)).sum())
        deltas.append(tp_spin - tp_dense)

    order = np.argsort(np.array(deltas))[::-1]  # from large to small
    topk = int(min(10, len(order)))
    pick = order[:topk].tolist()
    print(f"[Viz] Selected top-{topk} indices by ΔTP:", [int(i) for i in pick])

    panel_dir = os.path.join(out_dir, "panels_top10")
    os.makedirs(panel_dir, exist_ok=True)

    saved = []
    for rank, i in enumerate(pick, 1):
        name = names[i]
        panel = make_check_panel(
            root=CFG["root"],
            name=name,
            pred_dense_small=dense_preds[i],
            pred_spin_small=spin_preds[i],
            alpha=0.5,  # same as training overlay
        )
        fname = f"panel_rank{rank:02d}_dTP{int(deltas[i])}_{name if name.lower().endswith('.png') else name+'.png'}"
        fpath = os.path.join(panel_dir, fname)
        panel.save(fpath)
        saved.append(fpath)
        print(f"[Saved] {fpath}")

    # —— save concise summary JSON (avoid numpy object serialization error) ——
    summary = {
        "cfg": CFG,
        "dense": {"test_mean_F1": float(results["dense"]["test"]["mean_F1"]),
                  "test_mean_IoU": float(results["dense"]["test"]["mean_IoU"]),
                  "ckpt": results["dense"]["ckpt"]},
        "sp_in": {"test_mean_F1": float(results["sp_in"]["test"]["mean_F1"]),
                  "test_mean_IoU": float(results["sp_in"]["test"]["mean_IoU"]),
                  "ckpt": results["sp_in"]["ckpt"]},
        "top10_panels": saved,
        "top10_indices": [int(i) for i in pick],
        "top10_names": [str(names[i]) for i in pick],
        "top10_delta_tp": [int(deltas[i]) for i in pick],
    }
    with open(os.path.join(out_dir, "summary.json"), "w") as f:
        json.dump(summary, f, indent=2)
    print(f"[Saved] summary -> {os.path.join(out_dir, 'summary.json')}")

# =========================
# Run
# =========================
if __name__ == "__main__":
    main()


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
[OK] Found LEVIR-CD at: /content/drive/MyDrive/LEVIR-CD

==== Train: DENSE ====




[Dense] Epoch 01/30 | TrainLoss 0.5161 | Val F1 0.0710 | Val IoU 0.0530




[Dense] Epoch 02/30 | TrainLoss 0.3102 | Val F1 0.0391 | Val IoU 0.0247




[Dense] Epoch 03/30 | TrainLoss 0.2180 | Val F1 0.1702 | Val IoU 0.1009




[Dense] Epoch 04/30 | TrainLoss 0.1644 | Val F1 0.3599 | Val IoU 0.2609




[Dense] Epoch 05/30 | TrainLoss 0.1324 | Val F1 0.2828 | Val IoU 0.1929




[Dense] Epoch 06/30 | TrainLoss 0.1091 | Val F1 0.3303 | Val IoU 0.2262




[Dense] Epoch 07/30 | TrainLoss 0.0972 | Val F1 0.3742 | Val IoU 0.2666




[Dense] Epoch 08/30 | TrainLoss 0.0872 | Val F1 0.4008 | Val IoU 0.2913




[Dense] Epoch 09/30 | TrainLoss 0.0803 | Val F1 0.4443 | Val IoU 0.3300




[Dense] Epoch 10/30 | TrainLoss 0.0759 | Val F1 0.4863 | Val IoU 0.3774




[Dense] Epoch 11/30 | TrainLoss 0.0678 | Val F1 0.5090 | Val IoU 0.3953




[Dense] Epoch 12/30 | TrainLoss 0.0662 | Val F1 0.3039 | Val IoU 0.2141




[Dense] Epoch 13/30 | TrainLoss 0.0681 | Val F1 0.4398 | Val IoU 0.3249




[Dense] Epoch 14/30 | TrainLoss 0.0614 | Val F1 0.5518 | Val IoU 0.4409




[Dense] Epoch 15/30 | TrainLoss 0.0567 | Val F1 0.4938 | Val IoU 0.3836




[Dense] Epoch 16/30 | TrainLoss 0.0536 | Val F1 0.4904 | Val IoU 0.3831




[Dense] Epoch 17/30 | TrainLoss 0.0521 | Val F1 0.5310 | Val IoU 0.4245




[Dense] Epoch 18/30 | TrainLoss 0.0513 | Val F1 0.5716 | Val IoU 0.4682




[Dense] Epoch 19/30 | TrainLoss 0.0492 | Val F1 0.5723 | Val IoU 0.4703




[Dense] Epoch 20/30 | TrainLoss 0.0479 | Val F1 0.5239 | Val IoU 0.4302




[Dense] Epoch 21/30 | TrainLoss 0.0485 | Val F1 0.5849 | Val IoU 0.4820




[Dense] Epoch 22/30 | TrainLoss 0.0463 | Val F1 0.5734 | Val IoU 0.4720




[Dense] Epoch 23/30 | TrainLoss 0.0455 | Val F1 0.5711 | Val IoU 0.4674




[Dense] Epoch 24/30 | TrainLoss 0.0424 | Val F1 0.5574 | Val IoU 0.4515




[Dense] Epoch 25/30 | TrainLoss 0.0438 | Val F1 0.5777 | Val IoU 0.4735




[Dense] Epoch 26/30 | TrainLoss 0.0432 | Val F1 0.5681 | Val IoU 0.4643




[Dense] Epoch 27/30 | TrainLoss 0.0408 | Val F1 0.5952 | Val IoU 0.4977




[Dense] Epoch 28/30 | TrainLoss 0.0411 | Val F1 0.5777 | Val IoU 0.4758




[Dense] Epoch 29/30 | TrainLoss 0.0399 | Val F1 0.5969 | Val IoU 0.5001




[Dense] Epoch 30/30 | TrainLoss 0.0387 | Val F1 0.5617 | Val IoU 0.4559
==> Dense Test: {'mean_IoU': 0.5344, 'mean_F1': 0.6418}

==== Train: SP-in ====




[SP-in] Epoch 01/30 | TrainLoss 0.4848 | Val F1 0.0056 | Val IoU 0.0030




[SP-in] Epoch 02/30 | TrainLoss 0.3006 | Val F1 0.0070 | Val IoU 0.0038




[SP-in] Epoch 03/30 | TrainLoss 0.2097 | Val F1 0.0001 | Val IoU 0.0001




[SP-in] Epoch 04/30 | TrainLoss 0.1581 | Val F1 0.2738 | Val IoU 0.1826




[SP-in] Epoch 05/30 | TrainLoss 0.1277 | Val F1 0.3569 | Val IoU 0.2508




[SP-in] Epoch 06/30 | TrainLoss 0.1059 | Val F1 0.3568 | Val IoU 0.2525




[SP-in] Epoch 07/30 | TrainLoss 0.0948 | Val F1 0.4269 | Val IoU 0.3209




[SP-in] Epoch 08/30 | TrainLoss 0.0865 | Val F1 0.3628 | Val IoU 0.2612




[SP-in] Epoch 09/30 | TrainLoss 0.0794 | Val F1 0.4989 | Val IoU 0.3891




[SP-in] Epoch 10/30 | TrainLoss 0.0717 | Val F1 0.2863 | Val IoU 0.2033




[SP-in] Epoch 11/30 | TrainLoss 0.0683 | Val F1 0.4682 | Val IoU 0.3614




[SP-in] Epoch 12/30 | TrainLoss 0.0647 | Val F1 0.4760 | Val IoU 0.3666




[SP-in] Epoch 13/30 | TrainLoss 0.0623 | Val F1 0.5048 | Val IoU 0.3971




[SP-in] Epoch 14/30 | TrainLoss 0.0563 | Val F1 0.5351 | Val IoU 0.4287




[SP-in] Epoch 15/30 | TrainLoss 0.0555 | Val F1 0.5355 | Val IoU 0.4260




[SP-in] Epoch 16/30 | TrainLoss 0.0527 | Val F1 0.5709 | Val IoU 0.4676




[SP-in] Epoch 17/30 | TrainLoss 0.0500 | Val F1 0.5534 | Val IoU 0.4483




[SP-in] Epoch 18/30 | TrainLoss 0.0493 | Val F1 0.5648 | Val IoU 0.4594




[SP-in] Epoch 19/30 | TrainLoss 0.0485 | Val F1 0.5551 | Val IoU 0.4546




[SP-in] Epoch 20/30 | TrainLoss 0.0434 | Val F1 0.5779 | Val IoU 0.4817




[SP-in] Epoch 21/30 | TrainLoss 0.0444 | Val F1 0.5385 | Val IoU 0.4367




[SP-in] Epoch 22/30 | TrainLoss 0.0427 | Val F1 0.5904 | Val IoU 0.4909




[SP-in] Epoch 23/30 | TrainLoss 0.0452 | Val F1 0.5121 | Val IoU 0.4127




[SP-in] Epoch 24/30 | TrainLoss 0.0429 | Val F1 0.5959 | Val IoU 0.5004




[SP-in] Epoch 25/30 | TrainLoss 0.0404 | Val F1 0.6024 | Val IoU 0.5060




[SP-in] Epoch 26/30 | TrainLoss 0.0398 | Val F1 0.6048 | Val IoU 0.5079




[SP-in] Epoch 27/30 | TrainLoss 0.0385 | Val F1 0.6099 | Val IoU 0.5111




[SP-in] Epoch 28/30 | TrainLoss 0.0368 | Val F1 0.6029 | Val IoU 0.5031




[SP-in] Epoch 29/30 | TrainLoss 0.0370 | Val F1 0.6081 | Val IoU 0.5105




[SP-in] Epoch 30/30 | TrainLoss 0.0359 | Val F1 0.5753 | Val IoU 0.4790
==> SP-in Test: {'mean_IoU': 0.5373, 'mean_F1': 0.6456}
[Viz] Selected top-10 indices by ΔTP: [68, 69, 107, 106, 55, 33, 61, 66, 103, 18]
[Saved] /content/drive/MyDrive/LEVIR_CD_runs/panels_top10/panel_rank01_dTP2609_test_45.png
[Saved] /content/drive/MyDrive/LEVIR_CD_runs/panels_top10/panel_rank02_dTP1624_test_46.png
[Saved] /content/drive/MyDrive/LEVIR_CD_runs/panels_top10/panel_rank03_dTP1203_test_80.png
[Saved] /content/drive/MyDrive/LEVIR_CD_runs/panels_top10/panel_rank04_dTP1132_test_8.png
[Saved] /content/drive/MyDrive/LEVIR_CD_runs/panels_top10/panel_rank05_dTP790_test_33.png
[Saved] /content/drive/MyDrive/LEVIR_CD_runs/panels_top10/panel_rank06_dTP758_test_13.png
[Saved] /content/drive/MyDrive/LEVIR_CD_runs/panels_top10/panel_rank07_dTP693_test_39.png
[Saved] /content/drive/MyDrive/LEVIR_CD_runs/panels_top10/panel_rank08_dTP624_test_43.png
[Saved] /content/drive/MyDrive/LEVIR_CD_runs/panels_top10/panel_ran

In [None]:
# ===== Top-K visualization with crisp English labels (no retraining) =====
import os, json, numpy as np, torch
from PIL import Image, ImageDraw, ImageFont

TOPK = 10            # change to the number you want
ALPHA = 0.5          # overlay transparency
FONT_SCALE = 0.045   # font size ~ W * 0.045, can be adjusted
PAD = 12             # label padding

# ---------- helpers ----------
def load_raw_triplet(root: str, split: str, name: str):
    pa = os.path.join(root, split, "A", name)
    pb = os.path.join(root, split, "B", name)
    pg = os.path.join(root, split, "label", name)
    A = np.array(Image.open(pa).convert("RGB"), dtype=np.uint8)
    B = np.array(Image.open(pb).convert("RGB"), dtype=np.uint8)
    GT = (np.array(Image.open(pg).convert("L"), dtype=np.uint8) > 127).astype(np.uint8)
    return A, B, GT

def upsample_mask_to(mask_small: np.ndarray, target_hw):
    H, W = target_hw
    pil = Image.fromarray(mask_small.astype(np.uint8)*255).resize((W,H), Image.NEAREST)
    return (np.array(pil, dtype=np.uint8) > 127).astype(np.uint8)

def binarize(logits, thr=0.5):
    return (torch.sigmoid(logits) >= thr).to(torch.uint8)

@torch.no_grad()
def eval_preds(model, loader, device, thr=0.5):
    model.eval()
    all_preds, all_gts, all_names = [], [], []
    for a,b,y,names in loader:
        a,b,y = a.to(device), b.to(device), y.to(device)
        pred = binarize(model(a,b), thr=thr).cpu().numpy().astype(np.uint8)[:,0]
        gt   = (y.cpu().numpy() > 0.5).astype(np.uint8)[:,0]
        all_preds.append(pred); all_gts.append(gt); all_names += list(names)
    return np.concatenate(all_preds, 0), np.concatenate(all_gts, 0), all_names

def _load_font(px):
    # use bold TrueType font; if not available, use default and add bold stroke
    for p in [
        "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf",
        "/usr/share/fonts/truetype/liberation/LiberationSans-Bold.ttf",
    ]:
        if os.path.exists(p):
            return ImageFont.truetype(p, px)
    return ImageFont.load_default()

def draw_label(img: Image.Image, xy, text, box_alpha=180):
    """clear English label in top-left: bold + white text + semi-transparent white background"""
    draw = ImageDraw.Draw(img, mode="RGBA")
    W, H = img.size
    # dynamic font size: related to single tile width (here img is the whole panel)
    tile_w = W // 3
    fs = max(16, int(tile_w * FONT_SCALE))
    font = _load_font(fs)

    x, y = xy
    # text boundary
    tw, th = draw.textbbox((0,0), text, font=font)[2:]
    bw = tw + PAD*2
    bh = th + PAD*2
    # semi-transparent white background (rounded corners)
    box = Image.new("RGBA", (bw, bh), (255,255,255,box_alpha))
    img.paste(box, (x, y), box)
    # white text + black stroke (for better clarity)
    draw.text((x+PAD, y+PAD), text, font=font, fill=(0,0,0,255), stroke_width=max(2, fs//10), stroke_fill=(255,255,255,255))

def make_panel_fixed_hd(root, name, pd_small, ps_small, alpha=0.5) -> Image.Image:
    A,B,GT = load_raw_triplet(root, "test", name)
    H,W = B.shape[:2]
    PD = upsample_mask_to(pd_small, (H,W))
    PS = upsample_mask_to(ps_small, (H,W))

    BOTH = (0,255,255); SPONLY=(0,255,0); DENSEONLY=(255,0,0)
    both = (PD==1)&(PS==1)&(GT==1)
    sp   = (PS==1)&(PD==0)&(GT==1)
    dn   = (PD==1)&(PS==0)&(GT==1)

    red_only   = np.zeros_like(B); red_only[dn] = DENSEONLY
    green_only = np.zeros_like(B); green_only[sp] = SPONLY
    color = np.zeros_like(B); color[both]=BOTH; color[sp]=SPONLY; color[dn]=DENSEONLY
    overlay = (B*(1-alpha)+color*alpha).astype(np.uint8)

    def gray3(x01): g=(x01.astype(np.uint8)*255); return np.stack([g,g,g],-1)
    row1 = np.concatenate([A, B, gray3(GT)], 1)
    row2 = np.concatenate([red_only, green_only, overlay], 1)
    panel = np.concatenate([row1, row2], 0)
    img = Image.fromarray(panel).convert("RGBA")

    tile_w, tile_h = W, H
    labels = [
        ("A (t1)", 0,0), ("B (t2)", 1,0), ("Ground Truth", 2,0),
        ("Dense-only (red, within GT)", 0,1),
        ("Adding SP-in (green, within GT)", 1,1),
        ("Overlay (cyan=both, green=SP-in, red=Dense)", 2,1),
    ]
    for text, cx, cy in labels:
        draw_label(img, (cx*tile_w+10, cy*tile_h+10), text)

    return img.convert("RGB")

# ---------- build test loader & reload ckpts ----------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
test_set  = LevirCD(CFG["root"], split="test", resize=CFG["resize"], aug=False)
test_loader = torch.utils.data.DataLoader(
    test_set, batch_size=CFG["batch_size"], shuffle=False,
    num_workers=2 if torch.cuda.is_available() else 0, pin_memory=torch.cuda.is_available()
)

# Dense preds
model_dense = FCSiamConc(in_ch=3, feats=(32,64,128,256), use_sp=False).to(device)
dense_ckpt = os.path.join(CFG["out_dir"], "dense", "best.pth")
model_dense.load_state_dict(torch.load(dense_ckpt, map_location=device))
dense_preds, gts_small, names = eval_preds(model_dense, test_loader, device)

# SP-in preds
model_spin = FCSiamConc(in_ch=3, feats=(32,64,128,256), use_sp=True, sp_kwargs=CFG["sp"]).to(device)
spin_ckpt = os.path.join(CFG["out_dir"], "sp_in", "best.pth")
model_spin.load_state_dict(torch.load(spin_ckpt, map_location=device))
spin_preds, _, _ = eval_preds(model_spin, test_loader, device)

# ---------- select Top-K by ΔTP ----------
deltas = []
for i in range(len(names)):
    tp_dense = int(((dense_preds[i]==1)&(gts_small[i]==1)).sum())
    tp_spin  = int(((spin_preds[i]==1)&(gts_small[i]==1)).sum())
    deltas.append(tp_spin - tp_dense)
order = np.argsort(np.asarray(deltas))[::-1]
pick = order[:min(TOPK, len(order))]

# ---------- save ----------
panel_dir = os.path.join(CFG["out_dir"], "panels_topk_hdlabels")
os.makedirs(panel_dir, exist_ok=True)
saved = []
for rank, idx in enumerate(pick, 1):
    name = names[int(idx)]
    panel = make_panel_fixed_hd(CFG["root"], name,
                                dense_preds[int(idx)], spin_preds[int(idx)], alpha=ALPHA)
    fname = f"panel_rank{rank:02d}_dTP{int(deltas[int(idx)])}_{name if name.lower().endswith('.png') else name+'.png'}"
    fpath = os.path.join(panel_dir, fname)
    panel.save(fpath, quality=95)
    saved.append(fpath)
    print(f"[Saved] {fpath}")

with open(os.path.join(panel_dir, "summary_hdlabels.json"), "w") as f:
    json.dump({
        "topk": int(len(pick)),
        "indices": [int(i) for i in pick],
        "names": [names[int(i)] for i in pick],
        "delta_tp": [int(deltas[int(i)]) for i in pick],
        "panels": saved,
        "dense_ckpt": dense_ckpt,
        "sp_in_ckpt": spin_ckpt,
        "alpha": ALPHA, "font_scale": FONT_SCALE
    }, f, indent=2)
print(f"[Saved] summary -> {os.path.join(panel_dir, 'summary_hdlabels.json')}")


[Saved] /content/drive/MyDrive/LEVIR_CD_runs/panels_topk_hdlabels/panel_rank01_dTP2609_test_45.png
[Saved] /content/drive/MyDrive/LEVIR_CD_runs/panels_topk_hdlabels/panel_rank02_dTP1624_test_46.png
[Saved] /content/drive/MyDrive/LEVIR_CD_runs/panels_topk_hdlabels/panel_rank03_dTP1203_test_80.png
[Saved] /content/drive/MyDrive/LEVIR_CD_runs/panels_topk_hdlabels/panel_rank04_dTP1132_test_8.png
[Saved] /content/drive/MyDrive/LEVIR_CD_runs/panels_topk_hdlabels/panel_rank05_dTP790_test_33.png
[Saved] /content/drive/MyDrive/LEVIR_CD_runs/panels_topk_hdlabels/panel_rank06_dTP758_test_13.png
[Saved] /content/drive/MyDrive/LEVIR_CD_runs/panels_topk_hdlabels/panel_rank07_dTP693_test_39.png
[Saved] /content/drive/MyDrive/LEVIR_CD_runs/panels_topk_hdlabels/panel_rank08_dTP624_test_43.png
[Saved] /content/drive/MyDrive/LEVIR_CD_runs/panels_topk_hdlabels/panel_rank09_dTP613_test_77.png
[Saved] /content/drive/MyDrive/LEVIR_CD_runs/panels_topk_hdlabels/panel_rank10_dTP541_test_115.png
[Saved] summary 