In [None]:
import warnings
warnings.filterwarnings("ignore")

In [None]:
!pip install -q timm==1.0.22
# because there are no dinov3 models in the default version

In [None]:
import timm
print([m for m in timm.list_models() if "dinov3" in m])

In [None]:
import os, random
from pathlib import Path

import numpy as np
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

In [None]:
class CFG:
    root = Path("/kaggle/input/oscd-for-change-detection/OSCD/")
    out_dir = Path("/kaggle/working")
    out_dir.mkdir(parents=True, exist_ok=True)
    pred_dir = Path("/kaggle/working/predictions")
    pred_dir.mkdir(parents=True, exist_ok=True)

    train_cities = [
        "aguasclaras","bercy","bordeaux","nantes","paris","rennes","saclay_e",
        "abudhabi","cupertino", "mumbai", "hongkong", "pisa"
    ]
    validation_cities = [ 
        "beihai", "beirut"
    ]
    test_cities = [
        "brasilia","montpellier","norcia","rio","saclay_w","valencia","dubai",
        "lasvegas","milano","chongqing"
    ]

    patch_size = 256
    stride = 128

    epochs = 100
    batch_size = 4
    freeze_backbone_epochs = 20
    unfreeze_blocks = 2      # unfreeze only last 2 ViT blocks
    lr_backbone = 5e-7      
    lr_decoder  = 3e-5       


    num_workers = 4
    seed = 42
    device = "cuda" if torch.cuda.is_available() else "cpu"

    vit_name = "vit_large_patch16_dinov3.sat493m"
    dinov3_ckpt = ""


In [None]:
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(CFG.seed)

In [None]:
import numpy as np
import random
import cv2

def random_block_mask(img, num_blocks=2, max_frac=0.15):
    H, W, C = img.shape
    out = img.copy()

    for _ in range(num_blocks):
        bh = int(random.uniform(0.05, max_frac) * H)
        bw = int(random.uniform(0.05, max_frac) * W)

        y = random.randint(0, H - bh)
        x = random.randint(0, W - bw)

        out[y:y+bh, x:x+bw] = 0

    return out

def paired_augment(pre, post, mask, scale_range=(0.8, 1.2), crop_size=256, blur_prob=0.3, jitter_prob=0.3):
    H, W, C = pre.shape

    # 1. Horizontal flip
    if random.random() < 0.5:
        pre = np.flip(pre, axis=1)
        post = np.flip(post, axis=1)
        mask = np.flip(mask, axis=1)

    # 2. Vertical flip
    if random.random() < 0.5:
        pre = np.flip(pre, axis=0)
        post = np.flip(post, axis=0)
        mask = np.flip(mask, axis=0)

    # 3. Rotation (0°, 90°, 180°, 270°)
    k = random.randint(0, 3)
    if k > 0:
        pre = np.rot90(pre, k)
        post = np.rot90(post, k)
        mask = np.rot90(mask, k)

    # 4. Random RESCALE (scale ∈ [0.8, 1.2])
    scale = random.uniform(scale_range[0], scale_range[1])
    new_h = int(H * scale)
    new_w = int(W * scale)

    pre_r = cv2.resize(pre, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
    post_r = cv2.resize(post, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
    mask_r = cv2.resize(mask, (new_w, new_h), interpolation=cv2.INTER_NEAREST)

    # 5. Random CROP back to crop_size
    if new_h > crop_size and new_w > crop_size:
        top = random.randint(0, new_h - crop_size)
        left = random.randint(0, new_w - crop_size)

        pre_r = pre_r[top:top+crop_size, left:left+crop_size]
        post_r = post_r[top:top+crop_size, left:left+crop_size]
        mask_r = mask_r[top:top+crop_size, left:left+crop_size]
    else:
        # If scaled too small → resize back (rare)
        pre_r  = cv2.resize(pre_r,  (crop_size, crop_size))
        post_r = cv2.resize(post_r, (crop_size, crop_size))
        mask_r = cv2.resize(mask_r, (crop_size, crop_size), interpolation=cv2.INTER_NEAREST)

    # 6. Gaussian BLUR
    if random.random() < blur_prob:
        k = random.choice([3,5])
        pre_r = cv2.GaussianBlur(pre_r,  (k,k), 0)
        post_r = cv2.GaussianBlur(post_r, (k,k), 0)
        # mask never blurred

    # 7. COLOR JITTER (brightness, contrast, saturation)
    if random.random() < jitter_prob:
        # brightness
        b = random.uniform(0.8, 1.2)
        pre_r  = np.clip(pre_r  * b, 0, 255)
        post_r = np.clip(post_r * b, 0, 255)

        # contrast
        c = random.uniform(0.8, 1.2)
        pre_r  = np.clip((pre_r  - 128) * c + 128, 0, 255)
        post_r = np.clip((post_r - 128) * c + 128, 0, 255)

        # saturation (convert to HSV)
        sat = random.uniform(0.8, 1.2)
        pre_hsv  = cv2.cvtColor(pre_r.astype(np.uint8),  cv2.COLOR_RGB2HSV)
        post_hsv = cv2.cvtColor(post_r.astype(np.uint8), cv2.COLOR_RGB2HSV)

        pre_hsv[:,:,1]  = np.clip(pre_hsv[:,:,1]  * sat, 0, 255)
        post_hsv[:,:,1] = np.clip(post_hsv[:,:,1] * sat, 0, 255)

        pre_r  = cv2.cvtColor(pre_hsv,  cv2.COLOR_HSV2RGB)
        post_r = cv2.cvtColor(post_hsv, cv2.COLOR_HSV2RGB)

    # 8. RANDOM OCCLUSION MASKING (same for pre & post)
    if random.random() < 0.4:
        Hc, Wc, _ = pre_r.shape
        bh = int(random.uniform(0.05, 0.15) * Hc)
        bw = int(random.uniform(0.05, 0.15) * Wc)
    
        y = random.randint(0, Hc - bh)
        x = random.randint(0, Wc - bw)
    
        pre_r[y:y+bh, x:x+bw] = 0
        post_r[y:y+bh, x:x+bw] = 0

    return pre_r.copy(), post_r.copy(), mask_r.copy()


In [None]:
# Normalization
IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406], dtype=torch.float32).view(3,1,1)
IMAGENET_STD  = torch.tensor([0.229, 0.224, 0.225], dtype=torch.float32).view(3,1,1)

class OSCDDataset(Dataset):
    def __init__(self, root: Path, cities, patch_size=256, stride=128, augment=False):
        self.root = Path(root)
        self.cities = cities
        self.patch_size = patch_size
        self.stride = stride
        self.augment = augment

        self.samples = []
        for city in cities:
            cdir = self.root / city
            pre = cdir/"img1.png"
            post = cdir/"img2.png"
            mask = cdir/"cm.png"
            if pre.exists() and post.exists() and mask.exists():
                self.samples.append((pre, post, mask))
            else:
                print(f"[WARN] Missing image in {city}")

        assert len(self.samples) > 0

        self.index = []
        for i, (pre_p, _, _) in enumerate(self.samples):
            pre_img = Image.open(pre_p).convert("RGB")
            W, H = pre_img.size
            for y in range(0, H - patch_size + 1, stride):
                for x in range(0, W - patch_size + 1, stride):
                    self.index.append((i, x, y))

    def __len__(self):
        return len(self.index)

    def __getitem__(self, idx):
        i, x, y = self.index[idx]
        pre_p, post_p, mask_p = self.samples[i]
        ps = self.patch_size

        pre = np.array(Image.open(pre_p).convert("RGB"))
        post = np.array(Image.open(post_p).convert("RGB"))
        mask = np.array(Image.open(mask_p).convert("RGB"))

        pre = pre[y:y+ps, x:x+ps]
        post = post[y:y+ps, x:x+ps]
        mask = mask[y:y+ps, x:x+ps]

        mask = (mask[...,0] > 127).astype(np.uint8)

        if self.augment:
            pre, post, mask = paired_augment(pre, post, mask)

        # tensor
        pre = torch.from_numpy(pre).permute(2,0,1).float()/255.
        post = torch.from_numpy(post).permute(2,0,1).float()/255.

        # normalize
        pre = (pre - IMAGENET_MEAN) / IMAGENET_STD
        post = (post - IMAGENET_MEAN) / IMAGENET_STD

        mask = torch.from_numpy(mask).long()

        return pre, post, mask


In [None]:
# Only last layer of features are extracted using this backbone
# class ViTBackboneFeatures(nn.Module):
#     def __init__(self, vit_name="vit_large_patch16_dinov3.sat493m"):
#         super().__init__()
#         self.backbone = timm.create_model(
#             vit_name,
#             pretrained=True,
#             features_only=True,
#         )
#         self.out_channels = self.backbone.feature_info.channels()[-1]

#     def forward(self, x):
#         feats = self.backbone(x)
#         f = feats[-1]
#         return f

# Mid and Last layers of features are extracted using this backbone
class ViTBackboneMultiScale(nn.Module):
    def __init__(self, vit_name, out_indices=(4, 8, 12)):
        super().__init__()
        self.backbone = timm.create_model(
            vit_name,
            pretrained=True,
            features_only=True,
            out_indices=out_indices
        )
        self.channels = self.backbone.feature_info.channels()

    def forward(self, x):
        return self.backbone(x)  # list of [B,C,h,w]

In [None]:
# Fusion from ChangeFormer
class Fusion(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.proj = nn.Conv2d(dim*3, dim, 1)

    def forward(self, f_pre, f_post):
        diff = torch.abs(f_pre - f_post)
        x = torch.cat([f_pre, f_post, diff], dim=1)
        return self.proj(x)

# Fusion with normalization and ReLU
class NormFusion(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.proj = nn.Sequential(
            nn.Conv2d(dim * 3, dim, kernel_size=1, bias=False),
            nn.BatchNorm2d(dim),
            nn.ReLU(inplace=True),
        )

    def forward(self, f_pre, f_post):
        diff = torch.abs(f_pre - f_post)
        fused = torch.cat([f_pre, f_post, diff], dim=1)
        return self.proj(fused)

# Fusion at mid and last layers
class MultiScaleFusion(nn.Module):
    def __init__(self, in_dim, out_dim=256):
        super().__init__()
        self.proj = nn.Sequential(
            nn.Conv2d(in_dim * 3, out_dim, 1, bias=False),
            nn.BatchNorm2d(out_dim),
            nn.ReLU(inplace=True),
        )

    def forward(self, f_pre, f_post):
        diff = torch.abs(f_pre - f_post)
        x = torch.cat([f_pre, f_post, diff], dim=1)
        return self.proj(x)

In [None]:
# FPN Decoder for multiscale feature extraction
class FPNDecoder(nn.Module):
    def __init__(self, in_channels, fpn_dim=256):
        super().__init__()
        self.lateral = nn.ModuleList([nn.Conv2d(c, fpn_dim, 1) for c in in_channels])

        self.head = nn.Sequential(
            nn.Conv2d(fpn_dim, fpn_dim, 3, padding=1, bias=False),
            nn.BatchNorm2d(fpn_dim),
            nn.ReLU(inplace=True),
            nn.Conv2d(fpn_dim, 1, 1)
        )

    def forward(self, feats):
        # feats ordered low->high resolution as returned by timm features_only
        x = self.lateral[-1](feats[-1])
        for i in reversed(range(len(feats) - 1)):
            x = F.interpolate(x, size=feats[i].shape[-2:], mode="bilinear", align_corners=False)
            x = x + self.lateral[i](feats[i])
        return self.head(x)  # (B,1,h,w)

# Decoder for multiscale feature extraction
class Decoder(nn.Module):
    def __init__(self, n_feats=3, feat_dim=256, hidden=256, n_blocks=3):
        super().__init__()
        in_ch = n_feats * feat_dim

        self.mix = nn.Sequential(
            nn.Conv2d(in_ch, hidden, 1, bias=False),
            nn.BatchNorm2d(hidden),
            nn.ReLU(inplace=True),
        )

        blocks = []
        for _ in range(n_blocks):
            blocks += [
                nn.Conv2d(hidden, hidden, 3, padding=1, bias=False),
                nn.BatchNorm2d(hidden),
                nn.ReLU(inplace=True),
            ]
        self.refine = nn.Sequential(*blocks)
        self.head = nn.Conv2d(hidden, 1, 1)

    def forward(self, feats):  # feats: list of [B,256,16,16]
        x = torch.cat(feats, dim=1)   # [B, 3*256, 16, 16]
        x = self.mix(x)
        x = self.refine(x)
        return self.head(x)           # [B,1,16,16]

In [None]:
class PositionEncodingSine(nn.Module):
    def __init__(self, num_pos_feats=128, temperature=10000):
        super().__init__()
        self.num_pos_feats = num_pos_feats
        self.temperature = temperature

    def forward(self, x):
        B, C, H, W = x.shape
        mask = torch.zeros(B, H, W, device=x.device, dtype=torch.bool)

        y_embed = (~mask).cumsum(1, dtype=torch.float32)
        x_embed = (~mask).cumsum(2, dtype=torch.float32)

        eps = 1e-6
        y_embed = y_embed / (y_embed[:, -1:, :] + eps)
        x_embed = x_embed / (x_embed[:, :, -1:] + eps)

        dim_t = torch.arange(self.num_pos_feats, device=x.device)
        dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)

        pos_x = x_embed[:, :, :, None] / dim_t
        pos_y = y_embed[:, :, :, None] / dim_t

        pos_x = torch.stack((pos_x[..., 0::2].sin(), pos_x[..., 1::2].cos()), dim=4).flatten(3)
        pos_y = torch.stack((pos_y[..., 0::2].sin(), pos_y[..., 1::2].cos()), dim=4).flatten(3)

        pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
        return pos


In [None]:
class PixelDecoder(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.lateral4 = nn.Conv2d(dim, dim, 1)
        self.lateral3 = nn.Conv2d(dim, dim, 1)
        self.lateral2 = nn.Conv2d(dim, dim, 1)

        self.out4 = nn.Conv2d(dim, dim, 3, padding=1)
        self.out3 = nn.Conv2d(dim, dim, 3, padding=1)
        self.out2 = nn.Conv2d(dim, dim, 3, padding=1)

    def forward(self, x):
        p4 = self.lateral4(x)
        p3 = F.interpolate(p4, scale_factor=2, mode="bilinear", align_corners=False)
        p3 = self.out3(p3)

        p2 = F.interpolate(p3, scale_factor=2, mode="bilinear", align_corners=False)
        p2 = self.out2(p2)

        return [p2, p3, p4]


In [None]:
class Mask2FormerLayer(nn.Module):
    def __init__(self, dim, nheads=8):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(dim, nheads, batch_first=True)
        self.cross_attn = nn.MultiheadAttention(dim, nheads, batch_first=True)

        self.linear1 = nn.Linear(dim, dim * 4)
        self.linear2 = nn.Linear(dim * 4, dim)
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.norm3 = nn.LayerNorm(dim)
        self.dropout = nn.Dropout(0.1)

    def forward(self, queries, key_value, mask):
        # self attention
        q = self.norm1(queries)
        q2, _ = self.self_attn(q, q, q)
        q = queries + self.dropout(q2)

        # masked cross attention
        q_norm = self.norm2(q)

        # flatten pixel features
        B, C, H, W = key_value.shape
        kv = key_value.flatten(2).transpose(1, 2)  # (B, HW, C)

        # mask → attention bias
        if mask is not None:
            m = mask.flatten(2).permute(0, 2, 1)  # (B, HW, 1)
            attn_mask = (m < 0.5).repeat(1, 1, q.size(1))  # False = allowed
        else:
            attn_mask = None

        q2, _ = self.cross_attn(q_norm, kv, kv,
                                attn_mask=None)
        q = q + self.dropout(q2)

        # FFN
        q_norm = self.norm3(q)
        f = self.linear2(F.relu(self.linear1(q_norm)))
        q = q + self.dropout(f)

        return q


In [None]:
class MaskHead(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.feat_proj = nn.Conv2d(dim, dim, 1)
        self.query_proj = nn.Linear(dim, dim)

    def forward(self, pixel_features, queries):
        B, C, H, W = pixel_features.shape
        feat = self.feat_proj(pixel_features)
        feat_flat = feat.view(B, C, H * W)

        q = self.query_proj(queries)

        logits = torch.einsum("bnd,bdp->bnp", q, feat_flat)

        masks = logits.view(B, -1, H, W)
        return masks


In [None]:
class Mask2FormerDecoder(nn.Module):
    def __init__(self, dim, num_queries=100, num_layers=4):
        super().__init__()
        self.dim = dim
        self.num_queries = num_queries
        self.num_layers = num_layers

        self.query_embed = nn.Parameter(torch.randn(num_queries, dim))

        self.pixel_decoder = PixelDecoder(dim)
        self.layers = nn.ModuleList([Mask2FormerLayer(dim) for _ in range(num_layers)])
        self.mask_head = MaskHead(dim)

    def forward(self, fused_features):
        B = fused_features.size(0)

        p2, p3, p4 = self.pixel_decoder(fused_features)
        pixel_features = p2
        Hf, Wf = pixel_features.shape[-2:]

        queries = self.query_embed.unsqueeze(0).repeat(B, 1, 1)

        mask_for_next = None

        for layer in self.layers:
            queries = layer(queries, pixel_features, mask_for_next)
            masks_q = self.mask_head(pixel_features, queries)
            mask_for_next = masks_q.sigmoid().mean(1, keepdim=True)

        # final mask = average of query masks
        final_mask = masks_q.mean(1, keepdim=True)
        # upsample to patch resolution (assuming 4x)
        final_mask = F.interpolate(
            final_mask, scale_factor=4, mode="bilinear", align_corners=False
        )

        return final_mask


In [None]:
# Only last layer of features are extracted using this model
# class SiameseDINOv3_M2F(nn.Module):
#     def __init__(self, vit_name):
#         super().__init__()
#         self.backbone = ViTBackboneFeatures(vit_name)
#         self.dim = self.backbone.out_channels
#         self.fusion = NormFusion(self.dim)
#         self.decoder = Mask2FormerDecoder(self.dim)

#     def forward(self, pre, post):
#         f_pre = self.backbone(pre)
#         f_post = self.backbone(post)
#         fused = self.fusion(f_pre, f_post)
#         out = self.decoder(fused)
#         return out

# Mid and Last layers of features are extracted using this model
class SiameseDINOv3(nn.Module):
    def __init__(self, vit_name, out_indices=(4, 8, 12), feat_dim=256, hidden=256, n_blocks=3):
        super().__init__()
        self.backbone = ViTBackboneMultiScale(vit_name, out_indices=out_indices)
        self.fusions = nn.ModuleList([MultiScaleFusion(c, feat_dim) for c in self.backbone.channels])
        self.decoder = Decoder(
            n_feats=len(self.backbone.channels),
            feat_dim=feat_dim,
            hidden=hidden,
            n_blocks=n_blocks
        )

    def forward(self, pre, post):
        feats_pre  = self.backbone(pre)
        feats_post = self.backbone(post)

        fused = [fus(a, b) for fus, a, b in zip(self.fusions, feats_pre, feats_post)]
        logits = self.decoder(fused)

        logits = F.interpolate(logits, size=pre.shape[-2:], mode="bilinear", align_corners=False)
        return logits  # (B,1,H,W)


In [None]:
# Binary Cross-Entropy
# def bce_loss(logits, targets, pos_weight=3.0):
#     targets = targets.unsqueeze(1).float()
#     pw = torch.tensor([pos_weight], device=logits.device)
#     return F.binary_cross_entropy_with_logits(
#         logits, targets, pos_weight=pw
#     )

# Binary Cross-Entropy + DICE
# def dice_loss(probs, targets, eps=1e-6):
#     targets = targets.unsqueeze(1).float()
#     inter = (probs * targets).sum(dim=(2,3))
#     union = probs.sum(dim=(2,3)) + targets.sum(dim=(2,3))
#     dice = (2*inter + eps) / (union + eps)
#     return 1 - dice.mean()

# def bce_dice_loss(logits, targets, pos_weight=3.0):
#     pw = torch.tensor([pos_weight], device=logits.device)
#     bce = F.binary_cross_entropy_with_logits(
#         logits, targets.unsqueeze(1).float(), pos_weight=pw
#     )
#     probs = torch.sigmoid(logits)
#     d = dice_loss(probs, targets)
#     return bce + d

# Focal loss
def focal_loss_with_logits(logits, targets, alpha=0.5, gamma=1.0):
    targets = targets.unsqueeze(1).float()
    bce = F.binary_cross_entropy_with_logits(logits, targets, reduction="none")
    p = torch.sigmoid(logits)
    pt = p*targets + (1-p)*(1-targets)          # prob of the true class
    w = alpha*targets + (1-alpha)*(1-targets)   # class weight
    loss = w * (1-pt).pow(gamma) * bce
    return loss.mean()


In [None]:
train_ds = OSCDDataset(CFG.root, CFG.train_cities, CFG.patch_size, CFG.stride, augment=True)
val_ds = OSCDDataset(CFG.root, CFG.validation_cities, CFG.patch_size, CFG.stride, augment=False)
test_ds  = OSCDDataset(CFG.root, CFG.test_cities,  CFG.patch_size, CFG.stride, augment=False)

train_loader = DataLoader(train_ds, batch_size=CFG.batch_size, shuffle=True,
                          num_workers=CFG.num_workers, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=CFG.batch_size, shuffle=False,
                        num_workers=CFG.num_workers, pin_memory=True)
test_loader  = DataLoader(test_ds, batch_size=CFG.batch_size, shuffle=False,
                          num_workers=CFG.num_workers, pin_memory=True)

len(train_ds), len(val_ds), len(test_ds)


In [None]:
@torch.no_grad()
def evaluate(model, loader, device, threshold=0.5):
    model.eval()

    tp = fp = fn = tn = 0

    for pre, post, mask in loader:
        pre, post, mask = pre.to(device), post.to(device), mask.to(device).long()

        logits = model(pre, post)
        probs = torch.sigmoid(logits)
        pred = (probs > threshold).long().squeeze(1)

        tp += ((pred == 1) & (mask == 1)).sum().item()
        fp += ((pred == 1) & (mask == 0)).sum().item()
        fn += ((pred == 0) & (mask == 1)).sum().item()
        tn += ((pred == 0) & (mask == 0)).sum().item()

    eps = 1e-6

    # Metrics
    precision = tp / (tp + fp + eps)
    recall    = tp / (tp + fn + eps)
    f1        = (2 * precision * recall) / (precision + recall + eps)
    iou       = tp / (tp + fp + fn + eps)

    # Overall Accuracy (OA)
    oa = (tp + tn) / (tp + tn + fp + fn + eps)

    return iou, f1, precision, recall, oa
@torch.no_grad()
def evaluate(model, loader, device, threshold=0.5):
    model.eval()

    tp = fp = fn = tn = 0

    for pre, post, mask in loader:
        pre, post, mask = pre.to(device), post.to(device), mask.to(device).long()

        logits = model(pre, post)
        probs = torch.sigmoid(logits)
        pred = (probs > threshold).long().squeeze(1)

        tp += ((pred == 1) & (mask == 1)).sum().item()
        fp += ((pred == 1) & (mask == 0)).sum().item()
        fn += ((pred == 0) & (mask == 1)).sum().item()
        tn += ((pred == 0) & (mask == 0)).sum().item()

    eps = 1e-6

    # Metrics
    precision = tp / (tp + fp + eps)
    f1        = (2 * precision * (tp / (tp + fn + eps))) / \
                (precision + (tp / (tp + fn + eps)) + eps)
    iou       = tp / (tp + fp + fn + eps)

    # Accuracies
    change_acc    = tp / (tp + fn + eps)   # OSCD Change accuracy
    no_change_acc = tn / (tn + fp + eps)   # OSCD No-change accuracy
    oa            = (tp + tn) / (tp + tn + fp + fn + eps)

    return iou, f1, precision, oa, change_acc, no_change_acc


In [None]:
model = SiameseDINOv3(CFG.vit_name).to(CFG.device)

def unfreeze_last_vit_blocks(model, n_blocks=2):
    # Freeze all ViT params, then unfreeze only last n transformer blocks + final norm.
    vit = model.backbone.backbone.model
    blocks = vit.blocks

    # freeze all vit params
    for p in vit.parameters():
        p.requires_grad = False

    # unfreeze last n blocks
    for blk in blocks[-n_blocks:]:
        for p in blk.parameters():
            p.requires_grad = True

    # unfreeze final norm
    if hasattr(vit, "norm"):
        for p in vit.norm.parameters():
            p.requires_grad = True

# initial freeze
for p in model.backbone.parameters():
    p.requires_grad = False

optimizer = torch.optim.AdamW([
    {"params": model.fusions.parameters(),  "lr": CFG.lr_decoder},
    {"params": model.decoder.parameters(), "lr": CFG.lr_decoder},
    {"params": filter(lambda p: p.requires_grad, model.backbone.parameters()), "lr": CFG.lr_backbone},
], weight_decay=1e-4)

scheduler = torch.optim.lr_scheduler.LinearLR(
    optimizer, start_factor=1.0, end_factor=0.0, total_iters=CFG.epochs
)

history = {"loss": [], "iou": [], "f1": [], "prec": [], "oa": [], "chg": [], "no_chg": []}
best_iou = 0.0
has_unfrozen = False

for epoch in range(CFG.epochs):
    model.train()

    # Partial unfreeze trigger (only once)
    if (epoch >= CFG.freeze_backbone_epochs) and (not has_unfrozen):
        print(f"[INFO] Unfreezing last {CFG.unfreeze_blocks} ViT blocks at epoch {epoch}")
        unfreeze_last_vit_blocks(model, n_blocks=CFG.unfreeze_blocks)
        has_unfrozen = True

        # rebuild optimizer so newly-unfrozen backbone params are included
        optimizer = torch.optim.AdamW([
            {"params": model.fusions.parameters(),  "lr": CFG.lr_decoder},
            {"params": model.decoder.parameters(), "lr": CFG.lr_decoder},
            {"params": filter(lambda p: p.requires_grad, model.backbone.parameters()), "lr": CFG.lr_backbone},
        ], weight_decay=1e-4)

        scheduler = torch.optim.lr_scheduler.LinearLR(
            optimizer,
            start_factor=1.0,
            end_factor=0.0,
            total_iters=CFG.epochs - epoch
        )

    running = 0.0
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{CFG.epochs}")

    for pre, post, mask in pbar:
        pre  = pre.to(CFG.device, non_blocking=True)
        post = post.to(CFG.device, non_blocking=True)
        mask = mask.to(CFG.device, non_blocking=True)

        logits = model(pre, post)
        loss = focal_loss_with_logits(logits, mask)

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

        running += loss.item()
        pbar.set_postfix(loss=running / (pbar.n + 1e-6))

    scheduler.step()

    # ---- VALIDATION ----
    iou, f1, precision, oa, change_acc, no_change_acc = evaluate(
        model, val_loader, CFG.device, threshold=0.35
    )

    print(
        f"[VAL] IoU={iou:.4f} | F1={f1:.4f} | "
        f"P={precision:.4f} | OA={oa:.4f} | "
        f"ChangeAcc={change_acc:.4f} | NoChangeAcc={no_change_acc:.4f}"
    )

    # log
    history["loss"].append(running / max(1, len(train_loader)))
    history["iou"].append(iou)
    history["f1"].append(f1)
    history["prec"].append(precision)
    history["oa"].append(oa)
    history["chg"].append(change_acc)
    history["no_chg"].append(no_change_acc)

    # save best
    if iou > best_iou:
        best_iou = iou
        torch.save(model.state_dict(), CFG.out_dir / "best.pt")
        print("  saved best.pt")

print("Training done. Best IoU:", best_iou)


In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(20, 10))

# 1. Loss
plt.subplot(2, 4, 1)
plt.plot(history["loss"], '-o')
plt.title("Training Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")

# 2. IoU
plt.subplot(2, 4, 2)
plt.plot(history["iou"], '-o')
plt.title("Validation IoU")
plt.xlabel("Epoch")

# 3. F1
plt.subplot(2, 4, 3)
plt.plot(history["f1"], '-o')
plt.title("Validation F1")
plt.xlabel("Epoch")

# 4. Precision
plt.subplot(2, 4, 4)
plt.plot(history["prec"], '-o')
plt.title("Validation Precision")
plt.xlabel("Epoch")

# 5. Change Accuracy (OSCD)
plt.subplot(2, 4, 5)
plt.plot(history["chg"], '-o')
plt.title("Validation Change Acc.")
plt.xlabel("Epoch")

# 6. No-change Accuracy (OSCD)
plt.subplot(2, 4, 6)
plt.plot(history["no_chg"], '-o')
plt.title("Validation No-change Acc.")
plt.xlabel("Epoch")


# 7. Overall Accuracy
plt.subplot(2, 4, 7)
plt.plot(history["oa"], '-o')
plt.title("Validation Overall Accuracy")
plt.xlabel("Epoch")

plt.tight_layout()
plt.savefig("training_metrics.png", dpi=300)
plt.show()


In [None]:
import numpy as np
import torch

def tune_threshold(model, loader, device):
    thresholds = np.linspace(0.2, 0.6, 41)
    best = {"thr": 0.5, "f1": -1.0}

    for thr in thresholds:
        iou, f1, prec, oa, change_acc, no_change_acc = evaluate(
            model, loader, device, threshold=thr
        )

        if f1 > best["f1"]:
            best = {
                "thr": float(thr),
                "iou": float(iou),
                "f1": float(f1),
                "prec": float(prec),
                "oa": float(oa),
                "change_acc": float(change_acc),
                "no_change_acc": float(no_change_acc),
            }

    return best


model.load_state_dict(torch.load(CFG.out_dir/"best.pt", map_location=CFG.device))
best = tune_threshold(model, val_loader, CFG.device)

print(
    f"[VAL-THR] best_thr={best['thr']:.2f} | "
    f"IoU={best['iou']:.4f} | F1={best['f1']:.4f} | "
    f"Prec={best['prec']:.4f} | "
    f"ChangeAcc={best['change_acc']:.4f} | NoChangeAcc={best['no_change_acc']:.4f} | "
    f"OA={best['oa']:.4f}"
)


In [None]:
def sliding_indices(length, patch_size, stride):
    # If image is smaller than one patch, we will just use start at 0
    if length <= patch_size:
        return [0]

    idxs = list(range(0, length - patch_size + 1, stride))
    last_start = length - patch_size
    if idxs[-1] != last_start:
        idxs.append(last_start)
    return idxs



def pad_to_patch(x, patch_size):
    h, w, c = x.shape
    pad_h = max(0, patch_size - h)
    pad_w = max(0, patch_size - w)
    if pad_h > 0 or pad_w > 0:
        x = np.pad(x, ((0, pad_h), (0, pad_w), (0, 0)), mode="reflect")
    return x

@torch.no_grad()
def tta_predict_prob(model, pp_t, qq_t):
    preds = []

    # original
    logits = model(pp_t, qq_t)
    preds.append(torch.sigmoid(logits))

    # horizontal flip
    logits = model(torch.flip(pp_t, dims=[3]), torch.flip(qq_t, dims=[3]))
    preds.append(torch.flip(torch.sigmoid(logits), dims=[3]))

    # vertical flip
    logits = model(torch.flip(pp_t, dims=[2]), torch.flip(qq_t, dims=[2]))
    preds.append(torch.flip(torch.sigmoid(logits), dims=[2]))

    # hv flip (both)
    logits = model(torch.flip(pp_t, dims=[2,3]), torch.flip(qq_t, dims=[2,3]))
    preds.append(torch.flip(torch.sigmoid(logits), dims=[2,3]))

    prob = torch.stack(preds, dim=0).mean(dim=0)
    return prob[0, 0].cpu().numpy()

@torch.no_grad()
def predict_city(model, city, patch_size=256, stride=128, threshold=0.5):
    model.eval()

    city_dir = CFG.root / city
    pre  = np.array(Image.open(city_dir/"img1.png").convert("RGB"))
    post = np.array(Image.open(city_dir/"img2.png").convert("RGB"))

    H, W, _ = pre.shape
    out = np.zeros((H, W), dtype=np.float32)
    cnt = np.zeros((H, W), dtype=np.float32)

    ys = sliding_indices(H, patch_size, stride)
    xs = sliding_indices(W, patch_size, stride)

    for y in ys:
        for x in xs:
            pp = pre[y:y+patch_size, x:x+patch_size]
            qq = post[y:y+patch_size, x:x+patch_size]

            # actual (unpadded) region size
            h0, w0 = pp.shape[:2]

            # pad to patch_size x patch_size if needed
            pp = pad_to_patch(pp, patch_size)
            qq = pad_to_patch(qq, patch_size)

            # to tensor + normalize
            pp_t = torch.from_numpy(pp).permute(2,0,1).float() / 255.
            qq_t = torch.from_numpy(qq).permute(2,0,1).float() / 255.

            pp_t = (pp_t - IMAGENET_MEAN) / IMAGENET_STD
            qq_t = (qq_t - IMAGENET_MEAN) / IMAGENET_STD

            pp_t = pp_t.unsqueeze(0).to(CFG.device)
            qq_t = qq_t.unsqueeze(0).to(CFG.device)

            logits = model(pp_t, qq_t)
            prob = torch.sigmoid(logits)[0, 0].cpu().numpy()  # (patch_size, patch_size)
            # prob = tta_predict_prob(model, pp_t, qq_t) # Make use of TTA


            # write back ONLY the valid part (before padding)
            out[y:y+h0, x:x+w0] += prob[:h0, :w0]
            cnt[y:y+h0, x:x+w0] += 1.0

    out /= np.maximum(cnt, 1e-6)
    return out, (out > threshold).astype(np.uint8) * 255

In [None]:
print("\nTest Evaluation")
model.eval()

best_thr = best["thr"]

test_iou, test_f1, test_prec, test_oa, test_change_acc, test_no_change_acc = evaluate(
    model, test_loader, CFG.device, threshold=best_thr
)

print(
    f"[TEST] "
    f"IoU={test_iou:.4f} | "
    f"F1={test_f1:.4f} | "
    f"Prec={test_prec:.4f} | "
    f"ChangeAcc={test_change_acc:.4f} | "
    f"NoChangeAcc={test_no_change_acc:.4f} | "
    f"OA={test_oa:.4f}"
)

print("\nSaving sample Test predictions")

for city in CFG.test_cities:
    prob, pred = predict_city(model, city, threshold=best_thr)
    out_path = CFG.pred_dir / f"{city}_pred.png"
    Image.fromarray(pred).save(out_path)
    print(f"Saved: {out_path}")
