In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification



  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from transformers import AutoTokenizer

dataset = load_dataset("ag_news", download_mode="force_redownload")

# Check the dataset structure
print(f"Train set: {len(dataset['train'])} examples")
print(f"Test set: {len(dataset['test'])} examples")
print(f"Label names: {dataset['train'].features['label'].names}")
# Load the DistilBERT tokenizer
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

# Tokenization function
def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=128)

# Apply tokenization to the dataset
tokenized_datasets = dataset.map(tokenize_function, batched=True)

# Prepare dataset for training (convert to PyTorch tensors)
tokenized_datasets = tokenized_datasets.remove_columns(["text"])
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
tokenized_datasets.set_format("torch")

Generating train split: 100%|█| 120000/120000 [00:00<00:00, 1950680.29 examples/s]
Generating test split: 100%|██████| 7600/7600 [00:00<00:00, 1450061.88 examples/s]


Train set: 120000 examples
Test set: 7600 examples
Label names: ['World', 'Sports', 'Business', 'Sci/Tech']


Map: 100%|██████████████████████| 120000/120000 [00:04<00:00, 29962.60 examples/s]
Map: 100%|██████████████████████████| 7600/7600 [00:00<00:00, 25356.77 examples/s]


In [3]:
model = AutoModelForSequenceClassification.from_pretrained(
    "distilbert-base-uncased",
    num_labels=4  # AG News has 4 classes
)

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [4]:
from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(
    output_dir="./results",
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    num_train_epochs=3,
    learning_rate=2e-5,
    # The following args are optional and can be removed if causing issues
    logging_dir="./logs",
    logging_steps=1500,
    report_to = "none"
)

In [5]:
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = predictions.argmax(axis=1)
    accuracy = accuracy_score(labels, predictions)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='weighted')
    return {
        'accuracy': accuracy,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }

In [6]:
import os
os.environ['WANDB_MODE'] = 'disabled'

train_val_dataset = tokenized_datasets["train"].train_test_split(test_size=0.1, seed=42)

train_dataset = train_val_dataset["train"]
eval_dataset = train_val_dataset["test"]
test_dataset = tokenized_datasets["test"]

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    compute_metrics=compute_metrics,
)

In [7]:
trainer.train()



Step,Training Loss
1500,0.2768
3000,0.2007
4500,0.1537
6000,0.1369
7500,0.1122
9000,0.0959




TrainOutput(global_step=10125, training_loss=0.1554474178361304, metrics={'train_runtime': 16259.9665, 'train_samples_per_second': 19.926, 'train_steps_per_second': 0.623, 'total_flos': 1.0730241994752e+16, 'train_loss': 0.1554474178361304, 'epoch': 3.0})

In [8]:
model_path = "./distilbert-ag-news"
model.save_pretrained(model_path)
tokenizer.save_pretrained(model_path)
print(f"Model saved to {model_path}")

Model saved to ./distilbert-ag-news


In [9]:
eval_results = trainer.evaluate()
print(f"Evaluation results: {eval_results}")



Evaluation results: {'eval_loss': 0.19162078201770782, 'eval_accuracy': 0.94525, 'eval_f1': 0.9452412045369722, 'eval_precision': 0.9453686726275785, 'eval_recall': 0.94525, 'eval_runtime': 65.3533, 'eval_samples_per_second': 183.617, 'eval_steps_per_second': 5.738, 'epoch': 3.0}
