# Stage 2 Sweep: Single Patch MIL Bags

혼합비 30%를 고정한 채 bag당 인스턴스 수를 변경하면서 Stage 2 데이터를 반복 생성합니다. 실험 구성을 파이썬 함수로 정리해 한 번의 실행으로 여러 세트를 만들 수 있도록 구성했습니다.

In [None]:
import os
import pickle
import random
from dataclasses import dataclass, replace
from typing import Any, Dict, List, Sequence, Tuple

import numpy as np
import pandas as pd

INSTANCES_PER_BAG_GRID = [10, 20, 30, 40, 50]
SEED_OFFSETS = {"train": 0, "val": 10, "test": 20}


@dataclass
class Stage2Config:
    margin_value: str = "0.4"
    embedding_dir: str = "/workspace/MIL/data/processed/embeddings"
    bags_dir: str = "/workspace/MIL/data/processed/bags"
    raw_meta_csv: str = "/workspace/MIL/data/raw/naver_ocr.csv"
    seed_base: int = 42
    win: int = 1
    stride: int = 1
    instances_per_bag: int = 30
    min_partner_instances: int = 1
    mix_instance_ratio: float = 0.30
    tokens_negative: int = 60
    tokens_positive: int = 60
    total_bags_per_writer: int = 20
    target_positive_ratio: float = 0.30
    positive_order: str = "shuffle"
    baseline_suffix: str = "baseline"
    compat_suffix: str = "random"

    def for_instances(self, instances: int) -> "Stage2Config":
        tokens_neg = max(self.tokens_negative, instances)
        tokens_pos = max(self.tokens_positive, instances)
        return replace(
            self,
            instances_per_bag=instances,
            tokens_negative=tokens_neg,
            tokens_positive=tokens_pos,
        )


config = Stage2Config()


def setup_environment(seed: int) -> None:
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    cuda_default = os.environ.get("CUDA_VISIBLE_DEVICES")
    os.environ["CUDA_VISIBLE_DEVICES"] = os.getenv("MIL_STAGE2_GPU", cuda_default or "4")
    np.random.seed(seed)
    random.seed(seed)


setup_environment(config.seed_base)
os.makedirs(config.bags_dir, exist_ok=True)


In [None]:
def pick_writer_col(df: pd.DataFrame) -> str:
    """작성자 식별 컬럼을 안전하게 선택"""
    candidates = ["author_id", "writer_id", "writer", "writerID", "author"]
    for c in candidates:
        if c in df.columns:
            # 최소 10명 이상의 고유 작성자가 있어야 정상
            if df[c].nunique() >= 10:
                return c
    # 최후의 수단으로 'label'을 쓰되, 경고/검증을 건다
    if "label" in df.columns:
        nunq = df["label"].nunique()
        if nunq <= 5:
            raise ValueError(
                f"'label'({nunq} uniques)로 작성자를 그룹핑하려고 합니다. "
                f"작성자 식별 컬럼(author_id/writer_id 등)을 CSV에 포함시키세요."
            )
        return "label"
    raise ValueError("작성자 식별 컬럼(author_id/writer_id 등)을 찾을 수 없습니다.")


def load_split_csv(config: Stage2Config, split: str) -> Tuple[pd.DataFrame, str, List[str]]:
    path = os.path.join(
        config.embedding_dir,
        f"mil_arcface_margin_{config.margin_value}_{split}_data.csv",
    )
    df = pd.read_csv(path)
    writer_col = pick_writer_col(df)
    emb_cols = [c for c in df.columns if c.startswith("embedding")]
    if not emb_cols:
        raise ValueError(f"No embedding_* columns found in {path}")
    return df, writer_col, emb_cols


def load_splits(config: Stage2Config):
    train_df, writer_col, emb_cols = load_split_csv(config, "train")
    val_df, _, _ = load_split_csv(config, "val")
    test_df, _, _ = load_split_csv(config, "test")
    return {"train": train_df, "val": val_df, "test": test_df}, writer_col, emb_cols


def build_writer_index(df: pd.DataFrame, writer_col: str, emb_cols: Sequence[str]) -> Dict[int, Dict[str, Any]]:
    store: Dict[int, Dict[str, Any]] = {}
    has_path = "path" in df.columns
    for wid, group in df.groupby(writer_col):
        key = int(wid)
        store[key] = {
            "emb": group[emb_cols].to_numpy(dtype=np.float32),
            "paths": group["path"].tolist() if has_path else [""] * len(group),
            "idx": group.index.to_list(),
        }
    return store


def build_writer_indices(frames, writer_col: str, emb_cols: Sequence[str]):
    return {
        split: build_writer_index(df, writer_col, emb_cols)
        for split, df in frames.items()
    }


frames, writer_col, emb_cols = load_splits(config)
embed_dim = len(emb_cols)
writer_indices = build_writer_indices(frames, writer_col, emb_cols)

print("Loaded ArcFace splits:")
for split_key, df in frames.items():
    print(f"  {split_key.capitalize():5s}: {len(df)} rows")
print(f"Embedding dimension: {embed_dim}, writer column: {writer_col}")

try:
    raw_count = len(pd.read_csv(config.raw_meta_csv))
    print(f"Raw metadata rows: {raw_count}")
except Exception:
    print("Raw metadata not available or failed to load.")

print("Writer counts:")
for split_key, store in writer_indices.items():
    print(f"  {split_key.capitalize():5s}: {len(store)} writers")


In [None]:
def sample_k(n: int, k: int, rng: np.random.Generator, replace_if_needed: bool = True) -> List[int]:
    if n >= k:
        return rng.choice(n, size=k, replace=False).tolist()
    if replace_if_needed:
        return rng.choice(n, size=k, replace=True).tolist()
    return rng.choice(n, size=n, replace=False).tolist()


def pack(word_indices: Sequence[int], wid: int, store: Dict[str, Any]) -> List[Tuple[np.ndarray, int, str, int]]:
    emb = store["emb"]
    paths = store["paths"]
    idxs = store["idx"]
    return [(emb[i], wid, paths[i], int(idxs[i])) for i in word_indices]


def sliding_windows(
    seq: Sequence[Tuple[np.ndarray, int, str, int]], win: int, stride: int
) -> Tuple[List[np.ndarray], List[Dict[str, Any]]]:
    if len(seq) < win:
        return [], []
    windows: List[np.ndarray] = []
    metas: List[Dict[str, Any]] = []
    for start in range(0, len(seq) - win + 1, stride):
        chunk = seq[start : start + win]
        windows.append(np.stack([item[0] for item in chunk], axis=0))
        metas.append({
            "window_idx": start,
            "word_indices": [item[3] for item in chunk],
            "word_paths": [item[2] for item in chunk],
            "writer_ids": [item[1] for item in chunk],
        })
    return windows, metas


def resolve_instance_counts(instances_per_bag: int, mix_ratio: float, min_partner: int) -> Tuple[int, int]:
    partner = max(min_partner, int(round(instances_per_bag * mix_ratio)))
    partner = min(partner, instances_per_bag - min_partner)
    anchor = instances_per_bag - partner
    return anchor, partner


def make_negative_bag(
    wid: int,
    store: Dict[str, Any],
    rng: np.random.Generator,
    cfg: Stage2Config,
) -> Tuple[np.ndarray, List[Dict[str, Any]], List[int]]:
    sel = sample_k(len(store["emb"]), cfg.tokens_negative, rng, replace_if_needed=True)
    seq = pack(sel, wid, store)
    wins, metas = sliding_windows(seq, cfg.win, cfg.stride)
    if not wins:
        raise ValueError(f"No windows available for writer {wid}")
    if len(wins) >= cfg.instances_per_bag:
        selected = rng.choice(len(wins), size=cfg.instances_per_bag, replace=False).tolist()
    else:
        selected = rng.choice(len(wins), size=cfg.instances_per_bag, replace=True).tolist()
    bag = np.stack([wins[i] for i in selected], axis=0)
    inst_meta: List[Dict[str, Any]] = []
    for i in selected:
        meta = dict(metas[i])
        meta["source_writer"] = int(wid)
        inst_meta.append(meta)
    return bag, inst_meta, [int(wid)]


def make_pure_windows_for_writer(
    wid: int,
    store: Dict[str, Any],
    rng: np.random.Generator,
    tokens_to_sample: int,
    cfg: Stage2Config,
) -> Tuple[List[np.ndarray], List[Dict[str, Any]]]:
    sel = sample_k(len(store["emb"]), tokens_to_sample, rng, replace_if_needed=True)
    seq = pack(sel, wid, store)
    return sliding_windows(seq, cfg.win, cfg.stride)


def make_positive_bag(
    wid_anchor: int,
    wid_partner: int,
    anchor_store: Dict[str, Any],
    partner_store: Dict[str, Any],
    rng: np.random.Generator,
    cfg: Stage2Config,
) -> Tuple[np.ndarray, List[Dict[str, Any]], List[int]]:
    anchor_wins, anchor_meta = make_pure_windows_for_writer(
        wid_anchor, anchor_store, rng, cfg.tokens_positive, cfg
    )
    partner_wins, partner_meta = make_pure_windows_for_writer(
        wid_partner, partner_store, rng, cfg.tokens_positive, cfg
    )

    anchor_count, partner_count = resolve_instance_counts(
        cfg.instances_per_bag, cfg.mix_instance_ratio, cfg.min_partner_instances
    )

    def pick_k(
        wins: Sequence[np.ndarray],
        metas: Sequence[Dict[str, Any]],
        k: int,
    ) -> Tuple[List[np.ndarray], List[Dict[str, Any]]]:
        if not wins:
            raise ValueError("Positive bag sampling failed: empty window list")
        if len(wins) >= k:
            idx = rng.choice(len(wins), size=k, replace=False).tolist()
        else:
            idx = rng.choice(len(wins), size=k, replace=True).tolist()
        return [wins[i] for i in idx], [metas[i] for i in idx]

    anchor_sel, anchor_meta_sel = pick_k(anchor_wins, anchor_meta, anchor_count)
    partner_sel, partner_meta_sel = pick_k(partner_wins, partner_meta, partner_count)

    def annotate(meta_list: Sequence[Dict[str, Any]], wid: int) -> List[Dict[str, Any]]:
        annotated: List[Dict[str, Any]] = []
        for meta in meta_list:
            new_meta = dict(meta)
            new_meta["source_writer"] = int(wid)
            annotated.append(new_meta)
        return annotated

    anchor_meta_sel = annotate(anchor_meta_sel, wid_anchor)
    partner_meta_sel = annotate(partner_meta_sel, wid_partner)

    order = cfg.positive_order.lower()
    if order == "a5b5":
        seq_w = anchor_sel + partner_sel
        seq_m = anchor_meta_sel + partner_meta_sel
    elif order == "abab":
        seq_w: List[np.ndarray] = []
        seq_m: List[Dict[str, Any]] = []
        pairable = min(len(anchor_sel), len(partner_sel))
        for i in range(pairable):
            seq_w.extend([anchor_sel[i], partner_sel[i]])
            seq_m.extend([anchor_meta_sel[i], partner_meta_sel[i]])
        if len(anchor_sel) > pairable:
            seq_w.extend(anchor_sel[pairable:])
            seq_m.extend(anchor_meta_sel[pairable:])
        if len(partner_sel) > pairable:
            seq_w.extend(partner_sel[pairable:])
            seq_m.extend(partner_meta_sel[pairable:])
    else:
        combined = list(zip(anchor_sel + partner_sel, anchor_meta_sel + partner_meta_sel))
        rng.shuffle(combined)
        seq_w = [item[0] for item in combined]
        seq_m = [item[1] for item in combined]

    if len(seq_w) < cfg.instances_per_bag:
        raise ValueError("Positive bag assembly produced insufficient instances.")
    seq_w = seq_w[: cfg.instances_per_bag]
    seq_m = seq_m[: cfg.instances_per_bag]

    bag_tensor = np.stack(seq_w, axis=0)
    return bag_tensor, seq_m, [int(wid_anchor), int(wid_partner)]


def generate_split(
    name: str,
    writer_index: Dict[int, Dict[str, Any]],
    cfg: Stage2Config,
    neg_per_writer: int,
    pos_per_writer: int,
    embed_dim: int,
    seed: int,
) -> Tuple[List[np.ndarray], List[int], List[Dict[str, Any]]]:
    rng = np.random.default_rng(seed)
    writer_ids = list(writer_index.keys())

    bags: List[np.ndarray] = []
    labels: List[int] = []
    metadata: List[Dict[str, Any]] = []

    for wid in writer_ids:
        for _ in range(neg_per_writer):
            bag, metas, authors = make_negative_bag(wid, writer_index[wid], rng, cfg)
            bags.append(bag)
            labels.append(0)
            metadata.append({
                "authors": authors,
                "bag_type": "negative",
                "instances": metas,
                "anchor_writer_id": int(wid),
                "partner_writer_id": None,
                "b_writer": None,
                "partner_instance_count": 0,
                "partner_instance_ratio": 0.0,
                "mix_ratio_target": 0.0,
            })

    for wid_anchor in writer_ids:
        for _ in range(pos_per_writer):
            partner_candidates = [w for w in writer_ids if w != wid_anchor]
            if not partner_candidates:
                raise ValueError("Positive bag requires at least two writers.")
            wid_partner = int(rng.choice(partner_candidates))
            bag, metas, authors = make_positive_bag(
                wid_anchor,
                wid_partner,
                writer_index[wid_anchor],
                writer_index[wid_partner],
                rng,
                cfg,
            )
            partner_writer = int(wid_partner)
            partner_instances = sum(1 for meta in metas if meta.get("source_writer") == partner_writer)
            partner_ratio = partner_instances / len(metas) if metas else 0.0
            bags.append(bag)
            labels.append(1)
            metadata.append({
                "authors": authors,
                "bag_type": "positive",
                "instances": metas,
                "anchor_writer_id": int(wid_anchor),
                "partner_writer_id": partner_writer,
                "b_writer": partner_writer,
                "partner_instance_count": partner_instances,
                "partner_instance_ratio": partner_ratio,
                "mix_ratio_target": cfg.mix_instance_ratio,
            })

    if not bags:
        raise ValueError(f"No bags generated for split {name}.")

    idx = rng.permutation(len(labels))
    bags = [bags[i] for i in idx]
    labels = [int(labels[i]) for i in idx]
    metadata = [metadata[i] for i in idx]

    expected_shape = (cfg.instances_per_bag, cfg.win, embed_dim)
    if bags[0].shape != expected_shape:
        raise ValueError(f"Unexpected bag shape {bags[0].shape}, expected {expected_shape}")

    return bags, labels, metadata


In [None]:
def format_ratio_tag(cfg: Stage2Config) -> str:
    return f"{int(round(cfg.mix_instance_ratio * 100)):02d}p_single{cfg.instances_per_bag:02d}"


def save_split(
    split_key: str,
    payload: Dict[str, Any],
    cfg: Stage2Config,
    ratio_tag: str,
    plain_tag: str,
) -> Tuple[str, str, str]:
    base_name = f"bags_arcface_margin_{cfg.margin_value}_{ratio_tag}_{cfg.baseline_suffix}_{split_key}.pkl"
    base_path = os.path.join(cfg.bags_dir, base_name)
    with open(base_path, "wb") as f:
        pickle.dump(payload, f)

    alias_name = f"bags_arcface_margin_{cfg.margin_value}_{ratio_tag}_{cfg.compat_suffix}_{split_key}.pkl"
    alias_path = os.path.join(cfg.bags_dir, alias_name)
    with open(alias_path, "wb") as f:
        pickle.dump(payload, f)

    plain_name = f"bags_arcface_margin_{cfg.margin_value}_{plain_tag}_{cfg.compat_suffix}_{split_key}.pkl"
    plain_path = os.path.join(cfg.bags_dir, plain_name)
    with open(plain_path, "wb") as f:
        pickle.dump(payload, f)

    print(f"    saved {base_name}")
    print(f"    alias {alias_name}")
    print(f"    alias {plain_name}")
    return base_path, alias_path, plain_path


def summarize_labels(split_key: str, labels: Sequence[int]) -> None:
    total = len(labels)
    pos = sum(labels)
    neg = total - pos
    rate = (pos / total * 100) if total else 0.0
    print(f"    {split_key.capitalize():5s}: N={total}, Pos={pos} ({rate:.1f}%), Neg={neg}")


def run_stage2_generation(cfg: Stage2Config, writer_indices, embed_dim: int) -> Dict[str, Tuple[str, str, str]]:
    pos_per_writer = int(round(cfg.total_bags_per_writer * cfg.target_positive_ratio))
    neg_per_writer = cfg.total_bags_per_writer - pos_per_writer
    if pos_per_writer <= 0 or neg_per_writer <= 0:
        raise ValueError("Adjust total_bags_per_writer/target_positive_ratio to include both bag types.")

    ratio_tag = format_ratio_tag(cfg)
    plain_tag = f"{int(round(cfg.mix_instance_ratio * 100)):02d}p_single{cfg.instances_per_bag:02d}"
    actual_ratio = pos_per_writer / (pos_per_writer + neg_per_writer)

    print(f"\n=== Stage 2 single-patch generation (instances={cfg.instances_per_bag}) ===")
    print(f"  mix_ratio={cfg.mix_instance_ratio:.2f} → ratio_tag={ratio_tag}")
    print(f"  bags per writer: neg={neg_per_writer}, pos={pos_per_writer} (share {actual_ratio * 100:.1f}%)")

    outputs: Dict[str, Dict[str, Any]] = {}
    for split_key in ("train", "val", "test"):
        split_writers = writer_indices[split_key]
        seed = cfg.seed_base + SEED_OFFSETS[split_key]
        bags, labels, metadata = generate_split(
            split_key.capitalize(),
            split_writers,
            cfg,
            neg_per_writer,
            pos_per_writer,
            embed_dim,
            seed,
        )
        outputs[split_key] = {"bags": bags, "labels": labels, "metadata": metadata}
        print(
            f"  {split_key.capitalize():5s} split → {len(labels)} bags, sample_shape={bags[0].shape}"
        )

    print("  saving splits...")
    saved_paths: Dict[str, Tuple[str, str, str]] = {}
    for split_key, payload in outputs.items():
        saved_paths[split_key] = save_split(split_key, payload, cfg, ratio_tag, plain_tag)

    print("  label summary:")
    for split_key, payload in outputs.items():
        summarize_labels(split_key, payload["labels"])

    return saved_paths


for instances in INSTANCES_PER_BAG_GRID:
    inst_cfg = config.for_instances(instances)
    run_stage2_generation(inst_cfg, writer_indices, embed_dim)
