# Text Generation Model - Multi-GPU with Accurate Resume
## 10M Parameter GPT-2 with 2xT4 GPU + Perfect Resume Support

**Features:**
- ‚ö° Multi-GPU training (2xT4 on Kaggle)
- üîÑ **ACCURATE RESUME** - Continue from exact step (e.g., step 20,000)
- üíæ Automatic checkpoint management (keeps 4 most recent)
- üõ°Ô∏è Fixed protobuf warnings
- üìä Global step tracking across epochs
- üéØ No training loss - picks up exactly where it left off

## 1. Environment Setup & Fix Warnings

In [None]:
# Fix protobuf warnings FIRST
import os
import sys

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'

import warnings
warnings.filterwarnings('ignore')

print(f"Python version: {sys.version}")
print(f"Running on Kaggle: {'/kaggle/working' in sys.path or 'KAGGLE_KERNEL_RUN_TYPE' in os.environ}")

In [None]:
# Fix protobuf version
!pip uninstall -y protobuf 2>/dev/null
!pip install -q protobuf==3.20.3
print("‚úì Protobuf fixed")

## 2. Import Dependencies

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.nn.parallel import DataParallel
from transformers import GPT2Config, GPT2LMHeadModel, GPT2Tokenizer
from transformers import get_linear_schedule_with_warmup
from datasets import load_dataset
import json
import glob
from pathlib import Path
from datetime import datetime
from tqdm.auto import tqdm
import gc
import math

print("‚úì All imports successful")
print(f"PyTorch: {torch.__version__}")

## 3. Multi-GPU Detection

In [None]:
# Detect GPUs
if torch.cuda.is_available():
    n_gpus = torch.cuda.device_count()
    print(f"\n{'='*60}")
    print(f"GPU CONFIGURATION")
    print(f"{'='*60}")
    print(f"GPUs available: {n_gpus}")
    
    for i in range(n_gpus):
        print(f"\nGPU {i}: {torch.cuda.get_device_name(i)}")
        print(f"  Memory: {torch.cuda.get_device_properties(i).total_memory / 1e9:.2f} GB")
    
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    torch.backends.cudnn.benchmark = True
    
    device = torch.device('cuda:0')
    use_multi_gpu = n_gpus > 1
    
    if use_multi_gpu:
        print(f"\n‚ö° MULTI-GPU MODE: Using {n_gpus} GPUs")
    print(f"{'='*60}")
else:
    print("‚ùå No GPU! Enable GPU in settings.")
    device = torch.device('cpu')
    use_multi_gpu = False
    n_gpus = 0

## 4. Configuration

In [None]:
CONFIG = {
    # Model
    'vocab_size': 50257,
    'n_positions': 512,
    'n_embd': 256,
    'n_layer': 8,
    'n_head': 8,
    'n_inner': 1024,
    
    # Training (optimized for multi-GPU)
    'batch_size': 16 if use_multi_gpu else 8,
    'gradient_accumulation_steps': 4 if use_multi_gpu else 8,
    'learning_rate': 5e-4,
    'weight_decay': 0.01,
    'max_grad_norm': 1.0,
    'epochs': 3,
    'warmup_steps': 500,
    'max_length': 512,
    
    # Checkpointing
    'save_steps': 1000,
    'eval_steps': 500,
    'max_checkpoints': 4,
    'checkpoint_dir': '/kaggle/working/checkpoints',
    
    # Dataset
    'dataset_name': 'wikitext',
    'dataset_config': 'wikitext-103-v1',
    
    # Multi-GPU
    'use_multi_gpu': use_multi_gpu,
    'n_gpus': n_gpus,
    
    # ‚≠ê RESUME SETTINGS
    'resume_from_checkpoint': None,  # Set to checkpoint path to resume
    # Example: '/kaggle/input/my-checkpoint/checkpoint_epoch2_step20000.pt'
}

os.makedirs(CONFIG['checkpoint_dir'], exist_ok=True)

effective_batch = CONFIG['batch_size'] * CONFIG['gradient_accumulation_steps']
if use_multi_gpu:
    effective_batch *= n_gpus

print("\n" + "="*60)
print("CONFIGURATION")
print("="*60)
print(f"GPUs: {n_gpus}")
print(f"Per-GPU Batch: {CONFIG['batch_size']}")
print(f"Gradient Accumulation: {CONFIG['gradient_accumulation_steps']}")
print(f"Effective Batch: {effective_batch}")
print(f"Epochs: {CONFIG['epochs']}")
print(f"Save Every: {CONFIG['save_steps']} steps")
print(f"Keep: {CONFIG['max_checkpoints']} checkpoints")
print("="*60)

## 5. Enhanced Checkpoint Management with Accurate Resume

In [None]:
def get_checkpoint_list(checkpoint_dir):
    """Get sorted list of checkpoints"""
    checkpoints = glob.glob(os.path.join(checkpoint_dir, 'checkpoint_*.pt'))
    checkpoints.sort(key=os.path.getmtime, reverse=True)
    return checkpoints

def cleanup_old_checkpoints(checkpoint_dir, max_keep=4):
    """Keep only N most recent checkpoints"""
    checkpoints = get_checkpoint_list(checkpoint_dir)
    if len(checkpoints) > max_keep:
        for ckpt in checkpoints[max_keep:]:
            try:
                os.remove(ckpt)
                print(f"  Deleted: {os.path.basename(ckpt)}")
            except Exception as e:
                print(f"  Warning: {e}")

def save_checkpoint(filepath, model, optimizer, scheduler, epoch, step, global_step, train_loss, val_loss=None, config=None):
    """Save checkpoint with ACCURATE resume information"""
    try:
        # Unwrap DataParallel
        if isinstance(model, nn.DataParallel):
            model_state = model.module.state_dict()
        else:
            model_state = model.state_dict()
        
        checkpoint = {
            # ‚≠ê CRITICAL: Track both epoch step and global step
            'epoch': epoch,
            'step': step,  # Step within current epoch
            'global_step': global_step,  # Total steps across ALL epochs
            
            # Model and optimizer states
            'model_state_dict': model_state,
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            
            # Losses
            'train_loss': train_loss,
            'val_loss': val_loss,
            
            # Metadata
            'config': config,
            'timestamp': datetime.now().isoformat(),
            'pytorch_version': torch.__version__,
        }
        
        torch.save(checkpoint, filepath)
        return True
    except Exception as e:
        print(f"‚ùå Error saving: {e}")
        return False

def load_checkpoint(filepath, model, optimizer=None, scheduler=None):
    """Load checkpoint with ACCURATE resume"""
    try:
        print(f"\nLoading checkpoint: {filepath}")
        checkpoint = torch.load(filepath, map_location='cpu')
        
        # Load model (handle DataParallel)
        if isinstance(model, nn.DataParallel):
            model.module.load_state_dict(checkpoint['model_state_dict'])
        else:
            model.load_state_dict(checkpoint['model_state_dict'])
        
        # Load optimizer
        if optimizer and 'optimizer_state_dict' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            print("‚úì Optimizer state loaded")
        
        # Load scheduler
        if scheduler and 'scheduler_state_dict' in checkpoint:
            scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
            print("‚úì Scheduler state loaded")
        
        # Return resume information
        metadata = {
            'epoch': checkpoint.get('epoch', 0),
            'step': checkpoint.get('step', 0),
            'global_step': checkpoint.get('global_step', 0),  # ‚≠ê CRITICAL
            'train_loss': checkpoint.get('train_loss', None),
            'val_loss': checkpoint.get('val_loss', None),
        }
        
        print(f"\n‚úì Checkpoint loaded successfully")
        print(f"  Epoch: {metadata['epoch']}")
        print(f"  Step in epoch: {metadata['step']}")
        print(f"  Global step: {metadata['global_step']}")
        print(f"  Train loss: {metadata['train_loss']:.4f}" if metadata['train_loss'] else "")
        print(f"  Val loss: {metadata['val_loss']:.4f}" if metadata['val_loss'] else "")
        
        return metadata
    except Exception as e:
        print(f"‚ùå Error loading: {e}")
        import traceback
        traceback.print_exc()
        return None

print("‚úì Enhanced checkpoint functions defined")
print("  - Tracks global_step for accurate resume")
print("  - Can resume from exact step (e.g., 20,000)")
print("  - No training loss when resuming")

## 6. Model Initialization

In [None]:
# Create model
model_config = GPT2Config(
    vocab_size=CONFIG['vocab_size'],
    n_positions=CONFIG['n_positions'],
    n_embd=CONFIG['n_embd'],
    n_layer=CONFIG['n_layer'],
    n_head=CONFIG['n_head'],
    n_inner=CONFIG['n_inner'],
    activation_function='gelu_new',
    resid_pdrop=0.1,
    embd_pdrop=0.1,
    attn_pdrop=0.1,
)

print("\nInitializing model...")
model = GPT2LMHeadModel(model_config)
model = model.to(device)

if use_multi_gpu:
    print(f"‚ö° Wrapping with DataParallel for {n_gpus} GPUs")
    model = nn.DataParallel(model, device_ids=list(range(n_gpus)))

total_params = sum(p.numel() for p in (model.module if isinstance(model, nn.DataParallel) else model).parameters())
print(f"‚úì Model ready: {total_params:,} parameters")

## 7. Data Loading

In [None]:
# Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token

# Dataset
print("\nLoading dataset...")
dataset = load_dataset(CONFIG['dataset_name'], CONFIG['dataset_config'])
print(f"‚úì Train: {len(dataset['train']):,} samples")
print(f"‚úì Val: {len(dataset['validation']):,} samples")

In [None]:
# Tokenize
def tokenize_function(examples):
    return tokenizer(
        examples['text'],
        truncation=True,
        max_length=CONFIG['max_length'],
        padding='max_length'
    )

print("Tokenizing...")
tokenized_train = dataset['train'].map(
    tokenize_function,
    batched=True,
    remove_columns=dataset['train'].column_names
)
tokenized_val = dataset['validation'].map(
    tokenize_function,
    batched=True,
    remove_columns=dataset['validation'].column_names
)
tokenized_train.set_format('torch')
tokenized_val.set_format('torch')
print("‚úì Tokenization complete")

In [None]:
# DataLoaders
train_loader = DataLoader(
    tokenized_train,
    batch_size=CONFIG['batch_size'],
    shuffle=True,
    num_workers=2,
    pin_memory=True
)
val_loader = DataLoader(
    tokenized_val,
    batch_size=CONFIG['batch_size'],
    num_workers=2,
    pin_memory=True
)
print(f"‚úì DataLoaders: {len(train_loader):,} train batches")

## 8. Optimizer and Scheduler

In [None]:
# Optimizer
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=CONFIG['learning_rate'],
    weight_decay=CONFIG['weight_decay']
)

# Scheduler
total_steps = (len(train_loader) * CONFIG['epochs']) // CONFIG['gradient_accumulation_steps']
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=CONFIG['warmup_steps'],
    num_training_steps=total_steps
)

print(f"\n‚úì Optimizer and scheduler ready")
print(f"  Total training steps: {total_steps:,}")
print(f"  Warmup steps: {CONFIG['warmup_steps']:,}")

## 9. Resume from Checkpoint (If Specified)

In [None]:
# Initialize training state
start_epoch = 1
start_step = 0
global_step = 0  # ‚≠ê CRITICAL: Tracks total steps across all epochs
best_val_loss = float('inf')

# Resume if checkpoint specified
if CONFIG['resume_from_checkpoint'] and os.path.exists(CONFIG['resume_from_checkpoint']):
    print("\n" + "="*60)
    print("RESUMING FROM CHECKPOINT")
    print("="*60)
    
    metadata = load_checkpoint(
        CONFIG['resume_from_checkpoint'],
        model,
        optimizer,
        scheduler
    )
    
    if metadata:
        start_epoch = metadata['epoch']
        start_step = metadata['step']
        global_step = metadata['global_step']  # ‚≠ê Resume from exact global step
        
        if metadata['val_loss']:
            best_val_loss = metadata['val_loss']
        
        print(f"\n‚ö° WILL RESUME FROM:")
        print(f"  Epoch: {start_epoch}")
        print(f"  Step in epoch: {start_step}")
        print(f"  Global step: {global_step}")
        print(f"  (Will skip first {start_step} steps of epoch {start_epoch})")
        print("="*60)
    else:
        print("\n‚ö†Ô∏è Failed to load checkpoint, starting from scratch")
else:
    print("\n‚úì Starting training from scratch (no checkpoint specified)")

## 10. Training Functions with Accurate Resume

In [None]:
def train_epoch(model, loader, optimizer, scheduler, device, epoch, config, start_step=0, global_step=0):
    """Train one epoch with ACCURATE resume from start_step"""
    model.train()
    total_loss = 0
    optimizer.zero_grad()
    
    progress_bar = tqdm(loader, desc=f"Epoch {epoch}")
    
    for step, batch in enumerate(progress_bar):
        # ‚≠ê SKIP steps if resuming
        if step < start_step:
            continue
        
        try:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=input_ids
            )
            
            loss = outputs.loss / config['gradient_accumulation_steps']
            loss.backward()
            
            if (step + 1) % config['gradient_accumulation_steps'] == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), config['max_grad_norm'])
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
                global_step += 1  # ‚≠ê Increment global step
            
            total_loss += loss.item() * config['gradient_accumulation_steps']
            
            progress_bar.set_postfix({
                'loss': f"{loss.item() * config['gradient_accumulation_steps']:.4f}",
                'global_step': global_step,  # ‚≠ê Show global step
                'lr': f"{scheduler.get_last_lr()[0]:.2e}"
            })
            
            # Save checkpoint
            if (step + 1) % config['save_steps'] == 0:
                checkpoint_path = os.path.join(
                    config['checkpoint_dir'],
                    f"checkpoint_epoch{epoch}_step{step+1}_global{global_step}.pt"
                )
                
                if save_checkpoint(
                    checkpoint_path,
                    model,
                    optimizer,
                    scheduler,
                    epoch,
                    step + 1,
                    global_step,  # ‚≠ê Save global step
                    total_loss / (step + 1 - start_step),
                    config=config
                ):
                    print(f"\n‚úì Saved: {os.path.basename(checkpoint_path)}")
                    print(f"  Global step: {global_step}")
                    cleanup_old_checkpoints(config['checkpoint_dir'], config['max_checkpoints'])
        
        except RuntimeError as e:
            if 'out of memory' in str(e):
                print(f"\n‚ö†Ô∏è OOM at step {step}")
                torch.cuda.empty_cache()
                gc.collect()
                continue
            else:
                raise e
    
    return total_loss / (len(loader) - start_step), global_step

def evaluate(model, loader, device):
    """Evaluate model"""
    model.eval()
    total_loss = 0
    
    with torch.no_grad():
        for batch in tqdm(loader, desc="Evaluating"):
            try:
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                
                outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=input_ids
                )
                total_loss += outputs.loss.item()
            except RuntimeError as e:
                if 'out of memory' in str(e):
                    torch.cuda.empty_cache()
                    gc.collect()
                    continue
                else:
                    raise e
    
    avg_loss = total_loss / len(loader)
    perplexity = math.exp(avg_loss) if avg_loss < 100 else float('inf')
    return avg_loss, perplexity

print("‚úì Training functions ready with accurate resume support")

## 11. Main Training Loop with Resume

In [None]:
# Training history
training_history = []

print("\n" + "="*60)
print("STARTING TRAINING")
print("="*60)
if start_step > 0:
    print(f"‚ö° RESUMING from epoch {start_epoch}, step {start_step}")
    print(f"‚ö° Global step: {global_step}")
else:
    print(f"Starting fresh from epoch 1")
if use_multi_gpu:
    print(f"‚ö° Using {n_gpus} GPUs")
print("="*60 + "\n")

try:
    for epoch in range(start_epoch, CONFIG['epochs'] + 1):
        print(f"\n{'='*60}")
        print(f"EPOCH {epoch}/{CONFIG['epochs']}")
        if epoch == start_epoch and start_step > 0:
            print(f"(Resuming from step {start_step})")
        print(f"{'='*60}")
        
        # Determine if we need to skip steps (only for first resumed epoch)
        skip_steps = start_step if epoch == start_epoch else 0
        
        # Train
        train_loss, global_step = train_epoch(
            model,
            train_loader,
            optimizer,
            scheduler,
            device,
            epoch,
            CONFIG,
            start_step=skip_steps,
            global_step=global_step
        )
        
        # Evaluate
        val_loss, val_perplexity = evaluate(model, val_loader, device)
        
        # Print results
        print(f"\n{'='*60}")
        print(f"Epoch {epoch} Results:")
        print(f"  Train Loss: {train_loss:.4f}")
        print(f"  Val Loss: {val_loss:.4f}")
        print(f"  Perplexity: {val_perplexity:.2f}")
        print(f"  Global Step: {global_step}")
        print(f"{'='*60}")
        
        # Save history
        training_history.append({
            'epoch': epoch,
            'global_step': global_step,
            'train_loss': float(train_loss),
            'val_loss': float(val_loss),
            'val_perplexity': float(val_perplexity),
            'timestamp': datetime.now().isoformat()
        })
        
        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_path = '/kaggle/working/best_model.pt'
            
            if save_checkpoint(
                best_path,
                model,
                optimizer,
                scheduler,
                epoch,
                len(train_loader),
                global_step,
                train_loss,
                val_loss,
                CONFIG
            ):
                print(f"\n‚úì New best model! (val_loss: {val_loss:.4f})")
        
        # Save epoch checkpoint
        epoch_path = os.path.join(
            CONFIG['checkpoint_dir'],
            f"checkpoint_epoch{epoch}_final_global{global_step}.pt"
        )
        
        save_checkpoint(
            epoch_path,
            model,
            optimizer,
            scheduler,
            epoch,
            len(train_loader),
            global_step,
            train_loss,
            val_loss,
            CONFIG
        )
        print(f"‚úì Epoch {epoch} checkpoint saved")
        
        # Cleanup
        cleanup_old_checkpoints(CONFIG['checkpoint_dir'], CONFIG['max_checkpoints'])
        torch.cuda.empty_cache()
        gc.collect()
        
        # Reset start_step after first epoch
        start_step = 0
    
    print("\n" + "="*60)
    print("‚úì TRAINING COMPLETED!")
    print(f"  Best val loss: {best_val_loss:.4f}")
    print(f"  Total global steps: {global_step}")
    if use_multi_gpu:
        print(f"  Trained on {n_gpus} GPUs")
    print("="*60)

except KeyboardInterrupt:
    print("\n‚ö†Ô∏è Interrupted! Saving emergency checkpoint...")
    emergency_path = '/kaggle/working/emergency_checkpoint.pt'
    save_checkpoint(
        emergency_path,
        model,
        optimizer,
        scheduler,
        epoch,
        step,
        global_step,
        train_loss,
        config=CONFIG
    )
    print(f"‚úì Emergency checkpoint saved")

except Exception as e:
    print(f"\n‚ùå Training failed: {e}")
    import traceback
    traceback.print_exc()
    raise

## 12. Save Training History

In [None]:
# Save history
history_path = '/kaggle/working/training_history.json'
with open(history_path, 'w') as f:
    json.dump(training_history, f, indent=2)

print("\n" + "="*60)
print("TRAINING SUMMARY")
print("="*60)
for entry in training_history:
    print(f"Epoch {entry['epoch']} (step {entry['global_step']}): "
          f"Train={entry['train_loss']:.4f}, "
          f"Val={entry['val_loss']:.4f}, "
          f"PPL={entry['val_perplexity']:.2f}")
print("="*60)
print(f"\n‚úì History saved to {history_path}")

## 13. Text Generation Test

In [None]:
def generate_text(prompt, max_length=100, temperature=0.8):
    """Generate text"""
    if isinstance(model, nn.DataParallel):
        gen_model = model.module
    else:
        gen_model = model
    
    gen_model.eval()
    input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
    
    with torch.no_grad():
        output = gen_model.generate(
            input_ids,
            max_length=max_length,
            temperature=temperature,
            top_k=50,
            top_p=0.95,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id
        )
    
    return tokenizer.decode(output[0], skip_special_tokens=True)

# Test
test_prompts = [
    "The future of artificial intelligence",
    "In a world where technology",
    "Scientists have discovered"
]

print("\n" + "="*60)
print("TEXT GENERATION EXAMPLES")
print("="*60)

for prompt in test_prompts:
    print(f"\nPrompt: '{prompt}'")
    print("‚îÄ" * 60)
    try:
        generated = generate_text(prompt, max_length=150)
        print(generated)
    except Exception as e:
        print(f"Error: {e}")

## 14. Save Final Model

In [None]:
# Save in HuggingFace format
final_dir = '/kaggle/working/final_model'
print(f"\nSaving final model...")

try:
    if isinstance(model, nn.DataParallel):
        save_model = model.module
    else:
        save_model = model
    
    save_model.save_pretrained(final_dir)
    tokenizer.save_pretrained(final_dir)
    
    with open(os.path.join(final_dir, 'training_config.json'), 'w') as f:
        json.dump(CONFIG, f, indent=2)
    
    print("‚úì Model saved")
except Exception as e:
    print(f"Error: {e}")

## 15. Output Summary

In [None]:
# Summary
print("\n" + "="*60)
print("OUTPUT FILES")
print("="*60)
print("\nüìÅ /kaggle/working/")
print("  ‚îú‚îÄ‚îÄ best_model.pt (best checkpoint)")
print("  ‚îú‚îÄ‚îÄ training_history.json")
print("  ‚îú‚îÄ‚îÄ checkpoints/")

checkpoints = get_checkpoint_list(CONFIG['checkpoint_dir'])
if checkpoints:
    for ckpt in checkpoints:
        size = os.path.getsize(ckpt) / 1e6
        print(f"  ‚îÇ   ‚îú‚îÄ‚îÄ {os.path.basename(ckpt)} ({size:.1f} MB)")

print("  ‚îî‚îÄ‚îÄ final_model/ (HuggingFace format)")
print("="*60)

if use_multi_gpu:
    print(f"\n‚ö° Trained on {n_gpus} GPUs")
print(f"\n‚úì Total global steps: {global_step}")
print("\n‚úì Training complete! Download from Output tab.")

print("\n" + "="*60)
print("TO RESUME TRAINING:")
print("="*60)
print("1. Upload any checkpoint to Kaggle Datasets")
print("2. Add dataset to notebook")
print("3. Set CONFIG['resume_from_checkpoint'] = '/kaggle/input/...'")
print("4. Run notebook - it will continue from exact step!")
print("="*60)