# Fine-tune DistilBERT for News Topic Classification

Multi-label classification using `ContextNews/labelled_articles`.

**Make sure to set Runtime > Change runtime type > T4 GPU**

In [None]:
!pip install -q datasets transformers accelerate scikit-learn torch

In [None]:
from huggingface_hub import login
login()

In [None]:
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"Device: {torch.cuda.get_device_name(0)}")

In [None]:
import numpy as np
import torch
from datasets import load_dataset
from sklearn.metrics import f1_score, precision_score, recall_score
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
)

DATASET_ID = "ContextNews/labelled_articles"
BASE_MODEL = "distilbert-base-uncased"
PUSH_TO = "ContextNews/news-classifier"  # change this to your repo

TOPICS = [
    "politics", "geopolitics", "conflict", "crime", "law", "business",
    "economy", "markets", "technology", "science", "health", "environment",
    "society", "education", "sports", "entertainment",
]

In [None]:
# Load dataset
train_ds = load_dataset(DATASET_ID, split="train")
val_ds = load_dataset(DATASET_ID, split="validation")
test_ds = load_dataset(DATASET_ID, split="test")

print(f"Train: {len(train_ds)}, Val: {len(val_ds)}, Test: {len(test_ds)}")

In [None]:
# Preprocess
def build_input_text(row):
    title = row.get("title") or ""
    summary = row.get("summary") or ""
    text = row.get("text") or ""
    text_excerpt = " ".join(text.split()[:300])
    return " ".join(p for p in [title, summary, text_excerpt] if p)


def preprocess(row):
    row["input_text"] = build_input_text(row)
    row["labels"] = [float(row[t] or 0) for t in TOPICS]
    return row


train_ds = train_ds.map(preprocess)
val_ds = val_ds.map(preprocess)
test_ds = test_ds.map(preprocess)

In [None]:
# Tokenize
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)


def tokenize(batch):
    encoding = tokenizer(
        batch["input_text"],
        truncation=True,
        padding="max_length",
        max_length=512,
    )
    encoding["labels"] = batch["labels"]
    return encoding


cols = train_ds.column_names
train_ds = train_ds.map(tokenize, batched=True, remove_columns=cols)
val_ds = val_ds.map(tokenize, batched=True, remove_columns=cols)
test_ds = test_ds.map(tokenize, batched=True, remove_columns=cols)

train_ds.set_format("torch")
val_ds.set_format("torch")
test_ds.set_format("torch")

In [None]:
# Model
model = AutoModelForSequenceClassification.from_pretrained(
    BASE_MODEL,
    num_labels=len(TOPICS),
    problem_type="multi_label_classification",
    id2label={i: t for i, t in enumerate(TOPICS)},
    label2id={t: i for i, t in enumerate(TOPICS)},
)

In [None]:
# Metrics
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = (torch.sigmoid(torch.tensor(logits)) > 0.5).int().numpy()
    return {
        "f1_micro": f1_score(labels, preds, average="micro", zero_division=0),
        "f1_macro": f1_score(labels, preds, average="macro", zero_division=0),
        "precision": precision_score(labels, preds, average="micro", zero_division=0),
        "recall": recall_score(labels, preds, average="micro", zero_division=0),
    }

In [None]:
# Training
training_args = TrainingArguments(
    output_dir="./model_output",
    num_train_epochs=3,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=32,
    learning_rate=2e-5,
    weight_decay=0.01,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="f1_micro",
    logging_steps=50,
    push_to_hub=True,
    hub_model_id=PUSH_TO,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    compute_metrics=compute_metrics,
)

trainer.train()

In [None]:
# Evaluate on test set
print("Test set evaluation:")
metrics = trainer.evaluate(test_ds)
for k, v in metrics.items():
    print(f"  {k}: {v:.4f}" if isinstance(v, float) else f"  {k}: {v}")

In [None]:
# Push to HuggingFace
trainer.push_to_hub()
tokenizer.push_to_hub(PUSH_TO)
print(f"Model pushed to {PUSH_TO}")