# News Topic Classifier Training Notebook
Fine-tune **bert-base-uncased** on **AG News** using headlines only.


In [None]:
from datasets import load_dataset
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments, DataCollatorWithPadding
from sklearn.metrics import accuracy_score, f1_score
import numpy as np
labels = ["World","Sports","Business","Sci/Tech"]
dataset = load_dataset("ag_news")
def select_cols(batch):
    texts = []
    for i in range(len(batch['label'])):
        title = batch.get('title', [None]*len(batch['label']))[i] if 'title' in batch else None
        text = batch['text'][i]
        headline = title if (title is not None and title != '') else text
        texts.append(headline)
    return {"headline": texts, "label": batch["label"]}
dataset = dataset.map(select_cols, batched=True, remove_columns=dataset['train'].column_names)
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
def tokenize(batch):
    return tokenizer(batch['headline'], truncation=True, max_length=64)
encoded = dataset.map(tokenize, batched=True)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=len(labels))
def compute_metrics(eval_pred):
    logits, y = eval_pred
    import numpy as np
    preds = np.argmax(logits, axis=-1)
    return {"accuracy": accuracy_score(y, preds), "f1": f1_score(y, preds, average="weighted")}
args = TrainingArguments(output_dir="./models", evaluation_strategy="epoch", save_strategy="epoch", learning_rate=2e-5,
                         per_device_train_batch_size=16, per_device_eval_batch_size=16, num_train_epochs=3, weight_decay=0.01,
                         logging_dir="./logs", logging_steps=100, load_best_model_at_end=True, metric_for_best_model="accuracy",
                         report_to="none")
trainer = Trainer(model=model, args=args, train_dataset=encoded['train'], eval_dataset=encoded['test'], tokenizer=tokenizer,
                  data_collator=data_collator, compute_metrics=compute_metrics)
trainer.train()
metrics = trainer.evaluate(); metrics
trainer.save_model("./models/news_topic_classifier"); tokenizer.save_pretrained("./models/news_topic_classifier")
