# 20M Parameter Text Generation Model - Kaggle Training (Clean)
## Train from Scratch - No Checkpoint Complications

This notebook trains a transformer model on Kaggle from scratch with simplified checkpointing.

### 1. Setup and Dependencies

In [None]:
# Install required packages (Kaggle already has most packages pre-installed)
# Only install what's missing or needs updating
!pip install -q --upgrade transformers datasets tokenizers
!pip install -q --no-deps sentencepiece

print("✓ Packages installed/updated")

In [None]:
# Suppress warnings
import warnings
warnings.filterwarnings('ignore')
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import GPT2Config, GPT2LMHeadModel, GPT2Tokenizer
from transformers import get_linear_schedule_with_warmup
from datasets import load_dataset
import json
from tqdm.auto import tqdm
import gc

print("✓ Imports successful")

# Check GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"\nUsing device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
    print("⚠️ WARNING: No GPU detected! Training will be very slow.")

### 2. Configuration

In [None]:
# Training configuration
CONFIG = {
    'batch_size': 8,
    'learning_rate': 5e-4,
    'epochs': 3,
    'warmup_steps': 500,
    'gradient_accumulation_steps': 8,
    'max_grad_norm': 1.0,
    'save_steps': 1000,
    'eval_steps': 500,
    'max_length': 512,
}

print("Configuration:")
for key, value in CONFIG.items():
    print(f"  {key}: {value}")

### 3. Model Configuration

In [None]:
# Model architecture (10M parameters)
model_config = GPT2Config(
    vocab_size=50257,
    n_positions=512,
    n_embd=256,
    n_layer=8,
    n_head=8,
    n_inner=1024,
    activation_function='gelu_new',
    resid_pdrop=0.1,
    embd_pdrop=0.1,
    attn_pdrop=0.1,
)

# Initialize model from scratch
model = GPT2LMHeadModel(model_config)
model = model.to(device)

total_params = sum(p.numel() for p in model.parameters())
print(f"\nTotal parameters: {total_params:,}")
print(f"Model size: {total_params * 4 / 1e6:.2f} MB (FP32)")
print("\n✓ Starting fresh training from scratch")

### 4. Data Preparation

In [None]:
# Load tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token

# Load dataset
print("Loading dataset...")
dataset = load_dataset('wikitext', 'wikitext-103-v1')

print(f"Train samples: {len(dataset['train'])}")
print(f"Validation samples: {len(dataset['validation'])}")

In [None]:
# Tokenization
def tokenize_function(examples):
    return tokenizer(
        examples['text'],
        truncation=True,
        max_length=CONFIG['max_length'],
        padding='max_length',
        return_tensors='pt'
    )

print("Tokenizing datasets...")
tokenized_train = dataset['train'].map(
    tokenize_function,
    batched=True,
    remove_columns=dataset['train'].column_names
)

tokenized_val = dataset['validation'].map(
    tokenize_function,
    batched=True,
    remove_columns=dataset['validation'].column_names
)

tokenized_train.set_format('torch')
tokenized_val.set_format('torch')

print("✓ Tokenization complete")

In [None]:
# Create dataloaders
train_loader = DataLoader(
    tokenized_train,
    batch_size=CONFIG['batch_size'],
    shuffle=True
)

val_loader = DataLoader(
    tokenized_val,
    batch_size=CONFIG['batch_size']
)

print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")

### 5. Training Setup

In [None]:
# Optimizer and scheduler
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=CONFIG['learning_rate'],
    weight_decay=0.01
)

total_steps = len(train_loader) * CONFIG['epochs'] // CONFIG['gradient_accumulation_steps']
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=CONFIG['warmup_steps'],
    num_training_steps=total_steps
)

print(f"Total training steps: {total_steps}")
print(f"Warmup steps: {CONFIG['warmup_steps']}")

### 6. Training Functions

In [None]:
def train_epoch(model, loader, optimizer, scheduler, device, epoch):
    model.train()
    total_loss = 0
    optimizer.zero_grad()
    
    progress_bar = tqdm(loader, desc=f"Epoch {epoch}")
    
    for step, batch in enumerate(progress_bar):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=input_ids
        )
        
        loss = outputs.loss / CONFIG['gradient_accumulation_steps']
        loss.backward()
        
        if (step + 1) % CONFIG['gradient_accumulation_steps'] == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), CONFIG['max_grad_norm'])
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
        
        total_loss += loss.item() * CONFIG['gradient_accumulation_steps']
        progress_bar.set_postfix({
            'loss': loss.item() * CONFIG['gradient_accumulation_steps'],
            'lr': scheduler.get_last_lr()[0]
        })
        
        # Save checkpoint periodically (PyTorch format only)
        if (step + 1) % CONFIG['save_steps'] == 0:
            checkpoint_path = f'/kaggle/working/checkpoint_epoch{epoch}_step{step+1}.pt'
            torch.save({
                'epoch': epoch,
                'step': step + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'train_loss': total_loss / (step + 1),
            }, checkpoint_path)
            print(f"\n✓ Checkpoint saved: {checkpoint_path}")
    
    return total_loss / len(loader)


def evaluate(model, loader, device):
    model.eval()
    total_loss = 0
    
    with torch.no_grad():
        for batch in tqdm(loader, desc="Evaluating"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=input_ids
            )
            total_loss += outputs.loss.item()
    
    avg_loss = total_loss / len(loader)
    perplexity = torch.exp(torch.tensor(avg_loss))
    return avg_loss, perplexity.item()

### 7. Training Loop

In [None]:
# Training loop
best_val_loss = float('inf')
training_history = []

for epoch in range(1, CONFIG['epochs'] + 1):
    print(f"\n{'='*60}")
    print(f"Epoch {epoch}/{CONFIG['epochs']}")
    print(f"{'='*60}")
    
    # Train
    train_loss = train_epoch(
        model,
        train_loader,
        optimizer,
        scheduler,
        device,
        epoch
    )
    
    # Evaluate
    val_loss, val_perplexity = evaluate(model, val_loader, device)
    
    print(f"\nTrain Loss: {train_loss:.4f}")
    print(f"Val Loss: {val_loss:.4f}")
    print(f"Val Perplexity: {val_perplexity:.2f}")
    
    training_history.append({
        'epoch': epoch,
        'train_loss': train_loss,
        'val_loss': val_loss,
        'val_perplexity': val_perplexity
    })
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'train_loss': train_loss,
            'val_loss': val_loss,
            'val_perplexity': val_perplexity,
        }, '/kaggle/working/best_model.pt')
        print("✓ Saved best model")
    
    # Save epoch checkpoint
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'train_loss': train_loss,
        'val_loss': val_loss,
        'val_perplexity': val_perplexity,
    }, f'/kaggle/working/checkpoint_epoch{epoch}.pt')
    
    # Clear cache
    torch.cuda.empty_cache()
    gc.collect()

print("\n" + "="*60)
print("Training completed!")
print(f"Best validation loss: {best_val_loss:.4f}")

### 8. Save Training History

In [None]:
# Save training history
with open('/kaggle/working/training_history.json', 'w') as f:
    json.dump(training_history, f, indent=2)

print("Training history saved!")
print("\nFinal Results:")
for entry in training_history:
    print(f"Epoch {entry['epoch']}: Train Loss={entry['train_loss']:.4f}, "
          f"Val Loss={entry['val_loss']:.4f}, Perplexity={entry['val_perplexity']:.2f}")

### 9. Text Generation Test

In [None]:
def generate_text(prompt, max_length=100, temperature=0.8):
    model.eval()
    input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
    
    with torch.no_grad():
        output = model.generate(
            input_ids,
            max_length=max_length,
            temperature=temperature,
            top_k=50,
            top_p=0.95,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id
        )
    
    return tokenizer.decode(output[0], skip_special_tokens=True)

# Test generation
test_prompts = [
    "The future of artificial intelligence",
    "In a world where technology",
    "Scientists have discovered"
]

print("\n" + "="*60)
print("Text Generation Examples")
print("="*60)

for prompt in test_prompts:
    print(f"\nPrompt: {prompt}")
    print("-" * 60)
    generated = generate_text(prompt, max_length=150)
    print(generated)
    print()

### 10. Save Final Model

In [None]:
# Save model in HuggingFace format
model.save_pretrained('/kaggle/working/final_model')
tokenizer.save_pretrained('/kaggle/working/final_model')

print("✓ Model saved in HuggingFace format")
print("\nOutput files:")
print("  - best_model.pt (best checkpoint)")
print("  - checkpoint_epoch*.pt (epoch checkpoints)")
print("  - training_history.json (training metrics)")
print("  - final_model/ (HuggingFace format)")