MAIN CODE

In [None]:
# train_lggf.py
# One-file script: training, 3-seed grid launcher, and aggregation.
# Implements: 0, 1A, 1B, 1D, 2, 4
# Requires: torch, torchvision, pycocotools, albumentations, numpy, tqdm

import os
import csv
import gc
import json
import copy
import sys
import math
import time
import random
import argparse
import tempfile
import warnings
import subprocess
import datetime
from glob import glob
from collections import OrderedDict

import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision

from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from albumentations.pytorch import ToTensorV2
import albumentations as A
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval

from torchvision.models.detection import (
    retinanet_resnet50_fpn,
    RetinaNet_ResNet50_FPN_Weights,
)
from torchvision.models.detection.anchor_utils import AnchorGenerator
from torchvision.ops.misc import FrozenBatchNorm2d
from tqdm import tqdm
import matplotlib.pyplot as plt

# --------------------------- Env & determinism --------------------------------
os.environ.setdefault("CUDA_LAUNCH_BLOCKING", "1")
os.environ.setdefault("TORCH_SHOW_CPP_STACKTRACES", "1")
os.environ.setdefault("CUBLAS_WORKSPACE_CONFIG", ":4096:8")
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "max_split_size_mb:128")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
warnings.filterwarnings("ignore", category=UserWarning, module="albumentations")
warnings.filterwarnings("ignore", category=UserWarning, module="torchvision")

try:
    from torch.amp import autocast, GradScaler
    autocast_kwargs = dict(device_type="cuda", dtype=torch.float16) if device.type == "cuda" else dict(dtype=torch.bfloat16)
    scaler = GradScaler(device="cuda") if device.type == "cuda" else None
except Exception:
    from torch.cuda.amp import autocast, GradScaler
    autocast_kwargs = dict(dtype=torch.float16) if device.type == "cuda" else dict(dtype=torch.bfloat16)
    scaler = GradScaler() if device.type == "cuda" else None

def set_seed(seed: int):
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed); torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    if hasattr(torch.backends, "cuda") and hasattr(torch.backends.cuda, "matmul"):
        torch.backends.cuda.matmul.allow_tf32 = False
    if hasattr(torch.backends, "cudnn") and hasattr(torch.backends.cudnn, "allow_tf32"):
        torch.backends.cudnn.allow_tf32 = False
    try:
        torch.use_deterministic_algorithms(True)
    except Exception:
        pass
    # Albumentations
    try:
        if hasattr(A, "set_seed"):
            A.set_seed(seed)
        else:
            from albumentations import random_utils
            random_utils.set_seed(seed)
    except Exception:
        pass

# ----------------------------- Args -------------------------------------------
def parse_args():
    p = argparse.ArgumentParser("lgf trainer with grid + aggregation")

    # What to train
    p.add_argument("--mode", type=str, default="baseline",
                   choices=["baseline","se","cbam","lgf_sum","lgf_softmax","lgf_gated","lgf_gated_spatial"])
    p.add_argument("--insert-level", type=str, default="C3", choices=["C3","C4","C5"])  # 1B
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--exp-name", type=str, default=None)

    # Dataset flags (1A)
    p.add_argument("--dataset", type=str, default="coco_nw",
                   choices=["coco_nw","coco_weather","acdc","custom"])
    p.add_argument("--train-img", type=str, default=None)
    p.add_argument("--train-ann", type=str, default=None)
    p.add_argument("--val-img", type=str, default=None)
    p.add_argument("--val-ann", type=str, default=None)

    # Performance knobs (1D)
    p.add_argument("--num-workers", type=int, default=8)
    p.add_argument("--prefetch-factor", type=int, default=4)

    # Training schedule
    p.add_argument("--epochs", type=int, default=80)
    p.add_argument("--batch-size", type=int, default=4)
    p.add_argument("--accum-steps", type=int, default=4)
    p.add_argument("--base-lr", type=float, default=0.005)
    p.add_argument("--warmup-epochs", type=int, default=2)
    p.add_argument("--lr-milestones", type=int, nargs="*", default=[40,60])
    p.add_argument("--img-size", type=int, default=640)

    # Grid runner (2)
    p.add_argument("--run-grid", action="store_true")
    p.add_argument("--grid-seeds", type=int, nargs="*", default=[42,1337,2025])
    p.add_argument("--grid-datasets", type=str, nargs="*", default=None)
    p.add_argument("--grid-levels", type=str, nargs="*", default=["C3"])
    p.add_argument("--grid-modes", type=str, nargs="*", default=None)
    p.add_argument("--subprocess", action="store_true", help="force subprocess grid even in notebook")

    # Aggregation (4)
    p.add_argument("--aggregate", action="store_true")
    p.add_argument("--agg-datasets", type=str, nargs="*", default=None)
    p.add_argument("--agg-levels", type=str, nargs="*", default=["C3"])
    p.add_argument("--agg-modes", type=str, nargs="*", default=None)
    p.add_argument("--agg-seeds", type=int, nargs="*", default=[42,1337,2025])

    # Jupyter friendliness
    args, _ = p.parse_known_args()
    return args

args = parse_args()

# ------------------------- Experiment configs ---------------------------------
BASELINE_CONFIG = dict(BRANCH_PRESET="none", gating_type="none", block_type="none",
                       description="BASELINE: RetinaNet without custom blocks")

lgf_SUM_CONFIG      = dict(BRANCH_PRESET="local_global", gating_type="sum",          block_type="lgf", description="lgf: local+global, sum")
lgf_SOFTMAX_CONFIG  = dict(BRANCH_PRESET="local_global", gating_type="softmax",      block_type="lgf", description="lgf: local+global, softmax")
lgf_GATED_CONFIG    = dict(BRANCH_PRESET="local_global", gating_type="gated",        block_type="lgf", description="lgf: local+global, sigmoid gate")
lgf_GATED_SP_CONFIG = dict(BRANCH_PRESET="local_global", gating_type="gated_spatial",block_type="lgf", description="lgf: local+global, spatial gate")

CONFIG_MAP = {
    "baseline": BASELINE_CONFIG,
    "lgf_sum": lgf_SUM_CONFIG,
    "lgf_softmax": lgf_SOFTMAX_CONFIG,
    "lgf_gated": lgf_GATED_CONFIG,
    "lgf_gated_spatial": lgf_GATED_SP_CONFIG,
    "se": dict(BRANCH_PRESET="none", gating_type="none", block_type="se",   description="SE before FPN"),
    "cbam": dict(BRANCH_PRESET="none", gating_type="none", block_type="cbam", description="CBAM before FPN"),
}

CURRENT_CONFIG = CONFIG_MAP[args.mode]
LEVEL_MAP = {"C3":"layer2", "C4":"layer3", "C5":"layer4"}  # 1B

# ------------------------- Anchors & helpers ----------------------------------
def _probe_feature_names(backbone: nn.Module) -> list:
    with torch.no_grad():
        x = torch.zeros(1,3,args.img_size,args.img_size, device=next(backbone.parameters()).device)
        feats = backbone(x)
        if not isinstance(feats, OrderedDict):
            raise RuntimeError("Backbone must return OrderedDict of features")
        return list(feats.keys())

_SIZE_MAP = {
    "0": (32,48,64),
    "1": (64,96,128),
    "2": (128,192,256),
    "3": (256,384,512),
    "p6": (256,384,512),
    "pool": (384,512,640),
    "p7": (384,512,640),
}

FORCE_STOCK_ANCHORS = False
_STOCK_AG = None
def get_stock_anchor_generator():
    global _STOCK_AG
    if _STOCK_AG is None:
        _STOCK_AG = retinanet_resnet50_fpn(weights=RetinaNet_ResNet50_FPN_Weights.COCO_V1).anchor_generator
    return _STOCK_AG

def make_anchor_generator_for(backbone: nn.Module) -> AnchorGenerator:
    names = _probe_feature_names(backbone)
    try:
        sizes = tuple(_SIZE_MAP[n] for n in names)
    except KeyError as e:
        raise RuntimeError(f"No anchor size tuple for feature '{e.args[0]}'. Backbone keys={names}")
    ratios = ((0.5,1.0,2.0),) * len(sizes)
    return AnchorGenerator(sizes=sizes, aspect_ratios=ratios)

# ---------------------------- Blocks ------------------------------------------
class StableLGFBlock(nn.Module):
    def __init__(self, channels, branches=("local","global"), gating_type="sum", norm_groups=32, squeeze_ratio=16):
        super().__init__()
        self.branches = tuple(branches)
        self.gating_type = gating_type
        self.num_branches = len(self.branches)
        self.save_maps = False
        self.viz_cache = {}
        def GN(c): return nn.GroupNorm(norm_groups, c)

        # Local branch
        self.local = None
        if "local" in self.branches:
            self.local = nn.Sequential(
                nn.Conv2d(channels, channels, 3, padding=1, groups=channels, bias=False),
                nn.Conv2d(channels, channels, 1, bias=False),
                GN(channels), nn.ReLU(inplace=True)
            )

        # Global branch
        self.global_branch = None
        if "global" in self.branches:
            self.global_branch = nn.Sequential(
                nn.AdaptiveAvgPool2d(1),
                nn.Conv2d(channels, channels, 1, bias=False),
                GN(channels), nn.SiLU(inplace=True)
            )

        # Gates
        if self.num_branches > 1:
            if self.gating_type == "softmax":
                self.temperature = nn.Parameter(torch.tensor(1.0))
                self.branch_weights = nn.Parameter(torch.ones(self.num_branches))
            elif self.gating_type == "gated":
                hid = max(channels // squeeze_ratio, 4)
                self.temperature = nn.Parameter(torch.tensor(1.0))
                self.gate_mlp = nn.Sequential(
                    nn.AdaptiveAvgPool2d(1), nn.Flatten(),
                    nn.Linear(channels, hid), nn.ReLU(inplace=True),
                    nn.Linear(hid, self.num_branches)
                )
                nn.init.zeros_(self.gate_mlp[-1].weight); nn.init.zeros_(self.gate_mlp[-1].bias)
            elif self.gating_type == "gated_spatial":
                r = 4
                self.temperature = nn.Parameter(torch.tensor(1.0))
                self.gate_reduce = nn.Conv2d(channels, channels//r, 1, bias=False)
                self.gate_expand = nn.Conv2d(channels//r, 2*channels, 1, bias=True)
                self.gate_norm   = nn.GroupNorm(num_groups=norm_groups, num_channels=2*channels)
                nn.init.zeros_(self.gate_expand.weight); nn.init.zeros_(self.gate_expand.bias)

        self.gamma = nn.Parameter(torch.tensor(0.1))

    def enable_visualization(self, enable=True):
        self.save_maps = bool(enable)
        if not enable:
            self.viz_cache = {}

    def _broadcast(self, s, like):
        if s.dim() == 1:
            s = s.view(-1,1,1,1)
        return s.expand_as(like)

    def forward(self, x):
        feats = []
        L = self.local(x) if self.local is not None else None
        if L is not None: feats.append(L)
        G = None
        if self.global_branch is not None:
            g = self.global_branch(x)
            G = g.expand_as(x)
            feats.append(G)

        if len(feats) == 1:
            out = feats[0]
        else:
            if self.gating_type == "sum":
                out = feats[0] + feats[1]
                wL = torch.full_like(L, 0.5); wG = torch.full_like(G, 0.5)
            elif self.gating_type == "softmax":
                tau = F.softplus(self.temperature) + 1e-3
                w = F.softmax(self.branch_weights / tau, dim=0)
                out = w[0]*feats[0] + w[1]*feats[1]
                wL = self._broadcast(w[0], L); wG = self._broadcast(w[1], G)
            elif self.gating_type == "gated":
                with torch.cuda.amp.autocast(enabled=False):
                    logits = self.gate_mlp(x.float())
                    tau = F.softplus(self.temperature.float()) + 1e-3
                    logits = logits.clamp_(-15,15)
                    w32 = torch.sigmoid(logits / tau)
                w = w32.to(dtype=x.dtype)
                wL = self._broadcast(w[:,0], L); wG = self._broadcast(w[:,1], G)
                out = wL*L + wG*G
            elif self.gating_type == "gated_spatial":
                h = F.relu(self.gate_reduce(x), inplace=True)
                logits = self.gate_expand(h)
                logits = self.gate_norm(logits)
                with torch.cuda.amp.autocast(enabled=False):
                    logits32 = logits.float().clamp_(-15,15)
                    tau = F.softplus(self.temperature.float()) + 1e-3
                    w = torch.sigmoid(logits32 / tau)
                w = w.to(dtype=x.dtype)
                N, twoC, H, W = w.shape
                C = twoC//2
                w = w.view(N,2,C,H,W)
                wL, wG = w[:,0], w[:,1]
                out = wL*L + wG*G
            else:
                out = feats[0] + feats[1]

        out = x + self.gamma*out
        if self.save_maps:
            self.viz_cache = {"L": (L.detach() if L is not None else None),
                              "G": (G.detach() if G is not None else None)}
        return out

class SEBlock(nn.Module):
    def __init__(self, channels, reduction=16):
        super().__init__()
        self.avg = nn.AdaptiveAvgPool2d(1)
        self.fc  = nn.Sequential(
            nn.Linear(channels, channels//reduction, bias=False), nn.ReLU(inplace=True),
            nn.Linear(channels//reduction, channels, bias=False), nn.Sigmoid()
        )
    def forward(self, x):
        b,c,_,_ = x.shape
        y = self.fc(self.avg(x).view(b,c)).view(b,c,1,1)
        return x * y

# class CBAMBlock(nn.Module):
#     def __init__(self, channels, reduction=16, k=7):
#         super().__init__()
#         self.avg = nn.AdaptiveAvgPool2d(1)
#         self.maxp = nn.AdaptiveMaxPool2d(1)
#         self.fc = nn.Sequential(
#             nn.Conv2d(channels, channels//reduction, 1, bias=False), nn.ReLU(inplace=True),
#             nn.Conv2d(channels//reduction, channels, 1, bias=False)
#         )
#         self.sigc = nn.Sigmoid()
#         self.convs = nn.Conv2d(2,1,k,padding=k//2,bias=False)
#         self.sigs = nn.Sigmoid()
#     def forward(self, x):
#         ca = self.fc(self.avg(x)) + self.fc(self.maxp(x))
#         x = x * self.sigc(ca)
#         s = torch.cat([x.mean(1,True), x.amax(1,True)], dim=1)
#         return x * self.sigs(self.convs(s))

class CBAMBlock(nn.Module):
    """
    Determinism-safe CBAM.
    - Channel attention: avg pool + softmax-pooling over HxW (approximates max) -> shared MLP
    - Spatial attention: concat(mean over C, softmax-pooling over C) -> 7x7 conv
    The softmax pooling avoids adaptive_max_pool2d backward, which is non-deterministic on CUDA.
    """
    def __init__(self, channels, reduction=16, k=5, beta=20.0):
        super().__init__()
        self.beta = beta
        self.avg = nn.AdaptiveAvgPool2d(1)  # deterministic
        # shared MLP for channel attention
        self.fc = nn.Sequential(
            nn.Conv2d(channels, channels // reduction, 1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels // reduction, channels, 1, bias=False)
        )
        self.sigc = nn.Sigmoid()
        # spatial attention
        self.convs = nn.Conv2d(2, 1, k, padding=k // 2, bias=False)
        self.sigs = nn.Sigmoid()

    @torch.no_grad()
    def _normalize_stable(self, x, dim):
        # subtract max along 'dim' for numerical stability, keep graph by doing it outside no_grad in callers
        return x - x.max(dim=dim, keepdim=True).values

    def _softmax_pool_spatial(self, x):
        # softmax over H*W per channel, returns [B,C,1,1]
        B, C, H, W = x.shape
        x_flat = x.view(B, C, H * W)
        # stabilize
        x_norm = x_flat - x_flat.max(dim=2, keepdim=True).values
        w = F.softmax(self.beta * x_norm, dim=2)
        pooled = (w * x_flat).sum(dim=2)  # [B,C]
        return pooled.view(B, C, 1, 1)

    def _softmax_pool_channel(self, x):
        # softmax over C per spatial location, returns [B,1,H,W]
        B, C, H, W = x.shape
        x_hw_c = x.permute(0, 2, 3, 1)              # [B,H,W,C]
        x_norm = x_hw_c - x_hw_c.max(dim=3, keepdim=True).values
        w = F.softmax(self.beta * x_norm, dim=3)    # [B,H,W,C]
        pooled = (w * x_hw_c).sum(dim=3, keepdim=True)  # [B,H,W,1]
        return pooled.permute(0, 3, 1, 2)           # [B,1,H,W]

    def forward(self, x):
        # Channel attention
        avg_pool = self.avg(x)
        smx_pool = self._softmax_pool_spatial(x)
        ca = self.fc(avg_pool) + self.fc(smx_pool)
        x = x * self.sigc(ca)

        # Spatial attention
        s_mean = x.mean(dim=1, keepdim=True)
        s_softmax_max = self._softmax_pool_channel(x)
        s = torch.cat([s_mean, s_softmax_max], dim=1)
        sa = self.sigs(self.convs(s))
        return x * sa

# ----------------------------- Dataset (1A, 1D) -------------------------------
def select_dataset_by_name(name: str):
    # Defaults from your code for COCO Non-Weather; others can be overridden via flags.
    if name == "coco_nw":
        return (
            "/nas.dbms/asera/PROJECTS/DATASET/COCO/non_weather-mini/images/train2017_non_weather-2400k6c",
            "/nas.dbms/asera/PROJECTS/DATASET/COCO/non_weather-mini/annotations/mini_train2017_non_weather-2400k6c.json",
            "/nas.dbms/asera/PROJECTS/DATASET/COCO/non_weather-mini/images/val2017_non_weather-500k6c",
            "/nas.dbms/asera/PROJECTS/DATASET/COCO/non_weather-mini/annotations/mini_val2017_non_weather-500k6c.json"

            # "/root/COCO/images/train2017_non_weather-2400k6c",
            # "/root/COCO/annotations/mini_train2017_non_weather-2400k6c.json",
            # "/root/COCO/images/val2017_non_weather-500k6c",
            # "/root/COCO/annotations/mini_val2017_non_weather-500k6c.json"
        )
    elif name == "acdc":
        return (
            "/nas.dbms/asera/PROJECTS/DATASET/ACDC-1/ACDC-1-NEW/images/train",
            "/nas.dbms/asera/PROJECTS/DATASET/ACDC-1/ACDC-1-NEW/annotations/mini_train.json",
            "/nas.dbms/asera/PROJECTS/DATASET/ACDC-1/ACDC-1-NEW/images/val",
            "/nas.dbms/asera/PROJECTS/DATASET/ACDC-1/ACDC-1-NEW/annotations/mini_val.json"
        )
    elif name == "coco_weather":
        # Fill these to your synthetic-weather paths or pass --train-img/--train-ann/--val-img/--val-ann
        return (
            "/nas.dbms/asera/PROJECTS/DATASET/COCO/weather-mini/images/train2017_weather-2400k6c",
            "/nas.dbms/asera/PROJECTS/DATASET/COCO/weather-mini/annotations/mini_train2017_weather-2400k6c.json", 
            "/nas.dbms/asera/PROJECTS/DATASET/COCO/weather-mini/images/val2017_weather-500k6c",
            "/nas.dbms/asera/PROJECTS/DATASET/COCO/weather-mini/annotations/mini_val2017_weather-500k6c.json"
        
            # "/root/COCO/images/train2017_weather-2400k6c",
            # "/root/COCO/annotations/mini_train2017_weather-2400k6c.json",
            # "/root/COCO/images/val2017_weather-500k6c",
            # "/root/COCO/annotations/mini_val2017_weather-500k6c.json"

        )
        # raise RuntimeError("Set --train-img/--train-ann/--val-img/--val-ann for coco_weather.")
    else:  # custom
        raise RuntimeError("Use --train-img/--train-ann/--val-img/--val-ann for dataset=custom.")

def patch_annotations_once(ann_file):
    with open(ann_file,"r") as f: data = json.load(f)
    if "info" not in data:
        data["info"] = {"description":"Patched COCO dataset","version":"1.0"}
        with tempfile.NamedTemporaryFile("w+",suffix=".json",delete=False) as tmp:
            json.dump(data,tmp); tmp.flush()
            return tmp.name
    return ann_file

class COCODataset(torch.utils.data.Dataset):
    def __init__(self, img_folder, ann_file, transforms=None, train=False):
        self.img_folder = img_folder
        self.coco = COCO(ann_file)
        self.transforms = transforms
        self.train = train
        valid_cat_ids = sorted(self.coco.getCatIds())
        self.cat_id_to_label = {cid:i for i,cid in enumerate(valid_cat_ids)}
        self.label_to_cat_id = {v:k for k,v in self.cat_id_to_label.items()}
        # expose number of classes for model construction
        self.num_classes = len(valid_cat_ids)
        self.ids = sorted(self.coco.imgs.keys())
        # Optional ablation subset
        k_env = os.getenv("MAX_TRAIN_IMAGES")
        if self.train and k_env:
            k = min(int(k_env), len(self.ids))
            rng = np.random.default_rng(12345)
            self.ids = rng.choice(self.ids, size=k, replace=False).tolist()

    def __len__(self): return len(self.ids)

    def __getitem__(self, idx):
        img_id = self.ids[idx]
        info = self.coco.loadImgs(img_id)[0]
        path = os.path.join(self.img_folder, info["file_name"])
        img = cv2.imread(path)
        if img is None:
            img = np.zeros((info["height"], info["width"], 3), dtype=np.uint8)
        else:
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        ann_ids = self.coco.getAnnIds(imgIds=img_id)
        anns = self.coco.loadAnns(ann_ids)
        boxes, labels = [], []
        for a in anns:
            x,y,w,h = a["bbox"]
            if w>1 and h>1:
                boxes.append([x,y,x+w,y+h])
                labels.append(self.cat_id_to_label[a["category_id"]])

        if self.transforms:
            t = self.transforms(image=img, bboxes=boxes, labels=labels)
            img, boxes, labels = t["image"], t["bboxes"], t["labels"]
        if img.dtype == torch.uint8:
            img = img.float()/255.0
        boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1,4)
        labels = torch.tensor([int(l) for l in labels], dtype=torch.int64)
        target = {"boxes": boxes, "labels": labels, "image_id": torch.tensor([img_id]),
                  "orig_size": torch.tensor([info["width"], info["height"]])}
        return img, target

def get_transform(train=True):
    if train:
        return A.Compose([
            A.HorizontalFlip(p=0.5),
            A.RandomBrightnessContrast(p=0.5),
            ToTensorV2(),
        ], bbox_params=A.BboxParams(format='pascal_voc', label_fields=['labels'], min_area=1, min_visibility=0.1, check_each_transform=True))
    else:
        return A.Compose([ToTensorV2()],
            bbox_params=A.BboxParams(format='pascal_voc', label_fields=['labels'], min_area=1, min_visibility=0.1, check_each_transform=True))

def collate_fn(batch): return tuple(zip(*batch))

def build_datasets_and_loaders(seed, dataset_code, overrides):
    set_seed(seed)
    if overrides["train_img"] and overrides["train_ann"] and overrides["val_img"] and overrides["val_ann"]:
        tr_img, tr_ann, va_img, va_ann = overrides["train_img"], overrides["train_ann"], overrides["val_img"], overrides["val_ann"]
    else:
        tr_img, tr_ann, va_img, va_ann = select_dataset_by_name(dataset_code)
    va_ann = patch_annotations_once(va_ann)
    train_dataset = COCODataset(tr_img, tr_ann, transforms=get_transform(train=True),  train=True)
    val_dataset   = COCODataset(va_img, va_ann, transforms=get_transform(train=False), train=False)

    g = torch.Generator(); g.manual_seed(seed)
    common = dict(collate_fn=collate_fn, worker_init_fn=worker_init_fn, pin_memory=True, generator=g)
    if args.num_workers > 0:
        common.update(num_workers=args.num_workers, persistent_workers=True, prefetch_factor=args.prefetch_factor)

    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True,  **common)
    val_loader   = DataLoader(val_dataset,   batch_size=args.batch_size, shuffle=False, drop_last=False, **common)
    return train_dataset, val_dataset, train_loader, val_loader, va_ann

def worker_init_fn(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(int(worker_seed)); random.seed(int(worker_seed))
    try:
        if hasattr(A, "set_seed"):
            A.set_seed(int(worker_seed))
        else:
            from albumentations import random_utils
            random_utils.set_seed(int(worker_seed))
    except Exception:
        pass

# ------------------------------ Model -----------------------------------------
def convert_bn_to_gn(module, num_groups=32, convert_frozen=False):
    for name, child in module.named_children():
        if isinstance(child, nn.BatchNorm2d):
            setattr(module, name, nn.GroupNorm(num_groups=num_groups, num_channels=child.num_features))
        elif convert_frozen and isinstance(child, FrozenBatchNorm2d):
            setattr(module, name, nn.GroupNorm(num_groups=num_groups, num_channels=child.num_features))
        else:
            convert_bn_to_gn(child, num_groups=num_groups, convert_frozen=convert_frozen)
    return module

def get_block(channels, block_type, branches, gating_type):
    if block_type == "lgf":
        return StableLGFBlock(channels, branches=branches, gating_type=gating_type)
    if block_type == "se":
        return SEBlock(channels)
    if block_type == "cbam":
        return CBAMBlock(channels)
    return nn.Identity()

class CustomBackboneBlockBeforeFPN(nn.Module):
    def __init__(self, backbone_with_fpn, selected_levels):
        super().__init__()
        self.body = backbone_with_fpn.body
        self.fpn  = backbone_with_fpn.fpn

        with torch.no_grad():
            feats = self.body(torch.zeros(1,3,224,224))
            fpn_feats = self.fpn(feats)
        actual_keys = list(feats.keys())
        semantic_to_actual = {
            "layer1": actual_keys[0] if len(actual_keys)>0 else None,
            "layer2": actual_keys[1] if len(actual_keys)>1 else None,
            "layer3": actual_keys[2] if len(actual_keys)>2 else None,
            "layer4": actual_keys[3] if len(actual_keys)>3 else None,
        }
        wanted = []
        for lvl in selected_levels:
            if lvl in actual_keys: wanted.append(lvl)
            elif lvl in semantic_to_actual and semantic_to_actual[lvl] is not None: wanted.append(semantic_to_actual[lvl])
        self.selected_actual = wanted

        self.block_fpn_in = nn.ModuleDict()
        for k in self.selected_actual:
            C = feats[k].shape[1]
            self.block_fpn_in[str(k)] = get_block(C, CURRENT_CONFIG.get("block_type","none"),
                                                  ("local","global") if CURRENT_CONFIG["BRANCH_PRESET"]=="local_global" else (),
                                                  CURRENT_CONFIG["gating_type"])

        first_out = next(iter(fpn_feats.keys()))
        self.out_channels = fpn_feats[first_out].shape[1]

    def forward(self, x):
        feats = self.body(x)
        for k in self.selected_actual:
            feats[k] = self.block_fpn_in[str(k)](feats[k])
        return self.fpn(feats)

def build_model(num_classes, insert_level):
    pretrained = retinanet_resnet50_fpn(weights=RetinaNet_ResNet50_FPN_Weights.COCO_V1)
    convert_bn_to_gn(pretrained.backbone.body, num_groups=32, convert_frozen=False)

    sel_levels = [LEVEL_MAP[insert_level]] if CURRENT_CONFIG.get("block_type","none") != "none" else []
    if sel_levels:
        bb = CustomBackboneBlockBeforeFPN(pretrained.backbone, selected_levels=sel_levels)
    else:
        bb = pretrained.backbone

    ag = get_stock_anchor_generator() if FORCE_STOCK_ANCHORS else make_anchor_generator_for(bb)
    model = torchvision.models.detection.RetinaNet(
        backbone=bb,
        num_classes=num_classes,
        anchor_generator=ag,
        head=pretrained.head,
        transform=pretrained.transform,
        detections_per_img=100,
        nms_thresh=0.5,
        score_thresh=0.05,
    )

    # Update cls head for our class count
    in_ch = model.head.classification_head.cls_logits.in_channels
    n_anchors = model.head.classification_head.num_anchors
    model.head.classification_head.cls_logits = nn.Conv2d(in_ch, n_anchors*num_classes, kernel_size=3, padding=1)
    model.head.classification_head.num_classes = num_classes
    torch.nn.init.normal_(model.head.classification_head.cls_logits.weight, std=0.01)
    prior_prob = 0.01
    bias_value = -torch.log(torch.tensor((1.0 - prior_prob) / prior_prob))
    torch.nn.init.constant_(model.head.classification_head.cls_logits.bias, bias_value)
    return model

# ----------------------------- Eval & logging ---------------------------------
def coco_evaluation(model, data_loader, ann_file, device):
    coco_gt = COCO(ann_file)
    detections = []
    model.eval()
    with torch.no_grad():
        for images, targets in tqdm(data_loader, desc="Evaluating"):
            images = [im.to(device) for im in images]
            outputs = model(images)
            for i,out in enumerate(outputs):
                img_id = int(targets[i]["image_id"].item())
                boxes = out["boxes"].detach().cpu().numpy()
                scores = out["scores"].detach().cpu().numpy()
                labels = out["labels"].detach().cpu().numpy()
                for b,s,l in zip(boxes,scores,labels):
                    x1,y1,x2,y2 = b
                    cat = data_loader.dataset.label_to_cat_id[int(l)] if hasattr(data_loader.dataset,"label_to_cat_id") else int(l)
                    detections.append({"image_id":img_id,"category_id":cat,
                                       "bbox":[float(x1),float(y1),float(x2-x1),float(y2-y1)],
                                       "score":float(s)})
    if not detections:
        return {k:0.0 for k in ["mAP","AP50","AP75","AP_small","AP_medium","AP_large",
                                "AR1","AR10","AR100","AR_small","AR_medium","AR_large"]} | \
               {"AP_per_class":{}, "detailed_metrics":{}, "detection_count":0}

    coco_dt = coco_gt.loadRes(detections)
    coco_eval = COCOeval(coco_gt, coco_dt, iouType="bbox")
    coco_eval.params.iouThrs = np.array([0.5,0.55,0.6,0.65,0.7,0.75,0.8,0.85,0.9,0.95])
    coco_eval.params.maxDets = [1,10,100]
    coco_eval.evaluate(); coco_eval.accumulate(); coco_eval.summarize()
    stats = coco_eval.stats

    precisions = coco_eval.eval["precision"]
    K = precisions.shape[2]
    ap_per_class = []
    for k in range(K):
        p = precisions[:,:,k,0,2]
        p = p[p>-1]
        ap_per_class.append(np.mean(p) if p.size else float("nan"))
    names = [c["name"] for c in coco_gt.loadCats(coco_gt.getCatIds())]
    per_class = dict(zip(names, ap_per_class))

    recall = coco_eval.eval["recall"]
    ar_dets = []
    for mdi, md in enumerate([1,10,100]):
        r = recall[:,:,:,mdi]; r = r[r>-1]
        ar_dets.append(np.mean(r) if r.size else 0.0)
    # stats layout: [AP, AP50, AP75, APS, APM, APL, AR1, AR10, AR100, ARS, ARM, ARL]
    return {
        "mAP": stats[0], "AP50": stats[1], "AP75": stats[2],
        "AP_small": stats[3], "AP_medium": stats[4], "AP_large": stats[5],
        "AR1": stats[6], "AR10": stats[7], "AR100": stats[8],
        "AR_small": stats[9], "AR_medium": stats[10], "AR_large": stats[11],
        "AP_per_class": per_class,
        "detailed_metrics": {},
        "detection_count": len(detections),
    }

class ValidationLogger:
    def __init__(self, experiment_name):
        self.experiment = experiment_name
        self.log_dir = os.path.join("/nas.dbms/asera/validation_logs", experiment_name)
        os.makedirs(self.log_dir, exist_ok=True)
        self.csv_path = os.path.join(self.log_dir, "validation_results.csv")
        self._init_csv()
    def _init_csv(self):
        fields = ["epoch","timestamp","experiment","mAP","AP50","AP75","AP_small","AP_medium","AP_large",
                  "AR1","AR10","AR100","AR_small","AR_medium","AR_large","detection_count"]
        with open(self.csv_path,"w",newline="") as f:
            csv.DictWriter(f, fieldnames=fields).writeheader()
    def log(self, epoch, metrics, writer):
        fields = ["epoch","timestamp","experiment","mAP","AP50","AP75","AP_small","AP_medium","AP_large",
                  "AR1","AR10","AR100","AR_small","AR_medium","AR_large","detection_count"]
        row = {
            "epoch": epoch, "timestamp": datetime.datetime.now().isoformat(), "experiment": self.experiment,
            **{k: metrics[k] for k in fields if k in metrics}
        }
        with open(self.csv_path,"a",newline="") as f:
            csv.DictWriter(f, fieldnames=fields).writerow(row)
        # TB scalars
        for k,v in metrics.items():
            if isinstance(v,(int,float)):
                writer.add_scalar(f"{self.experiment}/Val/{k}", v, epoch)
        # JSON snapshot per epoch
        with open(os.path.join(self.log_dir, f"epoch_{epoch:03d}_results.json"), "w") as f:
            json.dump({"epoch":epoch,"timestamp":row["timestamp"],"experiment":self.experiment,"metrics":metrics}, f, indent=2)
        return metrics

# ------------------------------ Complexity/speed -------------------------------
def benchmark_inference(model, data_loader, device, max_images=200, warmup_batches=20):
    model.eval()
    timings = []; counted = 0
    with torch.no_grad():
        if device.type == "cuda":
            starter = torch.cuda.Event(enable_timing=True)
            ender = torch.cuda.Event(enable_timing=True)
            # warmup
            for bi,(images,_) in enumerate(data_loader):
                _ = model([im.to(device) for im in images])
                if bi+1 >= warmup_batches: break
            for (images,_) in data_loader:
                images = [im.to(device) for im in images]
                starter.record(); _ = model(images); ender.record()
                torch.cuda.synchronize()
                ms_per_img = starter.elapsed_time(ender)/max(len(images),1)
                timings.append(ms_per_img); counted += len(images)
                if counted >= max_images: break
            lat = float(np.mean(timings)) if timings else float("nan")
            return {"latency_ms_per_image": lat, "images_per_second": (1000.0/lat if lat>0 else float("nan"))}
        else:
            # CPU fallback
            for (images,_) in data_loader:
                images = [im.to(device) for im in images]
                t0 = time.perf_counter(); _ = model(images); dt = time.perf_counter()-t0
                timings.append(1000.0*dt/max(len(images),1)); counted += len(images)
                if counted >= max_images: break
            lat = float(np.mean(timings)) if timings else float("nan")
            return {"latency_ms_per_image": lat, "images_per_second": (1000.0/lat if lat>0 else float("nan"))}

def try_flops_params(model, image_size=(3, 800, 1333)):
    params_m = sum(p.numel() for p in model.parameters())/1e6
    try:
        from fvcore.nn import FlopCountAnalysis
        m = copy.deepcopy(model).to("cpu").eval()
        dummy = torch.zeros(1,*image_size)
        flops = FlopCountAnalysis(m, ([dummy],)).total()
        return {"params_M": float(params_m), "FLOPs_G": float(flops)/1e9}
    except Exception:
        try:
            from thop import profile
            m = copy.deepcopy(model).to("cpu").eval()
            dummy = torch.zeros(1,*image_size)
            macs,_ = profile(m, inputs=([dummy],), verbose=False)
            return {"params_M": float(params_m), "FLOPs_G": float(macs)/1e9}
        except Exception:
            return {"params_M": float(params_m), "FLOPs_G": None}

# ------------------------------ Train (best by mAP) ---------------------------
def train_one(experiment_name, train_dataset, val_dataset, train_loader, val_loader, patched_val_ann_file, insert_level):
    C = train_dataset.num_classes
    model = build_model(C, insert_level).to(device)

    # assert isolation: only the block we asked for
    has_lgf = any(isinstance(m, StableLGFBlock) for m in model.modules())
    has_cbam  = any(isinstance(m, CBAMBlock)       for m in model.modules())
    has_se    = any(isinstance(m, SEBlock)          for m in model.modules())
    bt = CURRENT_CONFIG.get("block_type","none")
    if bt == "lgf": assert has_lgf and not has_cbam and not has_se
    if bt == "cbam":  assert has_cbam  and not has_lgf and not has_se
    if bt == "se":    assert has_se    and not has_lgf and not has_cbam
    if bt == "none":  assert not (has_lgf or has_cbam or has_se)

    convert_bn_to_gn(model, num_groups=32, convert_frozen=False)

    # Warmup freeze: heads + pre-FPN blocks
    for p in model.parameters(): p.requires_grad = False
    for p in model.head.parameters(): p.requires_grad = True
    if hasattr(model.backbone, "block_fpn_in"):
        for p in model.backbone.block_fpn_in.parameters(): p.requires_grad = True

    ema = ModelEMA(model, decay=0.9999, device=device)
    val_logger = ValidationLogger(experiment_name)
    writer = SummaryWriter(f"/nas.dbms/asera/NEW-4.1.2/runs/{experiment_name}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}")

    trainable = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.SGD(trainable, lr=args.base_lr, momentum=0.9, weight_decay=1e-4)
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_milestones, gamma=0.1)

    best_map = 0.0
    best_epoch = -1
    best_ckpt = None
    unfroze = False

    for epoch in range(args.epochs):
        if not unfroze and epoch == 5:
            for p in model.parameters(): p.requires_grad = True
            optimizer = torch.optim.SGD(model.parameters(), lr=args.base_lr, momentum=0.9, weight_decay=1e-4)
            lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_milestones, gamma=0.1)
            unfroze = True

        # warmup
        if epoch < args.warmup_epochs:
            lr_scale = min(1., float(epoch+1)/args.warmup_epochs)
            for g in optimizer.param_groups: g["lr"] = args.base_lr * lr_scale

        model.train()
        ep_loss = 0.0
        prog = tqdm(train_loader, desc=f"{experiment_name} - Epoch {epoch+1}/{args.epochs}")
        for bi,(images,targets) in enumerate(prog):
            images = [im.to(device) for im in images]
            targets = [{k:v.to(device) for k,v in t.items()} for t in targets]
            if device.type == "cuda":
                with autocast(**autocast_kwargs):
                    loss_dict = model(images, targets)
                    loss = sum(loss_dict.values())/args.accum_steps
                scaler.scale(loss).backward()
                if (bi+1) % args.accum_steps == 0:
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0)
                    scaler.step(optimizer); scaler.update()
                    optimizer.zero_grad(set_to_none=True)
                    ema.update(model)
                ep_loss += loss.item()*args.accum_steps
                prog.set_postfix(loss=loss.item()*args.accum_steps, lr=optimizer.param_groups[0]["lr"])
            else:
                loss_dict = model(images, targets)
                loss = sum(loss_dict.values())/args.accum_steps
                loss.backward()
                if (bi+1) % args.accum_steps == 0:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0)
                    optimizer.step()
                    optimizer.zero_grad(set_to_none=True)
                    ema.update(model)
                ep_loss += loss.item()*args.accum_steps
                prog.set_postfix(loss=loss.item()*args.accum_steps, lr=optimizer.param_groups[0]["lr"])

        lr_scheduler.step()

        if (epoch+1) % 5 == 0 or epoch == args.epochs-1:
            eval_results = validate_model(ema.ema, epoch, writer, experiment_name, val_logger, val_loader, patched_val_ann_file)
            spd = benchmark_inference(ema.ema, val_loader, device)
            cmp = try_flops_params(ema.ema)
            print("[SPEED]", spd, "[COMPLEXITY]", cmp)
            if eval_results["mAP"] > best_map:  # best-by-mAP (fixes AP_small bias)
            # change to AP_small to match the paper
            # if eval_results["AP_small"] > best_map_small:
                best_map = eval_results["mAP"]
                # best_map_small = eval_results["AP_small"]
                best_epoch = epoch+1
                os.makedirs("/nas.dbms/asera/NEW", exist_ok=True)
                best_ckpt = os.path.join("/nas.dbms/asera/NEW",
                    f"BEST_{experiment_name}_epoch_{best_epoch}_map_{best_map:.4f}.pth")
                # best_ckpt_small = os.path.join("/nas.dbms/asera/NEW",
                #     f"BEST_{experiment_name}_epoch_{best_epoch}_ap_small_{best_map_small:.4f}.pth")
                torch.save({
                    "epoch": epoch,
                    "model_state_dict": model.state_dict(),
                    "ema_state_dict": ema.ema.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "scaler_state_dict": (scaler.state_dict() if scaler is not None else None),
                    "best_map": best_map,
                    # "best_map_small": best_map_small,
                    "eval_results": eval_results,
                    "experiment_name": experiment_name,
                    "config": CURRENT_CONFIG,
                }, best_ckpt)
                print(f"[SAVE] Best {experiment_name} by mAP: {best_map:.4f} @ epoch {best_epoch}")
                # print(f"[SAVE] Best {experiment_name} by AP_small: {best_map_small:.4f} @ epoch {best_epoch}")

        if (epoch+1) % 10 == 0:
            torch.cuda.empty_cache(); gc.collect()

    writer.close()
    return best_map, best_epoch, best_ckpt

def validate_model(model, epoch, writer, experiment_name, val_logger, val_loader, patched_val_ann_file):
    model.eval()
    eval_results = coco_evaluation(model, val_loader, patched_val_ann_file, device)
    # >>> ADD THIS BLOCK <<<
    try:
        from lggf_controls_toolkit import append_result
        os.makedirs("results", exist_ok=True)
        append_result("results/all_runs.csv", dict(
            exp=experiment_name,
            # if you added these as args, they’ll be picked up; else pass fixed strings
            gate_override=getattr(args, "gate_override", ""),   # ok if empty
            alpha_local=getattr(args, "alpha_local", ""),       # ok if empty
            insert_level=args.insert_level,
            seed=args.seed,
            mAP=eval_results["mAP"],
            AP_small=eval_results["AP_small"],
            AP50=eval_results["AP50"],
            AP75=eval_results["AP75"],
            AR100=eval_results["AR100"],
            dets=eval_results.get("detection_count", 0),
            ips=0.0,  # you log ips elsewhere; leave 0.0 here or wire in your speed dict
        ))
    except Exception as e:
        print(f"[WARN] append_result failed: {e}")
    # <<< END ADD >>>

    logged = val_logger.log(epoch, eval_results, writer)
    model.train()
    return logged

class ModelEMA:
    def __init__(self, model, decay=0.9999, device=None):
        self.ema = copy.deepcopy(model)
        self.ema.eval()
        self.decay = decay
        if device is not None: self.ema.to(device)
        for p in self.ema.parameters(): p.requires_grad_(False)
    def update(self, model):
        with torch.no_grad():
            msd = model.state_dict()
            for k,v in self.ema.state_dict().items():
                if not v.dtype.is_floating_point: continue
                src = msd.get(k); 
                if src is None: continue
                v.copy_(v * self.decay + src.to(v.device) * (1. - self.decay))

# ------------------------------- Aggregation (4) -------------------------------
def best_by_map(run_dir):
    files = sorted(glob(os.path.join(run_dir, "epoch_*_results.json")))
    best = None
    for f in files:
        d = json.load(open(f))
        m = d["metrics"]
        row = {"epoch": d["epoch"], **{k:m[k] for k in ["mAP","AP50","AP75","AP_small","AP_medium","AP_large"]}}
        if best is None or row["mAP"] > best["mAP"]:
            best = row
    return best

def best_by_apsmall(run_dir):
    files = sorted(glob(os.path.join(run_dir, "epoch_*_results.json")))
    best = None
    for f in files:
        d = json.load(open(f))
        m = d["metrics"]
        row = {"epoch": d["epoch"], **{k:m[k] for k in ["AP_small","AP_medium","AP_large","AP50","AP75","mAP","AR1","AR10","AR100","AR_small","AR_medium","AR_large"]}}
        if best is None or row["AP_small"] > best["AP_small"]:
            best = row
    return best

def aggregate_group(group_name, run_names):
    rows = []
    for r in run_names:
        path = os.path.join("/nas.dbms/asera/validation_logs", r)
        if not os.path.isdir(path):
            print(f"[SKIP] missing {path}")
            continue
        # rows.append(best_by_map(path))
        rows.append(best_by_apsmall(path))
    if not rows:
        print(f"[EMPTY] {group_name}")
        return
    def ms(key):
        vals = np.array([row[key] for row in rows], dtype=float)
        return vals.mean(), vals.std(ddof=1) if len(vals) > 1 else 0.0
    print("\n" + "="*70)
    print(group_name)
    for k in ["AP_small","AP_medium","AP_large","AP50","AP75","mAP","AR1","AR10","AR100","AR_small","AR_medium","AR_large"]:
        mu, sd = ms(k)
        print(f"{k}: {mu:.3f} ± {sd:.3f}  (epochs {[row['epoch'] for row in rows]})")

# ------------------------------- Grid runner (2) -------------------------------
def run_grid_subprocess(modes, datasets, levels, seeds):
    this = os.path.abspath(sys.argv[0])
    py = sys.executable
    for ds in datasets:
        for lvl in levels:
            for mode in modes:
                for s in seeds:
                    exp = f"{ds}_{lvl}_{mode}_s{s}"
                    cmd = [
                        py, "-u", this,
                        "--mode", mode,
                        "--dataset", ds,
                        "--insert-level", lvl,
                        "--seed", str(s),
                        "--exp-name", exp,
                        "--epochs", str(args.epochs),
                        "--batch-size", str(args.batch_size),
                        "--accum-steps", str(args.accum_steps),
                        "--num-workers", str(args.num_workers),
                        "--prefetch-factor", str(args.prefetch_factor),
                    ]
                    print("[RUN]", " ".join(cmd))
                    subprocess.run(cmd, check=True)

def run_grid_inprocess(modes, datasets, levels, seeds):
    # Clean in-process loop that rebuilds loaders for each seed/dataset (covers item 0)
    for ds in datasets:
        for lvl in levels:
            for mode in modes:
                global CURRENT_CONFIG
                CURRENT_CONFIG = CONFIG_MAP[mode]
                for s in seeds:
                    exp = f"{ds}_{lvl}_{mode}_s{s}"
                    overrides = dict(train_img=args.train_img, train_ann=args.train_ann,
                                     val_img=args.val_img, val_ann=args.val_ann)
                    train_dataset, val_dataset, train_loader, val_loader, patched_val = \
                        build_datasets_and_loaders(s, ds, overrides)
                    print(f"\n=== {exp} ===")
                    best_map, best_epoch, best_ckpt = train_one(exp, train_dataset, val_dataset,
                                                                train_loader, val_loader, patched_val, lvl)
                    print(f"[DONE] {exp}: best mAP {best_map:.4f} @ {best_epoch}; ckpt={best_ckpt}")

# ----------------------------------- main -------------------------------------
def main():
    print(f"Device: {device}")
    print(f"Mode={args.mode} Insert={args.insert_level} Dataset={args.dataset} Seed={args.seed}")

    # Grid
    if args.run_grid:
        modes = args.grid_modes or ["baseline","se","cbam","lgf_gated_spatial"]
        datasets = args.grid_datasets or ["coco_nw"]
        levels = args.grid_levels or ["C3"]
        seeds  = args.grid_seeds or [42,1337,2025]
        if args.subprocess and hasattr(sys, "argv") and sys.argv[0].endswith(".py"):
            run_grid_subprocess(modes, datasets, levels, seeds)
        else:
            run_grid_inprocess(modes, datasets, levels, seeds)
        return

    # Single run
    overrides = dict(train_img=args.train_img, train_ann=args.train_ann,
                     val_img=args.val_img, val_ann=args.val_ann)
    train_dataset, val_dataset, train_loader, val_loader, patched_val = \
        build_datasets_and_loaders(args.seed, args.dataset, overrides)

    global CURRENT_CONFIG
    exp = args.exp_name or f"{args.dataset}_{args.insert_level}_{args.mode}_s{args.seed}"
    print(f"Training on {len(train_dataset)} images; validating on {len(val_dataset)} images")
    print(f"Batch={args.batch_size}, Accum={args.accum_steps}, LR={args.base_lr}, Epochs={args.epochs}")
    print(f"Training on {len(train_dataset)} images; validating on {len(val_dataset)} images")
    print(f"Batch={args.batch-size if hasattr(args,'batch-size') else args.batch_size}, Accum={args.accum_steps}, LR={args.base_lr}, Epochs={args.epochs}")

    best_map, best_epoch, best_ckpt = train_one(exp, train_dataset, val_dataset,
                                                train_loader, val_loader, patched_val, args.insert_level)
    print(f"[RESULT] {exp}: best mAP {best_map:.4f} @ epoch {best_epoch}; ckpt={best_ckpt}")

    # Aggregate if asked
    if args.aggregate:
        agg_datasets = args.agg_datasets or [args.dataset]
        agg_levels   = args.agg_levels or [args.insert_level]
        agg_modes    = args.agg_modes  or [args.mode]
        agg_seeds    = args.agg_seeds
        for ds in agg_datasets:
            for lvl in agg_levels:
                for mode in agg_modes:
                    runs = [f"{ds}_{lvl}_{mode}_s{s}" for s in agg_seeds]
                    aggregate_group(f"Aggregate {ds} {lvl} {mode}", runs)

# if __name__ == "__main__":
#     main()


TRAINING & EVALUATION

In [None]:
# 3 seeds, all models (COCO‑NW, COCO-Weather, ACDC at C3):
run_grid_inprocess(["baseline","se","cbam","lgf_gated","lgf_softmax","lgf_sum","lgf_gated_spatial"], ["coco_nw"], ["coco_nw","coco_weather","acdc"], [42, 1337, 2025])

# Or you can run specific configurations, e.g.:
# # 3 seeds, LGGF gated‑spatial (COCO‑NW, C3):
# run_grid_inprocess(["lgf_gated_spatial"], ["coco_nw"], ["C3"], [2025])

3-SEEDS SUMMARY REPORT

In [None]:
# BY MODEL AGGREGATIONS
#=============BASELINE AGGREGATIONS =============#
# Baseline on COCO‑NW C3 over the 3 seeds
runs = [f"coco_nw_C3_baseline_s{s}" for s in [42,1337,2025]]
aggregate_group("Baseline C3 on COCO-NW", runs)

# Baseline on COCO‑Weather C3 over the 3 seeds
runs = [f"coco_weather_C3_baseline_s{s}" for s in [42,1337,2025]]
aggregate_group("Baseline C3 on COCO-Weather", runs)

# Baseline on ACDC C3 over the 3 seeds
runs = [f"acdc_C3_baseline_s{s}" for s in [42,1337,2025]]
aggregate_group("Baseline C3 on ACDC", runs)

#=============LGGF GATED SPATIAL AGGREGATIONS =============#
# LGGF gated_spatial on COCO‑NW C3 over the 3 seeds
runs = [f"coco_nw_C3_lgf_gated_spatial_s{s}" for s in [42,1337,2025]]
aggregate_group("LGGF gated_spatial C3 on COCO-NW", runs)

# LGGF gated_spatial on COCO_WEATHER C3 over the 3 seeds
runs = [f"coco_weather_C3_lgf_gated_spatial_s{s}" for s in [42,1337,2025]]
aggregate_group("LGGF gated_spatial C3 on COCO-Weather", runs)

# LGGF gated_spatial on ACDC C3 over the 3 seeds
runs = [f"acdc_C3_lgf_gated_spatial_s{s}" for s in [42,1337,2025]]
aggregate_group("LGGF gated_spatial C3 on ACDC", runs)

#=============LGGF SUM SPATIAL AGGREGATIONS =============#
# LGGF sum on COCO‑NW C3 over the 3 seeds
runs = [f"coco_nw_C3_lgf_sum_s{s}" for s in [42,1337,2025]]
aggregate_group("LGGF sum C3 on COCO-NW", runs)

# LGGF sum on COCO‑Weather C3 over the 3 seeds
runs = [f"coco_weather_C3_lgf_sum_s{s}" for s in [42,1337,2025]]
aggregate_group("LGGF sum C3 on COCO-Weather", runs)

# LGGF sum on ACDC C3 over the 3 seeds
runs = [f"acdc_C3_lgf_sum_s{s}" for s in [42,1337,2025]]
aggregate_group("LGGF sum C3 on ACDC", runs)

#=============LGGF SOFTMAX AGGREGATIONS =============#
# LGGF softmax on COCO‑NW C3 over the 3 seeds
runs = [f"coco_nw_C3_lgf_softmax_s{s}" for s in [42,1337,2025]]
aggregate_group("LGGF softmax C3 on COCO-NW", runs)

# LGGF softmax on COCO‑Weather C3 over the 3 seeds
runs = [f"coco_weather_C3_lgf_softmax_s{s}" for s in [42,1337,2025]]
aggregate_group("LGGF softmax C3 on COCO-Weather", runs)

# LGGF softmax on ACDC C3 over the 3 seeds
runs = [f"acdc_C3_lgf_softmax_s{s}" for s in [42,1337,2025]]
aggregate_group("LGGF softmax C3 on ACDC", runs)

#=============LGGF GATED AGGREGATIONS =============#
# LGGF gated on COCO‑NW C3 over the 3 seeds
runs = [f"coco_nw_C3_lgf_gated_s{s}" for s in [42,1337,2025]]
aggregate_group("LGGF gated C3 on COCO-NW", runs)

# LGGF gated on COCO‑Weather C3 over the 3 seeds
runs = [f"coco_weather_C3_lgf_gated_s{s}" for s in [42,1337,2025]]
aggregate_group("LGGF gated C3 on COCO-Weather", runs)

# LGGF gated on ACDC C3 over the 3 seeds
runs = [f"acdc_C3_lgf_gated_s{s}" for s in [42,1337,2025]]
aggregate_group("LGGF gated C3 on ACDC", runs)

#=============CBAM AGGREGATIONS =============#
# CBAM on COCO‑NW C3 over the 3 seeds
runs = [f"coco_nw_C3_cbam_s{s}" for s in [42,1337,2025]]
aggregate_group("CBAM C3 on COCO-NW", runs)

# CBAM on COCO‑Weather C3 over the 3 seeds
runs = [f"coco_weather_C3_cbam_s{s}" for s in [42,1337,2025]]
aggregate_group("CBAM C3 on COCO-Weather", runs)

# CBAM on ACDC C3 over the 3 seeds
runs = [f"acdc_C3_cbam_s{s}" for s in [42,1337,2025]]
aggregate_group("CBAM C3 on ACDC", runs)

#=============SE AGGREGATIONS =============#
# SE on COCO‑NW C3 over the 3 seeds
runs = [f"coco_nw_C3_se_s{s}" for s in [42,1337,2025]]
aggregate_group("SE C3 on COCO-NW", runs)

# SE on Weather C3 over the 3 seeds
runs = [f"coco_weather_C3_se_s{s}" for s in [42,1337,2025]]
aggregate_group("SE C3 on COCO-Weather", runs)

# SE on ACDC C3 over the 3 seeds
runs = [f"acdc_C3_se_s{s}" for s in [42,1337,2025]]
aggregate_group("SE C3 on ACDC", runs)


In [None]:
# BY DATASET AGGREGATIONS
#=============COCO-NW AGGREGATIONS =============#
# Baseline on COCO‑NW C3 over the 3 seeds
runs = [f"coco_nw_C3_baseline_s{s}" for s in [42,1337,2025]]
aggregate_group("Baseline C3 on COCO-NW", runs)

# CBAM on COCO‑NW C3 over the 3 seeds
runs = [f"coco_nw_C3_cbam_s{s}" for s in [42,1337,2025]]
aggregate_group("CBAM C3 on COCO-NW", runs)

# SE on COCO‑NW C3 over the 3 seeds
runs = [f"coco_nw_C3_se_s{s}" for s in [42,1337,2025]]
aggregate_group("SE C3 on COCO-NW", runs)

# LGGF gated_spatial on COCO‑NW C3 over the 3 seeds
runs = [f"coco_nw_C3_lgf_gated_spatial_s{s}" for s in [42,1337,2025]]
aggregate_group("LGGF gated_spatial C3 on COCO-NW", runs)

# LGGF sum on COCO‑NW C3 over the 3 seeds
runs = [f"coco_nw_C3_lgf_sum_s{s}" for s in [42,1337,2025]]
aggregate_group("LGGF sum C3 on COCO-NW", runs)

# LGGF softmax on COCO‑NW C3 over the 3 seeds
runs = [f"coco_nw_C3_lgf_softmax_s{s}" for s in [42,1337,2025]]
aggregate_group("LGGF softmax C3 on COCO-NW", runs)

# LGGF gated on COCO‑NW C3 over the 3 seeds
runs = [f"coco_nw_C3_lgf_gated_s{s}" for s in [42,1337,2025]]
aggregate_group("LGGF gated C3 on COCO-NW", runs)



#=============COCO-WEATHER AGGREGATIONS =============#

# Baseline on COCO‑Weather C3 over the 3 seeds
runs = [f"coco_weather_C3_baseline_s{s}" for s in [42,1337,2025]]
aggregate_group("Baseline C3 on COCO-Weather", runs)

# CBAM on COCO‑Weather C3 over the 3 seeds
runs = [f"coco_weather_C3_cbam_s{s}" for s in [42,1337,2025]]
aggregate_group("CBAM C3 on COCO-Weather", runs)

# SE on Weather C3 over the 3 seeds
runs = [f"coco_weather_C3_se_s{s}" for s in [42,1337,2025]]
aggregate_group("SE C3 on COCO-Weather", runs)

# LGGF gated_spatial on COCO_WEATHER C3 over the 3 seeds
runs = [f"coco_weather_C3_lgf_gated_spatial_s{s}" for s in [42,1337,2025]]
aggregate_group("LGGF gated_spatial C3 on COCO-Weather", runs)

# LGGF sum on COCO‑Weather C3 over the 3 seeds
runs = [f"coco_weather_C3_lgf_sum_s{s}" for s in [42,1337,2025]]
aggregate_group("LGGF sum C3 on COCO-Weather", runs)

# LGGF softmax on COCO‑Weather C3 over the 3 seeds
runs = [f"coco_weather_C3_lgf_softmax_s{s}" for s in [42,1337,2025]]
aggregate_group("LGGF softmax C3 on COCO-Weather", runs)

# LGGF gated on COCO‑Weather C3 over the 3 seeds
runs = [f"coco_weather_C3_lgf_gated_s{s}" for s in [42,1337,2025]]
aggregate_group("LGGF gated C3 on COCO-Weather", runs)



#=============ACDC AGGREGATIONS =============#

# Baseline on ACDC C3 over the 3 seeds
runs = [f"acdc_C3_baseline_s{s}" for s in [42,1337,2025]]
aggregate_group("Baseline C3 on ACDC", runs)

# CBAM on ACDC C3 over the 3 seeds
runs = [f"acdc_C3_cbam_s{s}" for s in [42,1337,2025]]
aggregate_group("CBAM C3 on ACDC", runs)

# SE on ACDC C3 over the 3 seeds
runs = [f"acdc_C3_se_s{s}" for s in [42,1337,2025]]
aggregate_group("SE C3 on ACDC", runs)

# LGGF gated_spatial on ACDC C3 over the 3 seeds
runs = [f"acdc_C3_lgf_gated_spatial_s{s}" for s in [42,1337,2025]]
aggregate_group("LGGF gated_spatial C3 on ACDC", runs)

# LGGF sum on ACDC C3 over the 3 seeds
runs = [f"acdc_C3_lgf_sum_s{s}" for s in [42,1337,2025]]
aggregate_group("LGGF sum C3 on ACDC", runs)

# LGGF softmax on ACDC C3 over the 3 seeds
runs = [f"acdc_C3_lgf_softmax_s{s}" for s in [42,1337,2025]]
aggregate_group("LGGF softmax C3 on ACDC", runs)

# LGGF gated on ACDC C3 over the 3 seeds
runs = [f"acdc_C3_lgf_gated_s{s}" for s in [42,1337,2025]]
aggregate_group("LGGF gated C3 on ACDC", runs)



POST EVALUATION FROM A SAVED TRAINED MODEL

In [None]:
# ==== EVAL-ONLY CELL ====
import json, torch
from torch.utils.data import DataLoader

def _find_key(sd, suffix):
    for k in sd.keys():
        if k.endswith(suffix):
            return k
    return None

def evaluate_checkpoint(
    ckpt_path,
    dataset="coco_weather",  # coco_nw | coco_weather | acdc | custom,
    use_ema=True,
    insert_level=None,   # C3 | C4 | C5 | None for auto
    val_img=None,
    val_ann=None,
    batch_size=4,
    num_workers=8,
    prefetch_factor=4,
    seed=42,
    speed=True,
    complexity=False,
    max_images_speed=200
):
    set_seed(seed)
    dev = device

    ckpt = torch.load(ckpt_path, map_location=dev)
    sd = ckpt.get("ema_state_dict") if use_ema and ("ema_state_dict" in ckpt) else ckpt["model_state_dict"]

    # Strip DDP "module." prefix if present
    if any(k.startswith("module.") for k in sd.keys()):
        sd = {k[7:]: v for k, v in sd.items()}

    # Build validation dataset/loader first to know class count
    if dataset == "custom":
        if not (val_img and val_ann):
            raise ValueError("dataset='custom' requires val_img and val_ann")
        va_img, va_ann = val_img, patch_annotations_once(val_ann)
    else:
        _, _, va_img, raw_va_ann = select_dataset_by_name(dataset)
        va_ann = patch_annotations_once(raw_va_ann)

    val_dataset = COCODataset(va_img, va_ann, transforms=get_transform(train=False), train=False)
    dataset_k = val_dataset.num_classes

    g = torch.Generator()
    g.manual_seed(seed)
    common = dict(collate_fn=collate_fn, worker_init_fn=worker_init_fn, pin_memory=True, generator=g)
    if num_workers > 0:
        common.update(num_workers=num_workers, persistent_workers=True, prefetch_factor=prefetch_factor)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=False, **common)

    # Restore trained config if present
    global CURRENT_CONFIG
    if "config" in ckpt:
        CURRENT_CONFIG = ckpt["config"]

    # Infer anchors-per-location and class count
    cls_w_key = _find_key(sd, "classification_head.cls_logits.weight")
    if cls_w_key is None:
        raise KeyError("Missing classification_head.cls_logits.weight in checkpoint")

    out_ch = sd[cls_w_key].shape[0]  # A * K
    if out_ch % dataset_k == 0:
        a_per_loc = out_ch // dataset_k
        k_trained = dataset_k
    else:
        reg_w_key = _find_key(sd, "regression_head.bbox_regression.weight")
        if reg_w_key is not None:
            a_per_loc = sd[reg_w_key].shape[0] // 4
            k_trained = out_ch // a_per_loc
        else:
            a_per_loc = 9
            k_trained = out_ch // a_per_loc
        if k_trained != dataset_k:
            raise RuntimeError(
                f"Class count mismatch. ckpt K={k_trained}, dataset K={dataset_k} "
                f"(cls out={out_ch}, anchors per loc={a_per_loc})"
            )

    # Decide insert level
    def _parse_lvl_from_name(name):
        for lvl in ["C3", "C4", "C5"]:
            if f"_{lvl}_" in name or name.endswith(f"_{lvl}") or name.startswith(f"{lvl}_"):
                return lvl
        return None

    if insert_level is None:
        name_guess = str(ckpt.get("experiment_name", ""))
        parsed = _parse_lvl_from_name(name_guess)
        candidates = [parsed] + [l for l in ["C3","C4","C5"] if l != parsed] if parsed else ["C3","C4","C5"]
    else:
        candidates = [insert_level]

    # Helper for scoring IncompatibleKeys without error_msgs
    def _score_incompat(res):
        mk = getattr(res, "missing_keys", [])
        uk = getattr(res, "unexpected_keys", [])
        return len(mk) + len(uk)

    # Rebuild model and load weights, pick best matching level
    best_model = None
    best_level = None
    best_score = None
    for lvl in candidates:
        m = build_model(k_trained, lvl).to(dev)
        res = m.load_state_dict(sd, strict=False)
        score = _score_incompat(res)
        if best_score is None or score < best_score:
            best_model = m
            best_level = lvl
            best_score = score
        if score == 0:
            break

    model = best_model
    model.eval()

    # COCO evaluation
    with torch.no_grad():
        metrics = coco_evaluation(model, val_loader, va_ann, dev)

    out = {
        "checkpoint": ckpt_path,
        "experiment_name": ckpt.get("experiment_name"),
        "block_type": CURRENT_CONFIG.get("block_type", "none"),
        "gating_type": CURRENT_CONFIG.get("gating_type", "none"),
        "insert_level": best_level,
        "dataset": dataset,
        "num_classes": k_trained,
        **metrics
    }

    if speed:
        out.update(benchmark_inference(model, val_loader, dev, max_images=max_images_speed))
    if complexity:
        out.update(try_flops_params(model))

    print(json.dumps(out, indent=2))
    return out


In [None]:
# Single checkpoint on a defined dataset
_ = evaluate_checkpoint(
    "NEW/4.1.2/BEST_coco_nw_C3_cbam_s42_epoch_80_map_0.3422_apsmall_0.2223.pth",
    dataset="coco_nw",
    use_ema=True,
    batch_size=16,
    num_workers=8,
    prefetch_factor=4,
    speed=True,
    complexity=False
)


In [None]:
# Evaluate a few checkpoints in a folder
from glob import glob
ckpts = sorted(glob("/nas.dbms/asera/NEW/4.1.2/BEST_coco_nw_C3_lgf_gated_spatial_s*_epoch_*_map_*.pth"))
results = [evaluate_checkpoint(p, dataset="coco_nw", use_ema=True, speed=False) for p in ckpts]

In [None]:
# Custom dataset, val-only paths OPTIONAL
_ = evaluate_checkpoint(
    "/nas.dbms/asera/NEW/BEST_custom_C4_cbam_s2025_epoch_080_map_0.3871.pth",
    dataset="custom",
    val_img="/abs/path/to/your/val/images",
    val_ann="/abs/path/to/your/val/annotations.json",
    use_ema=True
)


CONTROL VARIANTS

In [None]:
# ==== CONTROL VARIANTS EVAL CELL ====
# Runs "native", "uniform", "global_mean", "shuffle", "swap_in_batch" on a saved lgf model.

import types, json, torch
from torch.utils.data import DataLoader

# ---------- helpers ----------
def _build_val_loader(dataset_code, seed=42, batch_size=4, num_workers=8, prefetch_factor=4,
                      val_img=None, val_ann=None):
    if dataset_code == "custom":
        if not (val_img and val_ann):
            raise ValueError("dataset='custom' needs val_img and val_ann")
        va_img, va_ann = val_img, patch_annotations_once(val_ann)
    else:
        _, _, va_img, raw_va_ann = select_dataset_by_name(dataset_code)
        va_ann = patch_annotations_once(raw_va_ann)

    val_dataset = COCODataset(va_img, va_ann, transforms=get_transform(train=False), train=False)
    g = torch.Generator(); g.manual_seed(seed)
    common = dict(collate_fn=collate_fn, worker_init_fn=worker_init_fn, pin_memory=True, generator=g)
    if num_workers > 0:
        common.update(num_workers=num_workers, persistent_workers=True, prefetch_factor=prefetch_factor)
    loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=False, **common)
    return val_dataset, loader, va_ann

def _strip_module_prefix(sd):
    return {k[7:]: v for k, v in sd.items()} if any(k.startswith("module.") for k in sd) else sd

def _infer_insert_level_from_name(name):
    for lvl in ["C3","C4","C5"]:
        if f"_{lvl}_" in name or name.endswith(f"_{lvl}") or name.startswith(f"{lvl}_"):
            return lvl
    return None

def _native_spatial_weights(block, x):
    # Compute native gated_spatial weights wL, wG using the block's own layers
    h = F.relu(block.gate_reduce(x), inplace=True)
    logits = block.gate_expand(h)
    logits = block.gate_norm(logits)
    with torch.cuda.amp.autocast(enabled=False):
        logits32 = logits.float().clamp_(-15, 15)
        tau = F.softplus(block.temperature.float()) + 1e-3
        w = torch.sigmoid(logits32 / tau)
    N, twoC, H, W = w.shape
    C = twoC // 2
    w = w.to(dtype=x.dtype).view(N, 2, C, H, W)
    return w[:,0], w[:,1]  # wL, wG

# def _patch_gate_override(model, override="native", gate_seed=None, gamma_mult=None):
#     """
#     Monkey-patch StableLGFBlock.forward to enforce control behaviors during eval.
#     Supports only 'lgf' blocks. Leaves others untouched.
#     """
#     rng = torch.Generator(device=device) if device.type == "cuda" else torch.Generator()
#     if gate_seed is not None:
#         rng.manual_seed(int(gate_seed))

#     for m in model.modules():
#         if not isinstance(m, StableLGFBlock):
#             continue

#         # option to scale gamma consistently across all LGGF blocks
#         if gamma_mult is not None:
#             with torch.no_grad():
#                 m.gamma.fill_(float(gamma_mult))

#         if override == "native":
#             continue  # no patching needed

#         # keep original forward for restoration if you want later
#         if not hasattr(m, "_orig_forward"):
#             m._orig_forward = m.forward

#         def wrapped_forward(self, x, *_args, **_kwargs):
#             # Recompute L and G branches explicitly
#             feats = []
#             L = self.local(x) if self.local is not None else None
#             if L is not None: feats.append(L)
#             G = None
#             if self.global_branch is not None:
#                 g = self.global_branch(x)
#                 G = g.expand_as(x)
#                 feats.append(G)

#             if len(feats) == 1:
#                 out_core = feats[0]
#                 wL = torch.ones_like(out_core); wG = torch.zeros_like(out_core)
#             else:
#                 if self.gating_type == "gated_spatial":
#                     # Base weights from native computation
#                     wL_native, wG_native = _native_spatial_weights(self, x)

#                     if override == "uniform":
#                         wL = torch.ones_like(wL_native) * 0.5
#                         wG = torch.ones_like(wG_native) * 0.5
#                     elif override == "global_mean":
#                         # Collapse spatial to per-channel constants, broadcast back
#                         wL_mean = wL_native.mean(dim=(2,3), keepdim=True)
#                         wG_mean = wG_native.mean(dim=(2,3), keepdim=True)
#                         wL = wL_mean.expand_as(wL_native)
#                         wG = wG_mean.expand_as(wG_native)
#                     elif override == "shuffle":
#                         # Shuffle gates across images in batch, same permutation for L and G
#                         N = wL_native.shape[0]
#                         if N > 1:
#                             perm = torch.randperm(N, generator=rng, device=wL_native.device)
#                             wL = wL_native[perm]
#                             wG = wG_native[perm]
#                         else:
#                             wL, wG = wL_native, wG_native
#                     elif override == "swap_in_batch":
#                         # Rotate gates within batch by +1
#                         N = wL_native.shape[0]
#                         if N > 1:
#                             wL = torch.roll(wL_native, shifts=1, dims=0)
#                             wG = torch.roll(wG_native, shifts=1, dims=0)
#                         else:
#                             wL, wG = wL_native, wG_native
#                     else:
#                         # Fallback to native if unknown label
#                         wL, wG = wL_native, wG_native

#                     # Fuse
#                     # L, G are [N,C,H,W]; wL,wG are [N,C,H,W]
#                     out_core = wL * L + wG * G
#                 else:
#                     # For non-spatial gate modes, emulate simple controls
#                     if override in ("uniform", "global_mean"):
#                         out_core = 0.5 * L + 0.5 * G
#                     elif override in ("shuffle", "swap_in_batch"):
#                         # emulate by shuffling G across batch only
#                         N = G.shape[0]
#                         if N > 1:
#                             perm = torch.randperm(N, generator=rng, device=G.device) if override=="shuffle" else torch.roll(torch.arange(N, device=G.device), shifts=1)
#                             out_core = 0.5 * L + 0.5 * G[perm]
#                         else:
#                             out_core = 0.5 * L + 0.5 * G
#                     else:
#                         out_core = 0.5 * L + 0.5 * G

#             return x + self.gamma * out_core

        # m.forward = types.MethodType(wrapped_forward, m)
import types, torch

def _patch_gate_override(
    model,
    override="native",
    *,
    gate_seed=None,
    gamma_mult=None,
    temp_mult=None,
    temp_set=None,
    alpha_local=None,            # α for the α‑sweep: scales local gates before renorm
    global_mean_mode="scalar"    # "scalar" = per-image scalar, "channel" = per-channel
):
    """
    override options:
      "native"               learned gate maps as-is
      "global_mean"          per-image scalar gate (paper); set global_mean_mode to "channel" to match old per-channel behavior
      "uniform"              wL=wG=0.5
      "shuffle"              shuffle gates across images AND channels
      "swap_lg"              swap learned local/global gate assignments
      "no_module"            remove module, return x
      "only_local"           wL=1, wG=0
      "only_global"          wL=0, wG=1
    Optional knobs:
      alpha_local            α-sweep; multiply local gates by α then renormalize with global
      temp_mult/temp_set     gate temperature tweaks
      gamma_mult             multiplicatively scale learned residual γ (do not overwrite)
    """
    device = next(model.parameters()).device
    rng = torch.Generator(device=device)
    if gate_seed is not None:
        rng.manual_seed(int(gate_seed))

    for m in model.modules():
        if m.__class__.__name__ != "StableLGFBlock":
            continue

        # keep learned gamma, allow multiplicative scaling only
        if gamma_mult is not None:
            with torch.no_grad():
                m.gamma.mul_(float(gamma_mult))

        # optional temperature probe
        if hasattr(m, "temperature"):
            with torch.no_grad():
                if temp_mult is not None:
                    m.temperature.mul_(float(temp_mult))
                if temp_set is not None:
                    m.temperature.fill_(float(temp_set))

        if override == "native" and alpha_local is None:
            continue

        if not hasattr(m, "_orig_forward"):
            m._orig_forward = m.forward

        def wrapped_forward(self, x, *a, **k):
            # must be the spatial gate for these controls to mean anything
            if getattr(self, "gating_type", None) != "gated_spatial":
                # fall back to original forward to avoid misleading results
                return self._orig_forward(x, *a, **k) if hasattr(self, "_orig_forward") else super(type(self), self).forward(x, *a, **k)

            # compute branches
            L = self.local(x) if getattr(self, "local", None) is not None else None
            G = None
            if getattr(self, "global_branch", None) is not None:
                g = self.global_branch(x)
                G = g.expand_as(L if L is not None else g)

            # kill the whole module if asked
            if override == "no_module":
                return x

            # degenerate cases
            if L is None or G is None:
                out_core = L if L is not None else G
                return x + self.gamma * out_core

            # native spatial gates
            wL_native, wG_native = _native_spatial_weights(self, x)  # assumes these sum to ~1

            # build wL, wG
            wL, wG = wL_native, wG_native

            if override == "only_local":
                wL = torch.ones_like(wL_native, 1)
                wG = torch.zeros_like(wG_native, 0)

            elif override == "only_global":
                wL = torch.zeros_like(wL_native, 0)
                wG = torch.ones_like(wG_native, 1)

            elif override == "uniform":
                wL = torch.full_like(wL_native, 0.5)
                wG = torch.full_like(wG_native, 0.5)

            elif override == "global_mean":
                if global_mean_mode == "scalar":
                    # per-image scalar (paper): mean over channels and spatial
                    wL_s = wL_native.mean(dim=(1,2,3), keepdim=True)
                    wG_s = wG_native.mean(dim=(1,2,3), keepdim=True)
                else:
                    # legacy per-channel: mean over spatial only
                    wL_s = wL_native.mean(dim=(2,3), keepdim=True)
                    wG_s = wG_native.mean(dim=(2,3), keepdim=True)
                wL = wL_s.expand_as(wL_native)
                wG = wG_s.expand_as(wG_native)

            elif override == "shuffle":
                # shuffle across images AND channels with a single permutation applied to both gates
                N, C, H, W = wL_native.shape
                if N*C > 1:
                    idx = torch.randperm(N*C, generator=rng, device=wL_native.device)
                    wL = wL_native.reshape(N*C, H, W)[idx].reshape(N, C, H, W)
                    wG = wG_native.reshape(N*C, H, W)[idx].reshape(N, C, H, W)

            elif override == "swap_lg":
                wL, wG = wG_native, wL_native

            # α‑sweep on top of whatever override we used
            if alpha_local is not None:
                wL = wL * float(alpha_local)
                denom = wL + wG + 1e-6
                wL = wL / denom
                wG = wG / denom

            # --- budget preservation: match sum of weights to native ---
            # compute native sum budget (choose per-image scalar; switch to per-channel if you prefer)
            s_native = (wL_native + wG_native).mean(dim=(1,2,3), keepdim=True)  # [N,1,1,1]
            s_override = (wL + wG).mean(dim=(1,2,3), keepdim=True) + 1e-6
            scale = s_native / s_override
            wL = wL * scale
            wG = wG * scale
            # --- end budget preservation ---

            out_core = wL * L + wG * G
            return x + self.gamma * out_core

        m.forward = types.MethodType(wrapped_forward, m)


def _load_model_from_ckpt(ckpt_path, dataset_k, insert_level_hint=None, use_ema=True):
    ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
    sd = ckpt.get("ema_state_dict") if use_ema and ("ema_state_dict" in ckpt) else ckpt["model_state_dict"]
    sd = _strip_module_prefix(sd)

    # infer anchors per location and confirm num classes
    cls_w = None
    for k in sd.keys():
        if k.endswith("classification_head.cls_logits.weight"):
            cls_w = sd[k]; break
    if cls_w is None:
        raise KeyError("cls_logits.weight not found in checkpoint")

    out_ch = cls_w.shape[0]
    if out_ch % dataset_k != 0:
        # try regression head to find A
        reg_w = None
        for k in sd.keys():
            if k.endswith("regression_head.bbox_regression.weight"):
                reg_w = sd[k]; break
        A = reg_w.shape[0] // 4 if reg_w is not None else 9
        K = out_ch // A
        if K != dataset_k:
            raise RuntimeError(f"Class count mismatch. ckpt K={K}, dataset K={dataset_k}")
    # decide insert level
    exp_name = str(ckpt.get("experiment_name", ""))
    candidates = [c for c in ["C3","C4","C5"]]
    if insert_level_hint:
        candidates = [insert_level_hint] + [c for c in candidates if c != insert_level_hint]
    else:
        parsed = _infer_insert_level_from_name(exp_name)
        if parsed:
            candidates = [parsed] + [c for c in candidates if c != parsed]

    # try levels until keys fit best
    best = None
    best_score = None
    for lvl in candidates:
        m = build_model(dataset_k, lvl).to(device)
        res = m.load_state_dict(sd, strict=False)
        score = len(getattr(res, "missing_keys", [])) + len(getattr(res, "unexpected_keys", []))
        if best is None or score < best_score:
            best, best_score = (m, lvl), score
        if score == 0:
            break

    model, level = best
    model.eval()
    return ckpt, model, level

# def run_eval(
#     mode, ckpt_path, *,
#     dataset="coco_nw",
#     insert_level=None,
#     gate_override="native",
#     gate_seed=None,
#     gamma_mult=1.0,
#     batch_size=4,
#     num_workers=8,
#     prefetch_factor=4,
#     seed=42,
#     speed=True,
#     complexity=False,
#     max_images_speed=200,
#     val_img=None,
#     val_ann=None,
#     exp_name=None
# ):
#     # dataset and loader
#     set_seed(seed)
#     val_dataset, val_loader, va_ann = _build_val_loader(
#         dataset, seed=seed, batch_size=batch_size, num_workers=num_workers,
#         prefetch_factor=prefetch_factor, val_img=val_img, val_ann=val_ann
#     )
#     K = val_dataset.num_classes

#     # build model from ckpt
#     ckpt, model, level = _load_model_from_ckpt(ckpt_path, K, insert_level_hint=insert_level, use_ema=True)

#     # sync CURRENT_CONFIG from ckpt when present, else from selected mode
#     global CURRENT_CONFIG
#     if "config" in ckpt:
#         CURRENT_CONFIG = ckpt["config"]
#     else:
#         CURRENT_CONFIG = CONFIG_MAP.get(mode, CONFIG_MAP["baseline"])

#     # patch gates
#     _patch_gate_override(model, override=gate_override, gate_seed=gate_seed, gamma_mult=gamma_mult)

#     # eval
#     with torch.no_grad():
#         metrics = coco_evaluation(model, val_loader, va_ann, device)

#     # add optional speed/complexity
#     out = {
#         "exp": exp_name or "control_eval",
#         "checkpoint": ckpt_path,
#         "mode": mode,
#         "insert_level": level,
#         "gate_override": gate_override,
#         "gamma_mult": gamma_mult,
#         **metrics
#     }
#     if speed:
#         out.update(benchmark_inference(model, val_loader, device, max_images=max_images_speed))
#     if complexity:
#         out.update(try_flops_params(model))

    # print(json.dumps(out, indent=2))
    # return out

def run_eval(
    mode, ckpt_path, *,
    dataset="coco_nw",
    insert_level=None,
    gate_override="native",
    gate_seed=None,
    gamma_mult=1.0,
    temp_mult=None,
    temp_set=None,
    alpha_local=None,           # new
    global_mean_mode="scalar",  # new
    score_thresh=None,
    batch_size=4,
    num_workers=8,
    prefetch_factor=4,
    seed=42,
    speed=True,
    complexity=False,
    max_images_speed=200,
    val_img=None,
    val_ann=None,
    exp_name=None
):
    set_seed(seed)
    # FIX: Ensure correct config is active before model creation
    global CURRENT_CONFIG
    try:
        ckpt_tmp = torch.load(ckpt_path, map_location="cpu")
        ckpt_cfg = ckpt_tmp.get("config", None)
        CURRENT_CONFIG = ckpt_cfg if ckpt_cfg is not None else CONFIG_MAP.get(mode, CONFIG_MAP["baseline"])
        del ckpt_tmp
    except Exception as e:
        print(f"[WARN] Could not read config from checkpoint ({e}), using default for mode={mode}")
        CURRENT_CONFIG = CONFIG_MAP.get(mode, CONFIG_MAP["baseline"])


    val_dataset, val_loader, va_ann = _build_val_loader(
        dataset, seed=seed, batch_size=batch_size, num_workers=num_workers,
        prefetch_factor=prefetch_factor, val_img=val_img, val_ann=val_ann
    )
    K = val_dataset.num_classes
    ckpt, model, level = _load_model_from_ckpt(ckpt_path, K, insert_level_hint=insert_level, use_ema=True)

    if score_thresh is not None:
        model.score_thresh = float(score_thresh)

    _patch_gate_override(
        model,
        override=gate_override,
        gate_seed=gate_seed,
        gamma_mult=gamma_mult,
        temp_mult=temp_mult,
        temp_set=temp_set,
        alpha_local=alpha_local,
        global_mean_mode=global_mean_mode
    )

    patched = sum(1 for m in model.modules() if m.__class__.__name__ == "StableLGFBlock")
    assert patched > 0, f"No lgf blocks found to patch. CURRENT_CONFIG={CURRENT_CONFIG}"


    with torch.no_grad():
        metrics = coco_evaluation(model, val_loader, va_ann, device)

    out = {
        "exp": exp_name or "control_eval",
        "checkpoint": ckpt_path,
        "mode": mode,
        "insert_level": level,
        "gate_override": gate_override,
        "gamma_mult": gamma_mult,
        "temp_mult": temp_mult,
        "temp_set": temp_set,
        "alpha_local": alpha_local,
        "global_mean_mode": global_mean_mode,
        "score_thresh": score_thresh,
        **metrics
    }
    if speed:
        out.update(benchmark_inference(model, val_loader, device, max_images=max_images_speed))
    if complexity:
        out.update(try_flops_params(model))
    print(json.dumps(out, indent=2))
    return out
    

In [None]:
# ==== BEST CHECKPOINTS DICT CELL ====
from glob import glob

def latest(pattern):
    files = sorted(glob(pattern))
    if not files:
        raise FileNotFoundError(pattern)
    return files[-1]

# Fill this to match where you saved BEST_* files
CKPT = {
    # "lgf_gated_spatial": "/nas.dbms/asera/NEW/4.1.2/BEST_coco_nw_C3_lgf_gated_spatial_s2025_epoch_70_map_0.3267_apsmall_0.2399.pth", # NW
    # "lgf_gated_spatial": "/nas.dbms/asera/NEW/4.1.2/BEST_coco_weather_C3_lgf_gated_spatial_s42_epoch_70_map_0.2935_apsmall_0.1973.pth", # Weather
    "lgf_gated_spatial": "/nas.dbms/asera/NEW/4.1.2/BEST_acdc_C3_lgf_gated_spatial_s1337_epoch_80_map_0.3305_apsmall_0.0866.pth", # ACDC

    # add others if you want
}


In [None]:
# ===== CONTROL RUNS  =====
MODE   = "lgf_gated_spatial"
CKPT   = CKPT[MODE]
DATA   = "acdc"  # "coco_nw" | "coco_weather" | "acdc" | "custom"

results = []

# 1. Native
results.append(run_eval(MODE, CKPT, dataset=DATA, exp_name="Native",
                        gate_override="native", gamma_mult=1.0))

# 2. Global-mean mask (per-image scalar)
results.append(run_eval(MODE, CKPT, dataset=DATA, exp_name="Global-mean",
                        gate_override="global_mean", global_mean_mode="scalar", gamma_mult=1.0))

# 3. Uniform 0.5/0.5
results.append(run_eval(MODE, CKPT, dataset=DATA, exp_name="Uniform",
                        gate_override="uniform", gamma_mult=1.0))

# 4. Shuffle (images+channels)
results.append(run_eval(MODE, CKPT, dataset=DATA, exp_name="Shuffle",
                        gate_override="shuffle", gate_seed=123, gamma_mult=1.0))

# 5. Swap local/global
results.append(run_eval(MODE, CKPT, dataset=DATA, exp_name="Swap",
                        gate_override="swap_lg", gamma_mult=1.0))

# 6. α-sweep (scale local gates, then renormalize)
for a in [0.25, 0.5, 2.0]:
    results.append(run_eval(MODE, CKPT, dataset=DATA, exp_name=f"alpha_{a}",
                            gate_override="native", alpha_local=a, gamma_mult=1.0))

# 7. No module
results.append(run_eval(MODE, CKPT, dataset=DATA, exp_name="No_module",
                        gate_override="no_module", gamma_mult=1.0))

# # 8. Local only
# results.append(run_eval(MODE, CKPT, dataset=DATA, exp_name="only_local",
#                         gate_override="local_only", gamma_mult=1.0))

# # 9. Global only
# results.append(run_eval(MODE, CKPT, dataset=DATA, exp_name="only_global",
#                         gate_override="global_only", gamma_mult=1.0))

In [None]:
# ===== SUMMARY FUNCTION =====
# # Install missing package when running in notebook
# %pip install pandas

import pandas as pd
from pathlib import Path

def summarize_results(results, baseline_exp="Native", baseline_override="native"):
    df = pd.DataFrame(results).copy()

    # pick baseline row: first native without alpha override
    alpha_col = 'alpha_local' if 'alpha_local' in df.columns else None
    base_mask = (df.get('exp','') == baseline_exp) | (df.get('gate_override','') == baseline_override)
    if alpha_col:
        base_mask &= df[alpha_col].isna()
    base = df[base_mask].iloc[0]

    keep_cols = [
        'exp','gate_override','alpha_local','insert_level',
        'mAP','AP_small','AP50','AP75','AR100',
        'detection_count','images_per_second'
    ]
    keep_cols = [c for c in keep_cols if c in df.columns]
    tab = df[keep_cols].copy()

    # deltas vs baseline
    for c in ['mAP','AP_small','AP50','AP75','AR100']:
        if c in tab.columns:
            tab[f'Δ{c}'] = tab[c] - float(base[c])

    # tidy names and rounding
    if 'detection_count' in tab.columns:
        tab = tab.rename(columns={'detection_count':'dets'})
        tab['dets'] = tab['dets'].astype(int)
    if 'images_per_second' in tab.columns:
        tab = tab.rename(columns={'images_per_second':'ips'})
        tab['ips'] = tab['ips'].round(3)

    for c in [x for x in tab.columns if x not in ('exp','gate_override','alpha_local','insert_level','dets')]:
        tab[c] = tab[c].astype(float).round(4)

    # order columns
    ordered = [c for c in [
        'exp','gate_override','alpha_local','insert_level',
        'mAP','AP_small','AP50','AP75','AR100','dets','ips',
        'ΔmAP','ΔAP_small','ΔAP50','ΔAP75','ΔAR100'
    ] if c in tab.columns]
    tab = tab[ordered]

    # sort by mAP desc if present
    if 'mAP' in tab.columns:
        tab = tab.sort_values('mAP', ascending=False)

    # print a compact text table
    print(tab.to_string(index=False))

    # save
    out_dir = Path("summaries"); out_dir.mkdir(exist_ok=True, parents=True)
    tab.to_csv(out_dir / "controls_summary.csv", index=False)

    # LaTeX table for the paper
    latex_cols = tab.rename(columns={
        'exp':'Control',
        'AP50':'AP$_{50}$','AP75':'AP$_{75}$',
        'AP_small':'AP$_{S}$','AR100':'AR$_{100}$',
        'ΔmAP':'$\\Delta$ mAP','ΔAP_small':'$\\Delta$ AP$_{S}$',
        'ΔAP50':'$\\Delta$ AP$_{50}$','ΔAP75':'$\\Delta$ AP$_{75}$',
        'ΔAR100':'$\\Delta$ AR$_{100}$'
    })
    tex = latex_cols.to_latex(index=False, escape=False, float_format="%.4f")
    (out_dir / "controls_summary.tex").write_text(tex)


In [None]:
# after you append all run_eval outputs into `results`
summarize_results(results, baseline_exp="Native", baseline_override="native")