In [None]:
import os, json, glob
import numpy as np
from tqdm import tqdm
import pandas as pd

import torch
from mmengine.config import Config
from mmengine.runner import Runner
from mmdet.apis import init_detector, inference_detector

# ---- Base paths (EDIT) ----
DATA_ROOT   = r'C:/Users/heheh/mmdetection/data/coco'
CONFIG      = r'C:/Users/heheh/mmdetection/configs/panoptic_fpn/panoptic-fpn_r50_fpn_1x_coco.py'
CHECKPOINT  = r'C:/Users/heheh/mmdetection/checkpoints/panoptic_fpn_r50_fpn_1x_coco_20210821_101153-9668fd13.pth'

IMG_DIR_CLEAN = os.path.join(DATA_ROOT, 'val2017')
PAN_JSON  = os.path.join(DATA_ROOT, 'annotations', 'panoptic_val2017.json')
PAN_SEG   = os.path.join(DATA_ROOT, 'annotations', 'panoptic_val2017')
WORK_DIR  = './work_dirs/compare_multi'

# Add as many dirs as you want:
ADV_DIRS = {
    'deeplabv3_pgd': r'C:/Users/heheh/mmsegmentation/data/semantic_adv_deeplabv3',
    'maskrcnn_pgd':  r'C:/Users/heheh/mmdetection/data/coco/instance_maskrcnn_adv',
    'panoptic_pgd':  r'C:/Users/heheh/mmdetection/data/coco/panoptic_fpn_adv',
}

BATCH = 2
NUM_WORKERS = 2
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
PREDCHANGE_THRESH = 0.05  # 5% pixels changed => success for ASR-predchange

SAMPLE_N = 5000            # evaluate at most 100 images per adv setting
SAMPLE_SEED = 42
SAMPLE_MODE = 'random'     # 'random' or 'first'
PREDCHANGE_THRESH = 0.05   # success if ≥5% pixels change

def existing_indices(ann_json, img_root):
    with open(ann_json, 'r') as f:
        data = json.load(f)
    names = [im['file_name'] for im in data['images']]
    idx = [i for i, n in enumerate(names) if os.path.exists(os.path.join(img_root, n))]
    return idx, names

def pick_sample(indices, n, mode='random', seed=0):
    """Pick up to n indices deterministically."""
    if not indices:
        return []
    if mode == 'first':
        return indices[:min(n, len(indices))]
    rng = np.random.default_rng(seed)
    if len(indices) <= n:
        return list(indices)
    return list(rng.choice(indices, size=n, replace=False))

def restrict_to_existing(indices, names, root, n=None, seed=0, mode='random'):
    """Keep only indices whose image file exists under root; optionally subsample to n."""
    keep = [i for i in indices if os.path.exists(os.path.join(root, names[i]))]
    if n is not None and len(keep) > n:
        keep = pick_sample(keep, n, mode=mode, seed=seed)
    return keep

def build_cfg(img_root, tag):
    cfg = Config.fromfile(CONFIG)
    cfg.default_scope = 'mmdet'
    cfg.load_from = CHECKPOINT
    cfg.work_dir = os.path.join(WORK_DIR, tag)
    os.makedirs(cfg.work_dir, exist_ok=True)

    ds = cfg.test_dataloader.dataset
    ds.type = 'CocoPanopticDataset'
    ds.data_root = DATA_ROOT
    ds.ann_file = PAN_JSON
    ds.data_prefix = dict(img=img_root, seg=PAN_SEG)
    ds.test_mode = True

    cfg.test_dataloader.batch_size = BATCH
    cfg.test_dataloader.num_workers = NUM_WORKERS

    cfg.test_evaluator = dict(
        type='CocoPanopticMetric',
        ann_file=PAN_JSON,
        seg_prefix=PAN_SEG,
        format_only=False
    )
    if hasattr(cfg, 'visualizer'):
        cfg.visualizer.vis_backends = None
    return cfg

def run_eval(cfg):
    runner = Runner.from_cfg(cfg)
    return runner.test()  # dict

def get_pred_sem_map(ds):
    """
    Return (H,W) int64 semantic category map from a DetDataSample.
    Prefers semantic branch; else builds semantic map from panoptic ids + segments_info.
    Falls back to INSTANCE_OFFSET decoding if segments_info is missing.
    """
    import numpy as np
    import torch

    # 1) Prefer semantic prediction if present
    sem = getattr(ds, 'pred_sem_seg', None)
    if sem is not None:
        seg = getattr(sem, 'sem_seg', None)
        if seg is None:
            seg = getattr(sem, 'data', None)
        if isinstance(seg, torch.Tensor):
            if seg.ndim == 3:  # (C,H,W) logits
                seg = seg.argmax(dim=0)
            seg = seg.detach().cpu().to(torch.int64).numpy()
        else:
            seg = np.asarray(seg, dtype=np.int64)
        return seg

    # 2) Panoptic prediction
    p = getattr(ds, 'pred_panoptic_seg', None)
    if p is None:
        return None

    # sequentially try fields (NO "or" with tensors!)
    pan = getattr(p, 'panoptic_seg', None)
    if pan is None:
        pan = getattr(p, 'sem_seg', None)
    if pan is None:
        pan = getattr(p, 'data', None)

    if pan is None:
        return None

    if isinstance(pan, torch.Tensor):
        pan = pan.detach().cpu()
    pan = np.asarray(pan)

    # segments_info can be on PixelData or in metainfo
    segs = getattr(p, 'segments_info', None)
    if segs is None:
        meta = getattr(p, 'metainfo', None)
        if isinstance(meta, dict):
            segs = meta.get('segments_info', None)

    if segs:
        sem_map = np.zeros_like(pan, dtype=np.int64)
        for s in segs:
            sid = int(s['id'] if isinstance(s, dict) else s.id)
            cid = int(s['category_id'] if isinstance(s, dict) else s.category_id)
            sem_map[pan == sid] = cid
        return sem_map

    # 3) Fallback: decode with common INSTANCE_OFFSET convention
    INSTANCE_OFFSET = 1000
    return (pan % INSTANCE_OFFSET).astype(np.int64)


def pull_metric(m, which):  # which in {'PQ','SQ','RQ'}
    if which in m:
        return m[which]
    # try common variants
    lk = which.lower()
    for k, v in m.items():
        kl = k.lower()
        if kl == lk or kl.endswith('.' + lk):
            # scale if looks like a fraction
            return float(v) * 100.0 if (isinstance(v, (int,float)) and v <= 1.5) else float(v)
    return float('nan')

def _pick_case_insensitive(d, key):
    # get d[key] ignoring case; returns None if missing
    lk = key.lower()
    for k, v in d.items():
        if k.lower() == lk:
            return v
    return None

def _scale_if_fraction(x):
    # If looks like 0–1, convert to %
    return float(x) * 100.0 if isinstance(x, (int, float)) and 0.0 <= x <= 1.5 else float(x)
import re
import numpy as np

def pull_panoptic_all(m):
    """
    Return (PQ_all, SQ_all, RQ_all) from CocoPanopticMetric results across MMDet 3.x variants.
    Looks for:
      - 'coco_panoptic/PQ' (preferred)
      - any key ending with '/PQ', '/SQ', '/RQ'
      - avoids *_th, *_st, 'things', 'stuff'
      - also handles nested dicts under 'All'
    """
    def pick(key_base):
        key_base_l = key_base.lower()
        best = None

        # 1) direct match like 'coco_panoptic/PQ'
        for k, v in m.items():
            kl = k.lower()
            if kl == f'coco_panoptic/{key_base_l}':
                return float(v)

        # 2) suffix '/PQ' (and not *_th/_st)
        pat = re.compile(rf'(^|[/\.]){key_base_l}$')
        for k, v in m.items():
            kl = k.lower()
            if any(x in kl for x in ['_th', '_st', 'things', 'stuff']):
                continue
            if pat.search(kl):
                best = float(v)
                break

        if best is not None:
            return best

        # 3) nested under 'All': {'All': {'PQ': ...}}
        all_block = None
        for k, v in m.items():
            if k.lower() == 'all' and isinstance(v, dict):
                all_block = v
                break
        if all_block is not None:
            for k, v in all_block.items():
                if k.lower() == key_base_l and isinstance(v, (int, float)):
                    return float(v)

        # 4) nested under 'PQ': {'PQ': {'All': ...}}
        block = None
        for k, v in m.items():
            if k.lower() == key_base_l and isinstance(v, dict):
                block = v
                break
        if block is not None:
            for k, v in block.items():
                if k.lower() == 'all' and isinstance(v, (int, float)):
                    return float(v)

        return np.nan

    pq = pick('PQ'); sq = pick('SQ'); rq = pick('RQ')
    # (optional) scale 0–1 → %
    for name, val in [('PQ', pq), ('SQ', sq), ('RQ', rq)]:
        if isinstance(val, (int, float)) and 0.0 <= val <= 1.5:
            if name == 'PQ': pq = float(val) * 100.0
            if name == 'SQ': sq = float(val) * 100.0
            if name == 'RQ': rq = float(val) * 100.0
    return pq, sq, rq

def pull_all_metrics(m):
    """
    Robustly get (PQ_all, SQ_all, RQ_all) from CocoPanopticMetric outputs across MMDet versions.
    Returns floats or NaN.
    """
    import numpy as np

    pq = sq = rq = np.nan

    # 1) direct top-level numbers
    if isinstance(_pick_case_insensitive(m, 'PQ'), (int, float)):
        pq = _scale_if_fraction(_pick_case_insensitive(m, 'PQ'))
    if isinstance(_pick_case_insensitive(m, 'SQ'), (int, float)):
        sq = _scale_if_fraction(_pick_case_insensitive(m, 'SQ'))
    if isinstance(_pick_case_insensitive(m, 'RQ'), (int, float)):
        rq = _scale_if_fraction(_pick_case_insensitive(m, 'RQ'))

    # 2) nested under 'PQ'/'SQ'/'RQ' → {'PQ': {'All': ...}, ...}
    if np.isnan(pq):
        d = _pick_case_insensitive(m, 'PQ')
        if isinstance(d, dict):
            v = _pick_case_insensitive(d, 'All')
            if isinstance(v, (int, float)): pq = _scale_if_fraction(v)
    if np.isnan(sq):
        d = _pick_case_insensitive(m, 'SQ')
        if isinstance(d, dict):
            v = _pick_case_insensitive(d, 'All')
            if isinstance(v, (int, float)): sq = _scale_if_fraction(v)
    if np.isnan(rq):
        d = _pick_case_insensitive(m, 'RQ')
        if isinstance(d, dict):
            v = _pick_case_insensitive(d, 'All')
            if isinstance(v, (int, float)): rq = _scale_if_fraction(v)

    # 3) nested under 'All' → {'All': {'PQ': ..., 'SQ': ..., 'RQ': ...}}
    d_all = _pick_case_insensitive(m, 'All')
    if isinstance(d_all, dict):
        if np.isnan(pq):
            v = _pick_case_insensitive(d_all, 'PQ')
            if isinstance(v, (int, float)): pq = _scale_if_fraction(v)
        if np.isnan(sq):
            v = _pick_case_insensitive(d_all, 'SQ')
            if isinstance(v, (int, float)): sq = _scale_if_fraction(v)
        if np.isnan(rq):
            v = _pick_case_insensitive(d_all, 'RQ')
            if isinstance(v, (int, float)): rq = _scale_if_fraction(v)

    # 4) flattened keys like 'All.pq'
    for k, v in m.items():
        if not isinstance(v, (int, float)): 
            continue
        kl = k.lower()
        if 'all' in kl and ('.pq' in kl or kl.endswith('pq')) and np.isnan(pq):
            pq = _scale_if_fraction(v)
        elif 'all' in kl and ('.sq' in kl or kl.endswith('sq')) and np.isnan(sq):
            sq = _scale_if_fraction(v)
        elif 'all' in kl and ('.rq' in kl or kl.endswith('rq')) and np.isnan(rq):
            rq = _scale_if_fraction(v)

    return pq, sq, rq

def asr_predchange(adv_root, model, subset_names, predchange_thresh=PREDCHANGE_THRESH):
    """
    Compute:
      - ASR_predchange = successes/N (success if ≥ predchange_thresh of pixels changed)
      - avg_changed_pixels = mean fraction changed across N images
    Only evaluates images present in adv_root within subset_names.
    """
    names = [n for n in subset_names if os.path.exists(os.path.join(adv_root, n))]
    if not names:
        return np.nan, 0, np.nan, (np.nan, np.nan), np.nan  # asr, N, mean, CI, std

    diffrates = []
    successes = 0
    for n in tqdm(names, desc=f'ASR-predchange ({os.path.basename(adv_root)})'):
        clean_path = os.path.join(IMG_DIR_CLEAN, n)
        adv_path   = os.path.join(adv_root, n)

        ds_c = inference_detector(model, clean_path)

        ds_a = inference_detector(model,   adv_path)
        if not hasattr(get_pred_sem_map, "_once"):
            lc_dbg = get_pred_sem_map(ds_c); la_dbg = get_pred_sem_map(ds_a)
            print("[dbg] shapes:", None if lc_dbg is None else lc_dbg.shape,
                                None if la_dbg is None else la_dbg.shape)
            get_pred_sem_map._once = True
            
        lc = get_pred_sem_map(ds_c)
        la = get_pred_sem_map(ds_a)
        if lc is None or la is None:
            continue
        if lc.shape != la.shape:
            # resize adv semantic map to clean shape (nearest)
            import cv2
            la = cv2.resize(la, (lc.shape[1], lc.shape[0]), interpolation=cv2.INTER_NEAREST)

        diff = (lc != la).mean()
        diffrates.append(diff)
        if diff >= predchange_thresh:
            successes += 1

    N = len(diffrates)
    if N == 0:
        return np.nan, 0, np.nan, (np.nan, np.nan), np.nan

    asr = successes / N
    # 95% Wilson CI for a proportion
    z = 1.96
    denom = 1 + z**2 / N
    center = (asr + z*z/(2*N)) / denom
    margin = z * np.sqrt((asr*(1-asr) + z*z/(4*N)) / N) / denom
    ci_low, ci_high = max(0.0, center - margin), min(1.0, center + margin)

    return asr, N, float(np.mean(diffrates)), (ci_low, ci_high), float(np.std(diffrates))

# ---------- Build global name list ----------
all_idx_clean, all_names = existing_indices(PAN_JSON, IMG_DIR_CLEAN)
# Base candidate pool (images that exist in clean dir)
base_pool = all_idx_clean

# Build once for inference-based ASR
pan_model = init_detector(CONFIG, CHECKPOINT, device=DEVICE)

# ---------- 2) Loop all ADV dirs ----------
rows = []
import pandas as pd

_, all_names = existing_indices(PAN_JSON, IMG_DIR_CLEAN)

for tag, adv_dir in ADV_DIRS.items():
    if not os.path.isdir(adv_dir):
        print(f"[WARN] skip missing dir: {adv_dir}")
        continue
     # Pick a 100-image subset that also exists in ADV dir (deterministic)
    adv_subset_idx = restrict_to_existing(
        indices=pick_sample(base_pool, SAMPLE_N, mode=SAMPLE_MODE, seed=SAMPLE_SEED),
        names=all_names,
        root=adv_dir,
        n=SAMPLE_N, seed=SAMPLE_SEED, mode=SAMPLE_MODE
    )
    subset_names = [all_names[i] for i in adv_subset_idx]
    
    # --- Clean metrics on the same subset
    cfg_clean = build_cfg(IMG_DIR_CLEAN, f'pan_clean_{tag}')
    cfg_clean.test_dataloader.dataset.indices = adv_subset_idx
    met_clean = run_eval(cfg_clean)
    # --- Adv metrics on the same subset
    cfg_adv = build_cfg(adv_dir, f'pan_adv_{tag}')
    cfg_adv.test_dataloader.dataset.indices = adv_subset_idx
    met_adv = run_eval(cfg_adv)

    # --- ASR on the same subset (no GT)
    asr, N_asr, avg_diff, (ci_lo, ci_hi), diff_std = asr_predchange(adv_dir, pan_model, subset_names)
    
    pq_c, sq_c, rq_c = pull_panoptic_all(met_clean)
    pq_a, sq_a, rq_a = pull_panoptic_all(met_adv)
    
    
# (B) Build LONG format with split={clean, adv}, then pivot to columns
rows_long = []
for r in rows:
    # CLEAN row
    rows_long.append({
        'adv_tag': r['adv_tag'],
        'split': 'clean',
        'PQ': r['PQ_clean'],
        'SQ': r['SQ_clean'],
        'RQ': r['RQ_clean'],
        # ASR is adv-only; keep NaN for clean
        'ASR_predchange': np.nan,
        'ASR_95CI_low': np.nan,
        'ASR_95CI_high': np.nan,
        'ASR_N': np.nan,
        'avg_changed_pixels': np.nan,
        'std_changed_pixels': np.nan,
        'thresh': r['thresh'],
        'images_eval': r['images_eval'],
    })
    # ADV row
    rows_long.append({
        'adv_tag': r['adv_tag'],
        'split': 'adv',
        'PQ': r['PQ_adv'],
        'SQ': r['SQ_adv'],
        'RQ': r['RQ_adv'],
        'ASR_predchange': r['ASR_predchange'],
        'ASR_95CI_low': r['ASR_95CI_low'],
        'ASR_95CI_high': r['ASR_95CI_high'],
        'ASR_N': r['ASR_N'],
        'avg_changed_pixels': r['avg_changed_pixels'],
        'std_changed_pixels': r['std_changed_pixels'],
        'thresh': r['thresh'],
        'images_eval': r['images_eval'],
    })

df_long = pd.DataFrame(rows_long)

# List the value columns we want to pivot
value_cols = [
    'PQ','SQ','RQ',
    'ASR_predchange','ASR_95CI_low','ASR_95CI_high','ASR_N',
    'avg_changed_pixels','std_changed_pixels','thresh','images_eval'
]

# Pivot so we get columns like PQ_clean, PQ_adv, ...
df_wide = df_long.pivot(index='adv_tag', columns='split', values=value_cols)

# Add delta columns for PQ, SQ, RQ
for m in ['PQ', 'SQ', 'RQ']:
    df_wide[(m, 'delta')] = df_wide[(m, 'adv')] - df_wide[(m, 'clean')]

# Flatten MultiIndex columns -> "PQ_clean", "PQ_adv", "PQ_delta", ...
df_wide.columns = [f'{a}_{b}' for a, b in df_wide.columns]
df_wide = df_wide.reset_index()

# (optional) reorder columns
ordered = (
    ['adv_tag', 'images_eval_clean'] +
    [f'{m}_{s}' for m in ['PQ','SQ','RQ'] for s in ['clean','adv','delta']] +
    ['ASR_predchange_adv','ASR_95CI_low_adv','ASR_95CI_high_adv','ASR_N_adv',
     'avg_changed_pixels_adv','std_changed_pixels_adv','thresh_adv']
)
# Keep only those that exist
ordered = [c for c in ordered if c in df_wide.columns]
df_wide = df_wide[ordered]

out_csv_wide = f'./panoptic_transfer_subset{SAMPLE_N}_seed{SAMPLE_SEED}_clean_as_columns.csv'
df_wide.to_csv(out_csv_wide, index=False)

# Pretty print (optional)
def f3(x): 
    return "—" if (x is None or (isinstance(x, float) and np.isnan(x))) else f"{x:.3f}"

print("\n=== PanopticFPN — results (clean/adv as columns) ===")
show_cols = [c for c in df_wide.columns if c != 'adv_tag']
print(df_wide[['adv_tag'] + show_cols].to_string(index=False, formatters={c: f3 for c in show_cols}))
print(f'\n[Saved] {out_csv_wide}')
