# Zero-Shot Classification: German News Articles
## Model: MoritzLaurer/mDeBERTa-v3-base-mnli-xnli

This notebook evaluates zero-shot classification performance on the test split (617 articles)
of the `Zorryy/news_articles_2025_elections_germany` dataset.

**Two experiments:**
1. Classification using **headline** only
2. Classification using **full article text** (truncated to 512 tokens)

**Primary metric:** F1 macro score across 13 German topic categories.

In [None]:
!pip install -q transformers datasets huggingface_hub scikit-learn matplotlib seaborn tqdm pandas

In [None]:
import time
import warnings

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from datasets import load_dataset
from huggingface_hub import login
from sklearn.metrics import (
    classification_report,
    confusion_matrix,
    f1_score,
    precision_score,
    recall_score,
)
from transformers import pipeline
from tqdm.auto import tqdm

warnings.filterwarnings("ignore")

In [None]:
from google.colab import userdata

try:
    hf_token = userdata.get("HF_TOKEN")
    print("Token loaded from Colab secrets.")
except Exception:
    hf_token = input("Enter your HuggingFace token: ")

login(token=hf_token)
print("Authenticated with HuggingFace.")

In [None]:
DATASET_ID = "Zorryy/news_articles_2025_elections_germany"

ds = load_dataset(DATASET_ID, split="test", token=hf_token)
print(f"Test split loaded: {len(ds)} articles")
print(f"Columns: {ds.column_names}")
print(f"\nLabel distribution:")
label_counts = pd.Series(ds["label"]).value_counts()
print(label_counts.to_string())

In [None]:
df = ds.to_pandas()

print(f"Total articles: {len(df)}")
print(f"\nText length statistics (characters):")
print(df["text"].str.len().describe().to_string())
print(f"\nHeadline length statistics (characters):")
print(df["headline"].str.len().describe().to_string())
print(f"\nMissing headlines: {df['headline'].isna().sum()}")
print(f"Missing text: {df['text'].isna().sum()}")

In [None]:
CANDIDATE_LABELS = [
    "Klima / Energie",
    "Zuwanderung",
    "Renten",
    "Soziales Gef\u00e4lle",
    "AfD/Rechte",
    "Arbeitslosigkeit",
    "Wirtschaftslage",
    "Politikverdruss",
    "Gesundheitswesen, Pflege",
    "Kosten/L\u00f6hne/Preise",
    "Ukraine/Krieg/Russland",
    "Bundeswehr/Verteidigung",
    "Andere",
]

HYPOTHESIS_TEMPLATE = "Dieser Text handelt von {}."

print(f"Number of candidate labels: {len(CANDIDATE_LABELS)}")
print(f"Hypothesis template: '{HYPOTHESIS_TEMPLATE}'")
print(f"\nExample hypotheses:")
for label in CANDIDATE_LABELS[:3]:
    print(f"  '{HYPOTHESIS_TEMPLATE.format(label)}'")

In [None]:
MODEL_ID = "MoritzLaurer/mDeBERTa-v3-base-mnli-xnli"

device = 0 if torch.cuda.is_available() else -1
device_name = "GPU" if device == 0 else "CPU"
print(f"Using device: {device_name}")

# mDeBERTa does NOT support FP16 (produces NaN). Always use FP32.
classifier = pipeline(
    "zero-shot-classification",
    model=MODEL_ID,
    device=device,
)

print(f"Model loaded: {MODEL_ID}")
print(f"Tokenizer max length: {classifier.tokenizer.model_max_length}")

## Experiment 1: Classification Using Headlines Only

Headlines are short (typically <20 tokens), well within the 512 token limit.
No truncation needed.

In [None]:
def classify_batch(texts, classifier, candidate_labels, hypothesis_template, batch_size=16):
    """Run zero-shot classification with progress tracking.
    Skips empty texts and assigns 'Andere' as default for those.
    """
    predictions = [None] * len(texts)
    non_empty_indices = [i for i, t in enumerate(texts) if t.strip()]
    non_empty_texts = [texts[i] for i in non_empty_indices]

    for i in tqdm(range(0, len(non_empty_texts), batch_size), desc="Classifying"):
        batch = non_empty_texts[i : i + batch_size]
        batch_indices = non_empty_indices[i : i + batch_size]
        results = classifier(
            batch,
            candidate_labels=candidate_labels,
            hypothesis_template=hypothesis_template,
            multi_label=False,
        )
        if isinstance(results, dict):
            results = [results]
        for idx, r in zip(batch_indices, results):
            predictions[idx] = r["labels"][0]

    # Fill empty texts with "Andere"
    empty_count = sum(1 for p in predictions if p is None)
    if empty_count > 0:
        print(f"  {empty_count} empty texts -> defaulting to 'Andere'")
    predictions = [p if p is not None else "Andere" for p in predictions]

    return predictions


headlines = df["headline"].fillna("").tolist()

start_time = time.time()
headline_predictions = classify_batch(
    headlines, classifier, CANDIDATE_LABELS, HYPOTHESIS_TEMPLATE, batch_size=16
)
headline_time = time.time() - start_time

df["pred_headline"] = headline_predictions
print(f"\nHeadline classification completed in {headline_time:.1f} seconds")
print(f"({headline_time / len(headlines):.2f} sec/article)")

## Experiment 2: Classification Using Full Article Text

mDeBERTa has a 512 token limit. Most articles exceed this significantly.
The transformers pipeline handles truncation automatically (keeps the first 512 tokens).
The first 512 tokens typically cover the headline, lede, and first several paragraphs —
the most information-dense portion of a news article (inverted pyramid structure).

In [None]:
tokenizer = classifier.tokenizer
token_lengths = []

for text in tqdm(df["text"].fillna("").tolist(), desc="Tokenizing"):
    tokens = tokenizer.encode(text, add_special_tokens=False)
    token_lengths.append(len(tokens))

df["token_count"] = token_lengths

print(f"Token length statistics:")
print(pd.Series(token_lengths).describe().to_string())
print(f"\nArticles exceeding 512 tokens: {sum(1 for t in token_lengths if t > 512)} / {len(token_lengths)}")
print(f"Percentage truncated: {sum(1 for t in token_lengths if t > 512) / len(token_lengths) * 100:.1f}%")

fig, ax = plt.subplots(figsize=(10, 4))
ax.hist(token_lengths, bins=50, edgecolor="black", alpha=0.7)
ax.axvline(x=512, color="red", linestyle="--", linewidth=2, label="512 token limit")
ax.set_xlabel("Token count")
ax.set_ylabel("Number of articles")
ax.set_title("Article Token Length Distribution (Test Set)")
ax.legend()
plt.tight_layout()
plt.show()

In [None]:
texts = df["text"].fillna("").tolist()

start_time = time.time()
text_predictions = classify_batch(
    texts, classifier, CANDIDATE_LABELS, HYPOTHESIS_TEMPLATE, batch_size=8
)
text_time = time.time() - start_time

df["pred_text"] = text_predictions
print(f"\nFull-text classification completed in {text_time:.1f} seconds")
print(f"({text_time / len(texts):.2f} sec/article)")

## Evaluation Metrics

Computing per-class and aggregate metrics for both experiments:
- **F1 macro** (primary metric — accounts for class imbalance)
- **F1 per class**
- **Precision and Recall** (macro and per-class)
- **Confusion matrices**

In [None]:
def compute_all_metrics(y_true, y_pred, labels, experiment_name):
    """Compute and display classification metrics."""
    print(f"{'=' * 70}")
    print(f"  {experiment_name}")
    print(f"{'=' * 70}")

    f1_macro = f1_score(y_true, y_pred, labels=labels, average="macro", zero_division=0)
    f1_weighted = f1_score(y_true, y_pred, labels=labels, average="weighted", zero_division=0)
    precision_macro = precision_score(y_true, y_pred, labels=labels, average="macro", zero_division=0)
    recall_macro = recall_score(y_true, y_pred, labels=labels, average="macro", zero_division=0)

    print(f"\n  F1 Macro:        {f1_macro:.4f}")
    print(f"  F1 Weighted:     {f1_weighted:.4f}")
    print(f"  Precision Macro: {precision_macro:.4f}")
    print(f"  Recall Macro:    {recall_macro:.4f}")

    report = classification_report(
        y_true, y_pred, labels=labels, output_dict=True, zero_division=0
    )

    per_class = pd.DataFrame({
        "Label": labels,
        "Precision": [report[l]["precision"] for l in labels],
        "Recall": [report[l]["recall"] for l in labels],
        "F1": [report[l]["f1-score"] for l in labels],
        "Support": [report[l]["support"] for l in labels],
    })

    print(f"\n  Per-Class Metrics:")
    print(per_class.to_string(index=False))

    return {
        "f1_macro": f1_macro,
        "f1_weighted": f1_weighted,
        "precision_macro": precision_macro,
        "recall_macro": recall_macro,
        "per_class": per_class,
    }

In [None]:
true_labels = df["label"].tolist()

headline_metrics = compute_all_metrics(
    true_labels,
    df["pred_headline"].tolist(),
    CANDIDATE_LABELS,
    "Experiment 1: Headline-Based Zero-Shot Classification",
)

In [None]:
text_metrics = compute_all_metrics(
    true_labels,
    df["pred_text"].tolist(),
    CANDIDATE_LABELS,
    "Experiment 2: Full-Text Zero-Shot Classification",
)

In [None]:
def plot_confusion_matrix(y_true, y_pred, labels, title):
    """Plot raw and normalized confusion matrices side by side."""
    cm = confusion_matrix(y_true, y_pred, labels=labels)
    cm_normalized = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis]

    short_labels = [l[:20] for l in labels]

    fig, axes = plt.subplots(1, 2, figsize=(28, 11))

    sns.heatmap(
        cm, annot=True, fmt="d", cmap="Blues",
        xticklabels=short_labels, yticklabels=short_labels,
        ax=axes[0],
    )
    axes[0].set_xlabel("Predicted")
    axes[0].set_ylabel("True")
    axes[0].set_title(f"{title} (Counts)")
    axes[0].tick_params(axis="x", rotation=45)
    axes[0].tick_params(axis="y", rotation=0)

    sns.heatmap(
        cm_normalized, annot=True, fmt=".2f", cmap="Blues",
        xticklabels=short_labels, yticklabels=short_labels,
        ax=axes[1],
    )
    axes[1].set_xlabel("Predicted")
    axes[1].set_ylabel("True")
    axes[1].set_title(f"{title} (Normalized)")
    axes[1].tick_params(axis="x", rotation=45)
    axes[1].tick_params(axis="y", rotation=0)

    plt.tight_layout()
    plt.show()

    return cm


headline_cm = plot_confusion_matrix(
    true_labels,
    df["pred_headline"].tolist(),
    CANDIDATE_LABELS,
    "Headline Zero-Shot",
)

In [None]:
text_cm = plot_confusion_matrix(
    true_labels,
    df["pred_text"].tolist(),
    CANDIDATE_LABELS,
    "Full-Text Zero-Shot",
)

## Comparison: Headline vs. Full Text

In [None]:
comparison = pd.DataFrame({
    "Label": CANDIDATE_LABELS,
    "F1 (Headline)": headline_metrics["per_class"]["F1"].values,
    "F1 (Text)": text_metrics["per_class"]["F1"].values,
    "Prec (Headline)": headline_metrics["per_class"]["Precision"].values,
    "Prec (Text)": text_metrics["per_class"]["Precision"].values,
    "Rec (Headline)": headline_metrics["per_class"]["Recall"].values,
    "Rec (Text)": text_metrics["per_class"]["Recall"].values,
    "Support": headline_metrics["per_class"]["Support"].values.astype(int),
})

comparison["F1 Delta"] = comparison["F1 (Text)"] - comparison["F1 (Headline)"]

summary_row = pd.DataFrame([{
    "Label": "MACRO AVERAGE",
    "F1 (Headline)": headline_metrics["f1_macro"],
    "F1 (Text)": text_metrics["f1_macro"],
    "Prec (Headline)": headline_metrics["precision_macro"],
    "Prec (Text)": text_metrics["precision_macro"],
    "Rec (Headline)": headline_metrics["recall_macro"],
    "Rec (Text)": text_metrics["recall_macro"],
    "Support": len(df),
    "F1 Delta": text_metrics["f1_macro"] - headline_metrics["f1_macro"],
}])

comparison = pd.concat([comparison, summary_row], ignore_index=True)

print("Comparison: Headline vs. Full-Text Zero-Shot Classification")
print("=" * 100)
print(comparison.to_string(index=False, float_format="%.4f"))

In [None]:
fig, ax = plt.subplots(figsize=(14, 7))

x = np.arange(len(CANDIDATE_LABELS))
width = 0.35

ax.bar(
    x - width / 2,
    headline_metrics["per_class"]["F1"].values,
    width,
    label=f'Headline (F1 macro={headline_metrics["f1_macro"]:.3f})',
    color="#4C72B0",
    alpha=0.85,
)
ax.bar(
    x + width / 2,
    text_metrics["per_class"]["F1"].values,
    width,
    label=f'Full Text (F1 macro={text_metrics["f1_macro"]:.3f})',
    color="#DD8452",
    alpha=0.85,
)

ax.set_xlabel("Category")
ax.set_ylabel("F1 Score")
ax.set_title(
    "Per-Class F1: Headline vs. Full-Text Zero-Shot Classification\n"
    f"Model: {MODEL_ID}"
)
ax.set_xticks(x)
ax.set_xticklabels([l[:18] for l in CANDIDATE_LABELS], rotation=45, ha="right")
ax.legend(loc="upper right")
ax.set_ylim(0, 1.0)
ax.grid(axis="y", alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
output_df = df[["id", "domain", "headline", "label", "pred_headline", "pred_text", "token_count"]].copy()
output_df.rename(columns={"label": "true_label"}, inplace=True)

output_df["headline_correct"] = output_df["true_label"] == output_df["pred_headline"]
output_df["text_correct"] = output_df["true_label"] == output_df["pred_text"]

output_path = "zero_shot_results_mdeberta.csv"
output_df.to_csv(output_path, index=False, encoding="utf-8")
print(f"Results saved to: {output_path}")
print(f"Total rows: {len(output_df)}")

both_correct = (output_df["headline_correct"] & output_df["text_correct"]).sum()
headline_only = (output_df["headline_correct"] & ~output_df["text_correct"]).sum()
text_only = (~output_df["headline_correct"] & output_df["text_correct"]).sum()
neither = (~output_df["headline_correct"] & ~output_df["text_correct"]).sum()

print(f"\nPrediction agreement:")
print(f"  Both correct:       {both_correct} ({both_correct/len(output_df)*100:.1f}%)")
print(f"  Headline only:      {headline_only} ({headline_only/len(output_df)*100:.1f}%)")
print(f"  Text only:          {text_only} ({text_only/len(output_df)*100:.1f}%)")
print(f"  Neither correct:    {neither} ({neither/len(output_df)*100:.1f}%)")

In [None]:
print("=" * 70)
print("  FINAL SUMMARY")
print("=" * 70)
print(f"\n  Model:              {MODEL_ID}")
print(f"  Test set size:      {len(df)} articles")
print(f"  Candidate labels:   {len(CANDIDATE_LABELS)}")
print(f"  Hypothesis:         '{HYPOTHESIS_TEMPLATE}'")
print(f"")
print(f"  HEADLINE-based:")
print(f"    F1 Macro:         {headline_metrics['f1_macro']:.4f}")
print(f"    F1 Weighted:      {headline_metrics['f1_weighted']:.4f}")
print(f"    Runtime:          {headline_time:.1f}s")
print(f"")
print(f"  FULL-TEXT-based:")
print(f"    F1 Macro:         {text_metrics['f1_macro']:.4f}")
print(f"    F1 Weighted:      {text_metrics['f1_weighted']:.4f}")
print(f"    Runtime:          {text_time:.1f}s")
print(f"")
print(f"  Articles truncated: {sum(1 for t in token_lengths if t > 512)} / {len(token_lengths)}")
print("=" * 70)