# Data Loading Basics for AG News Classification

## Overview

This tutorial demonstrates fundamental data loading techniques following best practices from:
- PyTorch DataLoader documentation and design patterns
- Hugging Face Datasets library architecture
- Efficient data pipeline design principles

### Learning Objectives
1. Understand the AG News dataset structure
2. Load data using custom dataset classes
3. Implement efficient batch loading
4. Handle data transformations and preprocessing
5. Create reproducible data splits

Author: Võ Hải Dũng  
Email: vohaidung.work@gmail.com  
Date: 2025

## 1. Setup and Imports

In [None]:
# Standard library imports
import os
import sys
from pathlib import Path
from typing import Dict, List, Tuple, Optional
import json
import time
from collections import Counter

# Data manipulation
import numpy as np
import pandas as pd

# PyTorch
import torch
from torch.utils.data import DataLoader, Dataset, random_split

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns

# Progress bar
from tqdm.notebook import tqdm

# Project setup
PROJECT_ROOT = Path("../..").resolve()
sys.path.insert(0, str(PROJECT_ROOT))

# Project imports
from src.data.datasets.ag_news import AGNewsDataset, AGNewsConfig
from src.data.loaders.dataloader import create_dataloader, DataLoaderConfig
from src.utils.reproducibility import set_seed
from configs.constants import (
    AG_NEWS_CLASSES,
    AG_NEWS_NUM_CLASSES,
    LABEL_TO_ID,
    ID_TO_LABEL,
    DATA_DIR
)

# Set random seed for reproducibility
set_seed(42)

# Configure visualization
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"Project root: {PROJECT_ROOT}")

## 2. Understanding AG News Dataset Structure

In [None]:
# Display dataset information
print("AG News Dataset Information")
print("=" * 50)
print(f"Number of classes: {AG_NEWS_NUM_CLASSES}")
print(f"Class names: {AG_NEWS_CLASSES}")
print("\nClass to ID mapping:")
for class_name, class_id in LABEL_TO_ID.items():
    print(f"  {class_name}: {class_id}")

# Dataset structure
print("\nExpected Dataset Structure:")
print("""
Each sample contains:
  - text: The news article text (title + description)
  - label: Integer label (0-3)
  - label_name: String label (World, Sports, Business, Sci/Tech)
  
Data splits:
  - Train: 120,000 samples (30,000 per class)
  - Test: 7,600 samples (1,900 per class)
  - No official validation set (we create one from training data)
""")

## 3. Basic Data Loading

In [None]:
# Create dataset configuration
dataset_config = AGNewsConfig(
    data_dir=DATA_DIR / "processed",
    max_samples=1000,  # Limit samples for tutorial
    use_cache=True,
    validate_labels=True
)

print("Loading AG News dataset...")
print(f"Configuration:")
print(f"  Data directory: {dataset_config.data_dir}")
print(f"  Max samples: {dataset_config.max_samples}")
print(f"  Use cache: {dataset_config.use_cache}")

# Load training dataset
try:
    train_dataset = AGNewsDataset(dataset_config, split="train")
    print(f"\nDataset loaded successfully!")
    print(f"Number of samples: {len(train_dataset)}")
    
    # Get first sample
    first_sample = train_dataset[0]
    print(f"\nFirst sample:")
    print(f"  Text: {first_sample['text'][:100]}...")
    print(f"  Label: {first_sample['label']}")
    print(f"  Label name: {first_sample['label_name']}")
    
except FileNotFoundError as e:
    print(f"\nError: Dataset files not found.")
    print(f"Please run the data preparation script first:")
    print(f"  python scripts/data_preparation/prepare_ag_news.py")
    
    # Alternative: Load from Hugging Face
    print("\nAlternatively, loading from Hugging Face...")
    from datasets import load_dataset
    hf_dataset = load_dataset("ag_news", split="train[:1000]")
    print(f"Loaded {len(hf_dataset)} samples from Hugging Face")

## 4. Creating Data Loaders

In [None]:
# Create DataLoader configuration
loader_config = DataLoaderConfig(
    batch_size=32,
    shuffle=True,
    num_workers=2,
    pin_memory=torch.cuda.is_available(),
    drop_last=False
)

print("Creating DataLoader...")
print(f"Configuration:")
print(f"  Batch size: {loader_config.batch_size}")
print(f"  Shuffle: {loader_config.shuffle}")
print(f"  Num workers: {loader_config.num_workers}")
print(f"  Pin memory: {loader_config.pin_memory}")

# Create DataLoader
train_loader = create_dataloader(train_dataset, loader_config)

print(f"\nDataLoader created!")
print(f"Number of batches: {len(train_loader)}")
print(f"Samples per batch: {loader_config.batch_size}")

# Test iteration
print("\nTesting batch iteration...")
for i, batch in enumerate(train_loader):
    if i == 0:
        print(f"Batch keys: {batch.keys()}")
        print(f"Text shape: {len(batch['text'])}")
        print(f"Labels shape: {batch['label'].shape}")
        print(f"Label names: {batch['label_name'][:5]}")
    if i >= 2:  # Show only first 3 batches
        break
print(f"Iterated through {i+1} batches successfully")

## 5. Custom Dataset Implementation

In [None]:
class SimpleAGNewsDataset(Dataset):
    """
    Simple custom dataset implementation for AG News.
    
    This demonstrates the core PyTorch Dataset interface.
    """
    
    def __init__(self, texts: List[str], labels: List[int]):
        """
        Initialize dataset.
        
        Args:
            texts: List of text samples
            labels: List of integer labels
        """
        self.texts = texts
        self.labels = labels
        assert len(texts) == len(labels), "Texts and labels must have same length"
    
    def __len__(self) -> int:
        """Return the number of samples."""
        return len(self.texts)
    
    def __getitem__(self, idx: int) -> Dict[str, any]:
        """
        Get a single sample.
        
        Args:
            idx: Sample index
            
        Returns:
            Dict containing text and label
        """
        return {
            'text': self.texts[idx],
            'label': self.labels[idx],
            'label_name': ID_TO_LABEL[self.labels[idx]]
        }

# Create sample data
sample_texts = [
    "Stock market reaches new heights today",
    "Scientists discover new planet in solar system",
    "Football team wins championship",
    "New technology breakthrough announced"
]
sample_labels = [2, 3, 1, 3]  # Business, Sci/Tech, Sports, Sci/Tech

# Create custom dataset
simple_dataset = SimpleAGNewsDataset(sample_texts, sample_labels)

print(f"Custom dataset created with {len(simple_dataset)} samples")
print("\nSamples:")
for i in range(len(simple_dataset)):
    sample = simple_dataset[i]
    print(f"{i+1}. {sample['label_name']}: {sample['text']}")

# Create DataLoader for custom dataset
simple_loader = DataLoader(simple_dataset, batch_size=2, shuffle=True)
print(f"\nDataLoader created with batch size 2")

for batch in simple_loader:
    print(f"Batch texts: {batch['text']}")
    print(f"Batch labels: {batch['label']}")

## 6. Creating Train/Validation Splits

In [None]:
def create_train_val_split(dataset: Dataset, val_ratio: float = 0.1, seed: int = 42):
    """
    Create train/validation split from a dataset.
    
    Args:
        dataset: Original dataset
        val_ratio: Ratio of validation samples
        seed: Random seed for reproducibility
        
    Returns:
        Tuple of (train_dataset, val_dataset)
    """
    # Set seed for reproducibility
    torch.manual_seed(seed)
    
    # Calculate split sizes
    total_size = len(dataset)
    val_size = int(total_size * val_ratio)
    train_size = total_size - val_size
    
    # Create random split
    train_dataset, val_dataset = random_split(
        dataset, 
        [train_size, val_size],
        generator=torch.Generator().manual_seed(seed)
    )
    
    return train_dataset, val_dataset

# Create train/validation split
print("Creating train/validation split...")
train_subset, val_subset = create_train_val_split(train_dataset, val_ratio=0.2)

print(f"Original dataset size: {len(train_dataset)}")
print(f"Training subset size: {len(train_subset)} ({len(train_subset)/len(train_dataset)*100:.1f}%)")
print(f"Validation subset size: {len(val_subset)} ({len(val_subset)/len(train_dataset)*100:.1f}%)")

# Verify no overlap
train_indices = set(train_subset.indices)
val_indices = set(val_subset.indices)
overlap = train_indices.intersection(val_indices)
print(f"\nOverlap between train and validation: {len(overlap)} samples")

if len(overlap) == 0:
    print("Split verification: PASSED (no overlap)")
else:
    print("Split verification: FAILED (overlap detected)")

## 7. Efficient Batch Loading

In [None]:
# Compare different loading strategies
print("Comparing DataLoader Performance")
print("=" * 50)

# Test configurations
configs = [
    {"batch_size": 16, "num_workers": 0, "pin_memory": False},
    {"batch_size": 16, "num_workers": 2, "pin_memory": False},
    {"batch_size": 32, "num_workers": 2, "pin_memory": False},
    {"batch_size": 32, "num_workers": 4, "pin_memory": torch.cuda.is_available()},
]

results = []

for config in configs:
    loader = DataLoader(
        train_subset,
        batch_size=config["batch_size"],
        num_workers=config["num_workers"],
        pin_memory=config["pin_memory"],
        shuffle=True
    )
    
    # Measure loading time
    start_time = time.time()
    for i, batch in enumerate(loader):
        if i >= 10:  # Test only first 10 batches
            break
        # Simulate some processing
        _ = batch['label'].numpy()
    
    elapsed_time = time.time() - start_time
    
    result = {
        **config,
        "time": elapsed_time,
        "batches_per_second": 10 / elapsed_time
    }
    results.append(result)
    
    print(f"Config: batch_size={config['batch_size']}, "
          f"workers={config['num_workers']}, "
          f"pin_memory={config['pin_memory']}")
    print(f"  Time: {elapsed_time:.3f}s")
    print(f"  Throughput: {result['batches_per_second']:.1f} batches/sec")
    print()

# Find best configuration
best_config = max(results, key=lambda x: x['batches_per_second'])
print(f"Best configuration:")
print(f"  Batch size: {best_config['batch_size']}")
print(f"  Workers: {best_config['num_workers']}")
print(f"  Pin memory: {best_config['pin_memory']}")
print(f"  Throughput: {best_config['batches_per_second']:.1f} batches/sec")

## 8. Memory-Efficient Loading for Large Datasets

In [None]:
class LazyAGNewsDataset(Dataset):
    """
    Lazy-loading dataset that reads data on-demand.
    
    This is memory-efficient for large datasets.
    """
    
    def __init__(self, data_file: Path):
        """
        Initialize lazy dataset.
        
        Args:
            data_file: Path to data file (JSON lines format)
        """
        self.data_file = data_file
        
        # Count lines without loading entire file
        self._length = 0
        if data_file.exists():
            with open(data_file, 'r') as f:
                for _ in f:
                    self._length += 1
    
    def __len__(self) -> int:
        return self._length
    
    def __getitem__(self, idx: int) -> Dict[str, any]:
        """
        Load a single sample from disk.
        
        Args:
            idx: Sample index
            
        Returns:
            Dict containing sample data
        """
        with open(self.data_file, 'r') as f:
            for i, line in enumerate(f):
                if i == idx:
                    sample = json.loads(line)
                    return {
                        'text': sample['text'],
                        'label': sample['label'],
                        'label_name': ID_TO_LABEL[sample['label']]
                    }
        
        raise IndexError(f"Index {idx} out of range")

# Demonstrate lazy loading concept
print("Lazy Loading Demonstration")
print("=" * 50)
print("""
Lazy loading is useful when:
1. Dataset is too large to fit in memory
2. You want to minimize memory usage
3. Data is stored in a streamable format

Trade-offs:
- Pros: Low memory usage, can handle huge datasets
- Cons: Slower data access, more I/O operations

Best practices:
- Use memory mapping for large files
- Cache frequently accessed samples
- Use efficient file formats (HDF5, Arrow, etc.)
""")

# Create sample data file for demonstration
sample_data_file = PROJECT_ROOT / "data" / "sample_lazy.jsonl"
sample_data_file.parent.mkdir(parents=True, exist_ok=True)

# Write sample data
with open(sample_data_file, 'w') as f:
    for i, (text, label) in enumerate(zip(sample_texts, sample_labels)):
        json.dump({"text": text, "label": label}, f)
        f.write('\n')

# Create lazy dataset
lazy_dataset = LazyAGNewsDataset(sample_data_file)
print(f"\nLazy dataset created with {len(lazy_dataset)} samples")
print("Memory usage: Minimal (data not loaded)")

# Access samples on-demand
print("\nAccessing samples on-demand:")
for i in range(min(3, len(lazy_dataset))):
    sample = lazy_dataset[i]
    print(f"  Sample {i}: {sample['text'][:50]}...")

## 9. Data Loading Best Practices

In [None]:
def create_optimized_dataloader(
    dataset: Dataset,
    batch_size: int = 32,
    training: bool = True
) -> DataLoader:
    """
    Create an optimized DataLoader with best practices.
    
    Args:
        dataset: PyTorch dataset
        batch_size: Batch size
        training: Whether this is for training
        
    Returns:
        Optimized DataLoader
    """
    # Determine optimal number of workers
    num_workers = min(4, os.cpu_count() // 2) if os.cpu_count() else 0
    
    # Create DataLoader with optimizations
    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=training,  # Shuffle only for training
        num_workers=num_workers,
        pin_memory=torch.cuda.is_available(),  # Pin memory for GPU
        drop_last=training,  # Drop last incomplete batch for training
        persistent_workers=num_workers > 0,  # Keep workers alive
        prefetch_factor=2 if num_workers > 0 else 2,  # Prefetch batches
    )
    
    return loader

# Create optimized loaders
print("Creating Optimized DataLoaders")
print("=" * 50)

train_loader_opt = create_optimized_dataloader(train_subset, batch_size=32, training=True)
val_loader_opt = create_optimized_dataloader(val_subset, batch_size=64, training=False)

print("Training DataLoader:")
print(f"  Batch size: {train_loader_opt.batch_size}")
print(f"  Shuffle: {train_loader_opt.shuffle}")
print(f"  Workers: {train_loader_opt.num_workers}")
print(f"  Pin memory: {train_loader_opt.pin_memory}")
print(f"  Drop last: {train_loader_opt.drop_last}")

print("\nValidation DataLoader:")
print(f"  Batch size: {val_loader_opt.batch_size}")
print(f"  Shuffle: {val_loader_opt.shuffle}")
print(f"  Workers: {val_loader_opt.num_workers}")
print(f"  Pin memory: {val_loader_opt.pin_memory}")
print(f"  Drop last: {val_loader_opt.drop_last}")

# Best practices summary
print("\nDataLoader Best Practices:")
print("""
1. Use multiple workers for CPU-bound preprocessing
2. Pin memory when using GPU for faster transfers
3. Shuffle training data, but not validation/test data
4. Drop last incomplete batch during training for consistent batch sizes
5. Use persistent workers to avoid recreation overhead
6. Prefetch batches to hide I/O latency
7. Larger batch sizes for validation (no gradients needed)
8. Consider gradient accumulation for effective larger batches
""")

## 10. Summary and Next Steps

In [None]:
# Summary statistics
print("Data Loading Tutorial Summary")
print("=" * 50)

summary = {
    "Dataset": {
        "Total samples": len(train_dataset) if 'train_dataset' in locals() else 0,
        "Classes": AG_NEWS_NUM_CLASSES,
        "Class names": AG_NEWS_CLASSES
    },
    "Data Splits": {
        "Train samples": len(train_subset) if 'train_subset' in locals() else 0,
        "Validation samples": len(val_subset) if 'val_subset' in locals() else 0
    },
    "DataLoader Settings": {
        "Optimal batch size": best_config['batch_size'] if 'best_config' in locals() else 32,
        "Optimal workers": best_config['num_workers'] if 'best_config' in locals() else 2,
        "GPU available": torch.cuda.is_available()
    }
}

for category, items in summary.items():
    print(f"\n{category}:")
    for key, value in items.items():
        print(f"  {key}: {value}")

print("\nKey Takeaways:")
print("""
1. AG News has 4 balanced classes suitable for classification
2. PyTorch DataLoader provides efficient batch loading
3. Multiple workers can significantly speed up data loading
4. Pin memory improves GPU transfer speed
5. Lazy loading enables handling of large datasets
6. Proper configuration can improve training efficiency
""")

print("\nNext Steps:")
print("""
1. Explore data preprocessing techniques (02_preprocessing_tutorial.ipynb)
2. Learn about model training (03_model_training_basics.ipynb)
3. Understand evaluation metrics (04_evaluation_tutorial.ipynb)
4. Experiment with different batch sizes and configurations
5. Implement custom data augmentation in the dataset
""")