# MBFC URL-masked: slices + robustness (v6)

This notebook adds reviewer-facing analyses on top of the **v6 URL-masked** experiment outputs:

1. **Head/torso/tail domain analysis** (macro-F1 + calibration by domain-frequency bin).
2. **Slice calibration** (reliability diagrams + ECE by theme + message-length bins).
3. **Tagger-noise robustness** for **style-only (no Theme)** tags.

It reuses the v6 **main URL-masked split** (the one that produced `test_predictions_all_models.csv`) so we can compute slices without re-running the full expensive training.


In [None]:
from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import (
    accuracy_score,
    brier_score_loss,
    f1_score,
    recall_score,
    roc_auc_score,
)
from sklearn.model_selection import GroupShuffleSplit, train_test_split
from sklearn.preprocessing import MultiLabelBinarizer


def _resolve_upward(start: Path, rel: Path) -> Path:
    for p in [start] + list(start.parents):
        cand = p / rel
        if cand.exists():
            return cand
    raise FileNotFoundError(f"Could not resolve path upward: {rel}")


HERE = Path.cwd().resolve()
DATA_PATH = _resolve_upward(
    HERE,
    Path("mbfc_channel_masked_logreg_fullpackage_v2_MBFC_C") / "MBFC " / "mega_samples_dedup_qwen_mbfc.csv",
)
V6_PREDS_PATH = _resolve_upward(
    HERE,
    Path("mbfc_channel_masked_logreg_fullpackage")
    / "Pipelines"
    / "mbfc_url_masked_logreg_results_v6"
    / "test_predictions_all_models.csv",
)

print({"data_path": str(DATA_PATH)})
print({"v6_test_predictions": str(V6_PREDS_PATH)})

df = pd.read_csv(DATA_PATH, low_memory=False)
if "source" not in df.columns:
    df = df.rename(columns={df.columns[0]: "source"})

# Match v6 preprocessing: strip URLs from text.
df["message"] = (
    df["message"]
    .astype(str)
    .str.replace(r"(https?://|http://|www\.[^\s]*|t\.me/[^\s]*)", " ", regex=True)
    .str.strip()
)
df = df[df["message"] != ""].copy()

# Use MBFC-derived binary risk label as y.
df = df.dropna(subset=["risk_label"]).copy()
df["y"] = df["risk_label"].astype(int)

print(
    {
        "rows": int(len(df)),
        "unique_domains": int(df["normalized_domain"].nunique()),
        "pos_rate": float(df["y"].mean()),
    }
)


# Reconstruct the exact v6 main URL-masked split used for test_predictions_all_models.csv.
groups = df["normalized_domain"].astype(str).values
y = df["y"].values
p_global = float(y.mean())

gss_outer = GroupShuffleSplit(n_splits=50, test_size=0.2, random_state=42)
best_score_outer = None
best_outer = None
for split_id, (trainval_idx, test_idx) in enumerate(gss_outer.split(df, y, groups)):
    y_trainval = y[trainval_idx]
    y_test = y[test_idx]
    p_trainval = y_trainval.mean()
    p_test = y_test.mean()
    score = max(abs(p_trainval - p_global), abs(p_test - p_global))
    if best_score_outer is None or score < best_score_outer:
        best_score_outer = score
        best_outer = (trainval_idx, test_idx, split_id)

trainval_idx, test_idx, outer_id = best_outer
df_trainval = df.iloc[trainval_idx].copy()
df_test = df.iloc[test_idx].copy().reset_index(drop=True)

df_train, df_val = train_test_split(
    df_trainval,
    test_size=0.125,
    random_state=43,
    stratify=df_trainval["y"],
)

print(
    {
        "outer_split_id": int(outer_id),
        "outer_balance_score": float(best_score_outer),
        "train_rows": int(len(df_train)),
        "val_rows": int(len(df_val)),
        "test_rows": int(len(df_test)),
        "test_pos_rate": float(df_test["y"].mean()),
    }
)


# Load v6 predictions and align row-by-row (v6 wrote them in df_test order).
preds = pd.read_csv(V6_PREDS_PATH).reset_index(drop=True)
assert len(preds) == len(df_test), (len(preds), len(df_test))
for col in ["source", "msg_id", "channel"]:
    assert (df_test[col].astype(str) == preds[col].astype(str)).all(), f"mismatch in {col}"
assert (df_test["y"].astype(int) == preds["y_true"].astype(int)).all(), "y mismatch"

df_test_pred = pd.concat(
    [
        df_test,
        preds[[
            "tfidf_pred",
            "tfidf_proba",
            "style_pred",
            "style_proba",
            "combined_pred",
            "combined_proba",
        ]],
    ],
    axis=1,
)
print({"df_test_pred_rows": int(len(df_test_pred))})


RESULTS_DIR = Path("mbfc_url_masked_logreg_generalization_analyses_results_v1")
RESULTS_DIR.mkdir(exist_ok=True)
print({"results_dir": str(RESULTS_DIR.resolve())})


In [None]:
def expected_calibration_error(y_true, y_proba, n_bins=10) -> float:
    y_true = np.asarray(y_true)
    y_proba = np.asarray(y_proba)
    bins = np.linspace(0.0, 1.0, n_bins + 1)
    idx = np.digitize(y_proba, bins) - 1
    ece = 0.0
    n = len(y_true)
    for b in range(n_bins):
        mask = idx == b
        if not np.any(mask):
            continue
        p_bin = float(y_proba[mask].mean())
        y_bin = float(y_true[mask].mean())
        weight = float(mask.sum() / n)
        ece += weight * abs(p_bin - y_bin)
    return float(ece)


def calibration_bins(y_true, y_proba, n_bins=10):
    y_true = np.asarray(y_true)
    y_proba = np.asarray(y_proba)
    bins = np.linspace(0.0, 1.0, n_bins + 1)
    idx = np.digitize(y_proba, bins) - 1
    rows = []
    for b in range(n_bins):
        mask = idx == b
        if not np.any(mask):
            continue
        rows.append(
            {
                "bin": int(b),
                "count": int(mask.sum()),
                "p_mean": float(y_proba[mask].mean()),
                "y_mean": float(y_true[mask].mean()),
            }
        )
    return pd.DataFrame(rows)


MODELS = [
    ("tfidf", "tfidf_pred", "tfidf_proba"),
    ("style", "style_pred", "style_proba"),
    ("combined", "combined_pred", "combined_proba"),
]


In [None]:
# 1) Head/torso/tail domain analysis

domain_counts = df["normalized_domain"].astype(str).value_counts()
df_test_pred["domain_count"] = df_test_pred["normalized_domain"].astype(str).map(domain_counts).astype(int)


def domain_bin(count: int) -> str:
    # Bins chosen from global count quantiles (roughly: <=4 tail, 5â€“49 torso, >=50 head).
    if count >= 50:
        return "head (>=50)"
    if count >= 5:
        return "torso (5-49)"
    return "tail (<=4)"


df_test_pred["domain_bin"] = df_test_pred["domain_count"].apply(domain_bin)
print(df_test_pred["domain_bin"].value_counts().to_dict())

rows = []
for bin_name, g in df_test_pred.groupby("domain_bin"):
    y_true = g["y"].astype(int).to_numpy()
    for model_name, pred_col, proba_col in MODELS:
        y_pred = g[pred_col].astype(int).to_numpy()
        y_proba = g[proba_col].astype(float).to_numpy()
        rows.append(
            {
                "domain_bin": bin_name,
                "model": model_name,
                "n": int(len(g)),
                "pos_rate": float(y_true.mean()),
                "macro_f1": float(f1_score(y_true, y_pred, average="macro")),
                "macro_recall": float(recall_score(y_true, y_pred, average="macro")),
                "roc_auc": float(roc_auc_score(y_true, y_proba)),
                "brier": float(brier_score_loss(y_true, y_proba)),
                "ece": float(expected_calibration_error(y_true, y_proba, n_bins=10)),
            }
        )

domain_slice_df = pd.DataFrame(rows)
domain_slice_csv = RESULTS_DIR / "domain_head_torso_tail_metrics.csv"
domain_slice_df.to_csv(domain_slice_csv, index=False)
print({"domain_slice_csv": str(domain_slice_csv)})
display(domain_slice_df.sort_values(["domain_bin", "model"]))


# Plot: macro-F1 and ECE by domain bin
bin_order = ["head (>=50)", "torso (5-49)", "tail (<=4)"]
model_order = ["tfidf", "style", "combined"]
colors = {"tfidf": "#666666", "style": "#1f77b4", "combined": "#2ca02c"}

x = np.arange(len(bin_order))
width = 0.25

fig, axes = plt.subplots(1, 2, figsize=(9.0, 3.0), dpi=300)

for mi, model in enumerate(model_order):
    sub = domain_slice_df[domain_slice_df["model"] == model].set_index("domain_bin")
    y_f1 = [float(sub.loc[b, "macro_f1"]) if b in sub.index else np.nan for b in bin_order]
    y_ece = [float(sub.loc[b, "ece"]) if b in sub.index else np.nan for b in bin_order]
    axes[0].bar(x + (mi - 1) * width, y_f1, width, label=model, color=colors[model])
    axes[1].bar(x + (mi - 1) * width, y_ece, width, label=model, color=colors[model])

axes[0].set_title("Macro-F1 by domain frequency")
axes[0].set_ylabel("Macro-F1")
axes[0].set_xticks(x)
axes[0].set_xticklabels(bin_order, rotation=15, ha="right")
axes[0].set_ylim(0.0, 1.0)
axes[0].grid(axis="y", alpha=0.25)

axes[1].set_title("ECE by domain frequency")
axes[1].set_ylabel("ECE")
axes[1].set_xticks(x)
axes[1].set_xticklabels(bin_order, rotation=15, ha="right")
axes[1].grid(axis="y", alpha=0.25)

axes[0].legend(loc="lower left", frameon=False)

fig.tight_layout()
fig_path = RESULTS_DIR / "domain_head_torso_tail_macro_f1_ece.png"
pdf_path = RESULTS_DIR / "domain_head_torso_tail_macro_f1_ece.pdf"
fig.savefig(fig_path, dpi=300)
fig.savefig(pdf_path)
plt.close(fig)
print({"domain_head_torso_tail_plot_png": str(fig_path)})


In [None]:
# 2) Slice calibration: reliability diagram + ECE by theme + message-length bins

# Theme normalization from v6 (topic bucket only; Link/URL removed elsewhere).
THEME_BUCKETS = [
    "Finance/Crypto",
    "Public health & medicine",
    "Politics",
    "Lifestyle & well-being",
    "Crime & public safety",
    "Gaming/Gambling",
    "News/Information",
    "Sports",
    "Technology",
    "Conversation/Chat/Other",
    "Other theme",
]


def _norm_theme(raw: object) -> Optional[str]:
    if not isinstance(raw, str):
        return None
    t = raw.strip()
    if not t:
        return None
    # Normalize a couple of common unicode dashes to ASCII.
    t = t.replace("\u2011", "-").replace("\u2013", "-").replace("\u2014", "-")
    tl = t.lower()

    if t in THEME_BUCKETS:
        return t

    if any(
        k in tl
        for k in [
            "crypto",
            "token",
            "coin",
            "airdrop",
            "ido",
            "staking",
            "defi",
            "exchange",
            "market",
            "finance",
            "econom",
        ]
    ):
        return "Finance/Crypto"

    if any(
        k in tl
        for k in [
            "health",
            "covid",
            "vaccine",
            "vaccination",
            "medicine",
            "medical",
            "clinical",
            "disease",
            "pandemic",
            "public health",
            "hospital",
        ]
    ):
        return "Public health & medicine"

    if any(
        k in tl
        for k in [
            "politic",
            "election",
            "parliament",
            "congress",
            "senate",
            "government",
            "president",
            "minister",
            "policy",
            "war",
            "conflict",
            "ukraine",
            "russia",
        ]
    ):
        return "Politics"

    if any(
        k in tl
        for k in [
            "crime",
            "criminal",
            "terror",
            "shooting",
            "police",
            "public safety",
            "fraud",
            "scam",
        ]
    ):
        return "Crime & public safety"

    if any(k in tl for k in ["gaming", "gambling", "casino", "betting", "lottery", "poker"]):
        return "Gaming/Gambling"

    if any(k in tl for k in ["sport", "football", "soccer", "basketball", "tennis", "nba", "nfl"]):
        return "Sports"

    if any(
        k in tl
        for k in [
            "technology",
            "tech",
            "software",
            "app ",
            "platform",
            "ai ",
            " a.i.",
            "machine learning",
            "blockchain",
            "internet",
            "social media",
            "algorithm",
            "science",
            "research",
            "study",
        ]
    ):
        return "Technology"

    if any(
        k in tl
        for k in [
            "lifestyle",
            "well-being",
            "wellbeing",
            "culture",
            "entertainment",
            "media",
            "celebrity",
            "social issues",
            "society",
            "family",
            "community",
        ]
    ):
        return "Lifestyle & well-being"

    if any(k in tl for k in ["news", "headline", "breaking", "coverage", "roundup", "update"]):
        return "News/Information"

    if any(k in tl for k in ["comment", "conversation", "chat", "q&a", "ama", "ask me anything"]):
        return "Conversation/Chat/Other"

    return "Other theme"


df_test_pred["theme_norm"] = df_test_pred["theme"].apply(_norm_theme)


# Reliability diagram
y_true = df_test_pred["y"].astype(int).to_numpy()

fig, ax = plt.subplots(figsize=(4.2, 4.0), dpi=300)
ax.plot([0, 1], [0, 1], "--", color="black", linewidth=1, alpha=0.6)

for model_name, _, proba_col in MODELS:
    y_proba = df_test_pred[proba_col].astype(float).to_numpy()
    bins_df = calibration_bins(y_true, y_proba, n_bins=10)
    ax.plot(bins_df["p_mean"], bins_df["y_mean"], marker="o", label=model_name)

ax.set_xlabel("Predicted probability")
ax.set_ylabel("Observed frequency")
ax.set_title("Reliability diagram (test)")
ax.grid(alpha=0.25)
ax.legend(frameon=False, loc="upper left")
fig.tight_layout()

rel_png = RESULTS_DIR / "reliability_diagram_tfidf_style_combined.png"
rel_pdf = RESULTS_DIR / "reliability_diagram_tfidf_style_combined.pdf"
fig.savefig(rel_png, dpi=300)
fig.savefig(rel_pdf)
plt.close(fig)
print({"reliability_png": str(rel_png)})


# ECE by theme slice
theme_rows = []
for theme, g in df_test_pred.groupby("theme_norm"):
    if theme is None:
        continue
    y_t = g["y"].astype(int).to_numpy()
    if len(y_t) < 50:
        continue
    for model_name, _, proba_col in MODELS:
        y_p = g[proba_col].astype(float).to_numpy()
        theme_rows.append(
            {
                "theme": theme,
                "model": model_name,
                "n": int(len(g)),
                "pos_rate": float(y_t.mean()),
                "ece": float(expected_calibration_error(y_t, y_p, n_bins=10)),
                "brier": float(brier_score_loss(y_t, y_p)),
            }
        )

theme_slice_df = pd.DataFrame(theme_rows)
theme_slice_csv = RESULTS_DIR / "ece_brier_by_theme.csv"
theme_slice_df.to_csv(theme_slice_csv, index=False)
print({"ece_by_theme_csv": str(theme_slice_csv)})
display(theme_slice_df.sort_values(["theme", "model"]))


# Message-length bins (quartiles) for ECE
df_test_pred["tok_len"] = df_test_pred["message"].astype(str).str.split().apply(len)
df_test_pred["len_bin"] = pd.qcut(df_test_pred["tok_len"], q=4, labels=["Q1", "Q2", "Q3", "Q4"]).astype(str)

len_rows = []
for lb, g in df_test_pred.groupby("len_bin"):
    y_t = g["y"].astype(int).to_numpy()
    for model_name, _, proba_col in MODELS:
        y_p = g[proba_col].astype(float).to_numpy()
        len_rows.append(
            {
                "len_bin": lb,
                "model": model_name,
                "n": int(len(g)),
                "ece": float(expected_calibration_error(y_t, y_p, n_bins=10)),
                "brier": float(brier_score_loss(y_t, y_p)),
            }
        )

len_slice_df = pd.DataFrame(len_rows)
len_slice_csv = RESULTS_DIR / "ece_brier_by_length_quartile.csv"
len_slice_df.to_csv(len_slice_csv, index=False)
print({"ece_by_length_csv": str(len_slice_csv)})
display(len_slice_df.sort_values(["len_bin", "model"]))


In [None]:
# 3) Tagger-noise robustness for style-only (no Theme)
# We simulate LLM tagging noise by flipping entries in the multi-hot tag vector.

# v6 behavior: remove Link/URL evidence tags
DROP_LINK_URL_LABEL = True
_LINK_URL_LABEL_NORM = "link/url"


def tokenize_multi(value: object) -> List[str]:
    if not isinstance(value, str):
        return []
    value = value.replace("+", ",")
    parts = [part.strip() for part in value.split(",") if part.strip()]
    if not DROP_LINK_URL_LABEL:
        return parts
    return [p for p in parts if "".join(p.lower().split()) != _LINK_URL_LABEL_NORM]


CLAIM_BUCKETS = [
    "Verifiable factual statement",
    "Rumour / unverified report",
    "Announcement",
    "Opinion / subjective statement",
    "Misleading context / cherry-picking",
    "Promotional hype / exaggerated profit guarantee",
    "Emotional appeal / fear-mongering",
    "Scarcity/FOMO tactic",
    "Statistics",
    "Other claim type",
    "No substantive claim",
    "Fake content",
    "Speculative forecast / prediction",
    "None / assertion only",
]

CTA_BUCKETS = [
    "Visit external link / watch video",
    "Engage/Ask questions",
    "Join/Subscribe",
    "Buy / invest / donate",
    "Attend event / livestream",
    "Share / repost / like",
    "No CTA",
    "Other CTA",
]

EVID_BUCKETS = [
    "Link/URL",
    "Statistics",
    "Quotes/Testimony",
    "Chart / price graph / TA diagram",
    "Other (Evidence)",
    "None / assertion only",
]


def _norm_claim_labels(raw: object) -> List[str]:
    labels = tokenize_multi(raw)
    out: List[str] = []
    for lbl in labels:
        base = lbl.strip()
        if not base:
            continue
        low = base.lower()
        if base in CLAIM_BUCKETS:
            out.append(base)
            continue
        if "verifiable" in low or "factual" in low:
            out.append("Verifiable factual statement")
        elif "rumour" in low or "unverified" in low:
            out.append("Rumour / unverified report")
        elif "misleading context" in low or "cherry" in low:
            out.append("Misleading context / cherry-picking")
        elif "promotional hype" in low or "exaggerated profit" in low:
            out.append("Promotional hype / exaggerated profit guarantee")
        elif "emotional appeal" in low or "fear-mongering" in low or "fear mongering" in low:
            out.append("Emotional appeal / fear-mongering")
        elif "scarcity" in low or "fomo" in low:
            out.append("Scarcity/FOMO tactic")
        elif "statistic" in low:
            out.append("Statistics")
        elif "fake content" in low or "fabricated" in low:
            out.append("Fake content")
        elif "predict" in low or "forecast" in low:
            out.append("Speculative forecast / prediction")
        elif "announcement" in low:
            out.append("Announcement")
        elif "opinion" in low or "interpretive" in low or "analysis" in low or "review" in low:
            out.append("Opinion / subjective statement")
        elif "none / assertion only" in low or "assertion only" in low:
            out.append("None / assertion only")
        else:
            out.append("Other claim type")

    seen = set()
    result: List[str] = []
    for v in out:
        if v not in seen:
            seen.add(v)
            result.append(v)
    return result


def _norm_cta_labels(raw: object) -> List[str]:
    labels = tokenize_multi(raw)
    out: List[str] = []
    for lbl in labels:
        base = lbl.strip()
        if not base:
            continue
        low = base.lower()

        if base in CTA_BUCKETS:
            out.append(base)
            continue

        if base in {"None", "No CTA"} or "no cta" in low:
            out.append("No CTA")
        elif "engage" in low or "ask" in low or "anything" in low:
            out.append("Engage/Ask questions")
        elif "attend" in low or "event" in low or "livestream" in low or "live stream" in low:
            out.append("Attend event / livestream")
        elif "join" in low or "subscribe" in low or "follow" in low or "whitelist" in low:
            out.append("Join/Subscribe")
        elif "buy" in low or "invest" in low or "donate" in low or "stake" in low or "swap" in low:
            out.append("Buy / invest / donate")
        elif "share" in low or "repost" in low or "like" in low:
            out.append("Share / repost / like")
        elif (
            "visit" in low
            or "read" in low
            or "watch" in low
            or "link" in low
            or "website" in low
            or "check" in low
            or "view charts" in low
        ):
            out.append("Visit external link / watch video")
        else:
            out.append("Other CTA")

    seen = set()
    result: List[str] = []
    for v in out:
        if v not in seen:
            seen.add(v)
            result.append(v)
    return result


def _norm_evidence_labels(raw: object) -> List[str]:
    labels = tokenize_multi(raw)
    out: List[str] = []
    for lbl in labels:
        base = lbl.strip()
        if not base:
            continue
        low = base.lower()

        if base in EVID_BUCKETS:
            if base != "Link/URL":
                out.append(base)
            continue

        if "link/url" in low or "link" in low or "url" in low:
            continue
        elif "statistic" in low:
            out.append("Statistics")
        elif "quote" in low or "testimony" in low:
            out.append("Quotes/Testimony")
        elif "chart" in low or "graph" in low or "diagram" in low:
            out.append("Chart / price graph / TA diagram")
        elif "none / assertion only" in low or "assertion only" in low:
            out.append("None / assertion only")
        else:
            out.append("Other (Evidence)")

    seen = set()
    result: List[str] = []
    for v in out:
        if v not in seen:
            seen.add(v)
            result.append(v)
    return result


def build_style_tokens_no_theme(row: pd.Series) -> List[str]:
    tokens: List[str] = []
    for label in _norm_claim_labels(row.get("claim_types")):
        tokens.append(f"claim={label}")
    for label in _norm_cta_labels(row.get("ctas")):
        tokens.append(f"cta={label}")
    for label in _norm_evidence_labels(row.get("evidence")):
        tokens.append(f"evid={label}")
    return tokens


def sweep_thresholds(y_true, proba, grid=None):
    grid = grid or [round(t, 2) for t in np.linspace(0.05, 0.95, 19)]
    best = None
    for t in grid:
        pred = (proba >= t).astype(int)
        macro_f1 = f1_score(y_true, pred, average="macro")
        if best is None or macro_f1 > best["macro_f1"]:
            best = {"threshold": float(t), "macro_f1": float(macro_f1)}
    return best


def fit_logreg_val_tuned(X_train, y_train, X_val, y_val, X_test, y_test, seed=0):
    clf = LogisticRegression(
        penalty="l2",
        C=1.0,
        solver="lbfgs",
        max_iter=1000,
        class_weight="balanced",
        random_state=seed,
    )
    clf.fit(X_train, y_train)
    val_proba = clf.predict_proba(X_val)[:, 1]
    thr = sweep_thresholds(y_val, val_proba)

    # retrain on train+val
    X_trainval = np.vstack([X_train, X_val])
    y_trainval = np.concatenate([y_train, y_val])
    clf2 = LogisticRegression(
        penalty="l2",
        C=1.0,
        solver="lbfgs",
        max_iter=1000,
        class_weight="balanced",
        random_state=seed,
    )
    clf2.fit(X_trainval, y_trainval)
    test_proba = clf2.predict_proba(X_test)[:, 1]
    test_pred = (test_proba >= thr["threshold"]).astype(int)
    return {
        "threshold": thr["threshold"],
        "macro_f1": float(f1_score(y_test, test_pred, average="macro")),
        "roc_auc": float(roc_auc_score(y_test, test_proba)),
        "accuracy": float(accuracy_score(y_test, test_pred)),
        "brier": float(brier_score_loss(y_test, test_proba)),
        "ece": float(expected_calibration_error(y_test, test_proba, n_bins=10)),
    }


# Build clean style-only matrices for the v6 main split
train_tokens = df_train.apply(build_style_tokens_no_theme, axis=1).tolist()
val_tokens = df_val.apply(build_style_tokens_no_theme, axis=1).tolist()
test_tokens = df_test.apply(build_style_tokens_no_theme, axis=1).tolist()

vocab = sorted(set(t for toks in (train_tokens + val_tokens) for t in toks))
mlb = MultiLabelBinarizer(classes=vocab)
X_train_clean = mlb.fit_transform(train_tokens).astype(np.int8)
X_val_clean = mlb.transform(val_tokens).astype(np.int8)
X_test_clean = mlb.transform(test_tokens).astype(np.int8)

y_train = df_train["y"].astype(int).to_numpy()
y_val = df_val["y"].astype(int).to_numpy()
y_test = df_test["y"].astype(int).to_numpy()

print({"style_only_no_theme_vocab": int(len(vocab))})


def flip_noise(X: np.ndarray, noise_rate: float, rng: np.random.Generator) -> np.ndarray:
    if noise_rate <= 0:
        return X
    mask = rng.random(X.shape) < noise_rate
    return np.logical_xor(X.astype(bool), mask).astype(np.int8)


noise_levels = [0.0, 0.05, 0.10, 0.20]
n_repeats = 5

robust_rows = []
for p in noise_levels:
    for rep in range(n_repeats):
        rng = np.random.default_rng(1000 + rep)
        Xtr = flip_noise(X_train_clean, p, rng)
        Xva = flip_noise(X_val_clean, p, rng)
        Xte = flip_noise(X_test_clean, p, rng)
        metrics = fit_logreg_val_tuned(Xtr, y_train, Xva, y_val, Xte, y_test, seed=rep)
        metrics.update({"noise_rate": float(p), "rep": int(rep)})
        robust_rows.append(metrics)

robust_df = pd.DataFrame(robust_rows)
robust_csv = RESULTS_DIR / "tagger_noise_robustness_style_only_no_theme.csv"
robust_df.to_csv(robust_csv, index=False)
print({"noise_robustness_csv": str(robust_csv)})

robust_summary = robust_df.groupby("noise_rate")[["macro_f1", "roc_auc", "ece", "brier", "accuracy"]].agg([
    "mean",
    "std",
]).round(4)
display(robust_summary)


# Plot robustness curves
means = robust_df.groupby("noise_rate")[["macro_f1", "roc_auc", "ece"]].mean()
stds = robust_df.groupby("noise_rate")[["macro_f1", "roc_auc", "ece"]].std()

fig, ax = plt.subplots(figsize=(4.8, 3.0), dpi=300)
x = means.index.to_numpy(dtype=float)
ax.errorbar(x, means["macro_f1"], yerr=stds["macro_f1"], marker="^", label="Macro-F1", color="black")
ax.errorbar(x, means["roc_auc"], yerr=stds["roc_auc"], marker="s", label="ROC-AUC", color="#1f77b4")
ax.errorbar(x, means["ece"], yerr=stds["ece"], marker="o", label="ECE", color="#d62728")
ax.set_xlabel("Noise rate (bit flips)")
ax.set_ylabel("Score")
ax.set_title("Tagger-noise robustness (style-only, no Theme)")
ax.grid(alpha=0.25)
ax.legend(frameon=False)
fig.tight_layout()

noise_png = RESULTS_DIR / "tagger_noise_robustness_curve.png"
noise_pdf = RESULTS_DIR / "tagger_noise_robustness_curve.pdf"
fig.savefig(noise_png, dpi=300)
fig.savefig(noise_pdf)
plt.close(fig)
print({"noise_robustness_plot_png": str(noise_png)})
