In [None]:
# FIX 1
%env CUBLAS_WORKSPACE_CONFIG=:4096:8


In [None]:
# # FIX 2
# %env CUBLAS_WORKSPACE_CONFIG=:4096:8
# import os; os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"


In [None]:
# Parameter count, FLOPs, and throughput measurement script
"""
measure_lggf_metrics.py

Compute parameter count, FLOPs, and inference throughput (images/sec) for saved RetinaNet checkpoints
trained with baseline, SE, CBAM, or LGF blocks.

Usage examples:

1) Single checkpoint on GPU with synthetic timing images of size 640x640
   python measure_lggf_metrics.py --ckpts /nas.dbms/asera/NEW/4.1.2/BEST_coco_nw_C3_lgf_gated_spatial_s2025_epoch_70_map_0.3267_apsmall_0.2399.pth

2) Multiple checkpoints and write CSV
   python measure_lggf_metrics.py --ckpts ckpts/*.pth --report-csv report.csv

3) Force insert level if experiment name is not standard
   python measure_lggf_metrics.py --ckpts model.pth --insert-level C3

Notes:
- FLOPs are computed with fvcore if available, otherwise thop. Both are optional.
- Throughput is measured with random images unless you provide a directory of images (see --image-folder).
- Checkpoints saved by the training script contain 'ema_state_dict' and 'config'; this script uses those.

"""

# %pip install -q fvcore iopath
# # or
# %pip install -q thop


import sys
sys.argv = [
    "measure_lggf_metrics.py",
    "--ckpts",
    "/nas.dbms/asera/NEW/4.1.2/BEST_coco_nw_C3_baseline_s2025_epoch_75_map_0.3363_apsmall_2199.pth",
    "/nas.dbms/asera/NEW/4.1.2/BEST_coco_nw_C3_se_s2025_epoch_70_map_0.3220_apsmall_0.2156.pth",
    "/nas.dbms/asera/NEW/4.1.2/BEST_coco_nw_C3_cbam_s42_epoch_80_map_0.3422_apsmall_0.2223.pth",
    "/nas.dbms/asera/NEW/4.1.2/BEST_coco_nw_C3_lgf_gated_spatial_s2025_epoch_70_map_0.3267_apsmall_0.2399.pth",
    "--img-size", "640",
    "--batch-size", "16",
    "--measure-images", "1000",
    "--report-csv", "lggf_report.csv",
]


import os
import re
import csv
import math
import time
import argparse
import copy
from glob import glob

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

import torchvision
from torchvision.models.detection import retinanet_resnet50_fpn, RetinaNet_ResNet50_FPN_Weights
from torchvision.ops.misc import FrozenBatchNorm2d
from torchvision.models.detection.anchor_utils import AnchorGenerator

# ------------------------------ Utilities -------------------------------------

def try_import(module_name):
    try:
        __import__(module_name)
        return True
    except Exception:
        return False

def set_torch_determinism():
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    if hasattr(torch, "use_deterministic_algorithms"):
        try:
            torch.use_deterministic_algorithms(True)
        except Exception:
            pass

# ------------------------------ Blocks ----------------------------------------

class StableLGFBlock(nn.Module):
    def __init__(self, channels, branches=("local","global"), gating_type="sum", norm_groups=32, squeeze_ratio=16):
        super().__init__()
        self.branches = tuple(branches)
        self.gating_type = gating_type
        self.num_branches = len(self.branches)
        def GN(c): return nn.GroupNorm(norm_groups, c)

        # Local branch
        self.local = None
        if "local" in self.branches:
            self.local = nn.Sequential(
                nn.Conv2d(channels, channels, 3, padding=1, groups=channels, bias=False),
                nn.Conv2d(channels, channels, 1, bias=False),
                GN(channels), nn.ReLU(inplace=True)
            )

        # Global branch
        self.global_branch = None
        if "global" in self.branches:
            self.global_branch = nn.Sequential(
                nn.AdaptiveAvgPool2d(1),
                nn.Conv2d(channels, channels, 1, bias=False),
                GN(channels), nn.SiLU(inplace=True)
            )

        # Gates
        if self.num_branches > 1:
            if self.gating_type == "softmax":
                self.temperature = nn.Parameter(torch.tensor(1.0))
                self.branch_weights = nn.Parameter(torch.ones(self.num_branches))
            elif self.gating_type == "gated":
                hid = max(channels // squeeze_ratio, 4)
                self.temperature = nn.Parameter(torch.tensor(1.0))
                self.gate_mlp = nn.Sequential(
                    nn.AdaptiveAvgPool2d(1), nn.Flatten(),
                    nn.Linear(channels, hid), nn.ReLU(inplace=True),
                    nn.Linear(hid, self.num_branches)
                )
                nn.init.zeros_(self.gate_mlp[-1].weight); nn.init.zeros_(self.gate_mlp[-1].bias)
            elif self.gating_type == "gated_spatial":
                r = 4
                self.temperature = nn.Parameter(torch.tensor(1.0))
                self.gate_reduce = nn.Conv2d(channels, channels//r, 1, bias=False)
                self.gate_expand = nn.Conv2d(channels//r, 2*channels, 1, bias=True)
                self.gate_norm   = nn.GroupNorm(num_groups=norm_groups, num_channels=2*channels)
                nn.init.zeros_(self.gate_expand.weight); nn.init.zeros_(self.gate_expand.bias)

        self.gamma = nn.Parameter(torch.tensor(0.1))

    def _broadcast(self, s, like):
        if s.dim() == 1:
            s = s.view(-1,1,1,1)
        return s.expand_as(like)

    def forward(self, x):
        feats = []
        L = self.local(x) if self.local is not None else None
        if L is not None: feats.append(L)
        G = None
        if self.global_branch is not None:
            g = self.global_branch(x)
            G = g.expand_as(x)
            feats.append(G)

        if len(feats) == 1:
            out = feats[0]
        else:
            if self.gating_type == "sum":
                out = feats[0] + feats[1]
                # equally weighted, no learnable params here
            elif self.gating_type == "softmax":
                tau = F.softplus(self.temperature) + 1e-3
                w = F.softmax(self.branch_weights / tau, dim=0)
                out = w[0]*feats[0] + w[1]*feats[1]
            elif self.gating_type == "gated":
                logits = self.gate_mlp(x.float())
                tau = F.softplus(self.temperature.float()) + 1e-3
                logits = logits.clamp_(-15,15)
                w = torch.sigmoid(logits / tau).to(dtype=x.dtype)
                wL = self._broadcast(w[:,0], L); wG = self._broadcast(w[:,1], G)
                out = wL*L + wG*G
            elif self.gating_type == "gated_spatial":
                h = F.relu(self.gate_reduce(x), inplace=True)
                logits = self.gate_expand(h)
                logits = self.gate_norm(logits)
                logits32 = logits.float().clamp_(-15,15)
                tau = F.softplus(self.temperature.float()) + 1e-3
                w = torch.sigmoid(logits32 / tau).to(dtype=x.dtype)
                N, twoC, H, W = w.shape
                C = twoC//2
                w = w.view(N,2,C,H,W)
                wL, wG = w[:,0], w[:,1]
                out = wL*L + wG*G
            else:
                out = feats[0] + feats[1]

        out = x + self.gamma*out
        return out

class SEBlock(nn.Module):
    def __init__(self, channels, reduction=16):
        super().__init__()
        self.avg = nn.AdaptiveAvgPool2d(1)
        self.fc  = nn.Sequential(
            nn.Linear(channels, channels//reduction, bias=False), nn.ReLU(inplace=True),
            nn.Linear(channels//reduction, channels, bias=False), nn.Sigmoid()
        )
    def forward(self, x):
        b,c,_,_ = x.shape
        y = self.fc(self.avg(x).view(b,c)).view(b,c,1,1)
        return x * y

class CBAMBlock(nn.Module):
    def __init__(self, channels, reduction=16, k=5, beta=20.0):
        super().__init__()
        self.beta = beta
        self.avg = nn.AdaptiveAvgPool2d(1)
        # shared MLP for channel attention
        self.fc = nn.Sequential(
            nn.Conv2d(channels, channels // reduction, 1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels // reduction, channels, 1, bias=False)
        )
        self.sigc = nn.Sigmoid()
        # spatial attention
        self.convs = nn.Conv2d(2, 1, k, padding=k // 2, bias=False)
        self.sigs = nn.Sigmoid()

    def _softmax_pool_spatial(self, x):
        B, C, H, W = x.shape
        x_flat = x.view(B, C, H * W)
        x_norm = x_flat - x_flat.max(dim=2, keepdim=True).values
        w = F.softmax(self.beta * x_norm, dim=2)
        pooled = (w * x_flat).sum(dim=2)
        return pooled.view(B, C, 1, 1)

    def _softmax_pool_channel(self, x):
        B, C, H, W = x.shape
        x_hw_c = x.permute(0, 2, 3, 1)
        x_norm = x_hw_c - x_hw_c.max(dim=3, keepdim=True).values
        w = F.softmax(self.beta * x_norm, dim=3)
        pooled = (w * x_hw_c).sum(dim=3, keepdim=True)
        return pooled.permute(0, 3, 1, 2)

    def forward(self, x):
        # Channel attention
        avg_pool = self.avg(x)
        smx_pool = self._softmax_pool_spatial(x)
        ca = self.fc(avg_pool) + self.fc(smx_pool)
        x = x * self.sigc(ca)

        # Spatial attention
        s_mean = x.mean(dim=1, keepdim=True)
        s_softmax_max = self._softmax_pool_channel(x)
        s = torch.cat([s_mean, s_softmax_max], dim=1)
        sa = self.sigs(self.convs(s))
        return x * sa

# ----------------------- Backbones and model build ----------------------------

LEVEL_MAP = {"C3":"layer2", "C4":"layer3", "C5":"layer4"}

def convert_bn_to_gn(module, num_groups=32, convert_frozen=False):
    for name, child in module.named_children():
        if isinstance(child, nn.BatchNorm2d):
            setattr(module, name, nn.GroupNorm(num_groups=num_groups, num_channels=child.num_features))
        elif convert_frozen and isinstance(child, FrozenBatchNorm2d):
            setattr(module, name, nn.GroupNorm(num_groups=num_groups, num_channels=child.num_features))
        else:
            convert_bn_to_gn(child, num_groups=num_groups, convert_frozen=convert_frozen)
    return module

def _probe_feature_names(backbone: nn.Module) -> list:
    with torch.no_grad():
        x = torch.zeros(1,3,224,224, device=next(backbone.parameters()).device)
        feats = backbone(x)
        from collections import OrderedDict
        if not isinstance(feats, OrderedDict):
            raise RuntimeError("Backbone must return OrderedDict of features")
        return list(feats.keys())

_SIZE_MAP = {
    "0": (32,48,64),
    "1": (64,96,128),
    "2": (128,192,256),
    "3": (256,384,512),
    "p6": (256,384,512),
    "pool": (384,512,640),
    "p7": (384,512,640),
}

def make_anchor_generator_for(backbone: nn.Module) -> AnchorGenerator:
    names = _probe_feature_names(backbone)
    try:
        sizes = tuple(_SIZE_MAP[n] for n in names)
    except KeyError as e:
        raise RuntimeError(f"No anchor size tuple for feature '{e.args[0]}'. Backbone keys={names}")
    ratios = ((0.5,1.0,2.0),) * len(sizes)
    return AnchorGenerator(sizes=sizes, aspect_ratios=ratios)

def get_block(channels, block_type, branches, gating_type):
    if block_type == "lgf":
        return StableLGFBlock(channels, branches=branches, gating_type=gating_type)
    if block_type == "se":
        return SEBlock(channels)
    if block_type == "cbam":
        return CBAMBlock(channels)
    return nn.Identity()

class CustomBackboneBlockBeforeFPN(nn.Module):
    def __init__(self, backbone_with_fpn, selected_levels, config):
        super().__init__()
        self.body = backbone_with_fpn.body
        self.fpn  = backbone_with_fpn.fpn

        with torch.no_grad():
            feats = self.body(torch.zeros(1,3,224,224))
            fpn_feats = self.fpn(feats)
        actual_keys = list(feats.keys())
        semantic_to_actual = {
            "layer1": actual_keys[0] if len(actual_keys)>0 else None,
            "layer2": actual_keys[1] if len(actual_keys)>1 else None,
            "layer3": actual_keys[2] if len(actual_keys)>2 else None,
            "layer4": actual_keys[3] if len(actual_keys)>3 else None,
        }
        wanted = []
        for lvl in selected_levels:
            if lvl in actual_keys: wanted.append(lvl)
            elif lvl in semantic_to_actual and semantic_to_actual[lvl] is not None: wanted.append(semantic_to_actual[lvl])
        self.selected_actual = wanted

        self.block_fpn_in = nn.ModuleDict()
        for k in self.selected_actual:
            C = feats[k].shape[1]
            branches = ("local","global") if config.get("BRANCH_PRESET","none") == "local_global" else ()
            self.block_fpn_in[str(k)] = get_block(C, config.get("block_type","none"),
                                                  branches, config.get("gating_type","none"))

        first_out = next(iter(fpn_feats.keys()))
        self.out_channels = fpn_feats[first_out].shape[1]

    def forward(self, x):
        feats = self.body(x)
        for k in self.selected_actual:
            feats[k] = self.block_fpn_in[str(k)](feats[k])
        return self.fpn(feats)

def build_model_from_config(num_classes, insert_level, config, transform_img_size=None):
    pretrained = retinanet_resnet50_fpn(weights=RetinaNet_ResNet50_FPN_Weights.COCO_V1)
    convert_bn_to_gn(pretrained.backbone.body, num_groups=32, convert_frozen=False)

    sel_levels = [LEVEL_MAP[insert_level]] if config.get("block_type","none") != "none" else []
    if sel_levels:
        bb = CustomBackboneBlockBeforeFPN(pretrained.backbone, selected_levels=sel_levels, config=config)
    else:
        bb = pretrained.backbone

    ag = make_anchor_generator_for(bb)
    model = torchvision.models.detection.RetinaNet(
        backbone=bb,
        num_classes=num_classes,
        anchor_generator=ag,
        head=pretrained.head,
        transform=pretrained.transform,
        detections_per_img=100,
        nms_thresh=0.5,
        score_thresh=0.05,
    )

    # Update classification head to match class count
    in_ch = model.head.classification_head.cls_logits.in_channels
    n_anchors = model.head.classification_head.num_anchors
    model.head.classification_head.cls_logits = nn.Conv2d(in_ch, n_anchors*num_classes, kernel_size=3, padding=1)
    model.head.classification_head.num_classes = num_classes
    torch.nn.init.normal_(model.head.classification_head.cls_logits.weight, std=0.01)
    prior_prob = 0.01
    bias_value = -torch.log(torch.tensor((1.0 - prior_prob) / prior_prob))
    torch.nn.init.constant_(model.head.classification_head.cls_logits.bias, bias_value)


    # Override detection transform size if requested
    if transform_img_size is not None and hasattr(model, "transform"):
        try:
            # min_size expects a tuple/list of sizes for multi-scale. Use single fixed size.
            model.transform.min_size = (int(transform_img_size),)
            model.transform.max_size = int(transform_img_size)
        except Exception:
            pass
    convert_bn_to_gn(model, num_groups=32, convert_frozen=False)
    return model

# ----------------------- Metrics: params, FLOPs, speed -------------------------

def count_params_m(model):
    return sum(p.numel() for p in model.parameters()) / 1e6

# def try_flops_params(model, image_size=(3, 800, 1333)):
#     params_m = count_params_m(model)
#     # Build a CPU copy for FLOPs tools
#     m = copy.deepcopy(model).to("cpu").eval()
#     dummy = torch.zeros(1,*image_size)
#     # # correct: RetinaNet expects a list of [C,H,W] tensors
#     # dummy = torch.zeros(*image_size) # 3xHxW?
#     # fvcore
#     if try_import("fvcore"):
#         try:
#             from fvcore.nn import FlopCountAnalysis
#             flops = FlopCountAnalysis(m, ([dummy],)).total()
#             return {"params_M": float(params_m), "FLOPs_G": float(flops)/1e9}
#         except Exception:
#             pass
#     # thop
#     if try_import("thop"):
#         try:
#             from thop import profile
#             macs,_ = profile(m, inputs=([dummy],), verbose=False)
#             return {"params_M": float(params_m), "FLOPs_G": float(macs)/1e9}
#         except Exception:
#             pass
#     return {"params_M": float(params_m), "FLOPs_G": None}

# def try_flops_params(model, image_size=(3, 640, 640)):
#     """
#     Robust FLOPs:
#       - counts parameters on the full model
#       - counts FLOPs on a heads-only wrapper (backbone + FPN + heads),
#         so there’s no transform/NMS/postprocess to confuse the counter.
#       - uses fvcore first, falls back to thop
#     """
#     params_m = sum(p.numel() for p in model.parameters()) / 1e6

#     class BackboneHead(nn.Module):
#         def __init__(self, m: nn.Module):
#             super().__init__()
#             self.backbone = copy.deepcopy(m.backbone).cpu().eval()
#             self.head = copy.deepcopy(m.head).cpu().eval()

#         def forward(self, x: torch.Tensor):
#             # x: [N,3,H,W], bypass torchvision's detection transform
#             feats = self.backbone(x)
#             if isinstance(feats, dict):
#                 feats = list(feats.values())
#             cls_logits, bbox_reg = self.head(feats)  # lists per FPN level
#             # return a tuple so graph isn't pruned
#             return tuple(cls_logits) + tuple(bbox_reg)

#     bh = BackboneHead(model)
#     dummy = torch.zeros(1, *image_size)  # batch=1 for FLOPs

#     # Try fvcore
#     try:
#         from fvcore.nn import FlopCountAnalysis
#         flops = FlopCountAnalysis(bh, dummy).total()
#         return {"params_M": float(params_m), "FLOPs_G": float(flops) / 1e9}
#     except Exception:
#         pass

#     # Fallback: thop
#     try:
#         from thop import profile
#         macs, _ = profile(bh, inputs=(dummy,), verbose=False)
#         return {"params_M": float(params_m), "FLOPs_G": float(macs) / 1e9}
#     except Exception:
#         return {"params_M": float(params_m), "FLOPs_G": None}

def try_flops_params(model, image_size=(3, 640, 640)):
    """
    Returns:
        {"params_M": float, "FLOPs_G": float or None}

    Strategy:
      1) Heads-only wrapper (backbone + FPN + heads) with fvcore
      2) Fallback to thop on the same wrapper
      3) Final fallback: manual conv/linear FLOPs via forward hooks
    """
    params_m = sum(p.numel() for p in model.parameters()) / 1e6

    class BackboneHead(nn.Module):
        def __init__(self, m: nn.Module):
            super().__init__()
            self.backbone = copy.deepcopy(m.backbone).cpu().eval()
            self.head = copy.deepcopy(m.head).cpu().eval()
        def forward(self, x: torch.Tensor):
            # x: [N,3,H,W], bypass torchvision detection transform/postprocess
            feats = self.backbone(x)
            if isinstance(feats, dict):
                feats = list(feats.values())
            cls_logits, bbox_reg = self.head(feats)  # lists per FPN level
            # return tuple to keep graph from collapsing
            return tuple(cls_logits) + tuple(bbox_reg)

    bh = BackboneHead(model)
    dummy = torch.zeros(1, *image_size)

    # 1) fvcore
    try:
        from fvcore.nn import FlopCountAnalysis
        flops = FlopCountAnalysis(bh, dummy).total()
        if flops and flops > 1e6:  # > ~1e6 FLOPs to avoid the “0.0G” trap
            return {"params_M": float(params_m), "FLOPs_G": float(flops) / 1e9}
    except Exception:
        pass

    # 2) thop
    try:
        from thop import profile
        macs, _ = profile(bh, inputs=(dummy,), verbose=False)
        if macs and macs > 1e6:
            return {"params_M": float(params_m), "FLOPs_G": float(macs) / 1e9}
    except Exception:
        pass

    # 3) Manual conv/linear FLOPs via forward hooks (MACs*2)
    def conv_linear_flops(m: nn.Module, inp_size=(1, 3, 640, 640)):
        hooks, flops = [], []
        def on_conv(mod, inputs, outputs):
            # outputs: [N, Cout, H, W]
            out = outputs
            if isinstance(out, (list, tuple)):
                out = out[0]
            N, Cout, H, W = out.shape
            Cin = mod.in_channels
            kH, kW = mod.kernel_size
            groups = mod.groups
            macs_per_out = (Cin // groups) * kH * kW
            macs = N * Cout * H * W * macs_per_out
            flops.append(2 * macs)
        def on_linear(mod, inputs, outputs):
            out = outputs
            if isinstance(out, (list, tuple)):
                out = out[0]
            N = out.shape[0] if out.dim() > 1 else 1
            flops.append(2 * N * mod.in_features * mod.out_features)

        for mod in m.modules():
            if isinstance(mod, nn.Conv2d):
                hooks.append(mod.register_forward_hook(on_conv))
            elif isinstance(mod, nn.Linear):
                hooks.append(mod.register_forward_hook(on_linear))
        with torch.no_grad():
            m(torch.zeros(*inp_size))
        for h in hooks: h.remove()
        return sum(flops)

    manual = conv_linear_flops(bh, (1, *image_size))
    return {"params_M": float(params_m), "FLOPs_G": float(manual) / 1e9 if manual else None}

@torch.no_grad()
def benchmark_inference_random(model, device, img_size=640, batch_size=4, warmup_batches=20, measure_images=200):
    model.eval()
    H = W = int(img_size)
    processed = 0
    timings = []

    # Warmup
    for _ in range(max(1, warmup_batches)):
        images = [torch.rand(3,H,W, device=device) for _ in range(batch_size)]
        _ = model(images)

    if device.type == "cuda":
        starter = torch.cuda.Event(enable_timing=True)
        ender = torch.cuda.Event(enable_timing=True)
        while processed < measure_images:
            images = [torch.rand(3,H,W, device=device) for _ in range(batch_size)]
            starter.record(); _ = model(images); ender.record()
            torch.cuda.synchronize()
            ms = starter.elapsed_time(ender) / max(len(images),1)
            timings.append(ms)
            processed += len(images)
    else:
        while processed < measure_images:
            images = [torch.rand(3,H,W, device=device) for _ in range(batch_size)]
            t0 = time.perf_counter(); _ = model(images); dt = time.perf_counter() - t0
            timings.append(1000.0*dt / max(len(images),1))
            processed += len(images)

    lat = float(np.mean(timings)) if len(timings) else float("nan")
    ips = 1000.0/lat if lat > 0 else float("nan")
    return {"latency_ms_per_image": lat, "images_per_second": ips}

# ----------------------- Checkpoint handling ----------------------------------

def infer_insert_level_from_name(name):
    m = re.search(r'\b(C3|C4|C5)\b', name)
    return m.group(1) if m else None

def infer_mode_from_config(cfg):
    bt = cfg.get("block_type", "none")
    if bt == "none": return "baseline"
    if bt == "se": return "se"
    if bt == "cbam": return "cbam"
    if bt == "lgf":
        gt = cfg.get("gating_type","sum")
        return f"lgf_{gt}"
    return "unknown"

def num_classes_from_ckpt_state(sd, n_anchors):
    # Read classification logits weight shape
    key = "head.classification_head.cls_logits.weight"
    if key not in sd:
        # sometimes saved under 'ema_state_dict' with prefix removed already
        for k in sd.keys():
            if k.endswith("cls_logits.weight"):
                key = k
                break
    n_out = sd[key].shape[0]
    return int(n_out // n_anchors)

def load_checkpoint_build_model(ckpt_path, insert_level_override=None, device="cuda" if torch.cuda.is_available() else "cpu", transform_img_size=None):
    ckpt = torch.load(ckpt_path, map_location="cpu")
    sd = ckpt.get("ema_state_dict") or ckpt.get("model_state_dict") or ckpt

    cfg = ckpt.get("config", {})  # contains block_type/gating_type/BRANCH_PRESET
    exp_name = ckpt.get("experiment_name", os.path.basename(ckpt_path))

    ins = insert_level_override or infer_insert_level_from_name(exp_name) or "C3"
    # Build a temp model to read n_anchors, then compute num_classes, then rebuild
    tmp_model = build_model_from_config(num_classes=80, insert_level=ins, config=cfg, transform_img_size=transform_img_size).to(device)
    n_anchors = tmp_model.head.classification_head.num_anchors
    num_classes = num_classes_from_ckpt_state(sd, n_anchors)

    model = build_model_from_config(num_classes=num_classes, insert_level=ins, config=cfg, transform_img_size=transform_img_size).to(device)
    missing, unexpected = model.load_state_dict(sd, strict=False)
    if unexpected:
        print(f"[WARN] Unexpected keys in state_dict: {unexpected}")
    if missing:
        # Common harmless misses: num_batches_tracked, etc.
        # Only warn if many keys are missing
        miss_sig = [k for k in missing if "num_batches_tracked" not in k]
        if miss_sig:
            print(f"[WARN] Missing keys in state_dict: {miss_sig[:5]}{' ...' if len(miss_sig)>5 else ''}")

    mode = infer_mode_from_config(cfg)
    return model, {"mode": mode, "insert_level": ins, "num_classes": num_classes, "experiment_name": exp_name}

# ------------------------------ CLI -------------------------------------------

def parse_args():
    p = argparse.ArgumentParser("Measure params, FLOPs, and throughput for saved RetinaNet checkpoints")
    p.add_argument("--ckpts", type=str, nargs="+", required=True, help="Path(s) or glob(s) to .pth checkpoints")
    p.add_argument("--insert-level", type=str, choices=["C3","C4","C5"], default=None, help="Override insert level if not in experiment name")
    p.add_argument("--img-size", type=int, default=640, help="Square image size for FLOPs and timing")
    p.add_argument("--batch-size", type=int, default=4, help="Batch size for timing")
    p.add_argument("--warmup-batches", type=int, default=20, help="Warmup batches before timing")
    p.add_argument("--measure-images", type=int, default=200, help="Number of images to time in total")
    p.add_argument("--device", type=str, default="auto", choices=["auto","cuda","cpu"])
    p.add_argument("--report-csv", type=str, default=None, help="Optional CSV to write results")
    return p.parse_args()

def expand_paths(paths):
    expanded = []
    for p in paths:
        if any(ch in p for ch in ["*","?","["]):
            expanded.extend(glob(p))
        else:
            expanded.append(p)
    # de-dup while preserving order
    seen = set()
    out = []
    for x in expanded:
        if x not in seen:
            seen.add(x); out.append(x)
    return out

def main():
    args = parse_args()
    set_torch_determinism()

    device = torch.device("cuda" if (args.device=="auto" and torch.cuda.is_available()) or args.device=="cuda" else "cpu")
    torch.set_float32_matmul_precision("high") if device.type=="cuda" and hasattr(torch, "set_float32_matmul_precision") else None

    ckpt_paths = expand_paths(args.ckpts)
    if not ckpt_paths:
        raise SystemExit("No checkpoints matched the given patterns.")

    rows = []
    for ck in ckpt_paths:
        model, meta = load_checkpoint_build_model(ck, insert_level_override=args.insert_level, device=device, transform_img_size=args.img_size)
        # Params + FLOPs
        flp = try_flops_params(copy.deepcopy(model).to("cpu").eval(), image_size=(3, args.img_size, args.img_size))
        # Speed
        spd = benchmark_inference_random(model, device=device, img_size=args.img_size,
                                         batch_size=args.batch_size, warmup_batches=args.warmup_batches,
                                         measure_images=args.measure_images)

        row = {
            "checkpoint": ck,
            "experiment": meta.get("experiment_name",""),
            "mode": meta["mode"],
            "insert_level": meta["insert_level"],
            "num_classes": meta["num_classes"],
            "img_size": args.img_size,
            "params_M": round(flp["params_M"], 3),
            "FLOPs_G": None if flp["FLOPs_G"] is None else round(flp["FLOPs_G"], 3),
            "latency_ms_per_image": round(spd["latency_ms_per_image"], 3),
            "images_per_second": round(spd["images_per_second"], 3),
            "device": str(device),
            "batch_size": args.batch_size,
        }
        rows.append(row)
        # Print one-line summary
        print(f"[OK] {os.path.basename(ck)} | {row['mode']} {row['insert_level']} | "
              f"Params {row['params_M']}M | FLOPs {row['FLOPs_G']}G | "
              f"{row['images_per_second']} img/s @ {args.img_size} on {device}")

    # Optional CSV
    if args.report_csv:
        os.makedirs(os.path.dirname(args.report_csv) or ".", exist_ok=True)
        with open(args.report_csv, "w", newline="") as f:
            wr = csv.DictWriter(f, fieldnames=list(rows[0].keys()))
            wr.writeheader()
            wr.writerows(rows)
        print(f"[WROTE] {args.report_csv}")

if __name__ == "__main__":
    main()
