# MBFC URL-masked TAG2RISK tag-field ablations (v6)

This notebook runs **tag-field ablations** for the TAG2RISK (tag-only) logistic regression model, using the **same URL-masked (domain-disjoint) evaluation protocol as `mbfc_url_masked_logreg_v6`**.

We keep everything fixed (splits, hyperparameter search, threshold tuning, metrics) and only change which **tag fields** are included in the multi-hot vector:

- `Theme` (`theme=...`)
- `Claim/Framing` (`claim=...`)
- `CTA` (`cta=...`)
- `Evidence` (`evid=...`)

Note: this matches **v6** behavior (drops `Link/URL` evidence tags from the style vector to avoid URL leakage).


In [None]:
from __future__ import annotations

from pathlib import Path

import numpy as np
import pandas as pd
from scipy.special import expit
from scipy import sparse
from sklearn.metrics import (
    accuracy_score,
    brier_score_loss,
    f1_score,
    recall_score,
    roc_auc_score,
)
from sklearn.model_selection import GroupShuffleSplit, train_test_split
from sklearn.preprocessing import MultiLabelBinarizer


# Resolve dataset path by searching upward from CWD.
def _resolve_data_path(start: Path) -> Path:
    for p in [start] + list(start.parents):
        candidate = (
            p
            / "mbfc_channel_masked_logreg_fullpackage_v2_MBFC_C"
            / "MBFC "
            / "mega_samples_dedup_qwen_mbfc.csv"
        )
        if candidate.exists():
            return candidate
    raise FileNotFoundError(
        "Could not locate mega_samples_dedup_qwen_mbfc.csv from current working directory"
    )

DATA_PATH = _resolve_data_path(Path.cwd().resolve())
print({"data_path": str(DATA_PATH)})

df = pd.read_csv(DATA_PATH, low_memory=False)

# Fix legacy header quirk in the v2 MBFC CSV where the first column name was mangled.
if "source" not in df.columns:
    first_col = df.columns[0]
    df = df.rename(columns={first_col: "source"})

# Strip URLs from message text at load time so the pipeline never sees raw URLs.
df["message"] = (
    df["message"]
    .astype(str)
    .str.replace(r"(https?://|http://|www\.[^\s]*|t\.me/[^\s]*)", " ", regex=True)
    .str.strip()
)
# Drop rows whose message becomes empty after URL stripping.
df = df[df["message"] != ""].copy()

# Use MBFC-derived binary risk label as y (1 = higher-risk / lower-credibility)
df = df.dropna(subset=["risk_label"]).copy()
df["y"] = df["risk_label"].astype(int)

# For URL-masked evaluation we only keep rows that have a resolved normalized_domain.
df_eval = df.dropna(subset=["normalized_domain"]).copy()
print(
    {
        "rows": int(len(df)),
        "eval_rows": int(len(df_eval)),
        "unique_domains": int(df_eval["normalized_domain"].nunique()),
        "pos_rate": float(df_eval["y"].mean()),
    }
)


In [None]:
from typing import List, Optional


# v6: remove Qwen 'Link/URL' labels from style features
DROP_LINK_URL_LABEL = True
_LINK_URL_LABEL_NORM = "link/url"


def tokenize_multi(value: object) -> List[str]:
    """Split a comma- or plus-separated label string into atomic pieces."""
    if not isinstance(value, str):
        return []
    value = value.replace("+", ",")
    parts = [part.strip() for part in value.split(",") if part.strip()]
    if not DROP_LINK_URL_LABEL:
        return parts
    # Normalize by lowercasing and removing whitespace so 'Link / URL' matches too.
    return [p for p in parts if "".join(p.lower().split()) != _LINK_URL_LABEL_NORM]


THEME_BUCKETS = [
    "Finance/Crypto",
    "Public health & medicine",
    "Politics",
    "Lifestyle & well-being",
    "Crime & public safety",
    "Gaming/Gambling",
    "News/Information",
    "Sports",
    "Technology",
    "Conversation/Chat/Other",
    "Other theme",
]

CLAIM_BUCKETS = [
    "Verifiable factual statement",
    "Rumour / unverified report",
    "Announcement",
    "Opinion / subjective statement",
    "Misleading context / cherry-picking",
    "Promotional hype / exaggerated profit guarantee",
    "Emotional appeal / fear-mongering",
    "Scarcity/FOMO tactic",
    "Statistics",
    "Other claim type",
    "No substantive claim",
    "Fake content",
    "Speculative forecast / prediction",
    "None / assertion only",
]

CTA_BUCKETS = [
    "Visit external link / watch video",
    "Engage/Ask questions",
    "Join/Subscribe",
    "Buy / invest / donate",
    "Attend event / livestream",
    "Share / repost / like",
    "No CTA",
    "Other CTA",
]

EVID_BUCKETS = [
    "Link/URL",
    "Statistics",
    "Quotes/Testimony",
    "Chart / price graph / TA diagram",
    "Other (Evidence)",
    "None / assertion only",
]


def _norm_theme(raw: object) -> Optional[str]:
    if not isinstance(raw, str):
        return None
    t = raw.strip()
    if not t:
        return None
    tl = t.lower()

    if t in THEME_BUCKETS:
        return t

    if any(
        k in tl
        for k in [
            "crypto",
            "token",
            "coin",
            "airdrop",
            "ido",
            "staking",
            "defi",
            "exchange",
            "market",
            "finance",
            "econom",
        ]
    ):
        return "Finance/Crypto"

    if any(
        k in tl
        for k in [
            "health",
            "covid",
            "vaccine",
            "vaccination",
            "medicine",
            "medical",
            "clinical",
            "disease",
            "pandemic",
            "public health",
            "hospital",
        ]
    ):
        return "Public health & medicine"

    if any(
        k in tl
        for k in [
            "politic",
            "election",
            "parliament",
            "congress",
            "senate",
            "government",
            "president",
            "minister",
            "policy",
            "war",
            "conflict",
            "ukraine",
            "russia",
        ]
    ):
        return "Politics"

    if any(
        k in tl
        for k in [
            "crime",
            "criminal",
            "terror",
            "shooting",
            "police",
            "public safety",
            "fraud",
            "scam",
        ]
    ):
        return "Crime & public safety"

    if any(k in tl for k in ["gaming", "gambling", "casino", "betting", "lottery", "poker"]):
        return "Gaming/Gambling"

    if any(k in tl for k in ["sport", "football", "soccer", "basketball", "tennis", "nba", "nfl"]):
        return "Sports"

    if any(
        k in tl
        for k in [
            "technology",
            "tech",
            "software",
            "app ",
            "platform",
            "ai ",
            " a.i.",
            "machine learning",
            "blockchain",
            "internet",
            "social media",
            "algorithm",
            "science",
            "research",
            "study",
        ]
    ):
        return "Technology"

    if any(
        k in tl
        for k in [
            "lifestyle",
            "well-being",
            "wellbeing",
            "culture",
            "entertainment",
            "media",
            "celebrity",
            "social issues",
            "society",
            "family",
            "community",
        ]
    ):
        return "Lifestyle & well-being"

    if any(k in tl for k in ["news", "headline", "breaking", "coverage", "roundup", "update"]):
        return "News/Information"

    if any(k in tl for k in ["comment", "conversation", "chat", "q&a", "ama", "ask me anything"]):
        return "Conversation/Chat/Other"

    return "Other theme"


def _norm_claim_labels(raw: object) -> List[str]:
    labels = tokenize_multi(raw)
    out: List[str] = []
    for lbl in labels:
        base = lbl.strip()
        if not base:
            continue
        low = base.lower()

        if base in CLAIM_BUCKETS:
            out.append(base)
            continue

        if "verifiable" in low or "factual" in low:
            out.append("Verifiable factual statement")
        elif "rumour" in low or "unverified" in low:
            out.append("Rumour / unverified report")
        elif "misleading context" in low or "cherry" in low:
            out.append("Misleading context / cherry-picking")
        elif "promotional hype" in low or "exaggerated profit" in low:
            out.append("Promotional hype / exaggerated profit guarantee")
        elif "emotional appeal" in low or "fear-mongering" in low or "fear mongering" in low:
            out.append("Emotional appeal / fear-mongering")
        elif "scarcity" in low or "fomo" in low:
            out.append("Scarcity/FOMO tactic")
        elif "statistic" in low:
            out.append("Statistics")
        elif "fake content" in low or "fabricated" in low:
            out.append("Fake content")
        elif "predict" in low or "forecast" in low:
            out.append("Speculative forecast / prediction")
        elif "announcement" in low:
            out.append("Announcement")
        elif "opinion" in low or "interpretive" in low or "analysis" in low or "review" in low:
            out.append("Opinion / subjective statement")
        elif "none / assertion only" in low or "assertion only" in low:
            out.append("None / assertion only")
        else:
            out.append("Other claim type")

    seen = set()
    result: List[str] = []
    for v in out:
        if v not in seen:
            seen.add(v)
            result.append(v)
    return result


def _norm_cta_labels(raw: object) -> List[str]:
    labels = tokenize_multi(raw)
    out: List[str] = []
    for lbl in labels:
        base = lbl.strip()
        if not base:
            continue
        low = base.lower()

        if base in CTA_BUCKETS:
            out.append(base)
            continue

        if base in {"None", "No CTA"} or "no cta" in low:
            out.append("No CTA")
        elif "engage" in low or "ask" in low or "anything" in low:
            out.append("Engage/Ask questions")
        elif "attend" in low or "event" in low or "livestream" in low or "live stream" in low:
            out.append("Attend event / livestream")
        elif "join" in low or "subscribe" in low or "follow" in low or "whitelist" in low:
            out.append("Join/Subscribe")
        elif "buy" in low or "invest" in low or "donate" in low or "stake" in low or "swap" in low:
            out.append("Buy / invest / donate")
        elif "share" in low or "repost" in low or "like" in low:
            out.append("Share / repost / like")
        elif (
            "visit" in low
            or "read" in low
            or "watch" in low
            or "link" in low
            or "website" in low
            or "check" in low
            or "view charts" in low
        ):
            out.append("Visit external link / watch video")
        else:
            out.append("Other CTA")

    seen = set()
    result: List[str] = []
    for v in out:
        if v not in seen:
            seen.add(v)
            result.append(v)
    return result


def _norm_evidence_labels(raw: object) -> List[str]:
    labels = tokenize_multi(raw)
    out: List[str] = []
    for lbl in labels:
        base = lbl.strip()
        if not base:
            continue
        low = base.lower()

        if base in EVID_BUCKETS:
            if base != "Link/URL":
                out.append(base)
            continue

        if "link/url" in low or "link" in low or "url" in low:
            # v6: drop Link/URL evidence tag entirely
            continue
        elif "statistic" in low:
            out.append("Statistics")
        elif "quote" in low or "testimony" in low:
            out.append("Quotes/Testimony")
        elif "chart" in low or "graph" in low or "diagram" in low:
            out.append("Chart / price graph / TA diagram")
        elif "none / assertion only" in low or "assertion only" in low:
            out.append("None / assertion only")
        else:
            out.append("Other (Evidence)")

    seen = set()
    result: List[str] = []
    for v in out:
        if v not in seen:
            seen.add(v)
            result.append(v)
    return result


def build_style_tokens(row: pd.Series, include_fields: set[str]) -> List[str]:
    tokens: List[str] = []

    if "theme" in include_fields:
        theme = _norm_theme(row.get("theme"))
        if theme is not None:
            tokens.append(f"theme={theme}")

    if "claim" in include_fields:
        for label in _norm_claim_labels(row.get("claim_types")):
            tokens.append(f"claim={label}")

    if "cta" in include_fields:
        for label in _norm_cta_labels(row.get("ctas")):
            tokens.append(f"cta={label}")

    if "evidence" in include_fields:
        for label in _norm_evidence_labels(row.get("evidence")):
            tokens.append(f"evid={label}")

    return tokens


def build_style_features_for_fields(train_df, val_df, test_df, include_fields: set[str]):
    train_tokens = train_df.apply(lambda r: build_style_tokens(r, include_fields), axis=1)
    val_tokens = val_df.apply(lambda r: build_style_tokens(r, include_fields), axis=1)
    test_tokens = test_df.apply(lambda r: build_style_tokens(r, include_fields), axis=1)

    mlb = MultiLabelBinarizer()
    X_train = mlb.fit_transform(train_tokens)
    X_val = mlb.transform(val_tokens)
    X_test = mlb.transform(test_tokens)
    return X_train, X_val, X_test, mlb


In [None]:
# Training/eval utilities (copied from v6)

RANDOM_STATES = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
LR_GRID = [0.001, 0.003, 0.01]
THRESH_GRID = [round(t, 2) for t in np.linspace(0.05, 0.95, 19)]
PRIMARY_METRIC = "macro_f1"


class ManualLogisticRegression:
    """Simple binary logistic regression implemented with gradient descent."""

    def __init__(
        self,
        lr: float = 0.1,
        max_iter: int = 200,
        C: float = 1.0,
        class_weight=None,
        tol: float | None = 1e-4,
        verbose: bool = False,
        n_jobs=None,  # kept for API compatibility; not used
    ):
        self.lr = lr
        self.max_iter = max_iter
        self.C = C
        self.class_weight = class_weight
        self.tol = tol
        self.verbose = verbose
        self.n_jobs = n_jobs

    def _prepare_X(self, X):
        if sparse.issparse(X):
            return X.tocsr()
        return np.asarray(X, dtype=float)

    def fit(self, X, y):
        X = self._prepare_X(X)
        y = np.asarray(y, dtype=float)
        n_samples, n_features = X.shape
        self.coef_ = np.zeros(n_features, dtype=float)
        self.intercept_ = 0.0
        self.loss_history_ = []

        if self.class_weight is None:
            sample_weights = np.ones_like(y)
        elif self.class_weight == "balanced":
            classes, counts = np.unique(y, return_counts=True)
            n_classes = len(classes)
            class_weight_values = {
                cls: n_samples / (n_classes * count) for cls, count in zip(classes, counts)
            }
            sample_weights = np.array([class_weight_values[yi] for yi in y], dtype=float)
        elif isinstance(self.class_weight, dict):
            sample_weights = np.array([self.class_weight.get(yi, 1.0) for yi in y], dtype=float)
        else:
            raise ValueError("Unsupported class_weight specification")

        prev_loss = None
        for i in range(self.max_iter):
            z = X.dot(self.coef_) + self.intercept_
            p = expit(z)
            residual = (p - y) * sample_weights

            if sparse.issparse(X):
                grad_w = X.T.dot(residual) / n_samples
            else:
                grad_w = (X.T @ residual) / n_samples
            grad_w += self.coef_ / (self.C * n_samples)
            grad_b = residual.mean()

            self.coef_ -= self.lr * grad_w
            self.intercept_ -= self.lr * grad_b

            if self.tol is not None and (i % 10 == 0 or i == self.max_iter - 1):
                z = X.dot(self.coef_) + self.intercept_
                p = expit(z)
                eps = 1e-15
                loss_vec = (-(y * np.log(p + eps) + (1 - y) * np.log(1 - p + eps))) * sample_weights
                loss = loss_vec.mean() + 0.5 * np.sum(self.coef_**2) / (self.C * n_samples)
                self.loss_history_.append(float(loss))
                if self.verbose:
                    print(f"Iter {i}: loss={loss:.6f}")
                if prev_loss is not None and abs(prev_loss - loss) < self.tol:
                    break
                prev_loss = loss

        self.classes_ = np.array([0.0, 1.0])
        return self

    def _decision_function(self, X):
        X = self._prepare_X(X)
        return X.dot(self.coef_) + self.intercept_

    def predict_proba(self, X):
        z = self._decision_function(X)
        p_pos = expit(z)
        return np.vstack([1 - p_pos, p_pos]).T

    def predict(self, X):
        proba = self.predict_proba(X)[:, 1]
        return (proba >= 0.5).astype(int)


def _stack_features(a, b):
    if sparse.issparse(a):
        return sparse.vstack([a, b])
    return np.vstack([a, b])


def sweep_thresholds(y_true, proba):
    rows = []
    for t in THRESH_GRID:
        pred = (proba >= t).astype(int)
        rows.append(
            {
                "threshold": float(t),
                "macro_f1": f1_score(y_true, pred, average="macro"),
                "macro_recall": recall_score(y_true, pred, average="macro"),
                "recall_pos": recall_score(y_true, pred, pos_label=1),
            }
        )
    best = max(rows, key=lambda r: r[PRIMARY_METRIC])
    return best


def expected_calibration_error(y_true, y_proba, n_bins=10):
    y_true = np.asarray(y_true)
    y_proba = np.asarray(y_proba)
    bins = np.linspace(0.0, 1.0, n_bins + 1)
    indices = np.digitize(y_proba, bins) - 1
    ece = 0.0
    n = len(y_true)
    for b in range(n_bins):
        mask = indices == b
        if not np.any(mask):
            continue
        p_bin = y_proba[mask].mean()
        y_bin = y_true[mask].mean()
        weight = mask.sum() / n
        ece += weight * abs(p_bin - y_bin)
    return float(ece)


def fit_with_val_search(X_train, y_train, X_val, y_val, X_test, y_test, lr_grid=None):
    lr_grid = lr_grid or LR_GRID
    best = None
    for lr in lr_grid:
        clf = ManualLogisticRegression(
            max_iter=1000,
            lr=lr,
            C=1.0,
            class_weight="balanced",
            tol=None,
            n_jobs=-1,
            verbose=False,
        )
        clf.fit(X_train, y_train)
        val_proba = clf.predict_proba(X_val)[:, 1]
        best_thr = sweep_thresholds(y_val, val_proba)
        candidate = {
            "lr": lr,
            "val_threshold": best_thr["threshold"],
            "val_macro_f1": best_thr["macro_f1"],
            "val_macro_recall": best_thr["macro_recall"],
            "val_recall_pos": best_thr["recall_pos"],
            "primary_score": best_thr[PRIMARY_METRIC],
        }
        if best is None or candidate["primary_score"] > best["primary_score"]:
            best = candidate

    X_trainval = _stack_features(X_train, X_val)
    y_trainval = np.concatenate([y_train, y_val])
    final_clf = ManualLogisticRegression(
        max_iter=1000,
        lr=best["lr"],
        C=1.0,
        class_weight="balanced",
        tol=None,
        n_jobs=-1,
        verbose=False,
    )
    final_clf.fit(X_trainval, y_trainval)
    test_proba = final_clf.predict_proba(X_test)[:, 1]
    test_pred = (test_proba >= best["val_threshold"]).astype(int)

    return {
        "best_lr": float(best["lr"]),
        "threshold": float(best["val_threshold"]),
        "test_macro_f1": float(f1_score(y_test, test_pred, average="macro")),
        "test_macro_recall": float(recall_score(y_test, test_pred, average="macro")),
        "test_recall_pos": float(recall_score(y_test, test_pred, pos_label=1)),
        "test_roc_auc": float(roc_auc_score(y_test, test_proba)),
        "test_accuracy": float(accuracy_score(y_test, test_pred)),
        "test_brier": float(brier_score_loss(y_test, test_proba)),
        "test_ece": float(expected_calibration_error(y_test, test_proba, n_bins=10)),
    }


In [None]:
# Run tag-field ablations (tag-only model)

ABLATIONS: dict[str, set[str]] = {
    "tags_full": {"theme", "claim", "cta", "evidence"},
    "tags_theme_only": {"theme"},
    "tags_claim_only": {"claim"},
    "tags_cta_only": {"cta"},
    "tags_evidence_only": {"evidence"},
    "tags_style_only_no_theme": {"claim", "cta", "evidence"},
    "tags_drop_claim": {"theme", "cta", "evidence"},
    "tags_drop_cta": {"theme", "claim", "evidence"},
    "tags_drop_evidence": {"theme", "claim", "cta"},
}

results_dir = Path("mbfc_url_masked_logreg_tag_field_ablations_results_v2")
results_dir.mkdir(exist_ok=True)

rows = []
for seed in RANDOM_STATES:
    gss = GroupShuffleSplit(n_splits=1, test_size=0.2, random_state=seed)
    trainval_idx, test_idx = next(gss.split(df_eval, df_eval["y"], df_eval["normalized_domain"]))
    df_trainval = df_eval.iloc[trainval_idx].copy()
    df_test_split = df_eval.iloc[test_idx].copy()

    df_train_split, df_val_split = train_test_split(
        df_trainval, test_size=0.125, random_state=100 + seed, stratify=df_trainval["y"]
    )
    y_train = df_train_split["y"].values
    y_val = df_val_split["y"].values
    y_test = df_test_split["y"].values

    for ablation_name, include_fields in ABLATIONS.items():
        X_train, X_val, X_test, mlb = build_style_features_for_fields(
            df_train_split, df_val_split, df_test_split, include_fields
        )
        metrics = fit_with_val_search(X_train, y_train, X_val, y_val, X_test, y_test)
        metrics.update(
            {
                "model": ablation_name,
                "split_seed": int(seed),
                "n_features": int(len(mlb.classes_)),
                "fields": "+".join(sorted(include_fields)),
            }
        )
        rows.append(metrics)
        print({"seed": int(seed), "model": ablation_name, "roc_auc": round(metrics["test_roc_auc"], 4)})

ablation_df = pd.DataFrame(rows)
ablation_csv = results_dir / "tag_field_ablation_metrics.csv"
ablation_df.to_csv(ablation_csv, index=False)
print({"ablation_metrics_csv": str(ablation_csv)})

summary = (
    ablation_df.groupby(["model", "fields", "n_features"])[
        [
            "test_accuracy",
            "test_roc_auc",
            "test_macro_f1",
            "test_macro_recall",
            "test_recall_pos",
            "test_brier",
            "test_ece",
        ]
    ]
    .agg(["mean", "std"])
    .round(4)
)
summary_csv = results_dir / "tag_field_ablation_summary.csv"
summary.to_csv(summary_csv)
print({"ablation_summary_csv": str(summary_csv)})
print("\nTag-field ablation summary (meanÂ±std over seeds):")
display(summary)

# Also save a flat (tidy) summary table.
metrics_for_tables = [
    "test_accuracy",
    "test_roc_auc",
    "test_macro_f1",
    "test_macro_recall",
    "test_recall_pos",
    "test_brier",
    "test_ece",
]
agg_spec = {}
for col in metrics_for_tables:
    agg_spec[f"{col}_mean"] = (col, "mean")
    agg_spec[f"{col}_std"] = (col, "std")

summary_flat = (
    ablation_df.groupby(["model", "fields", "n_features"]).agg(**agg_spec).reset_index().round(4)
)
summary_flat_csv = results_dir / "tag_field_ablation_summary_flat.csv"
summary_flat.to_csv(summary_flat_csv, index=False)
print({"ablation_summary_flat_csv": str(summary_flat_csv)})


# Load v6 TF-IDF and Combined baselines for side-by-side comparison.
def _resolve_v6_metrics_path(start: Path) -> Path:
    for p in [start] + list(start.parents):
        candidate = p / "mbfc_url_masked_logreg_results_v6" / "url_masked_val_tuned_metrics.csv"
        if candidate.exists():
            return candidate
    raise FileNotFoundError(
        "Could not locate mbfc_url_masked_logreg_results_v6/url_masked_val_tuned_metrics.csv"
    )

v6_metrics_path = _resolve_v6_metrics_path(Path.cwd().resolve())
v6_df = pd.read_csv(v6_metrics_path)
baseline_df = v6_df[v6_df["model"].isin(["tfidf", "combined"])].copy()
baseline_df["fields"] = ""
baseline_df["n_features"] = np.nan

all_df = pd.concat([baseline_df, ablation_df], ignore_index=True)
all_metrics = [
    "test_accuracy",
    "test_roc_auc",
    "test_macro_f1",
    "test_macro_recall",
    "test_brier",
    "test_ece",
]

all_agg_spec = {}
for col in all_metrics:
    all_agg_spec[f"{col}_mean"] = (col, "mean")
    all_agg_spec[f"{col}_std"] = (col, "std")

all_summary_flat = all_df.groupby(["model"]).agg(**all_agg_spec).reset_index().round(4)
all_summary_flat_csv = results_dir / "all_models_summary_flat.csv"
all_summary_flat.to_csv(all_summary_flat_csv, index=False)
print({"all_models_summary_flat_csv": str(all_summary_flat_csv)})


# Marker-style figures (like TF-IDF vs Style vs TF-IDF+Style example)
import matplotlib.pyplot as plt

plot_metrics = [
    "test_accuracy",
    "test_roc_auc",
    "test_macro_f1",
    "test_macro_recall",
]
means = all_df.groupby("model")[plot_metrics].mean()
stds = all_df.groupby("model")[plot_metrics].std()

METRIC_SPECS = [
    ("Accuracy", "test_accuracy", "o"),
    ("ROC-AUC", "test_roc_auc", "s"),
    ("F1 (macro)", "test_macro_f1", "^"),
    ("Recall (macro)", "test_macro_recall", "D"),
]


def plot_marker_figure(
    model_order: list[str],
    model_labels: list[str],
    out_png: Path,
    out_pdf: Path | None = None,
    title: str | None = None,
    figsize=(6.0, 3.0),
    xtick_rotation: float = 0,
):
    x = np.arange(len(model_order), dtype=float)
    fig, ax = plt.subplots(figsize=figsize, dpi=300)

    y_all = []
    for _, col, _ in METRIC_SPECS:
        y_all.extend([float(means.loc[m, col]) for m in model_order])
    y_min = max(0.0, min(y_all) - 0.06)
    y_max = min(1.0, max(y_all) + 0.06)

    for label, col, marker in METRIC_SPECS:
        y = [float(means.loc[m, col]) for m in model_order]
        ax.scatter(
            x,
            y,
            marker=marker,
            s=70,
            facecolors="none",
            edgecolors="black",
            linewidths=1.2,
            label=label,
            zorder=3,
        )

    ax.set_xticks(x)
    ha = "right" if xtick_rotation else "center"
    ax.set_xticklabels(model_labels, rotation=xtick_rotation, ha=ha)
    ax.set_ylim(y_min, y_max)
    ax.set_ylabel("Score")
    if title:
        ax.set_title(title)
    ax.grid(axis="y", alpha=0.25, zorder=0)
    ax.legend(loc="upper left", frameon=False)
    fig.tight_layout()
    fig.savefig(out_png, dpi=300)
    if out_pdf is not None:
        fig.savefig(out_pdf)
    plt.close(fig)
    print({"figure_png": str(out_png), "figure_pdf": str(out_pdf) if out_pdf else None})


# 1) Baseline plot: TF-IDF vs Style vs TF-IDF+Style
baseline_order = ["tfidf", "tags_full", "combined"]
baseline_order = [m for m in baseline_order if m in means.index]
baseline_label_map = {"tfidf": "TF-IDF", "tags_full": "Style", "combined": "TF-IDF\n+Style"}
baseline_labels = [baseline_label_map[m] for m in baseline_order]
plot_marker_figure(
    baseline_order,
    baseline_labels,
    out_png=results_dir / "marker_metrics_tfidf_style_combined.png",
    out_pdf=results_dir / "marker_metrics_tfidf_style_combined.pdf",
    figsize=(4.6, 2.8),
)


# 2) Full plot: single-field + drop-one-field ablations (plus TF-IDF and Combined)
ablation_order = [
    "tfidf",
    "tags_theme_only",
    "tags_claim_only",
    "tags_cta_only",
    "tags_evidence_only",
    "tags_style_only_no_theme",
    "tags_drop_claim",
    "tags_drop_cta",
    "tags_drop_evidence",
    "tags_full",
    "combined",
]
ablation_order = [m for m in ablation_order if m in means.index]
label_map = {
    "tfidf": "TF-IDF",
    "tags_theme_only": "Theme only",
    "tags_claim_only": "Claim only",
    "tags_cta_only": "CTA only",
    "tags_evidence_only": "Evidence only",
    "tags_style_only_no_theme": "Style only\n(no Theme)",
    "tags_drop_claim": "Full - Claim",
    "tags_drop_cta": "Full - CTA",
    "tags_drop_evidence": "Full - Evidence",
    "tags_full": "Full tags",
    "combined": "TF-IDF\n+Style",
}
ablation_labels = [label_map.get(m, m) for m in ablation_order]
plot_marker_figure(
    ablation_order,
    ablation_labels,
    out_png=results_dir / "marker_metrics_all_ablations.png",
    out_pdf=results_dir / "marker_metrics_all_ablations.pdf",
    figsize=(9.6, 3.0),
    xtick_rotation=20,
)
