# Lab 2.5.3 Solution: AG News Data Pipeline

This notebook provides the complete solution for the "Try It Yourself" exercise in Lab 2.5.3.

**Task**: Build a complete data pipeline for the AG News dataset.

---

In [None]:
import torch
from datasets import load_dataset, DatasetDict
from transformers import AutoTokenizer
from collections import Counter
import numpy as np

print("AG News Data Pipeline Solution")
print("=" * 60)

## Step 1: Load the AG News Dataset

In [None]:
# 1. Load AG News dataset
print("\nStep 1: Loading AG News dataset...")
ag_news = load_dataset("ag_news")

print(f"Splits: {list(ag_news.keys())}")
print(f"Train examples: {len(ag_news['train']):,}")
print(f"Test examples: {len(ag_news['test']):,}")
print(f"\nFeatures: {ag_news['train'].features}")

# Show sample
print("\nSample:")
print(ag_news['train'][0])

## Step 2: Analyze Label Distribution

In [None]:
# 2. Analyze label distribution
print("\nStep 2: Label Distribution Analysis")
print("-" * 50)

# AG News categories
categories = ["World", "Sports", "Business", "Sci/Tech"]

for split in ['train', 'test']:
    labels = ag_news[split]['label']
    counts = Counter(labels)
    total = len(labels)
    
    print(f"\n{split.upper()}:")
    for label_id, count in sorted(counts.items()):
        pct = 100 * count / total
        print(f"  {label_id} ({categories[label_id]:10}): {count:6,} ({pct:.1f}%)")

## Step 3: Filter Articles > 200 Characters

In [None]:
# 3. Filter to only keep articles > 200 characters
print("\nStep 3: Filtering short articles...")

def filter_by_length(example):
    return len(example['text']) > 200

filtered_train = ag_news['train'].filter(filter_by_length, num_proc=4)
filtered_test = ag_news['test'].filter(filter_by_length, num_proc=4)

print(f"Train: {len(ag_news['train']):,} -> {len(filtered_train):,} ")
print(f"       (removed {len(ag_news['train']) - len(filtered_train):,} short articles)")
print(f"Test: {len(ag_news['test']):,} -> {len(filtered_test):,}")
print(f"      (removed {len(ag_news['test']) - len(filtered_test):,} short articles)")

## Step 4: Create Train/Val/Test Splits (80/10/10)

In [None]:
# 4. Create train/val/test splits (80/10/10)
print("\nStep 4: Creating stratified splits...")

# First split: 90% train+val, 10% test
split1 = filtered_train.train_test_split(
    test_size=0.1,
    stratify_by_column='label',
    seed=42
)

# Second split: 88.9% train, 11.1% val (of the 90%)
# This gives us ~80% train, ~10% val overall
split2 = split1['train'].train_test_split(
    test_size=0.111,  # 10/90 = 0.111
    stratify_by_column='label',
    seed=42
)

# Create DatasetDict
dataset = DatasetDict({
    'train': split2['train'],
    'validation': split2['test'],
    'test': split1['test']
})

total = sum(len(d) for d in dataset.values())
print(f"\nFinal splits:")
for split, data in dataset.items():
    pct = 100 * len(data) / total
    print(f"  {split:12}: {len(data):,} ({pct:.1f}%)")

# Verify stratification
print("\nLabel distribution preserved:")
for split in ['train', 'validation', 'test']:
    counts = Counter(dataset[split]['label'])
    dist = [counts[i] / len(dataset[split]) for i in range(4)]
    print(f"  {split}: {[f'{d:.1%}' for d in dist]}")

## Step 5: Tokenize with Transformer Tokenizer

In [None]:
# 5. Tokenize with a transformer tokenizer
print("\nStep 5: Tokenizing...")

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

def tokenize_function(examples):
    return tokenizer(
        examples['text'],
        padding='max_length',
        truncation=True,
        max_length=256
    )

# Tokenize all splits
tokenized_dataset = dataset.map(
    tokenize_function,
    batched=True,
    num_proc=4,
    remove_columns=['text'],
    desc="Tokenizing"
)

# Rename label to labels (for Trainer)
tokenized_dataset = tokenized_dataset.rename_column('label', 'labels')

print(f"\nTokenized columns: {tokenized_dataset['train'].column_names}")
print(f"Sample input_ids length: {len(tokenized_dataset['train'][0]['input_ids'])}")

## Step 6: Set Format for PyTorch

In [None]:
# 6. Set format for PyTorch
print("\nStep 6: Setting PyTorch format...")

tokenized_dataset.set_format("torch")

# Verify format
sample = tokenized_dataset['train'][0]
print(f"\nSample tensor types:")
for key, val in sample.items():
    print(f"  {key}: {type(val).__name__}, shape: {val.shape if hasattr(val, 'shape') else 'scalar'}")

## Final Summary

In [None]:
print("\n" + "=" * 60)
print("AG NEWS DATASET READY FOR TRAINING!")
print("=" * 60)

print(f"\nDataset structure:")
print(tokenized_dataset)

print(f"\nCategories (4 classes):")
for i, cat in enumerate(categories):
    print(f"  {i}: {cat}")

print(f"\nReady to use with:")
print("  model = AutoModelForSequenceClassification.from_pretrained(")
print("      'distilbert-base-uncased', num_labels=4")
print("  )")
print("  trainer = Trainer(model=model, train_dataset=tokenized_dataset['train'], ...)")

---

## Bonus: Quick Training Test

In [None]:
# Optional: Quick training test
from transformers import AutoModelForSequenceClassification, Trainer, TrainingArguments
import evaluate

# Load model
model = AutoModelForSequenceClassification.from_pretrained(
    "distilbert-base-uncased",
    num_labels=4,
    torch_dtype=torch.bfloat16
)

# Metrics
accuracy = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=labels)

# Quick training args (1 epoch for verification)
args = TrainingArguments(
    output_dir="./ag_news_test",
    num_train_epochs=1,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=64,
    eval_strategy="epoch",
    bf16=True,
    logging_steps=100,
    report_to="none"
)

# Trainer
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=tokenized_dataset['train'].select(range(5000)),  # Subset for quick test
    eval_dataset=tokenized_dataset['validation'].select(range(1000)),
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

# Train
print("\nQuick training test (1 epoch, subset)...")
trainer.train()

# Evaluate
results = trainer.evaluate()
print(f"\nValidation accuracy: {results['eval_accuracy']:.4f}")
print("Pipeline is working correctly!")