In [2]:
from datasets import load_dataset

# Load AG News dataset
dataset = load_dataset("ag_news")

print("Dataset structure:", dataset)
print("Sample text:", dataset["train"][0]["text"])
print("Label (0=World, 1=Sports, 2=Business, 3=Sci/Tech):", dataset["train"][0]["label"])

Dataset structure: DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 120000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 7600
    })
})
Sample text: Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\band of ultra-cynics, are seeing green again.
Label (0=World, 1=Sports, 2=Business, 3=Sci/Tech): 2


In [3]:
# Define label names
label_names = ['World', 'Sports', 'Business', 'Sci/Tech']

# Print a few examples
for i in range(3):
    print(f"Text: {dataset['train'][i]['text']}")
    print(f"Label: {label_names[dataset['train'][i]['label']]}")
    print("-" * 50)

Text: Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\band of ultra-cynics, are seeing green again.
Label: Business
--------------------------------------------------
Text: Carlyle Looks Toward Commercial Aerospace (Reuters) Reuters - Private investment firm Carlyle Group,\which has a reputation for making well-timed and occasionally\controversial plays in the defense industry, has quietly placed\its bets on another part of the market.
Label: Business
--------------------------------------------------
Text: Oil and Economy Cloud Stocks' Outlook (Reuters) Reuters - Soaring crude prices plus worries\about the economy and the outlook for earnings are expected to\hang over the stock market next week during the depth of the\summer doldrums.
Label: Business
--------------------------------------------------


In [4]:
# %%
from transformers import AutoTokenizer

# Load pre-trained tokenizer
model_name = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)

def tokenize_function(examples):
    return tokenizer(examples["text"], truncation=True, padding=True, max_length=64)

# Apply tokenization
tokenized_datasets = dataset.map(tokenize_function, batched=True)

# Rename 'label' to 'labels' (required by PyTorch models)
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")

# Set format for PyTorch tensors
tokenized_datasets.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])

print("Tokenized sample keys:", tokenized_datasets["train"][0].keys())

Map:   0%|          | 0/7600 [00:00<?, ? examples/s]

Tokenized sample keys: dict_keys(['labels', 'input_ids', 'attention_mask'])


In [5]:
# %%
from transformers import AutoModelForSequenceClassification

# Number of labels: 4 (AG News)
num_labels = 4

# Load model
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)

print("Model loaded successfully!")

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model loaded successfully!


In [6]:
# %%
from transformers import TrainingArguments, Trainer
import evaluate

# Load accuracy and F1 metrics
accuracy_metric = evaluate.load("accuracy")
f1_metric = evaluate.load("f1")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = predictions.argmax(axis=-1)
    
    accuracy = accuracy_metric.compute(predictions=predictions, references=labels)
    f1 = f1_metric.compute(predictions=predictions, references=labels, average="weighted")
    
    return {
        "accuracy": accuracy["accuracy"],
        "f1": f1["f1"]
    }

In [7]:
# %%
training_args = TrainingArguments(
    output_dir="bert-news-classifier-fast",
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=16,
    num_train_epochs=2,
    max_steps=500,               # 🔥 Stop after 500 steps total (~2 epochs on 2k data)
    eval_steps=250,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=50,
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    report_to="none",
    save_total_limit=2,
    seed=42,
)

In [8]:
# %%
# Reduce dataset for faster training (ideal for development)
train_subset = tokenized_datasets["train"].select(range(2000))   # Use only 2000 samples
test_subset  = tokenized_datasets["test"] .select(range(500))    # Use 500 for eval

print("Training on", len(train_subset), "samples")
print("Evaluating on", len(test_subset), "samples")

Training on 2000 samples
Evaluating on 500 samples


In [9]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["test"],
    compute_metrics=compute_metrics,  # We defined this earlier
)

In [10]:
# %%
import evaluate

accuracy_metric = evaluate.load("accuracy")
f1_metric = evaluate.load("f1")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = predictions.argmax(axis=-1)
    
    accuracy = accuracy_metric.compute(predictions=predictions, references=labels)
    f1 = f1_metric.compute(predictions=predictions, references=labels, average="weighted")
    
    return {
        "accuracy": accuracy["accuracy"],
        "f1": f1["f1"]
    }
# %%
print("Starting training...")
trainer.train()

Starting training...




Epoch,Training Loss,Validation Loss,Accuracy,F1
0,0.4283,0.329753,0.901974,0.901801


TrainOutput(global_step=500, training_loss=0.47203154373168943, metrics={'train_runtime': 1521.6789, 'train_samples_per_second': 2.629, 'train_steps_per_second': 0.329, 'total_flos': 131557890048000.0, 'train_loss': 0.47203154373168943, 'epoch': 0.03333333333333333})

In [11]:
# %%
results = trainer.evaluate()

print(f"Final Test Accuracy: {results['eval_accuracy']:.4f}")
print(f"Final F1-Score: {results['eval_f1']:.4f}")



Final Test Accuracy: 0.9020
Final F1-Score: 0.9018


In [12]:
# %%
model.save_pretrained("bert-news-classifier-finetuned")
tokenizer.save_pretrained("bert-news-classifier-finetuned")

print("✅ Model and tokenizer saved to 'bert-news-classifier-finetuned'")

✅ Model and tokenizer saved to 'bert-news-classifier-finetuned'


In [13]:
# %%
model.save_pretrained("bert-news-classifier-finetuned")
tokenizer.save_pretrained("bert-news-classifier-finetuned")

print("✅ Model and tokenizer saved to 'bert-news-classifier-finetuned'")

✅ Model and tokenizer saved to 'bert-news-classifier-finetuned'


In [15]:
# %%
import torch
def predict_topic(text):
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=64)
    with torch.no_grad():
        logits = model(**inputs).logits
    pred_class = logits.argmax().item()
    return label_names[pred_class]

# Test with sample headlines
test_headlines = [
    "France wins World Cup in dramatic penalty shootout",
    "Apple announces new AI-powered iPhone next month",
    "Stock markets hit record high on tech rally",
    "Scientists discover black hole near Milky Way center"
]

for headline in test_headlines:
    pred = predict_topic(headline)
    print(f"Headline: {headline}")
    print(f"Predicted Topic: {pred}\n")

Headline: France wins World Cup in dramatic penalty shootout
Predicted Topic: Sports

Headline: Apple announces new AI-powered iPhone next month
Predicted Topic: Sci/Tech

Headline: Stock markets hit record high on tech rally
Predicted Topic: Business

Headline: Scientists discover black hole near Milky Way center
Predicted Topic: Sci/Tech



In [17]:
# %%
model.save_pretrained("bert-news-classifier-finetuned")
tokenizer.save_pretrained("bert-news-classifier-finetuned")

('bert-news-classifier-finetuned/tokenizer_config.json',
 'bert-news-classifier-finetuned/special_tokens_map.json',
 'bert-news-classifier-finetuned/vocab.txt',
 'bert-news-classifier-finetuned/added_tokens.json',
 'bert-news-classifier-finetuned/tokenizer.json')