# AD Context MCQ Ablation (Colab)

This notebook compares three settings on the exact same sampled images:
1. no AD
2. legacy AD context
3. v2 AD context

Outputs include overall accuracy and per-question-type accuracy.

In [None]:
import json
import os
import random
import shlex
import subprocess
from collections import defaultdict
from datetime import datetime
from pathlib import Path

import pandas as pd

In [None]:
# ===== User config =====
REPO_ROOT = Path('/content/multimodal-anomaly-report-generation')
DATA_ROOT = '/content/dataset/MMAD'
MMAD_JSON = '/content/dataset/MMAD/mmad_10classes.json'
CHECKPOINT_DIR = '/content/dataset/MMAD/patchcore_ckpt'
LLM_MODEL = 'internvl3.5-2b'
FEW_SHOT = 1
BATCH_MODE = 'true'
SAMPLE_PER_FOLDER = 3
SAMPLE_SEED = 42

# Optional: use only part of sampled images for quick debug
MAX_IMAGES = None  # e.g., 100

RUN_TAG = datetime.now().strftime('%Y%m%d_%H%M%S')
WORK_DIR = REPO_ROOT / 'outputs' / 'eval_ablation' / RUN_TAG
WORK_DIR.mkdir(parents=True, exist_ok=True)

SAMPLED_MMAD_JSON = WORK_DIR / '_sampled_mmad.json'
AD_LEGACY_JSON = WORK_DIR / 'ad_legacy.json'
AD_V2_JSON = WORK_DIR / 'ad_v2.json'

OUT_NO_AD = WORK_DIR / 'no_ad'
OUT_LEGACY = WORK_DIR / 'legacy_ad'
OUT_V2 = WORK_DIR / 'v2_ad'

for p in [OUT_NO_AD, OUT_LEGACY, OUT_V2]:
    p.mkdir(parents=True, exist_ok=True)

print('WORK_DIR:', WORK_DIR)

In [None]:
def run_cmd(cmd: str, cwd: Path = REPO_ROOT):
    print('\n$ ' + cmd)
    proc = subprocess.run(
        cmd,
        cwd=str(cwd),
        shell=True,
        text=True,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
    )
    print(proc.stdout)
    if proc.returncode != 0:
        raise RuntimeError(f'Command failed ({proc.returncode}): {cmd}')


def stratified_sample_paths(image_paths, n_per_folder: int, seed: int = 42):
    rng = random.Random(seed)
    folders = defaultdict(list)
    for path in image_paths:
        parts = path.split('/')
        if len(parts) >= 4:
            key = f'{parts[0]}/{parts[1]}/{parts[3]}'
        elif len(parts) >= 2:
            key = f'{parts[0]}/{parts[1]}'
        else:
            key = 'unknown'
        folders[key].append(path)

    sampled = []
    for key in sorted(folders.keys()):
        imgs = folders[key]
        sampled.extend(rng.sample(imgs, min(n_per_folder, len(imgs))))
    return sampled


def build_sampled_mmad(input_json: str, output_json: Path, n_per_folder: int, seed: int = 42, max_images=None):
    with open(input_json, 'r', encoding='utf-8') as f:
        mmad_data = json.load(f)

    sampled_paths = stratified_sample_paths(list(mmad_data.keys()), n_per_folder=n_per_folder, seed=seed)
    if max_images is not None:
        sampled_paths = sampled_paths[:max_images]

    sampled_data = {k: mmad_data[k] for k in sampled_paths}
    with open(output_json, 'w', encoding='utf-8') as f:
        json.dump(sampled_data, f, ensure_ascii=False, indent=2)

    n_good = sum('/good/' in p for p in sampled_paths)
    print('sampled images:', len(sampled_paths), 'good:', n_good, 'anomaly:', len(sampled_paths) - n_good)
    print('sample json:', output_json)


def latest_meta(output_dir: Path):
    metas = sorted(output_dir.glob('*.meta.json'), key=lambda p: p.stat().st_mtime)
    if not metas:
        raise FileNotFoundError(f'No .meta.json found in {output_dir}')
    return metas[-1]


def question_type_accuracy(answers_json_path: Path):
    with open(answers_json_path, 'r', encoding='utf-8') as f:
        rows = json.load(f)

    by_type = defaultdict(lambda: {'total': 0, 'correct': 0})
    for r in rows:
        qt = r.get('question_type', 'unknown')
        by_type[qt]['total'] += 1
        if r.get('gpt_answer') == r.get('correct_answer'):
            by_type[qt]['correct'] += 1

    out = []
    for qt, s in by_type.items():
        acc = 100.0 * s['correct'] / s['total'] if s['total'] else 0.0
        out.append({'question_type': qt, 'total': s['total'], 'correct': s['correct'], 'accuracy': round(acc, 2)})
    return pd.DataFrame(out).sort_values(['accuracy', 'total'], ascending=[False, False]).reset_index(drop=True)

In [None]:
build_sampled_mmad(
    input_json=MMAD_JSON,
    output_json=SAMPLED_MMAD_JSON,
    n_per_folder=SAMPLE_PER_FOLDER,
    seed=SAMPLE_SEED,
    max_images=MAX_IMAGES,
)

In [None]:
# Build legacy AD context JSON
cmd_legacy = ' '.join([
    'python scripts/run_ad_inference.py',
    '--backend ckpt',
    f'--checkpoint-dir {shlex.quote(CHECKPOINT_DIR)}',
    f'--data-root {shlex.quote(DATA_ROOT)}',
    f'--mmad-json {shlex.quote(str(SAMPLED_MMAD_JSON))}',
    f'--output {shlex.quote(str(AD_LEGACY_JSON))}',
    '--output-format report',
    '--device cpu',
    '--context-mode legacy',
])
run_cmd(cmd_legacy)

# Build v2 AD context JSON
cmd_v2 = ' '.join([
    'python scripts/run_ad_inference.py',
    '--backend ckpt',
    f'--checkpoint-dir {shlex.quote(CHECKPOINT_DIR)}',
    f'--data-root {shlex.quote(DATA_ROOT)}',
    f'--mmad-json {shlex.quote(str(SAMPLED_MMAD_JSON))}',
    f'--output {shlex.quote(str(AD_V2_JSON))}',
    '--output-format report',
    '--device cpu',
    '--context-mode v2',
])
run_cmd(cmd_v2)

In [None]:
# no AD
cmd_no_ad = ' '.join([
    'python scripts/run_experiment.py',
    f'--llm {shlex.quote(LLM_MODEL)}',
    '--ad-model null',
    f'--data-root {shlex.quote(DATA_ROOT)}',
    f'--mmad-json {shlex.quote(str(SAMPLED_MMAD_JSON))}',
    f'--few-shot {FEW_SHOT}',
    f'--batch-mode {BATCH_MODE}',
    f'--output-dir {shlex.quote(str(OUT_NO_AD))}',
])
run_cmd(cmd_no_ad)

# legacy AD context
cmd_legacy_eval = ' '.join([
    'python scripts/run_experiment.py',
    f'--llm {shlex.quote(LLM_MODEL)}',
    '--ad-model patchcore',
    f'--ad-output {shlex.quote(str(AD_LEGACY_JSON))}',
    f'--data-root {shlex.quote(DATA_ROOT)}',
    f'--mmad-json {shlex.quote(str(SAMPLED_MMAD_JSON))}',
    f'--few-shot {FEW_SHOT}',
    f'--batch-mode {BATCH_MODE}',
    f'--output-dir {shlex.quote(str(OUT_LEGACY))}',
])
run_cmd(cmd_legacy_eval)

# v2 AD context
cmd_v2_eval = ' '.join([
    'python scripts/run_experiment.py',
    f'--llm {shlex.quote(LLM_MODEL)}',
    '--ad-model patchcore',
    f'--ad-output {shlex.quote(str(AD_V2_JSON))}',
    f'--data-root {shlex.quote(DATA_ROOT)}',
    f'--mmad-json {shlex.quote(str(SAMPLED_MMAD_JSON))}',
    f'--few-shot {FEW_SHOT}',
    f'--batch-mode {BATCH_MODE}',
    f'--output-dir {shlex.quote(str(OUT_V2))}',
])
run_cmd(cmd_v2_eval)

In [None]:
meta_no_ad = json.load(open(latest_meta(OUT_NO_AD), 'r', encoding='utf-8'))
meta_legacy = json.load(open(latest_meta(OUT_LEGACY), 'r', encoding='utf-8'))
meta_v2 = json.load(open(latest_meta(OUT_V2), 'r', encoding='utf-8'))

summary = pd.DataFrame([
    {'setting': 'no_ad', 'accuracy': meta_no_ad['accuracy'], 'processed': meta_no_ad['processed'], 'errors': meta_no_ad['errors'], 'answers_file': meta_no_ad['answers_file']},
    {'setting': 'legacy_ad_context', 'accuracy': meta_legacy['accuracy'], 'processed': meta_legacy['processed'], 'errors': meta_legacy['errors'], 'answers_file': meta_legacy['answers_file']},
    {'setting': 'v2_ad_context', 'accuracy': meta_v2['accuracy'], 'processed': meta_v2['processed'], 'errors': meta_v2['errors'], 'answers_file': meta_v2['answers_file']},
]).sort_values('accuracy', ascending=False).reset_index(drop=True)
display(summary)

# Per-question-type view (focus on anomaly/localization)
df_no_ad = question_type_accuracy(Path(meta_no_ad['answers_file']))
df_legacy = question_type_accuracy(Path(meta_legacy['answers_file']))
df_v2 = question_type_accuracy(Path(meta_v2['answers_file']))

pivot = (
    df_no_ad[['question_type', 'accuracy']].rename(columns={'accuracy': 'no_ad'})
    .merge(df_legacy[['question_type', 'accuracy']].rename(columns={'accuracy': 'legacy'}), on='question_type', how='outer')
    .merge(df_v2[['question_type', 'accuracy']].rename(columns={'accuracy': 'v2'}), on='question_type', how='outer')
    .fillna(0.0)
)
pivot['delta_v2_vs_legacy'] = (pivot['v2'] - pivot['legacy']).round(2)
pivot['delta_v2_vs_no_ad'] = (pivot['v2'] - pivot['no_ad']).round(2)
pivot = pivot.sort_values('delta_v2_vs_legacy', ascending=False).reset_index(drop=True)
display(pivot)

print('Work directory:', WORK_DIR)