# 20M Parameter Text Generation Model - Kaggle Training
## Resume Training from JSON Checkpoint

This notebook trains a transformer model on Kaggle, with support for resuming from JSON checkpoints.

### 1. Setup and Dependencies

In [None]:
# Install required packages
!pip install -q torch torchvision torchaudio
!pip install -q transformers datasets tokenizers accelerate
!pip install -q sentencepiece protobuf

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import GPT2Config, GPT2LMHeadModel, GPT2Tokenizer
from transformers import get_linear_schedule_with_warmup
from datasets import load_dataset
import json
import os
from tqdm.auto import tqdm
import gc

# Check GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

### 2. Configuration

In [None]:
# Training configuration
CONFIG = {
    'batch_size': 8,  # Reduced for Kaggle
    'learning_rate': 5e-4,
    'epochs': 3,
    'warmup_steps': 500,
    'gradient_accumulation_steps': 8,  # Increased to compensate for smaller batch
    'max_grad_norm': 1.0,
    'save_steps': 500,
    'eval_steps': 500,
    'max_length': 512,
    'resume_from_json': False,  # Set to True if resuming
    'json_checkpoint_path': '/kaggle/input/your-checkpoint/checkpoint.json'  # Update this
}

print("Configuration:")
for key, value in CONFIG.items():
    print(f"  {key}: {value}")

### 3. Model Configuration

In [None]:
# Model architecture (10M parameters)
model_config = GPT2Config(
    vocab_size=50257,
    n_positions=512,
    n_embd=256,
    n_layer=8,
    n_head=8,
    n_inner=1024,
    activation_function='gelu_new',
    resid_pdrop=0.1,
    embd_pdrop=0.1,
    attn_pdrop=0.1,
)

# Initialize model
model = GPT2LMHeadModel(model_config)
model = model.to(device)

total_params = sum(p.numel() for p in model.parameters())
print(f"\nTotal parameters: {total_params:,}")
print(f"Model size: {total_params * 4 / 1e6:.2f} MB (FP32)")

### 4. JSON Checkpoint Loader

In [None]:
def load_checkpoint_from_json(json_path, model, optimizer=None):
    """
    Load checkpoint from JSON format
    
    Args:
        json_path: Path to JSON checkpoint
        model: Model to load weights into
        optimizer: Optimizer to load state into (optional)
    
    Returns:
        metadata: Dictionary with epoch, step, loss info
    """
    print(f"Loading checkpoint from: {json_path}")
    
    with open(json_path, 'r') as f:
        checkpoint = json.load(f)
    
    # Load model state
    if 'model_state_dict' in checkpoint:
        print("Loading model weights...")
        model_state = {}
        for key, value in checkpoint['model_state_dict'].items():
            if isinstance(value, dict) and 'data' in value:
                model_state[key] = torch.tensor(value['data'])
            else:
                model_state[key] = value
        
        model.load_state_dict(model_state)
        print("✓ Model weights loaded")
    
    # Load optimizer state
    if optimizer is not None and 'optimizer_state_dict' in checkpoint:
        print("Loading optimizer state...")
        opt_state = checkpoint['optimizer_state_dict']
        
        # Reconstruct optimizer state
        optimizer_state = {
            'state': {},
            'param_groups': opt_state.get('param_groups', [])
        }
        
        for param_id, param_state in opt_state.get('state', {}).items():
            optimizer_state['state'][int(param_id)] = {}
            for key, value in param_state.items():
                if isinstance(value, dict) and 'data' in value:
                    optimizer_state['state'][int(param_id)][key] = torch.tensor(value['data'])
                else:
                    optimizer_state['state'][int(param_id)][key] = value
        
        optimizer.load_state_dict(optimizer_state)
        print("✓ Optimizer state loaded")
    
    # Extract metadata
    metadata = {
        'epoch': checkpoint.get('epoch', 0),
        'step': checkpoint.get('step', 0),
        'val_loss': checkpoint.get('val_loss', None),
        'train_loss': checkpoint.get('train_loss', None)
    }
    
    print(f"\nCheckpoint metadata:")
    for key, value in metadata.items():
        print(f"  {key}: {value}")
    
    return metadata


def save_checkpoint_to_json(filepath, model, optimizer, epoch, step, train_loss, val_loss=None):
    """
    Save checkpoint in JSON format
    
    Args:
        filepath: Path to save JSON file
        model: Model to save
        optimizer: Optimizer to save
        epoch: Current epoch
        step: Current step
        train_loss: Training loss
        val_loss: Validation loss (optional)
    """
    print(f"Saving checkpoint to: {filepath}")
    
    checkpoint = {
        'epoch': epoch,
        'step': step,
        'train_loss': train_loss,
        'val_loss': val_loss,
        'model_state_dict': {},
        'optimizer_state_dict': {
            'state': {},
            'param_groups': optimizer.state_dict()['param_groups']
        }
    }
    
    # Convert model state
    for key, tensor in model.state_dict().items():
        checkpoint['model_state_dict'][key] = {
            'data': tensor.cpu().numpy().tolist(),
            'shape': list(tensor.shape),
            'dtype': str(tensor.dtype)
        }
    
    # Convert optimizer state (simplified - only save essential parts)
    opt_state = optimizer.state_dict()['state']
    for param_id, param_state in opt_state.items():
        checkpoint['optimizer_state_dict']['state'][str(param_id)] = {}
        for key, value in param_state.items():
            if isinstance(value, torch.Tensor):
                checkpoint['optimizer_state_dict']['state'][str(param_id)][key] = {
                    'data': value.cpu().numpy().tolist(),
                    'shape': list(value.shape),
                    'dtype': str(value.dtype)
                }
    
    with open(filepath, 'w') as f:
        json.dump(checkpoint, f)
    
    print("✓ Checkpoint saved")

### 5. Data Preparation

In [None]:
# Load tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token

# Load dataset
print("Loading dataset...")
dataset = load_dataset('wikitext', 'wikitext-103-v1')

print(f"Train samples: {len(dataset['train'])}")
print(f"Validation samples: {len(dataset['validation'])}")

In [None]:
# Tokenization
def tokenize_function(examples):
    return tokenizer(
        examples['text'],
        truncation=True,
        max_length=CONFIG['max_length'],
        padding='max_length',
        return_tensors='pt'
    )

print("Tokenizing datasets...")
tokenized_train = dataset['train'].map(
    tokenize_function,
    batched=True,
    remove_columns=dataset['train'].column_names
)

tokenized_val = dataset['validation'].map(
    tokenize_function,
    batched=True,
    remove_columns=dataset['validation'].column_names
)

tokenized_train.set_format('torch')
tokenized_val.set_format('torch')

print("✓ Tokenization complete")

In [None]:
# Create dataloaders
train_loader = DataLoader(
    tokenized_train,
    batch_size=CONFIG['batch_size'],
    shuffle=True
)

val_loader = DataLoader(
    tokenized_val,
    batch_size=CONFIG['batch_size']
)

print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")

### 6. Training Setup

In [None]:
# Optimizer and scheduler
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=CONFIG['learning_rate'],
    weight_decay=0.01
)

total_steps = len(train_loader) * CONFIG['epochs'] // CONFIG['gradient_accumulation_steps']
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=CONFIG['warmup_steps'],
    num_training_steps=total_steps
)

print(f"Total training steps: {total_steps}")

# Resume from checkpoint if specified
start_epoch = 1
global_step = 0

if CONFIG['resume_from_json'] and os.path.exists(CONFIG['json_checkpoint_path']):
    metadata = load_checkpoint_from_json(
        CONFIG['json_checkpoint_path'],
        model,
        optimizer
    )
    start_epoch = metadata['epoch'] + 1
    global_step = metadata['step']
    print(f"\n✓ Resuming from epoch {start_epoch}, step {global_step}")
else:
    print("\nStarting training from scratch")

### 7. Training Functions

In [None]:
def train_epoch(model, loader, optimizer, scheduler, device, epoch, start_step=0):
    model.train()
    total_loss = 0
    optimizer.zero_grad()
    
    progress_bar = tqdm(loader, desc=f"Epoch {epoch}")
    
    for step, batch in enumerate(progress_bar):
        if step < start_step:
            continue
        
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=input_ids
        )
        
        loss = outputs.loss / CONFIG['gradient_accumulation_steps']
        loss.backward()
        
        if (step + 1) % CONFIG['gradient_accumulation_steps'] == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), CONFIG['max_grad_norm'])
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
        
        total_loss += loss.item() * CONFIG['gradient_accumulation_steps']
        progress_bar.set_postfix({
            'loss': loss.item() * CONFIG['gradient_accumulation_steps'],
            'lr': scheduler.get_last_lr()[0]
        })
        
        # Save checkpoint periodically
        if (step + 1) % CONFIG['save_steps'] == 0:
            checkpoint_path = f'/kaggle/working/checkpoint_epoch{epoch}_step{step+1}.json'
            save_checkpoint_to_json(
                checkpoint_path,
                model,
                optimizer,
                epoch,
                step + 1,
                total_loss / (step + 1)
            )
    
    return total_loss / len(loader)


def evaluate(model, loader, device):
    model.eval()
    total_loss = 0
    
    with torch.no_grad():
        for batch in tqdm(loader, desc="Evaluating"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=input_ids
            )
            total_loss += outputs.loss.item()
    
    avg_loss = total_loss / len(loader)
    perplexity = torch.exp(torch.tensor(avg_loss))
    return avg_loss, perplexity.item()

### 8. Training Loop

In [None]:
# Training loop
best_val_loss = float('inf')
training_history = []

for epoch in range(start_epoch, CONFIG['epochs'] + 1):
    print(f"\n{'='*60}")
    print(f"Epoch {epoch}/{CONFIG['epochs']}")
    print(f"{'='*60}")
    
    # Train
    train_loss = train_epoch(
        model,
        train_loader,
        optimizer,
        scheduler,
        device,
        epoch
    )
    
    # Evaluate
    val_loss, val_perplexity = evaluate(model, val_loader, device)
    
    print(f"\nTrain Loss: {train_loss:.4f}")
    print(f"Val Loss: {val_loss:.4f}")
    print(f"Val Perplexity: {val_perplexity:.2f}")
    
    training_history.append({
        'epoch': epoch,
        'train_loss': train_loss,
        'val_loss': val_loss,
        'val_perplexity': val_perplexity
    })
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        save_checkpoint_to_json(
            '/kaggle/working/best_model.json',
            model,
            optimizer,
            epoch,
            len(train_loader),
            train_loss,
            val_loss
        )
        print("✓ Saved best model")
    
    # Save epoch checkpoint
    save_checkpoint_to_json(
        f'/kaggle/working/checkpoint_epoch{epoch}.json',
        model,
        optimizer,
        epoch,
        len(train_loader),
        train_loss,
        val_loss
    )
    
    # Clear cache
    torch.cuda.empty_cache()
    gc.collect()

print("\n" + "="*60)
print("Training completed!")
print(f"Best validation loss: {best_val_loss:.4f}")

### 9. Save Training History

In [None]:
# Save training history
with open('/kaggle/working/training_history.json', 'w') as f:
    json.dump(training_history, f, indent=2)

print("Training history saved!")
print("\nFinal Results:")
for entry in training_history:
    print(f"Epoch {entry['epoch']}: Train Loss={entry['train_loss']:.4f}, "
          f"Val Loss={entry['val_loss']:.4f}, Perplexity={entry['val_perplexity']:.2f}")

### 10. Text Generation Test

In [None]:
def generate_text(prompt, max_length=100, temperature=0.8):
    model.eval()
    input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
    
    with torch.no_grad():
        output = model.generate(
            input_ids,
            max_length=max_length,
            temperature=temperature,
            top_k=50,
            top_p=0.95,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id
        )
    
    return tokenizer.decode(output[0], skip_special_tokens=True)

# Test generation
test_prompts = [
    "The future of artificial intelligence",
    "In a world where technology",
    "Scientists have discovered"
]

print("\n" + "="*60)
print("Text Generation Examples")
print("="*60)

for prompt in test_prompts:
    print(f"\nPrompt: {prompt}")
    print("-" * 60)
    generated = generate_text(prompt, max_length=150)
    print(generated)
    print()

### 11. Save Final Model

In [None]:
# Save model in HuggingFace format
model.save_pretrained('/kaggle/working/final_model')
tokenizer.save_pretrained('/kaggle/working/final_model')

print("✓ Model saved in HuggingFace format")
print("\nOutput files:")
print("  - best_model.json (best checkpoint)")
print("  - checkpoint_epoch*.json (epoch checkpoints)")
print("  - training_history.json (training metrics)")
print("  - final_model/ (HuggingFace format)")