# Per-Class Threshold Tuning for News Topic Classification

Optimizes per-label decision thresholds for `ContextNews/news-classifier`.

- **Tune** thresholds on the validation split
- **Evaluate** with frozen thresholds on the held-out test split

**Make sure to set Runtime > Change runtime type > T4 GPU**

In [None]:
!pip install -q datasets transformers accelerate scikit-learn torch matplotlib

In [None]:
from huggingface_hub import login
login()

## 1. Setup and Load Model + Dataset

In [None]:
import json

import matplotlib.pyplot as plt
import numpy as np
import torch
from datasets import load_dataset
from sklearn.metrics import f1_score, precision_score, recall_score
from transformers import AutoModelForSequenceClassification, AutoTokenizer

MODEL_ID = "ContextNews/news-classifier"
DATASET_ID = "ContextNews/labelled_articles"

TOPICS = [
    "politics", "geopolitics", "conflict", "crime", "law", "business",
    "economy", "markets", "technology", "science", "health", "environment",
    "society", "education", "sports", "entertainment",
]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID).to(device)
model.eval()

print(f"Model labels: {model.config.id2label}")

In [None]:
val_ds = load_dataset(DATASET_ID, split="validation")
test_ds = load_dataset(DATASET_ID, split="test")
print(f"Validation set: {len(val_ds)} rows")
print(f"Test set:       {len(test_ds)} rows")

## 2. Preprocessing and Inference Helpers

In [None]:
def build_input_text(row):
    title = row.get("title") or ""
    summary = row.get("summary") or ""
    text = row.get("text") or ""
    text_excerpt = " ".join(text.split()[:300])
    return " ".join(p for p in [title, summary, text_excerpt] if p)


def preprocess(row):
    row["input_text"] = build_input_text(row)
    row["labels"] = [float(row[t] or 0) for t in TOPICS]
    return row


def run_inference(ds, batch_size=32):
    """Run batched inference, return (y_true, y_probs) numpy arrays."""
    all_probs = []
    all_labels = []

    for i in range(0, len(ds), batch_size):
        batch = ds[i : i + batch_size]

        encodings = tokenizer(
            batch["input_text"],
            truncation=True,
            padding="max_length",
            max_length=512,
            return_tensors="pt",
        ).to(device)

        with torch.no_grad():
            logits = model(**encodings).logits
            probs = torch.sigmoid(logits).cpu().numpy()

        all_probs.append(probs)
        all_labels.append(np.array(batch["labels"]))

        if (i // batch_size) % 10 == 0:
            print(f"\r  Processed {min(i + batch_size, len(ds))}/{len(ds)}", end="", flush=True)

    print()
    return np.concatenate(all_labels, axis=0), np.concatenate(all_probs, axis=0)

## 3. Run Inference on Validation Set

In [None]:
val_ds = val_ds.map(preprocess)
val_true, val_probs = run_inference(val_ds)
print(f"val_true shape:  {val_true.shape}")
print(f"val_probs shape: {val_probs.shape}")

## 4. Baseline Metrics on Validation (threshold = 0.5)

In [None]:
val_pred_baseline = (val_probs >= 0.5).astype(int)

val_baseline_metrics = {
    "f1_micro": f1_score(val_true, val_pred_baseline, average="micro", zero_division=0),
    "f1_macro": f1_score(val_true, val_pred_baseline, average="macro", zero_division=0),
    "precision": precision_score(val_true, val_pred_baseline, average="micro", zero_division=0),
    "recall": recall_score(val_true, val_pred_baseline, average="micro", zero_division=0),
}

print("Validation baseline metrics (threshold = 0.5):")
for k, v in val_baseline_metrics.items():
    print(f"  {k}: {v:.4f}")

print("\nPer-class F1 (validation baseline):")
val_baseline_per_class = {}
for i, topic in enumerate(TOPICS):
    f1 = f1_score(val_true[:, i], val_pred_baseline[:, i], zero_division=0)
    val_baseline_per_class[topic] = f1
    print(f"  {topic:15s} {f1:.4f}")

## 5. Per-Class Threshold Sweep (on Validation)

In [None]:
thresholds_range = np.arange(0.05, 0.96, 0.01)

optimal_thresholds = {}
f1_curves = {}  # topic -> list of (threshold, f1)

for i, topic in enumerate(TOPICS):
    best_t = 0.5
    best_f1 = 0.0
    curve = []

    for t in thresholds_range:
        preds = (val_probs[:, i] >= t).astype(int)
        f1 = f1_score(val_true[:, i], preds, zero_division=0)
        curve.append((t, f1))
        if f1 > best_f1:
            best_f1 = f1
            best_t = float(round(t, 2))

    optimal_thresholds[topic] = best_t
    f1_curves[topic] = curve
    print(f"  {topic:15s}  threshold={best_t:.2f}  F1={best_f1:.4f}  (baseline F1={val_baseline_per_class[topic]:.4f})")

print("\nOptimal thresholds:")
print(json.dumps(optimal_thresholds, indent=2))

## 6. Evaluate on Held-Out Test Set

Thresholds were tuned on validation. To get an unbiased estimate, we freeze them and evaluate on the **test** split.

In [None]:
test_ds = test_ds.map(preprocess)
test_true, test_probs = run_inference(test_ds)
print(f"test_true shape:  {test_true.shape}")
print(f"test_probs shape: {test_probs.shape}")

In [None]:
# Baseline on test (threshold = 0.5)
test_pred_baseline = (test_probs >= 0.5).astype(int)

test_baseline_metrics = {
    "f1_micro": f1_score(test_true, test_pred_baseline, average="micro", zero_division=0),
    "f1_macro": f1_score(test_true, test_pred_baseline, average="macro", zero_division=0),
    "precision": precision_score(test_true, test_pred_baseline, average="micro", zero_division=0),
    "recall": recall_score(test_true, test_pred_baseline, average="micro", zero_division=0),
}

# Tuned on test (frozen thresholds from validation)
test_pred_tuned = np.zeros_like(test_probs, dtype=int)
for i, topic in enumerate(TOPICS):
    test_pred_tuned[:, i] = (test_probs[:, i] >= optimal_thresholds[topic]).astype(int)

test_tuned_metrics = {
    "f1_micro": f1_score(test_true, test_pred_tuned, average="micro", zero_division=0),
    "f1_macro": f1_score(test_true, test_pred_tuned, average="macro", zero_division=0),
    "precision": precision_score(test_true, test_pred_tuned, average="micro", zero_division=0),
    "recall": recall_score(test_true, test_pred_tuned, average="micro", zero_division=0),
}

print("Test set — Baseline (0.5) vs Tuned (frozen from validation):")
print(f"{'metric':>12s}  {'baseline':>8s}  {'tuned':>8s}  {'delta':>8s}")
print(f"{'—'*12}  {'—'*8}  {'—'*8}  {'—'*8}")
for k in ["f1_micro", "f1_macro", "precision", "recall"]:
    b = test_baseline_metrics[k]
    t = test_tuned_metrics[k]
    print(f"{k:>12s}  {b:8.4f}  {t:8.4f}  {t - b:+8.4f}")

print("\n" + "=" * 50)
print(f"  Baseline macro F1 (test): {test_baseline_metrics['f1_macro']:.4f}")
print(f"  Tuned macro F1 (test):    {test_tuned_metrics['f1_macro']:.4f}")
improvement = test_tuned_metrics["f1_macro"] - test_baseline_metrics["f1_macro"]
print(f"  Improvement:              {improvement:+.4f}")
print("=" * 50)

## 7. Save Thresholds to Disk

In [None]:
OUTPUT_PATH = "news_classifier_thresholds.json"

with open(OUTPUT_PATH, "w") as f:
    json.dump(optimal_thresholds, f, indent=2)

print(f"Thresholds saved to {OUTPUT_PATH}")

## 8. F1 vs Threshold Curves (3 Worst Performing Labels)

Curves are from the validation sweep (where tuning happened).

In [None]:
sorted_by_f1 = sorted(val_baseline_per_class.items(), key=lambda x: x[1])
worst_3 = [topic for topic, _ in sorted_by_f1[:3]]

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

for ax, topic in zip(axes, worst_3):
    curve = f1_curves[topic]
    ts = [t for t, _ in curve]
    f1s = [f for _, f in curve]
    best_t = optimal_thresholds[topic]
    best_f1 = max(f1s)

    ax.plot(ts, f1s, linewidth=1.5)
    ax.axvline(x=0.5, color="gray", linestyle="--", alpha=0.5, label="default (0.5)")
    ax.axvline(x=best_t, color="red", linestyle="--", alpha=0.7, label=f"optimal ({best_t:.2f})")
    ax.scatter([best_t], [best_f1], color="red", zorder=5)
    ax.set_title(topic)
    ax.set_xlabel("Threshold")
    ax.set_ylabel("F1")
    ax.set_ylim(0, 1)
    ax.legend(fontsize=8)

plt.suptitle("F1 vs Threshold - 3 Worst Performing Labels (validation)", y=1.02)
plt.tight_layout()
plt.show()