# 3️⃣ DistilBERT Fine-tuning on Yahoo Answers


### 🧠 Model: DistilBERT Fine-tuning
**Dataset:** Yahoo Answers  
**Classes:** 10  
**Technique:** We fine-tune a lightweight transformer — DistilBERT — using Hugging Face’s `Trainer` API for multiclass classification.

DistilBERT offers a balance between speed and accuracy, making it suitable for mid-scale datasets with reasonable training time.


In [None]:
!pip install transformers datasets accelerate scikit-learn nltk --quiet

In [1]:
from datasets import load_dataset
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification, TrainingArguments, Trainer
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import numpy as np

In [None]:
# Load Dataset
dataset = load_dataset("yahoo_answers_topics")
dataset = dataset.rename_column("topic", "label")

In [None]:
# Tokenizer
tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")

def tokenize(example):
    # Combine title and content for each example in the batch
    full_text = [title + " " + content for title, content in zip(example["question_title"], example["question_content"])]
    # Pass the list of combined texts to the tokenizer
    return tokenizer(full_text, padding="max_length", truncation=True)

tokenized_ds = dataset.map(tokenize, batched=True)

In [17]:
tokenized_ds.set_format("torch", columns=["input_ids", "attention_mask", "label"])

In [13]:
tokenized_ds['train']

Dataset({
    features: ['id', 'label', 'question_title', 'question_content', 'best_answer', 'input_ids', 'attention_mask'],
    num_rows: 1400000
})

In [None]:
# Load Model
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=10)

In [19]:
#  Metrics
def compute_metrics(pred):
    labels = pred.label_ids
    preds = np.argmax(pred.predictions, axis=1)
    acc = accuracy_score(labels, preds)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average="weighted")
    return {"accuracy": acc, "precision": precision, "recall": recall, "f1": f1}


In [None]:
# Training Arguments
training_args = TrainingArguments(
    output_dir="./distilbert-yahoo",
    eval_strategy="epoch",
    save_strategy="epoch", # Ensure save strategy matches eval strategy
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=32,
    num_train_epochs=3,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_strategy="steps", # Set logging strategy to log based on steps
    logging_steps=100,       # Log every 100 steps
    save_total_limit=1,
    push_to_hub=False,
    load_best_model_at_end=True
)

# Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_ds["train"].shuffle(seed=42).select(range(50000)),
    eval_dataset=tokenized_ds["test"].select(range(5000)),
    compute_metrics=compute_metrics,
)

In [22]:
# Train
trainer.train()



<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize?ref=models
[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mkoushikreddy143749[0m ([33mkoushikreddy143749-na[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.8997,0.902129,0.7152,0.714531,0.7152,0.71208


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.8997,0.902129,0.7152,0.714531,0.7152,0.71208
2,0.7236,0.920743,0.7234,0.720477,0.7234,0.718284
3,0.5978,0.941784,0.725,0.721594,0.725,0.721938


TrainOutput(global_step=9375, training_loss=0.776661376953125, metrics={'train_runtime': 7347.9987, 'train_samples_per_second': 20.414, 'train_steps_per_second': 1.276, 'total_flos': 1.987294464e+16, 'train_loss': 0.776661376953125, 'epoch': 3.0})

In [23]:
# Evaluate
results = trainer.evaluate()
print(results)

{'eval_loss': 0.9021289944648743, 'eval_accuracy': 0.7152, 'eval_precision': 0.7145307273334158, 'eval_recall': 0.7152, 'eval_f1': 0.7120797031423628, 'eval_runtime': 77.5182, 'eval_samples_per_second': 64.501, 'eval_steps_per_second': 2.025, 'epoch': 3.0}
