# Text Generation Model - Multi-GPU Kaggle Training
## 10M Parameter GPT-2 with 2xT4 GPU Support

**Features:**
- ‚ö° Multi-GPU training (2xT4 on Kaggle)
- üîÑ Automatic checkpoint management (keeps 4 most recent)
- üõ°Ô∏è Robust error handling with OOM recovery
- üíæ Smart memory optimization
- üìä Progress tracking and logging

## 1. Environment Setup & Fix Warnings

In [None]:
# Fix protobuf warnings FIRST (before any other imports)
import os
import sys

# Suppress TensorFlow and protobuf warnings
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'

# Suppress Python warnings
import warnings
warnings.filterwarnings('ignore')

print(f"Python version: {sys.version}")
print(f"Running on Kaggle: {'/kaggle/working' in sys.path or 'KAGGLE_KERNEL_RUN_TYPE' in os.environ}")

In [None]:
# Fix protobuf version conflict (common Kaggle issue)
!pip uninstall -y protobuf 2>/dev/null
!pip install -q protobuf==3.20.3
print("‚úì Protobuf fixed")

## 2. Import Dependencies

In [None]:
# Import all required libraries
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.nn.parallel import DataParallel
from transformers import GPT2Config, GPT2LMHeadModel, GPT2Tokenizer
from transformers import get_linear_schedule_with_warmup
from datasets import load_dataset
import json
import glob
from pathlib import Path
from datetime import datetime
from tqdm.auto import tqdm
import gc
import math

print("‚úì All imports successful")
print(f"PyTorch version: {torch.__version__}")
try:
    import transformers
    import datasets as ds
    print(f"Transformers version: {transformers.__version__}")
    print(f"Datasets version: {ds.__version__}")
except:
    pass

## 3. Multi-GPU Detection and Setup

In [None]:
# Detect all available GPUs
if torch.cuda.is_available():
    n_gpus = torch.cuda.device_count()
    print(f"\n{'='*60}")
    print(f"GPU CONFIGURATION")
    print(f"{'='*60}")
    print(f"Number of GPUs available: {n_gpus}")
    
    for i in range(n_gpus):
        print(f"\nGPU {i}:")
        print(f"  Name: {torch.cuda.get_device_name(i)}")
        print(f"  Memory: {torch.cuda.get_device_properties(i).total_memory / 1e9:.2f} GB")
        print(f"  Compute Capability: {torch.cuda.get_device_properties(i).major}.{torch.cuda.get_device_properties(i).minor}")
    
    print(f"\nCUDA Version: {torch.version.cuda}")
    print(f"cuDNN Version: {torch.backends.cudnn.version()}")
    
    # Enable optimizations
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    torch.backends.cudnn.benchmark = True
    print(f"\n‚úì TF32 and cuDNN optimizations enabled")
    
    # Set device
    device = torch.device('cuda:0')
    use_multi_gpu = n_gpus > 1
    
    if use_multi_gpu:
        print(f"\n‚ö° MULTI-GPU MODE: Will use {n_gpus} GPUs with DataParallel")
        print(f"   Effective batch size will be multiplied by {n_gpus}")
    else:
        print(f"\n‚ö†Ô∏è Single GPU mode (only 1 GPU detected)")
    
    print(f"{'='*60}")
else:
    print("\n‚ùå ERROR: No GPU detected!")
    print("Please enable GPU in Kaggle notebook settings:")
    print("  Settings ‚Üí Accelerator ‚Üí GPU T4 x2")
    device = torch.device('cpu')
    use_multi_gpu = False
    n_gpus = 0

## 4. Configuration (Multi-GPU Optimized)

In [None]:
# Training configuration - optimized for 2xT4 GPUs
CONFIG = {
    # Model architecture
    'vocab_size': 50257,
    'n_positions': 512,
    'n_embd': 256,
    'n_layer': 8,
    'n_head': 8,
    'n_inner': 1024,
    
    # Training hyperparameters (adjusted for multi-GPU)
    'batch_size': 16 if use_multi_gpu else 8,  # Per-GPU batch size
    'gradient_accumulation_steps': 4 if use_multi_gpu else 8,
    'learning_rate': 5e-4,
    'weight_decay': 0.01,
    'max_grad_norm': 1.0,
    'epochs': 3,
    'warmup_steps': 500,
    'max_length': 512,
    
    # Checkpointing
    'save_steps': 1000,
    'eval_steps': 500,
    'max_checkpoints': 4,
    'checkpoint_dir': '/kaggle/working/checkpoints',
    
    # Dataset
    'dataset_name': 'wikitext',
    'dataset_config': 'wikitext-103-v1',
    
    # Multi-GPU settings
    'use_multi_gpu': use_multi_gpu,
    'n_gpus': n_gpus,
    
    # Resume training
    'resume_from_checkpoint': None,
}

# Create checkpoint directory
os.makedirs(CONFIG['checkpoint_dir'], exist_ok=True)

# Calculate effective batch size
effective_batch_size = CONFIG['batch_size'] * CONFIG['gradient_accumulation_steps']
if use_multi_gpu:
    effective_batch_size *= n_gpus

print("\n" + "="*60)
print("TRAINING CONFIGURATION")
print("="*60)
print(f"Model Parameters: ~10M")
print(f"\nMulti-GPU Settings:")
print(f"  Number of GPUs: {n_gpus}")
print(f"  Multi-GPU Mode: {'Enabled' if use_multi_gpu else 'Disabled'}")
print(f"\nBatch Configuration:")
print(f"  Per-GPU Batch Size: {CONFIG['batch_size']}")
print(f"  Gradient Accumulation: {CONFIG['gradient_accumulation_steps']}")
if use_multi_gpu:
    print(f"  Total Batch per Step: {CONFIG['batch_size'] * n_gpus}")
print(f"  Effective Batch Size: {effective_batch_size}")
print(f"\nTraining Settings:")
print(f"  Learning Rate: {CONFIG['learning_rate']}")
print(f"  Epochs: {CONFIG['epochs']}")
print(f"  Max Checkpoints: {CONFIG['max_checkpoints']}")
print("="*60)

## 5. Checkpoint Management Functions

In [None]:
def get_checkpoint_list(checkpoint_dir):
    """Get sorted list of checkpoint files"""
    checkpoints = glob.glob(os.path.join(checkpoint_dir, 'checkpoint_*.pt'))
    checkpoints.sort(key=os.path.getmtime, reverse=True)
    return checkpoints

def cleanup_old_checkpoints(checkpoint_dir, max_keep=4):
    """Keep only the most recent N checkpoints"""
    checkpoints = get_checkpoint_list(checkpoint_dir)
    
    if len(checkpoints) > max_keep:
        to_delete = checkpoints[max_keep:]
        for ckpt in to_delete:
            try:
                os.remove(ckpt)
                print(f"  Deleted old checkpoint: {os.path.basename(ckpt)}")
            except Exception as e:
                print(f"  Warning: Could not delete {ckpt}: {e}")

def save_checkpoint(filepath, model, optimizer, scheduler, epoch, step, train_loss, val_loss=None, config=None):
    """Save training checkpoint (handles DataParallel)"""
    try:
        # Get model state dict (unwrap DataParallel if needed)
        if isinstance(model, nn.DataParallel):
            model_state = model.module.state_dict()
        else:
            model_state = model.state_dict()
        
        checkpoint = {
            'epoch': epoch,
            'step': step,
            'model_state_dict': model_state,
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'train_loss': train_loss,
            'val_loss': val_loss,
            'config': config,
            'timestamp': datetime.now().isoformat(),
        }
        
        torch.save(checkpoint, filepath)
        return True
    except Exception as e:
        print(f"Error saving checkpoint: {e}")
        return False

def load_checkpoint(filepath, model, optimizer=None, scheduler=None):
    """Load training checkpoint (handles DataParallel)"""
    try:
        checkpoint = torch.load(filepath, map_location='cpu')
        
        # Load model state (handle DataParallel)
        if isinstance(model, nn.DataParallel):
            model.module.load_state_dict(checkpoint['model_state_dict'])
        else:
            model.load_state_dict(checkpoint['model_state_dict'])
        
        if optimizer is not None and 'optimizer_state_dict' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        
        if scheduler is not None and 'scheduler_state_dict' in checkpoint:
            scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        
        return {
            'epoch': checkpoint.get('epoch', 0),
            'step': checkpoint.get('step', 0),
            'train_loss': checkpoint.get('train_loss', None),
            'val_loss': checkpoint.get('val_loss', None),
        }
    except Exception as e:
        print(f"Error loading checkpoint: {e}")
        return None

print("‚úì Checkpoint management functions defined")

## 6. Model Initialization with Multi-GPU Support

In [None]:
# Create model configuration
model_config = GPT2Config(
    vocab_size=CONFIG['vocab_size'],
    n_positions=CONFIG['n_positions'],
    n_embd=CONFIG['n_embd'],
    n_layer=CONFIG['n_layer'],
    n_head=CONFIG['n_head'],
    n_inner=CONFIG['n_inner'],
    activation_function='gelu_new',
    resid_pdrop=0.1,
    embd_pdrop=0.1,
    attn_pdrop=0.1,
)

print("\n" + "="*60)
print("MODEL ARCHITECTURE")
print("="*60)
print(f"Vocabulary Size: {model_config.vocab_size:,}")
print(f"Max Sequence Length: {model_config.n_positions}")
print(f"Embedding Dimension: {model_config.n_embd}")
print(f"Number of Layers: {model_config.n_layer}")
print(f"Number of Attention Heads: {model_config.n_head}")
print(f"FFN Inner Dimension: {model_config.n_inner}")
print("="*60)

In [None]:
# Initialize model
print("\nInitializing model...")
model = GPT2LMHeadModel(model_config)

# Move to GPU and wrap with DataParallel if multiple GPUs
model = model.to(device)

if use_multi_gpu:
    print(f"\n‚ö° Wrapping model with DataParallel for {n_gpus} GPUs...")
    model = nn.DataParallel(model, device_ids=list(range(n_gpus)))
    print(f"‚úì Model distributed across GPUs: {list(range(n_gpus))}")

# Count parameters
if isinstance(model, nn.DataParallel):
    total_params = sum(p.numel() for p in model.module.parameters())
    trainable_params = sum(p.numel() for p in model.module.parameters() if p.requires_grad)
else:
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\n‚úì Model initialized successfully")
print(f"  Total Parameters: {total_params:,}")
print(f"  Trainable Parameters: {trainable_params:,}")
print(f"  Model Size (FP32): {total_params * 4 / 1e6:.2f} MB")
if use_multi_gpu:
    print(f"  Per-GPU Memory: ~{total_params * 4 / 1e6 / n_gpus:.2f} MB (replicated on each GPU)")

## 7. Data Loading and Preparation

In [None]:
# Load tokenizer
print("\nLoading tokenizer...")
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
print(f"‚úì Tokenizer loaded (vocab size: {len(tokenizer):,})")

In [None]:
# Load dataset
print("\nLoading dataset...")
try:
    dataset = load_dataset(CONFIG['dataset_name'], CONFIG['dataset_config'])
    print(f"‚úì Dataset loaded successfully")
    print(f"  Train samples: {len(dataset['train']):,}")
    print(f"  Validation samples: {len(dataset['validation']):,}")
except Exception as e:
    print(f"Error loading dataset: {e}")
    raise

In [None]:
# Tokenization function
def tokenize_function(examples):
    return tokenizer(
        examples['text'],
        truncation=True,
        max_length=CONFIG['max_length'],
        padding='max_length',
        return_tensors='pt'
    )

print("\nTokenizing datasets...")
tokenized_train = dataset['train'].map(
    tokenize_function,
    batched=True,
    remove_columns=dataset['train'].column_names,
    desc="Tokenizing train set"
)

tokenized_val = dataset['validation'].map(
    tokenize_function,
    batched=True,
    remove_columns=dataset['validation'].column_names,
    desc="Tokenizing validation set"
)

tokenized_train.set_format('torch')
tokenized_val.set_format('torch')
print("‚úì Tokenization complete")

In [None]:
# Create dataloaders (adjusted for multi-GPU)
print("\nCreating dataloaders...")
train_loader = DataLoader(
    tokenized_train,
    batch_size=CONFIG['batch_size'],
    shuffle=True,
    num_workers=2,
    pin_memory=True if torch.cuda.is_available() else False
)

val_loader = DataLoader(
    tokenized_val,
    batch_size=CONFIG['batch_size'],
    num_workers=2,
    pin_memory=True if torch.cuda.is_available() else False
)

print(f"‚úì DataLoaders created")
print(f"  Train batches: {len(train_loader):,}")
print(f"  Validation batches: {len(val_loader):,}")
if use_multi_gpu:
    print(f"  Samples per step: {CONFIG['batch_size'] * n_gpus}")

## 8. Optimizer and Scheduler Setup

In [None]:
# Initialize optimizer
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=CONFIG['learning_rate'],
    weight_decay=CONFIG['weight_decay'],
    betas=(0.9, 0.999),
    eps=1e-8
)

# Calculate total training steps
total_steps = (len(train_loader) * CONFIG['epochs']) // CONFIG['gradient_accumulation_steps']

# Initialize learning rate scheduler
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=CONFIG['warmup_steps'],
    num_training_steps=total_steps
)

print("\n" + "="*60)
print("TRAINING SETUP")
print("="*60)
print(f"Total Training Steps: {total_steps:,}")
print(f"Warmup Steps: {CONFIG['warmup_steps']:,}")
print(f"Initial Learning Rate: {CONFIG['learning_rate']}")
print(f"Weight Decay: {CONFIG['weight_decay']}")
if use_multi_gpu:
    print(f"\n‚ö° Multi-GPU Training:")
    print(f"  Training will be {n_gpus}x faster (approximately)")
    print(f"  Each GPU processes {CONFIG['batch_size']} samples")
print("="*60)

## 9. Resume from Checkpoint (Optional)

In [None]:
# Resume training from checkpoint if specified
start_epoch = 1
global_step = 0
best_val_loss = float('inf')

if CONFIG['resume_from_checkpoint'] and os.path.exists(CONFIG['resume_from_checkpoint']):
    print(f"\nResuming from checkpoint: {CONFIG['resume_from_checkpoint']}")
    metadata = load_checkpoint(
        CONFIG['resume_from_checkpoint'],
        model,
        optimizer,
        scheduler
    )
    
    if metadata:
        start_epoch = metadata['epoch'] + 1
        global_step = metadata['step']
        if metadata['val_loss']:
            best_val_loss = metadata['val_loss']
        print(f"‚úì Resumed from epoch {metadata['epoch']}, step {global_step}")
else:
    print("\n‚úì Starting training from scratch")

## 10. Training Functions (Multi-GPU Compatible)

In [None]:
def train_epoch(model, loader, optimizer, scheduler, device, epoch, config):
    """Train for one epoch with multi-GPU support"""
    model.train()
    total_loss = 0
    optimizer.zero_grad()
    
    progress_bar = tqdm(loader, desc=f"Epoch {epoch}")
    
    for step, batch in enumerate(progress_bar):
        try:
            # Move batch to device
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            
            # Forward pass (DataParallel handles distribution automatically)
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=input_ids
            )
            
            # Loss is automatically averaged across GPUs by DataParallel
            loss = outputs.loss / config['gradient_accumulation_steps']
            loss.backward()
            
            # Update weights after accumulation
            if (step + 1) % config['gradient_accumulation_steps'] == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), config['max_grad_norm'])
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
            
            # Track loss
            total_loss += loss.item() * config['gradient_accumulation_steps']
            
            # Update progress bar
            progress_bar.set_postfix({
                'loss': f"{loss.item() * config['gradient_accumulation_steps']:.4f}",
                'lr': f"{scheduler.get_last_lr()[0]:.2e}",
                'gpus': f"{config['n_gpus']}" if config['use_multi_gpu'] else '1'
            })
            
            # Save checkpoint periodically
            if (step + 1) % config['save_steps'] == 0:
                checkpoint_path = os.path.join(
                    config['checkpoint_dir'],
                    f"checkpoint_epoch{epoch}_step{step+1}.pt"
                )
                
                if save_checkpoint(
                    checkpoint_path,
                    model,
                    optimizer,
                    scheduler,
                    epoch,
                    step + 1,
                    total_loss / (step + 1),
                    config=config
                ):
                    print(f"\n‚úì Checkpoint saved: {os.path.basename(checkpoint_path)}")
                    cleanup_old_checkpoints(config['checkpoint_dir'], config['max_checkpoints'])
        
        except RuntimeError as e:
            if 'out of memory' in str(e):
                print(f"\n‚ö†Ô∏è OOM Error at step {step}. Clearing cache...")
                torch.cuda.empty_cache()
                gc.collect()
                continue
            else:
                raise e
    
    return total_loss / len(loader)


def evaluate(model, loader, device):
    """Evaluate the model (multi-GPU compatible)"""
    model.eval()
    total_loss = 0
    
    with torch.no_grad():
        for batch in tqdm(loader, desc="Evaluating"):
            try:
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                
                outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=input_ids
                )
                total_loss += outputs.loss.item()
            
            except RuntimeError as e:
                if 'out of memory' in str(e):
                    torch.cuda.empty_cache()
                    gc.collect()
                    continue
                else:
                    raise e
    
    avg_loss = total_loss / len(loader)
    perplexity = math.exp(avg_loss) if avg_loss < 100 else float('inf')
    return avg_loss, perplexity

print("‚úì Training functions defined (multi-GPU compatible)")

## 11. Main Training Loop

In [None]:
# Training history
training_history = []

print("\n" + "="*60)
print("STARTING TRAINING")
print("="*60)
print(f"Training from epoch {start_epoch} to {CONFIG['epochs']}")
if use_multi_gpu:
    print(f"‚ö° Using {n_gpus} GPUs in parallel")
    print(f"‚ö° Effective speedup: ~{n_gpus}x")
print(f"Checkpoints: {CONFIG['checkpoint_dir']}")
print(f"Keeping {CONFIG['max_checkpoints']} most recent checkpoints")
print("="*60 + "\n")

try:
    for epoch in range(start_epoch, CONFIG['epochs'] + 1):
        print(f"\n{'='*60}")
        print(f"EPOCH {epoch}/{CONFIG['epochs']}")
        print(f"{'='*60}")
        
        # Train
        train_loss = train_epoch(
            model,
            train_loader,
            optimizer,
            scheduler,
            device,
            epoch,
            CONFIG
        )
        
        # Evaluate
        val_loss, val_perplexity = evaluate(model, val_loader, device)
        
        # Print results
        print(f"\n{'='*60}")
        print(f"Epoch {epoch} Results:")
        print(f"  Train Loss: {train_loss:.4f}")
        print(f"  Val Loss: {val_loss:.4f}")
        print(f"  Val Perplexity: {val_perplexity:.2f}")
        print(f"{'='*60}")
        
        # Save history
        training_history.append({
            'epoch': epoch,
            'train_loss': float(train_loss),
            'val_loss': float(val_loss),
            'val_perplexity': float(val_perplexity),
            'timestamp': datetime.now().isoformat()
        })
        
        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model_path = '/kaggle/working/best_model.pt'
            
            if save_checkpoint(
                best_model_path,
                model,
                optimizer,
                scheduler,
                epoch,
                len(train_loader),
                train_loss,
                val_loss,
                CONFIG
            ):
                print(f"\n‚úì New best model saved! (val_loss: {val_loss:.4f})")
        
        # Save epoch checkpoint
        epoch_checkpoint_path = os.path.join(
            CONFIG['checkpoint_dir'],
            f"checkpoint_epoch{epoch}_final.pt"
        )
        
        save_checkpoint(
            epoch_checkpoint_path,
            model,
            optimizer,
            scheduler,
            epoch,
            len(train_loader),
            train_loss,
            val_loss,
            CONFIG
        )
        print(f"‚úì Epoch {epoch} checkpoint saved")
        
        # Cleanup old checkpoints
        cleanup_old_checkpoints(CONFIG['checkpoint_dir'], CONFIG['max_checkpoints'])
        
        # Clear cache
        torch.cuda.empty_cache()
        gc.collect()
    
    print("\n" + "="*60)
    print("TRAINING COMPLETED SUCCESSFULLY!")
    print(f"Best Validation Loss: {best_val_loss:.4f}")
    if use_multi_gpu:
        print(f"‚ö° Trained using {n_gpus} GPUs")
    print("="*60)

except KeyboardInterrupt:
    print("\n\n‚ö†Ô∏è Training interrupted by user")
    emergency_path = '/kaggle/working/emergency_checkpoint.pt'
    save_checkpoint(emergency_path, model, optimizer, scheduler, epoch, step, train_loss, config=CONFIG)
    print(f"‚úì Emergency checkpoint saved")

except Exception as e:
    print(f"\n\n‚ùå Training failed: {e}")
    import traceback
    traceback.print_exc()
    raise

## 12. Save Training History

In [None]:
# Save training history
history_path = '/kaggle/working/training_history.json'
with open(history_path, 'w') as f:
    json.dump(training_history, f, indent=2)

print("\n" + "="*60)
print("TRAINING SUMMARY")
print("="*60)
for entry in training_history:
    print(f"Epoch {entry['epoch']}: "
          f"Train={entry['train_loss']:.4f}, "
          f"Val={entry['val_loss']:.4f}, "
          f"PPL={entry['val_perplexity']:.2f}")
print("="*60)
print(f"\n‚úì Training history saved to {history_path}")

## 13. Text Generation Test

In [None]:
def generate_text(prompt, max_length=100, temperature=0.8, num_return_sequences=1):
    """Generate text (uses only GPU 0 for generation)"""
    # For generation, use the base model (not DataParallel)
    if isinstance(model, nn.DataParallel):
        gen_model = model.module
    else:
        gen_model = model
    
    gen_model.eval()
    input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
    
    with torch.no_grad():
        output = gen_model.generate(
            input_ids,
            max_length=max_length,
            temperature=temperature,
            top_k=50,
            top_p=0.95,
            do_sample=True,
            num_return_sequences=num_return_sequences,
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id
        )
    
    return [tokenizer.decode(seq, skip_special_tokens=True) for seq in output]

# Test prompts
test_prompts = [
    "The future of artificial intelligence",
    "In a world where technology",
    "Scientists have discovered",
    "Once upon a time"
]

print("\n" + "="*60)
print("TEXT GENERATION EXAMPLES")
print("="*60)

for prompt in test_prompts:
    print(f"\n{'‚îÄ'*60}")
    print(f"Prompt: '{prompt}'")
    print(f"{'‚îÄ'*60}")
    try:
        generated = generate_text(prompt, max_length=150, temperature=0.8)
        print(generated[0])
    except Exception as e:
        print(f"Error: {e}")

## 14. Save Final Model

In [None]:
# Save model in HuggingFace format (unwrap DataParallel)
final_model_dir = '/kaggle/working/final_model'
print(f"\nSaving final model to {final_model_dir}...")

try:
    # Get the base model (unwrap DataParallel if needed)
    if isinstance(model, nn.DataParallel):
        save_model = model.module
    else:
        save_model = model
    
    save_model.save_pretrained(final_model_dir)
    tokenizer.save_pretrained(final_model_dir)
    
    # Save training config
    config_path = os.path.join(final_model_dir, 'training_config.json')
    with open(config_path, 'w') as f:
        json.dump(CONFIG, f, indent=2)
    
    print("‚úì Model saved successfully")
except Exception as e:
    print(f"Error saving model: {e}")

## 15. Output Summary

In [None]:
# List all output files
print("\n" + "="*60)
print("OUTPUT FILES")
print("="*60)
print("\nüìÅ /kaggle/working/")
print("  ‚îú‚îÄ‚îÄ best_model.pt (best checkpoint)")
print("  ‚îú‚îÄ‚îÄ training_history.json")
print("  ‚îú‚îÄ‚îÄ checkpoints/")

checkpoints = get_checkpoint_list(CONFIG['checkpoint_dir'])
if checkpoints:
    for ckpt in checkpoints:
        size_mb = os.path.getsize(ckpt) / 1e6
        print(f"  ‚îÇ   ‚îú‚îÄ‚îÄ {os.path.basename(ckpt)} ({size_mb:.1f} MB)")

print("  ‚îî‚îÄ‚îÄ final_model/ (HuggingFace format)")
print("      ‚îú‚îÄ‚îÄ pytorch_model.bin")
print("      ‚îú‚îÄ‚îÄ config.json")
print("      ‚îú‚îÄ‚îÄ training_config.json")
print("      ‚îî‚îÄ‚îÄ tokenizer files")
print("="*60)

if use_multi_gpu:
    print(f"\n‚ö° Training completed using {n_gpus} GPUs")
    print(f"‚ö° Approximate speedup: {n_gpus}x vs single GPU")

print("\n‚úì Training complete! Download files from the Output tab.")