In [None]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
v4: MMSeg direct inference (no pipeline / no PackSegInputs).
    - Build MMSeg model from cfg/ckpt (init_model)
    - Read image as numpy, build minimal SegDataSample meta
    - Use model.data_preprocessor + model.predict
    - Visualize via model.visualizer.add_datasample

MMDet paths (Mask R-CNN, PanopticFPN) remain as in v2.
"""

import os
import random
from collections import OrderedDict
from typing import Dict, Tuple, List

import numpy as np
from PIL import Image, ImageDraw, ImageFont

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

# ======== USER CONFIG ========
# DeeplabV3 (MMSeg, COCO-Stuff-164k)
MMSEG_CFG  = r'C:/Users/heheh/mmsegmentation/configs/deeplabv3/deeplabv3_r50-d8_4xb4-80k_coco-stuff164k-512x512.py'   # EDIT
MMSEG_CKPT = r'C:/Users/heheh/mmsegmentation/checkpoints/deeplabv3_r50-d8_512x512_4x4_80k_coco-stuff164k_20210709_163016-88675c24.pth'                 # EDIT

# Mask R-CNN (MMDet)
MMDET_MASKRCNN_CFG  = r'C:/Users/heheh/mmdetection/configs/mask_rcnn/mask-rcnn_r50_fpn_1x_coco.py'
MMDET_MASKRCNN_CKPT = r'C:/Users/heheh/mmdetection/checkpoints/mask_rcnn_r50_fpn_1x_coco_20200205-d4b0c5d6.pth'

# Panoptic FPN (MMDet)
MMDET_PANO_CFG  = r'C:/Users/heheh/mmdetection/configs/panoptic_fpn/panoptic-fpn_r50_fpn_1x_coco.py'
MMDET_PANO_CKPT = r'C:/Users/heheh/mmdetection/checkpoints/panoptic_fpn_r50_fpn_1x_coco_20210821_101153-9668fd13.pth'

DEVICE = 'cuda'  # or 'cpu'

# --- Datasets (directories) ---
# Each directory must contain image files named identically across variants (e.g., "000000123456.jpg").
CLEAN_DIR     = r'C:/Users/heheh/mmdetection/data/coco/val2017'                 # originals
SEM_ADV_DIR   = r'C:/Users/heheh/mmsegmentation/data/semantic_adv_deeplabv3'          # semantic adversarial
INST_ADV_DIR  = r'C:/Users/heheh/mmdetection/data/coco/instance_maskrcnn_adv'           # instance adversarial
PANO_ADV_DIR  = r'C:/Users/heheh/mmdetection/data/coco/panoptic_fpn_adv'       # panoptic adversarial

NUM_IMAGES = 5
RANDOM_SEED = 123

# Output directory
OUT_DIR = r'C:/Users/heheh/cross_paradigm_vis'

# Optional uniform resize before inference for nicer grids (None to disable)
RESIZE_TO = None  # e.g., (512, 512)

# Opacity for MMSeg overlay
OPACITY = 1.0

# =============================

def _init_mmseg_model(cfg, ckpt, device='cuda'):
    from mmseg.apis import init_model as init_segmentor
    model = init_segmentor(cfg, ckpt, device=device)
    model.eval()
    # Give the visualizer a save dir
    if getattr(model, 'visualizer', None) is not None:
        try:
            model.visualizer.set_save_dir(OUT_DIR)
        except Exception:
            pass
    return model

def _mmseg_predict_direct_and_vis(model, img_path, out_file, opacity=0.75):
    """Direct MMSeg inference without pipeline. No PackSegInputs needed."""
    import torch
    from mmseg.structures import SegDataSample

    # Read image
    img = np.array(Image.open(img_path).convert('RGB'))
    H, W = img.shape[:2]

    # Build minimal data_sample with metainfo
    data_sample = SegDataSample()
    data_sample.set_metainfo(dict(
        ori_shape=(H, W),
        img_shape=(H, W),
        pad_shape=(H, W),
        scale_factor=1.0,
        img_path=img_path
    ))

    # Preprocess (handles normalization, channel conversion, stacking)
    batch = model.data_preprocessor(
    data=dict(inputs=[img], data_samples=[data_sample]),
    training=False
    )  
    inputs = batch['inputs']          # Tensor (N,C,H,W)
    samples = batch['data_samples']   # List[SegDataSample]

    with torch.no_grad():
        preds = model.predict(inputs, samples)  # List[SegDataSample]

    # Visualize
    vis = getattr(model, 'visualizer', None)
    if vis is None:
        from mmseg.visualization import SegLocalVisualizer
        vis = SegLocalVisualizer()
        vis.dataset_meta = getattr(model, 'dataset_meta', None)
    else:
        if getattr(vis, 'dataset_meta', None) is None:
            vis.dataset_meta = getattr(model, 'dataset_meta', None)

    vis.add_datasample(
        name='pred',
        image=img,
        data_sample=preds[0],
        draw_gt=False,
        out_file=out_file,
        wait_time=0,
        opacity=opacity
    )

def _init_mmdet_model(cfg, ckpt, device='cuda'):
    from mmdet.apis import init_detector
    model = init_detector(cfg, ckpt, device=device)
    if getattr(model, 'visualizer', None) is not None:
        try:
            model.visualizer.set_save_dir(OUT_DIR)
        except Exception:
            pass
    return model

def _mmdet_infer_and_vis(model, img_path, out_file):
    from mmdet.apis import inference_detector
    result = inference_detector(model, img_path)

    img = np.array(Image.open(img_path).convert('RGB'))
    vis = getattr(model, 'visualizer', None)
    if vis is None:
        from mmdet.visualization import DetLocalVisualizer
        vis = DetLocalVisualizer()
        meta = getattr(model, 'dataset_meta', None) or getattr(model, 'metainfo', None)
        vis.dataset_meta = meta
    else:
        if getattr(vis, 'dataset_meta', None) is None:
            meta = getattr(model, 'dataset_meta', None) or getattr(model, 'metainfo', None)
            vis.dataset_meta = meta

    vis.add_datasample(
        name='pred',
        image=img,
        data_sample=result,
        draw_gt=False,
        out_file=out_file,
        wait_time=0
    )

def _maybe_resize(src_path, dst_path, size_wh=None):
    if size_wh is None:
        return src_path
    W, H = size_wh
    img = Image.open(src_path).convert('RGB').resize((W, H), Image.BILINEAR)
    os.makedirs(os.path.dirname(dst_path), exist_ok=True)
    img.save(dst_path, quality=95)
    return dst_path

def _collect_basenames(img_dir: str) -> List[str]:
    exts = ('.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff', '.webp')
    names = [f for f in os.listdir(img_dir) if f.lower().endswith(exts)]
    names.sort()
    return names

def _compose_grid(cells: Dict[Tuple[int,int], str],
                  n_rows: int, n_cols: int,
                  titles_row: List[str], titles_col: List[str],
                  out_file: str):
    tile_w, tile_h = 0, 0
    pil_tiles = {}
    for r in range(n_rows):
        for c in range(n_cols):
            p = cells.get((r, c))
            if p and os.path.isfile(p):
                im = Image.open(p).convert('RGB')
            else:
                im = Image.new('RGB', (640, 480), color=(230,230,230))
                ImageDraw.Draw(im).text((10, 10), "MISSING", fill=(0,0,0))
            pil_tiles[(r,c)] = im
            tile_w = max(tile_w, im.width)
            tile_h = max(tile_h, im.height)

    header_h = 40
    left_w   = 160
    gap = 6
    total_w = left_w + n_cols*tile_w + (n_cols+1)*gap
    total_h = header_h + n_rows*tile_h + (n_rows+1)*gap

    canvas = Image.new('RGB', (total_w, total_h), color=(255,255,255))
    draw = ImageDraw.Draw(canvas)
    try:
        font = ImageFont.truetype("arial.ttf", 20)
    except Exception:
        font = ImageFont.load_default()

    for c in range(n_cols):
        x = left_w + gap + c*(tile_w + gap)
        y = gap
        draw.text((x, y), titles_col[c], fill=(0,0,0), font=font)

    for r in range(n_rows):
        ry = header_h + gap + r*(tile_h + gap)
        draw.text((10, ry + 10), titles_row[r], fill=(0,0,0), font=font)
        for c in range(n_cols):
            x = left_w + gap + c*(tile_w + gap)
            y = header_h + gap + r*(tile_h + gap)
            im = pil_tiles[(r,c)]
            canvas.paste(im, (x, y))

    os.makedirs(os.path.dirname(out_file), exist_ok=True)
    canvas.save(out_file, quality=95)

def main():
    random.seed(RANDOM_SEED)
    np.random.seed(RANDOM_SEED)

    for label, d in [('Clean', CLEAN_DIR), ('Semantic-Adv', SEM_ADV_DIR),
                     ('Instance-Adv', INST_ADV_DIR), ('Panoptic-Adv', PANO_ADV_DIR)]:
        if not os.path.isdir(d):
            print(f"[WARN] Directory does not exist: {d} ({label})")

    basenames = _collect_basenames(CLEAN_DIR)
    if NUM_IMAGES is not None and len(basenames) > NUM_IMAGES:
        basenames = random.sample(basenames, NUM_IMAGES)
    print(f"[INFO] Will visualize {len(basenames)} images.")

    vis_dir = os.path.join(OUT_DIR, 'cells')
    grid_dir = os.path.join(OUT_DIR, 'grids')
    os.makedirs(vis_dir, exist_ok=True)
    os.makedirs(grid_dir, exist_ok=True)

    print("[INFO] Loading models...")
    seg_model  = _init_mmseg_model(MMSEG_CFG, MMSEG_CKPT, device=DEVICE)
    mask_model = _init_mmdet_model(MMDET_MASKRCNN_CFG, MMDET_MASKRCNN_CKPT, device=DEVICE)
    pano_model = _init_mmdet_model(MMDET_PANO_CFG, MMDET_PANO_CKPT, device=DEVICE)

    dataset_rows = OrderedDict([
        ('Clean',        CLEAN_DIR),
        ('Semantic-Adv', SEM_ADV_DIR),
        ('Instance-Adv', INST_ADV_DIR),
        ('Panoptic-Adv', PANO_ADV_DIR),
    ])
    model_cols = OrderedDict([
        ('DeeplabV3',  ('mmseg', seg_model)),
        ('Mask R-CNN', ('mmdet', mask_model)),
        ('PanopticFPN',('mmdet', pano_model)),
    ])

    tmp_resize_root = os.path.join(OUT_DIR, '_tmp_resized') if RESIZE_TO is not None else None
    if tmp_resize_root:
        os.makedirs(tmp_resize_root, exist_ok=True)

    for i, base in enumerate(basenames):
        print(f"[{i+1}/{len(basenames)}] {base}")
        cell_paths: Dict[Tuple[int,int], str] = {}

        for r, (row_name, row_dir) in enumerate(dataset_rows.items()):
            src_img = os.path.join(row_dir, base)
            if not os.path.isfile(src_img):
                print(f"  [WARN] Missing image for row {row_name}: {src_img}")
                resized_path = None
            else:
                if tmp_resize_root:
                    resized_path = _maybe_resize(
                        src_img,
                        os.path.join(tmp_resize_root, row_name, base),
                        RESIZE_TO
                    )
                else:
                    resized_path = src_img

            for c, (col_name, (family, model)) in enumerate(model_cols.items()):
                out_cell = os.path.join(vis_dir, f"{os.path.splitext(base)[0]}__{row_name.replace(' ','_')}__{col_name.replace(' ','_')}.jpg")
                if resized_path is None:
                    cell_paths[(r, c)] = None
                    continue

                try:
                    if family == 'mmseg':
                        _mmseg_predict_direct_and_vis(model, resized_path, out_cell, opacity=OPACITY)
                    else:
                        _mmdet_infer_and_vis(model, resized_path, out_cell)
                    cell_paths[(r, c)] = out_cell
                except Exception as e:
                    print(f"  [ERR] {col_name} failed on {row_name}/{base}: {e}")
                    cell_paths[(r, c)] = None

        grid_path = os.path.join(grid_dir, f"{os.path.splitext(base)[0]}__grid.jpg")
        _compose_grid(
            cells=cell_paths,
            n_rows=len(dataset_rows),
            n_cols=len(model_cols),
            titles_row=list(dataset_rows.keys()),
            titles_col=list(model_cols.keys()),
            out_file=grid_path
        )

    print(f"[DONE] Grids saved to: {grid_dir}")

if __name__ == '__main__':
    main()
