# Local Training Notebook for Custom LLM

In [None]:
import sys
sys.path.append('..')  # Add parent directory to path

import torch
from transformers import AutoTokenizer
from datasets import load_dataset

from src.model.transformer import CustomTransformer
from src.training.trainer import Trainer
from src.data.data_processor import DataProcessor

## 1. Load and Prepare Data

We'll use WikiText-2 dataset for testing, which is smaller than WikiText-103 used in the full training.

In [2]:
# Load WikiText-2 dataset
dataset = load_dataset('wikitext', 'wikitext-2-raw-v1')

# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained('gpt2')  # Using GPT-2 tokenizer

# Process training data
train_texts = dataset['train']['text']
val_texts = dataset['validation']['text']

# Initialize data processor
data_processor = DataProcessor(
    tokenizer=tokenizer,
    max_length=512,  # Shorter sequence length for testing
    batch_size=8  # Smaller batch size for local training
)

# Create dataloaders
train_dataloader, val_dataloader = data_processor.prepare_data(
    texts=train_texts,
    split_ratio=0.1
)

## 2. Initialize Model

We'll create a smaller version of the model for testing purposes.

In [None]:
# Model configuration
model = CustomTransformer(
    vocab_size=len(tokenizer),
    d_model=256,  # Smaller dimension
    n_heads=4,   # Fewer attention heads
    n_layers=4,  # Fewer layers
    d_ff=1024,   # Smaller feed-forward dimension
    dropout=0.1
)

print(f'Model Parameters: {sum(p.numel() for p in model.parameters())/1e6:.2f}M')
print(f'Device: {'cuda' if torch.cuda.is_available() else 'cpu'}')

## 3. Training Configuration

Set up the trainer with appropriate hyperparameters for local testing.

In [4]:
# Initialize trainer
trainer = Trainer(
    model=model,
    learning_rate=1e-4,
    warmup_steps=100,  # Fewer warmup steps for testing
    max_grad_norm=1.0,
    use_wandb=False  # Disable W&B for local testing
)

# Training parameters
EPOCHS = 2  # Fewer epochs for testing
SAVE_PATH = '../checkpoints/model_local.pt'  # Local checkpoint path

## 4. Training Loop

Run the training loop and monitor the results.

In [None]:
# Start training
history = trainer.train(
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    epochs=EPOCHS,
    save_path=SAVE_PATH,
    log_interval=10  # More frequent logging for debugging
)

## 5. Analyze Results

Plot training metrics to visualize the model's performance.

In [None]:
import matplotlib.pyplot as plt

# Plot training and validation loss
plt.figure(figsize=(10, 5))
plt.plot(history['train_loss'], label='Training Loss')
plt.plot(history['val_loss'], label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()

# Plot perplexity
plt.figure(figsize=(10, 5))
plt.plot(history['train_perplexity'], label='Training Perplexity')
plt.plot(history['val_perplexity'], label='Validation Perplexity')
plt.title('Training and Validation Perplexity')
plt.xlabel('Epoch')
plt.ylabel('Perplexity')
plt.legend()
plt.show()