# 🤖 Cyberdyne LLM - Complete Training & Inference System

A complete system for training and deploying custom language models with Llama3-inspired architecture (RMSNorm, SwiGLU, RoPE).

## Features
- 🎓 Train custom language models from scratch
- 📊 Multi-dataset support (9+ Hugging Face datasets)
- 💾 Optimized for Kaggle's 50GB SSD with intelligent caching
- 🚀 Gradient accumulation for effective large batch training
- 🔄 Checkpoint resuming for interrupted training
- 💬 Offline inference

## Quick Start Guide
1. Run the installation cell
2. Configure dataset caching for optimal disk usage
3. Choose to either train a new model or load an existing one
4. Load and test your trained model

**Compatible with Google Colab & Kaggle (Optimized for 50GB SSD)**

## 📦 Installation & Setup

In [1]:
# Install required packages
!pip install -q torch>=2.0.0 transformers>=4.30.0 datasets>=2.14.0 tqdm>=4.65.0 huggingface_hub>=0.16.0 accelerate>=0.20.0 psutil

print("✅ All packages installed successfully!")

[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
bigframes 2.12.0 requires google-cloud-bigquery-storage<3.0.0,>=2.30.0, which is not installed.
pylibcudf-cu12 25.2.2 requires pyarrow<20.0.0a0,>=14.0.0; platform_machine == "x86_64", but you have pyarrow 22.0.0 which is incompatible.
cudf-cu12 25.2.2 requires pyarrow<20.0.0a0,>=14.0.0; platform_machine == "x86_64", but you have pyarrow 22.0.0 which is incompatible.
bigframes 2.12.0 requires google-cloud-bigquery[bqstorage,pandas]>=3.31.0, but you have google-cloud-bigquery 3.25.0 which is incompatible.
bigframes 2.12.0 requires rich<14,>=12.4.4, but you have rich 14.1.0 which is incompatible.
libcugraph-cu12 25.6.0 requires libraft-cu12==25.6.*, but you have libraft-cu12 25.2.0 which is incompatible.
gradio 5.38.1 requires pydantic<2.12,>=2.0, but you have pydantic 2.12.0a1 which is incompatible.
cudf-polars

In [2]:
# Import libraries
import torch
from torch import nn
from torch.utils.data import DataLoader, IterableDataset
from transformers import GPT2Tokenizer
from datasets import load_dataset
from tqdm import tqdm
import time
import os
import math
import uuid
import shutil
import psutil
import gc
from pathlib import Path
from datetime import datetime
from typing import Optional, Dict, List
import json

# Check device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"🖥️  Using device: {device}")
if device == 'cuda':
    print(f"   GPU: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    # Check if multiple GPUs available
    n_gpus = torch.cuda.device_count()
    print(f"   Number of GPUs: {n_gpus}")
    if n_gpus > 1:
        print("   Multi-GPU training available!")

🖥️  Using device: cuda
   GPU: Tesla T4
   Memory: 15.83 GB
   Number of GPUs: 2
   Multi-GPU training available!


## 💾 Dataset Configuration & Disk Space Management

Optimized configuration for Kaggle's 50GB SSD space

In [3]:
# Dataset Configuration for Kaggle Environment
import os
import shutil
from pathlib import Path
import psutil

# Kaggle-specific paths
KAGGLE_CACHE_DIR = "/kaggle/working/cache"
KAGGLE_MODEL_DIR = "/kaggle/working/models"
KAGGLE_DATA_DIR = "/kaggle/working/data"
KAGGLE_CHECKPOINT_DIR = "/kaggle/working/checkpoints"

# Create directories
for dir_path in [KAGGLE_CACHE_DIR, KAGGLE_MODEL_DIR, KAGGLE_DATA_DIR, KAGGLE_CHECKPOINT_DIR]:
    os.makedirs(dir_path, exist_ok=True)

# Disk space utilities
def get_disk_usage(path="/kaggle/working"):
    """Get disk usage statistics for the specified path."""
    if os.path.exists(path):
        total, used, free = shutil.disk_usage(path)
    else:
        # Fallback to current directory if Kaggle path doesn't exist
        total, used, free = shutil.disk_usage(".")
    
    return {
        "total_gb": total / (1024**3),
        "used_gb": used / (1024**3), 
        "free_gb": free / (1024**3),
        "percent_used": (used / total) * 100 if total > 0 else 0
    }

def get_memory_usage():
    """Get current memory usage."""
    memory = psutil.virtual_memory()
    return {
        "total_gb": memory.total / (1024**3),
        "available_gb": memory.available / (1024**3),
        "used_gb": memory.used / (1024**3),
        "percent_used": memory.percent
    }

def monitor_resources():
    """Monitor disk and memory resources."""
    disk_info = get_disk_usage()
    mem_info = get_memory_usage()
    
    print("\n" + "="*60)
    print("📊 Resource Monitor")
    print("="*60)
    print(f"💾 Disk: {disk_info['free_gb']:.1f}GB free of {disk_info['total_gb']:.1f}GB ({disk_info['percent_used']:.1f}% used)")
    print(f"🧠 RAM: {mem_info['available_gb']:.1f}GB free of {mem_info['total_gb']:.1f}GB ({mem_info['percent_used']:.1f}% used)")
    
    if device == 'cuda':
        gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
        gpu_used = torch.cuda.memory_allocated() / 1e9
        gpu_cached = torch.cuda.memory_reserved() / 1e9
        print(f"🎮 GPU: {gpu_used:.1f}GB used, {gpu_cached:.1f}GB cached of {gpu_memory:.1f}GB")
    print("="*60 + "\n")
    
    return disk_info, mem_info

# Calculate optimal data configuration
disk_info = get_disk_usage()
print("\n🚀 Kaggle Environment Configuration")
print("="*60)
print(f"💾 Disk space: {disk_info['free_gb']:.1f}GB free of {disk_info['total_gb']:.1f}GB")

# Reserve space for models and overhead
DATA_CACHE_SIZE_GB = min(35, disk_info['free_gb'] - 10)  # Reserve 10GB for models/checkpoints/overhead
MAX_SAMPLES_ESTIMATE = int(DATA_CACHE_SIZE_GB * 1e9 / 2000)  # Estimate ~2KB per tokenized sample

print(f"📦 Will use up to {DATA_CACHE_SIZE_GB:.1f}GB for data caching")
print(f"📈 Estimated capacity: {MAX_SAMPLES_ESTIMATE:,} samples")
print("="*60)

# Configuration for different training scales
TRAINING_CONFIGS = {
    'quick': {
        'max_samples': min(1_000_000, MAX_SAMPLES_ESTIMATE // 20),
        'batch_size': 8,
        'gradient_accumulation_steps': 4,
        'checkpoint_interval': 500,
        'description': 'Quick training with 1M samples'
    },
    'standard': {
        'max_samples': min(5_000_000, MAX_SAMPLES_ESTIMATE // 4),
        'batch_size': 8,
        'gradient_accumulation_steps': 8,
        'checkpoint_interval': 1000,
        'description': 'Standard training with 5M samples'
    },
    'large': {
        'max_samples': min(20_000_000, MAX_SAMPLES_ESTIMATE),
        'batch_size': 4,
        'gradient_accumulation_steps': 16,
        'checkpoint_interval': 2000,
        'description': 'Large-scale training with up to 20M samples'
    },
    'pile': {
        'max_samples': MAX_SAMPLES_ESTIMATE,
        'batch_size': 4,
        'gradient_accumulation_steps': 32,
        'checkpoint_interval': 5000,
        'description': f'Pile dataset training with {MAX_SAMPLES_ESTIMATE:,} samples'
    }
}

print("\n📋 Available Training Configurations:")
for config_name, config in TRAINING_CONFIGS.items():
    print(f"  • {config_name}: {config['description']}")
    print(f"    - Batch size: {config['batch_size']} × {config['gradient_accumulation_steps']} accumulation")
    print(f"    - Effective batch size: {config['batch_size'] * config['gradient_accumulation_steps']}")


🚀 Kaggle Environment Configuration
💾 Disk space: 19.5GB free of 19.5GB
📦 Will use up to 9.5GB for data caching
📈 Estimated capacity: 4,751,216 samples

📋 Available Training Configurations:
  • quick: Quick training with 1M samples
    - Batch size: 8 × 4 accumulation
    - Effective batch size: 32
  • standard: Standard training with 5M samples
    - Batch size: 8 × 8 accumulation
    - Effective batch size: 64
  • large: Large-scale training with up to 20M samples
    - Batch size: 4 × 16 accumulation
    - Effective batch size: 64
  • pile: Pile dataset training with 4,751,216 samples
    - Batch size: 4 × 32 accumulation
    - Effective batch size: 128


## 🏗️ Model Architecture

Llama3-inspired Transformer with RMSNorm, SwiGLU, and Rotary Position Embeddings

In [4]:
# ==================== RMSNorm Implementation ====================
class RMSNorm(nn.Module):
    """Root Mean Square Layer Normalization.
    
    As used in Llama3, RMSNorm is computationally simpler than LayerNorm.
    It normalizes by the root mean square without subtracting mean.
    """
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        norm = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
        return x / norm * self.weight


# ==================== Rotary Position Embeddings (RoPE) ====================
def precompute_freqs_cis(dim: int, max_len: int, theta: float = 10000.0):
    """Precompute the frequency tensor for complex exponentials (cis) with given dimensions."""
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(max_len)
    freqs = torch.outer(t, freqs).float()
    freqs_cos = torch.cos(freqs)
    freqs_sin = torch.sin(freqs)
    return freqs_cos, freqs_sin


def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor) -> tuple:
    """Apply rotary embeddings to query and key tensors."""
    # Reshape for rotary computation
    xq_r = xq.float().reshape(*xq.shape[:-1], -1, 2)
    xk_r = xk.float().reshape(*xk.shape[:-1], -1, 2)
    
    # Apply rotation
    xq_out_r = xq_r[..., 0] * freqs_cos - xq_r[..., 1] * freqs_sin
    xq_out_i = xq_r[..., 0] * freqs_sin + xq_r[..., 1] * freqs_cos
    xk_out_r = xk_r[..., 0] * freqs_cos - xk_r[..., 1] * freqs_sin
    xk_out_i = xk_r[..., 0] * freqs_sin + xk_r[..., 1] * freqs_cos
    
    # Concatenate back
    xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(-2)
    xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(-2)
    
    return xq_out.type_as(xq), xk_out.type_as(xk)


# ==================== SwiGLU Activation Function ====================
class SwiGLU(nn.Module):
    """SwiGLU activation function as used in Llama3.
    
    Combines Swish activation with Gated Linear Unit.
    More effective than ReLU/GELU for language modeling.
    """
    def __init__(self, dim: int, hidden_dim: Optional[int] = None):
        super().__init__()
        if hidden_dim is None:
            hidden_dim = int(2 * dim * 4 / 3)
            # Round to nearest multiple of 64 for hardware efficiency
            hidden_dim = ((hidden_dim + 63) // 64) * 64
        
        self.w1 = nn.Linear(dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, dim, bias=False)
        self.w3 = nn.Linear(dim, hidden_dim, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x))


# ==================== Multi-Head Attention with RoPE ====================
class MultiHeadAttention(nn.Module):
    def __init__(self, emb_size, n_heads, max_len=512, dropout=0.1):
        super().__init__()
        assert emb_size % n_heads == 0
        
        self.emb_size = emb_size
        self.n_heads = n_heads
        self.head_dim = emb_size // n_heads
        
        self.qkv = nn.Linear(emb_size, 3 * emb_size, bias=False)
        self.out = nn.Linear(emb_size, emb_size)
        self.dropout = nn.Dropout(dropout)
        
        # Precompute RoPE frequencies
        self.freqs_cos, self.freqs_sin = precompute_freqs_cis(self.head_dim, max_len)
        
    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.shape
        device = x.device
        
        # Compute Q, K, V
        qkv = self.qkv(x)
        qkv = qkv.reshape(batch_size, seq_len, 3, self.n_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        # Apply RoPE to Q and K
        self.freqs_cos = self.freqs_cos.to(device)
        self.freqs_sin = self.freqs_sin.to(device)
        freqs_cos = self.freqs_cos[:seq_len].unsqueeze(0).unsqueeze(0)
        freqs_sin = self.freqs_sin[:seq_len].unsqueeze(0).unsqueeze(0)
        
        q, k = apply_rotary_emb(q, k, freqs_cos, freqs_sin)
        
        # Compute attention scores
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        
        # Create causal mask for autoregressive generation
        if mask is None:
            mask = torch.tril(torch.ones(seq_len, seq_len)).to(device)
        
        # Apply the mask to attention scores
        attn_scores = attn_scores.masked_fill(mask == 0, float('-inf'))
        
        
        attn_weights = torch.softmax(attn_scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        # Apply attention to values
        out = torch.matmul(attn_weights, v)
        out = out.permute(0, 2, 1, 3).contiguous()
        out = out.reshape(batch_size, seq_len, self.emb_size)
        out = self.out(out)
        
        return out


# ==================== Transformer Block with RMSNorm and SwiGLU ====================
class TransformerBlock(nn.Module):
    def __init__(self, emb_size, n_heads, ff_size, max_len=512, dropout=0.1, use_checkpoint=False):
        super().__init__()
        self.attn = MultiHeadAttention(emb_size, n_heads, max_len, dropout)
        self.ff = SwiGLU(emb_size)  # Using SwiGLU instead of standard FeedForward
        self.ln1 = RMSNorm(emb_size)  # Using RMSNorm instead of LayerNorm
        self.ln2 = RMSNorm(emb_size)  # Using RMSNorm instead of LayerNorm
        self.dropout = nn.Dropout(dropout)
        self.use_checkpoint = use_checkpoint
    
    def forward(self, x, mask=None):
        if self.use_checkpoint and self.training:
            # Use gradient checkpointing to save memory
            return torch.utils.checkpoint.checkpoint(self._forward, x, mask)
        else:
            return self._forward(x, mask)
    
    def _forward(self, x, mask=None):
        # Pre-norm with RMSNorm -> Attention with residual
        attn_out = self.attn(self.ln1(x), mask)
        x = x + self.dropout(attn_out)
        # Pre-norm with RMSNorm -> SwiGLU FFN with residual
        ff_out = self.ff(self.ln2(x))
        x = x + self.dropout(ff_out)
        return x


# ==================== Advanced LLM with Llama3-Inspired Architecture ====================
class AdvancedLLM(nn.Module):
    """Advanced LLM with Llama3-inspired improvements:
    - RMSNorm for efficient normalization
    - SwiGLU activation for better performance
    - Rotary Position Embeddings (RoPE) for better position encoding
    - Optional gradient checkpointing for memory efficiency
    """
    def __init__(self, vocab_size, emb_size=768, n_layers=12, n_heads=12,
                 ff_size=3072, max_len=512, dropout=0.1, use_gradient_checkpoint=False):
        super().__init__()
        self.vocab_size = vocab_size
        self.emb_size = emb_size
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.max_len = max_len
        self.use_gradient_checkpoint = use_gradient_checkpoint
        
        # Token embeddings (no position embeddings needed with RoPE)
        self.token_embed = nn.Embedding(vocab_size, emb_size)
        self.dropout = nn.Dropout(dropout)
        
        # Transformer blocks with RMSNorm and SwiGLU
        self.blocks = nn.ModuleList([
            TransformerBlock(emb_size, n_heads, ff_size, max_len, dropout, use_gradient_checkpoint)
            for _ in range(n_layers)
        ])
        
        # Final RMSNorm and output projection
        self.ln_final = RMSNorm(emb_size)
        self.head = nn.Linear(emb_size, vocab_size, bias=False)
        
        # Weight tying
        self.token_embed.weight = self.head.weight
        
        self._init_weights()
    
    def _init_weights(self):
        """Initialize model weights."""
        for module in self.modules():
            if isinstance(module, nn.Linear):
                torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
                if module.bias is not None:
                    torch.nn.init.zeros_(module.bias)
            elif isinstance(module, nn.Embedding):
                torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
    
    def forward(self, x, mask=None):
        batch_size, seq_len = x.shape
        
        # Token embeddings (no position embeddings, using RoPE instead)
        token_emb = self.token_embed(x)
        x = self.dropout(token_emb)
        
        # Apply transformer blocks
        for block in self.blocks:
            x = block(x, mask)
        
        # Final norm and output projection
        x = self.ln_final(x)
        logits = self.head(x)
        
        return logits
    
    def get_num_params(self, non_embedding=False):
        """Count model parameters."""
        n_params = sum(p.numel() for p in self.parameters())
        if non_embedding:
            n_params -= self.token_embed.weight.numel()
        return n_params
    
    def estimate_memory_usage(self, batch_size, seq_length):
        """Estimate memory usage for given batch size and sequence length."""
        # Rough estimation in GB
        params_memory = self.get_num_params() * 4 / 1e9  # 4 bytes per float32
        activation_memory = batch_size * seq_length * self.emb_size * self.n_layers * 4 / 1e9
        gradient_memory = params_memory if self.training else 0
        total_memory = params_memory + activation_memory + gradient_memory
        return {
            'params_gb': params_memory,
            'activations_gb': activation_memory,
            'gradients_gb': gradient_memory,
            'total_gb': total_memory
        }
    
    def save_model(self, path):
        """Save model to disk."""
        config = {
            'vocab_size': self.vocab_size,
            'emb_size': self.emb_size,
            'n_layers': self.n_layers,
            'n_heads': self.n_heads,
            'max_len': self.max_len,
            'use_gradient_checkpoint': self.use_gradient_checkpoint,
            'state_dict': self.state_dict()
        }
        torch.save(config, path)
        file_size = os.path.getsize(path) / 1e9
        print(f"✅ Model saved to {path} ({file_size:.2f}GB)")
    
    @classmethod
    def load_model(cls, path, device='cpu'):
        """Load model from disk."""
        config = torch.load(path, map_location=device)
        model = cls(
            vocab_size=config['vocab_size'],
            emb_size=config['emb_size'],
            n_layers=config['n_layers'],
            n_heads=config['n_heads'],
            max_len=config['max_len'],
            use_gradient_checkpoint=config.get('use_gradient_checkpoint', False)
        )
        model.load_state_dict(config['state_dict'])
        model.to(device)
        print(f"✅ Model loaded from {path}")
        return model

print("✅ Llama3-inspired model architecture defined")
print("   Features: RMSNorm, SwiGLU activation, Rotary Position Embeddings (RoPE), Gradient Checkpointing")

✅ Llama3-inspired model architecture defined
   Features: RMSNorm, SwiGLU activation, Rotary Position Embeddings (RoPE), Gradient Checkpointing


## 📚 Enhanced Dataset Loader with Caching

Optimized for large-scale data loading with local disk caching

In [5]:
class HuggingFaceDataset(IterableDataset):
    """Enhanced dataset loader with disk caching support for efficient large-scale training."""
    
    def __init__(self, dataset_name, tokenizer, max_samples=100000,
                 max_len=512, split='train', config=None, streaming=True,
                 text_field='text', instruction_field=None,
                 cache_dir=None, use_disk_cache=True, prefetch_batches=10):
        
        self.dataset_name = dataset_name
        self.tokenizer = tokenizer
        self.max_samples = max_samples
        self.max_len = max_len
        self.split = split
        self.config = config
        self.streaming = streaming
        self.text_field = text_field
        self.instruction_field = instruction_field
        self.prefetch_batches = prefetch_batches
        
        # Caching configuration
        self.cache_dir = cache_dir or KAGGLE_CACHE_DIR
        self.use_disk_cache = use_disk_cache
        
        # Set up caching for datasets library
        if use_disk_cache:
            os.environ['HF_DATASETS_CACHE'] = self.cache_dir
            os.environ['TRANSFORMERS_CACHE'] = self.cache_dir
            os.makedirs(self.cache_dir, exist_ok=True)
            print(f"💾 Using cache directory: {self.cache_dir}")

        try:
            # Load dataset with caching
            if config:
                self.dataset = load_dataset(
                    dataset_name, config, split=split, 
                    streaming=streaming,
                    cache_dir=self.cache_dir if use_disk_cache else None,
                    keep_in_memory=False  # Don't keep entire dataset in RAM
                )
            else:
                self.dataset = load_dataset(
                    dataset_name, split=split, 
                    streaming=streaming,
                    cache_dir=self.cache_dir if use_disk_cache else None,
                    keep_in_memory=False
                )
            print(f"✅ Loaded dataset: {dataset_name}")
            
            # For large datasets, optionally take a subset
            if streaming and max_samples is not None:
                self.dataset = self.dataset.take(max_samples)
                
        except Exception as e:
            print(f"❌ Error loading dataset {dataset_name}: {e}")
            raise
    
    def __iter__(self):
        count = 0
        buffer = []
        
        for item in self.dataset:
            if count >= self.max_samples:
                break

            text = self._extract_text(item)
            if not text:
                continue

            tokens = self.tokenizer.encode(
                text,
                max_length=self.max_len,
                truncation=True,
                padding='max_length',
                return_tensors='pt'
            ).squeeze(0)

            target = tokens.clone()
            target[:-1] = tokens[1:]
            target[-1] = self.tokenizer.eos_token_id or self.tokenizer.pad_token_id

            buffer.append((tokens, target))
            count += 1
            
            # Yield in batches for better efficiency
            if len(buffer) >= self.prefetch_batches:
                for item in buffer:
                    yield item
                buffer = []
        
        # Yield remaining items
        for item in buffer:
            yield item

    def _extract_text(self, item):
        if self.instruction_field and self.instruction_field in item:
            if isinstance(item[self.instruction_field], list):
                instruction = item[self.instruction_field][0] if item[self.instruction_field] else ""
            else:
                instruction = item[self.instruction_field]

            if self.text_field in item:
                if isinstance(item[self.text_field], list):
                    response = item[self.text_field][0] if item[self.text_field] else ""
                else:
                    response = item[self.text_field]
                return f"Instruction: {instruction}\n\nResponse: {response}"
            return instruction

        if self.text_field in item:
            text = item[self.text_field]
            if isinstance(text, list):
                return text[0] if text else ""
            return text

        for key in ['text', 'content', 'document', 'article', 'response', 'output']:
            if key in item:
                value = item[key]
                if isinstance(value, list):
                    return value[0] if value else ""
                return value

        return ""
    
    def get_cache_size(self):
        """Get the current cache directory size in GB."""
        if not os.path.exists(self.cache_dir):
            return 0
        
        total_size = 0
        for dirpath, dirnames, filenames in os.walk(self.cache_dir):
            for filename in filenames:
                filepath = os.path.join(dirpath, filename)
                total_size += os.path.getsize(filepath)
        
        return total_size / (1024**3)  # Convert to GB


class MultiDatasetLoader:
    """Multi-dataset loader with support for Pile and other large datasets."""
    
    DATASET_CONFIGS = {
        'wikipedia': {
            'name': 'wikipedia',
            'config': '20231101.en',
            'text_field': 'text',
        },
        'openwebtext': {
            'name': 'openwebtext',
            'text_field': 'text',
        },
        'wikitext': {
            'name': 'wikitext',
            'config': 'wikitext-103-v1',
            'text_field': 'text',
        },
        'bookcorpus': {
            'name': 'bookcorpus',
            'text_field': 'text',
        },
        'c4': {
            'name': 'c4',
            'config': 'en',
            'text_field': 'text',
        },
        'dolly': {
            'name': 'databricks/databricks-dolly-15k',
            'text_field': 'response',
            'instruction_field': 'instruction',
        },
        'alpaca': {
            'name': 'tatsu-lab/alpaca',
            'text_field': 'output',
            'instruction_field': 'instruction',
        },
        'squad': {
            'name': 'squad',
            'text_field': 'context',
        },
        'pile': {
            'name': 'EleutherAI/pile',
            'text_field': 'text',
            'config': 'all',  # Can also use specific subsets
        },
        'pile_cc': {
            'name': 'EleutherAI/pile',
            'config': 'pile_cc',
            'text_field': 'text',
        },
        'pile_pubmed': {
            'name': 'EleutherAI/pile',
            'config': 'pubmed_central',
            'text_field': 'text',
        },
        'pile_arxiv': {
            'name': 'EleutherAI/pile',
            'config': 'arxiv',
            'text_field': 'text',
        },
        'pile_books3': {
            'name': 'EleutherAI/pile',
            'config': 'books3',
            'text_field': 'text',
        },
        'pile_openwebtext2': {
            'name': 'EleutherAI/pile',
            'config': 'openwebtext2',
            'text_field': 'text',
        },
    }

    @classmethod
    def create_dataset(cls, dataset_key, tokenizer, max_samples=100000,
                       max_len=512, split='train', streaming=True,
                       cache_dir=None, use_disk_cache=True):
        if dataset_key not in cls.DATASET_CONFIGS:
            raise ValueError(f"Unknown dataset: {dataset_key}. Available: {list(cls.DATASET_CONFIGS.keys())}")

        config = cls.DATASET_CONFIGS[dataset_key]
        return HuggingFaceDataset(
            dataset_name=config['name'],
            tokenizer=tokenizer,
            max_samples=max_samples,
            max_len=max_len,
            split=split,
            config=config.get('config'),
            streaming=streaming,
            text_field=config['text_field'],
            instruction_field=config.get('instruction_field'),
            cache_dir=cache_dir,
            use_disk_cache=use_disk_cache
        )
    
    @classmethod
    def create_pile_dataset(cls, tokenizer, subsets=None, max_samples_per_subset=1_000_000,
                           max_len=512, cache_dir=None, use_disk_cache=True):
        """Create a dataset from multiple Pile subsets."""
        if subsets is None:
            subsets = ['pile_cc', 'pile_pubmed', 'pile_arxiv', 'pile_openwebtext2']
        
        datasets = []
        for subset in subsets:
            if subset in cls.DATASET_CONFIGS:
                try:
                    ds = cls.create_dataset(
                        subset, tokenizer, 
                        max_samples=max_samples_per_subset,
                        max_len=max_len,
                        streaming=True,
                        cache_dir=cache_dir,
                        use_disk_cache=use_disk_cache
                    )
                    datasets.append(ds)
                    print(f"✅ Added Pile subset: {subset}")
                except Exception as e:
                    print(f"⚠️ Failed to load {subset}: {e}")
        
        return datasets

    @classmethod
    def list_available_datasets(cls):
        return list(cls.DATASET_CONFIGS.keys())

    @classmethod
    def get_dataset_info(cls, dataset_key):
        if dataset_key in cls.DATASET_CONFIGS:
            return cls.DATASET_CONFIGS[dataset_key]
        return None

print("✅ Enhanced dataset loader with caching defined")
print(f"📚 Available datasets: {MultiDatasetLoader.list_available_datasets()}")

✅ Enhanced dataset loader with caching defined
📚 Available datasets: ['wikipedia', 'openwebtext', 'wikitext', 'bookcorpus', 'c4', 'dolly', 'alpaca', 'squad', 'pile', 'pile_cc', 'pile_pubmed', 'pile_arxiv', 'pile_books3', 'pile_openwebtext2']


## 🎓 Advanced Training Engine with Gradient Accumulation

Supports checkpoint resuming, gradient accumulation, and multi-GPU training

In [6]:
class LLMTrainer:
    """Advanced trainer with gradient accumulation, checkpoint resuming, and resource monitoring."""
    
    def __init__(self, model_name='cyberdyne-llm', vocab_size=None, emb_size=768,
                 n_layers=12, n_heads=12, ff_size=3072, max_len=512,
                 learning_rate=5e-5, weight_decay=0.01, warmup_steps=1000,
                 device=None, use_gradient_checkpoint=False, mixed_precision=False):
        
        self.model_name = model_name
        self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
        self.mixed_precision = mixed_precision and self.device == 'cuda'
        
        # Initialize tokenizer
        if vocab_size is None:
            tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
            tokenizer.add_special_tokens({'pad_token': '[PAD]'})
            vocab_size = len(tokenizer)

        # Initialize model
        self.model = AdvancedLLM(
            vocab_size=vocab_size,
            emb_size=emb_size,
            n_layers=n_layers,
            n_heads=n_heads,
            ff_size=ff_size,
            max_len=max_len,
            use_gradient_checkpoint=use_gradient_checkpoint
        ).to(self.device)
        
        # Multi-GPU support
        self.n_gpus = torch.cuda.device_count() if self.device == 'cuda' else 0
        if self.n_gpus > 1:
            print(f"🚀 Using {self.n_gpus} GPUs for training")
            self.model = nn.DataParallel(self.model)
        
        # Optimizer with weight decay
        self.optimizer = torch.optim.AdamW(
            self.model.parameters(), 
            lr=learning_rate,
            weight_decay=weight_decay,
            betas=(0.9, 0.95)  # Llama-style betas
        )
        
        # Learning rate scheduler
        self.warmup_steps = warmup_steps
        self.scheduler = None
        
        # Loss function
        self.loss_fn = nn.CrossEntropyLoss()
        
        # Mixed precision training
        self.scaler = torch.amp.GradScaler('cuda') if self.mixed_precision else None
        
        # Training state
        self.global_step = 0
        self.best_loss = float('inf')
        
        # Get actual model for parameter counting
        base_model = self.model.module if hasattr(self.model, 'module') else self.model
        print(f"🤖 Model initialized with {base_model.get_num_params():,} parameters")
        print(f"🖥️  Using device: {self.device}")
        
        if use_gradient_checkpoint:
            print("💾 Gradient checkpointing enabled for memory efficiency")
        if self.mixed_precision:
            print("⚡ Mixed precision training enabled")

    def train(self, dataset_key, num_epochs=3, batch_size=4, max_samples=10000,
              max_len=512, gradient_accumulation_steps=8, 
              checkpoint_interval=1000, save_dir='models', 
              checkpoint_dir='checkpoints', log_interval=10,
              resume_from_checkpoint=None, cache_dir=None,
              eval_interval=None, eval_samples=1000):
        
        print(f"\n🎓 Starting training on dataset: {dataset_key}")
        print(f"   Epochs: {num_epochs}, Batch size: {batch_size}")
        print(f"   Gradient accumulation: {gradient_accumulation_steps}")
        print(f"   Effective batch size: {batch_size * gradient_accumulation_steps * max(1, self.n_gpus)}")
        print(f"   Max samples: {max_samples:,}")
        
        os.makedirs(save_dir, exist_ok=True)
        os.makedirs(checkpoint_dir, exist_ok=True)
        
        # Load checkpoint if resuming
        start_epoch = 0
        if resume_from_checkpoint and os.path.exists(resume_from_checkpoint):
            start_epoch, self.global_step = self._load_checkpoint(resume_from_checkpoint)
            print(f"📂 Resumed from checkpoint: epoch {start_epoch}, step {self.global_step}")
        
        # Initialize tokenizer
        tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
        tokenizer.add_special_tokens({'pad_token': '[PAD]'})
        
        # Create dataset with caching
        try:
            dataset = MultiDatasetLoader.create_dataset(
                dataset_key,
                tokenizer,
                max_samples=max_samples,
                max_len=max_len,
                streaming=True,
                cache_dir=cache_dir or KAGGLE_CACHE_DIR,
                use_disk_cache=True
            )
        except Exception as e:
            print(f"❌ Error loading dataset: {e}")
            return None
        
        # Create dataloader
        dataloader = DataLoader(
            dataset, 
            batch_size=batch_size, 
            num_workers=0,
            pin_memory=True if self.device == 'cuda' else False
        )
        
        # Setup learning rate scheduler
        total_steps = (max_samples // (batch_size * gradient_accumulation_steps)) * num_epochs
        self.scheduler = self._get_linear_schedule_with_warmup(total_steps)
        
        # Training history
        training_start = time.time()
        training_history = {
            'model_name': self.model_name,
            'dataset': dataset_key,
            'epochs': num_epochs,
            'batch_size': batch_size,
            'gradient_accumulation_steps': gradient_accumulation_steps,
            'losses': [],
            'epoch_losses': [],
            'checkpoints': []
        }
        
        self.model.train()
        accumulated_loss = 0
        
        for epoch in range(start_epoch, num_epochs):
            epoch_loss = 0.0
            batch_count = 0
            
            progress_bar = tqdm(
                dataloader, 
                desc=f"Epoch {epoch+1}/{num_epochs}",
                total=max_samples // batch_size
            )
            
            for batch_idx, (inputs, targets) in enumerate(progress_bar):
                inputs = inputs.to(self.device)
                targets = targets.to(self.device)
                
                # Mixed precision training
                if self.mixed_precision:
                    with torch.cuda.amp.autocast():
                        logits = self.model(inputs)
                        loss = self.loss_fn(logits.view(-1, logits.size(-1)), targets.view(-1))
                        loss = loss / gradient_accumulation_steps
                    
                    self.scaler.scale(loss).backward()
                else:
                    logits = self.model(inputs)
                    loss = self.loss_fn(logits.view(-1, logits.size(-1)), targets.view(-1))
                    loss = loss / gradient_accumulation_steps
                    loss.backward()
                
                accumulated_loss += loss.item()
                
                # Gradient accumulation
                if (batch_idx + 1) % gradient_accumulation_steps == 0:
                    if self.mixed_precision:
                        self.scaler.unscale_(self.optimizer)
                        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
                        self.scaler.step(self.optimizer)
                        self.scaler.update()
                    else:
                        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
                        self.optimizer.step()
                    
                    if self.scheduler:
                        self.scheduler.step()
                    
                    self.optimizer.zero_grad()
                    self.global_step += 1
                    
                    # Update metrics
                    current_loss = accumulated_loss * gradient_accumulation_steps
                    epoch_loss += current_loss
                    batch_count += 1
                    accumulated_loss = 0
                    
                    # Update progress bar
                    current_lr = self.optimizer.param_groups[0]['lr']
                    progress_bar.set_postfix({
                        'loss': f'{current_loss:.4f}',
                        'lr': f'{current_lr:.2e}',
                        'step': self.global_step
                    })
                    
                    # Logging
                    if self.global_step % log_interval == 0:
                        training_history['losses'].append({
                            'epoch': epoch + 1,
                            'step': self.global_step,
                            'loss': current_loss,
                            'lr': current_lr
                        })
                    
                    # Save checkpoint
                    if self.global_step % checkpoint_interval == 0:
                        checkpoint_path = self._save_checkpoint(
                            checkpoint_dir, epoch, batch_idx
                        )
                        training_history['checkpoints'].append(checkpoint_path)
                        
                        # Monitor resources
                        if self.global_step % (checkpoint_interval * 2) == 0:
                            monitor_resources()
                            # Clear cache if needed
                            if self.device == 'cuda':
                                torch.cuda.empty_cache()
                            gc.collect()
                
                # Early stopping if we've processed enough samples
                if (batch_idx + 1) * batch_size >= max_samples:
                    break
            
            # Epoch statistics
            avg_epoch_loss = epoch_loss / max(batch_count, 1)
            training_history['epoch_losses'].append(avg_epoch_loss)
            print(f"✅ Epoch {epoch+1} completed. Average loss: {avg_epoch_loss:.4f}")
            
            # Save epoch checkpoint
            epoch_checkpoint = os.path.join(
                save_dir, f"{self.model_name}_epoch_{epoch+1}.pt"
            )
            self._save_model(epoch_checkpoint)
            
            # Evaluation
            if eval_interval and (epoch + 1) % eval_interval == 0:
                self._evaluate(dataset_key, tokenizer, eval_samples, max_len)
        
        # Training completion
        training_duration = time.time() - training_start
        training_history['duration'] = training_duration
        
        # Save final model
        final_path = os.path.join(save_dir, f"{self.model_name}_final.pt")
        self._save_model(final_path)
        
        # Save training history
        history_path = os.path.join(save_dir, f"{self.model_name}_history.json")
        with open(history_path, 'w') as f:
            json.dump(training_history, f, indent=2)
        
        print(f"\n🎉 Training completed in {training_duration/60:.2f} minutes")
        print(f"💾 Final model saved to: {final_path}")
        print(f"📊 Training history saved to: {history_path}")
        
        # Final resource check
        monitor_resources()
        
        return training_history
    
    def _save_model(self, path):
        """Save model with proper handling for DataParallel."""
        model_to_save = self.model.module if hasattr(self.model, 'module') else self.model
        model_to_save.save_model(path)
    
    def _save_checkpoint(self, checkpoint_dir, epoch, batch_idx):
        """Save training checkpoint."""
        checkpoint_path = os.path.join(
            checkpoint_dir, 
            f"checkpoint_epoch_{epoch+1}_step_{self.global_step}.pt"
        )
        
        model_state = self.model.module.state_dict() if hasattr(self.model, 'module') else self.model.state_dict()
        
        checkpoint = {
            'epoch': epoch,
            'batch_idx': batch_idx,
            'global_step': self.global_step,
            'model_state_dict': model_state,
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None,
            'best_loss': self.best_loss,
            'model_config': {
                'vocab_size': self.model.module.vocab_size if hasattr(self.model, 'module') else self.model.vocab_size,
                'emb_size': self.model.module.emb_size if hasattr(self.model, 'module') else self.model.emb_size,
                'n_layers': self.model.module.n_layers if hasattr(self.model, 'module') else self.model.n_layers,
                'n_heads': self.model.module.n_heads if hasattr(self.model, 'module') else self.model.n_heads,
                'max_len': self.model.module.max_len if hasattr(self.model, 'module') else self.model.max_len,
            }
        }
        
        torch.save(checkpoint, checkpoint_path)
        print(f"💾 Checkpoint saved: {checkpoint_path}")
        
        # Clean old checkpoints to save space (keep only last 3)
        self._cleanup_old_checkpoints(checkpoint_dir, keep_last=3)
        
        return checkpoint_path
    
    def _load_checkpoint(self, checkpoint_path):
        """Load training checkpoint."""
        checkpoint = torch.load(checkpoint_path, map_location=self.device)
        
        if hasattr(self.model, 'module'):
            self.model.module.load_state_dict(checkpoint['model_state_dict'])
        else:
            self.model.load_state_dict(checkpoint['model_state_dict'])
        
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        
        if checkpoint.get('scheduler_state_dict') and self.scheduler:
            self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        
        self.global_step = checkpoint['global_step']
        self.best_loss = checkpoint.get('best_loss', float('inf'))
        
        return checkpoint['epoch'], checkpoint['global_step']
    
    def _cleanup_old_checkpoints(self, checkpoint_dir, keep_last=3):
        """Remove old checkpoints to save disk space."""
        checkpoints = sorted(
            [f for f in os.listdir(checkpoint_dir) if f.startswith('checkpoint_')],
            key=lambda x: os.path.getmtime(os.path.join(checkpoint_dir, x))
        )
        
        if len(checkpoints) > keep_last:
            for checkpoint in checkpoints[:-keep_last]:
                os.remove(os.path.join(checkpoint_dir, checkpoint))
                print(f"🗑️ Removed old checkpoint: {checkpoint}")
    
    def _get_linear_schedule_with_warmup(self, total_steps):
        """Create linear learning rate schedule with warmup."""
        def lr_lambda(step):
            if step < self.warmup_steps:
                return float(step) / float(max(1, self.warmup_steps))
            return max(
                0.0, float(total_steps - step) / float(max(1, total_steps - self.warmup_steps))
            )
        
        return torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda)
    
    def _evaluate(self, dataset_key, tokenizer, eval_samples=1000, max_len=512):
        """Evaluate model during training."""
        print(f"\n📊 Evaluating on {eval_samples} samples...")
        self.model.eval()
        
        # Create evaluation dataset
        eval_dataset = MultiDatasetLoader.create_dataset(
            dataset_key,
            tokenizer,
            max_samples=eval_samples,
            max_len=max_len,
            split='validation' if 'validation' in dataset_key else 'train',
            streaming=True
        )
        
        eval_loader = DataLoader(eval_dataset, batch_size=8, num_workers=0)
        
        total_loss = 0
        batch_count = 0
        
        with torch.no_grad():
            for inputs, targets in tqdm(eval_loader, desc="Evaluating", total=eval_samples//8):
                inputs = inputs.to(self.device)
                targets = targets.to(self.device)
                
                logits = self.model(inputs)
                loss = self.loss_fn(logits.view(-1, logits.size(-1)), targets.view(-1))
                
                total_loss += loss.item()
                batch_count += 1
                
                if batch_count * 8 >= eval_samples:
                    break
        
        avg_loss = total_loss / max(batch_count, 1)
        print(f"📈 Evaluation loss: {avg_loss:.4f}")
        
        self.model.train()
        return avg_loss

print("✅ Advanced training engine with gradient accumulation defined")
print("   Features: Gradient accumulation, checkpoint resuming, mixed precision, multi-GPU support")

✅ Advanced training engine with gradient accumulation defined
   Features: Gradient accumulation, checkpoint resuming, mixed precision, multi-GPU support


## 💬 Inference Engine

Conversational AI with context management

In [7]:
class ConversationalInference:
    def __init__(self, model, tokenizer_name='gpt2', device=None, max_context_length=5):
        self.model = model
        self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
        self.model.to(self.device)
        self.model.eval()

        self.tokenizer = GPT2Tokenizer.from_pretrained(tokenizer_name)
        self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})

        self.max_context_length = max_context_length
        self.conversations = {}

        print(f"💬 Inference engine initialized on {self.device}")

    def create_session(self):
        session_id = str(uuid.uuid4())
        self.conversations[session_id] = []
        return session_id

    def get_conversation_history(self, session_id):
        return self.conversations.get(session_id, [])

    def clear_session(self, session_id):
        if session_id in self.conversations:
            self.conversations[session_id] = []

    def generate_response(self, prompt, session_id=None, max_new_tokens=150,
                         temperature=0.8, top_k=50, top_p=0.95,
                         use_context=True):
        if session_id is None:
            session_id = self.create_session()

        if session_id not in self.conversations:
            self.conversations[session_id] = []

        if use_context and self.conversations[session_id]:
            context_messages = self.conversations[session_id][-self.max_context_length:]
            context_text = self._build_context(context_messages)
            full_prompt = f"{context_text}\nUser: {prompt}\nAssistant:"
        else:
            full_prompt = f"User: {prompt}\nAssistant:"

        response = self._generate_text(
            full_prompt,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p
        )

        self.conversations[session_id].append({
            'role': 'user',
            'content': prompt
        })
        self.conversations[session_id].append({
            'role': 'assistant',
            'content': response
        })

        return response, session_id

    def _build_context(self, messages):
        context_parts = []
        for msg in messages:
            if msg['role'] == 'user':
                context_parts.append(f"User: {msg['content']}")
            else:
                context_parts.append(f"Assistant: {msg['content']}")
        return "\n".join(context_parts)

    def _generate_text(self, prompt, max_new_tokens=150, temperature=0.8,
                       top_k=50, top_p=0.95):
        tokens = self.tokenizer.encode(prompt, return_tensors='pt').to(self.device)

        if tokens.size(1) > self.model.max_len - max_new_tokens:
            tokens = tokens[:, -(self.model.max_len - max_new_tokens):]

        generated = tokens

        with torch.no_grad():
            for _ in range(max_new_tokens):
                if generated.size(1) >= self.model.max_len:
                    break

                logits = self.model(generated)
                next_token_logits = logits[:, -1, :]

                next_token_logits = next_token_logits / temperature

                sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
                cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)

                sorted_indices_to_remove = cumulative_probs > top_p
                sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
                sorted_indices_to_remove[..., 0] = 0

                indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
                next_token_logits[indices_to_remove] = float('-inf')

                top_probs, top_indices = torch.topk(torch.softmax(next_token_logits, dim=-1), top_k)
                next_token = top_indices[0, torch.multinomial(top_probs[0], 1)]

                generated = torch.cat([generated, next_token.unsqueeze(0)], dim=1)


                if next_token.item() == self.tokenizer.eos_token_id:
                    break

        output_text = self.tokenizer.decode(generated[0], skip_special_tokens=True)

        if "Assistant:" in output_text:
            response = output_text.split("Assistant:")[-1].strip()
        else:
            response = output_text[len(prompt):].strip()

        return response

    def chat(self, message, session_id=None, **kwargs):
        return self.generate_response(message, session_id=session_id, **kwargs)


class OfflineInference:
    def __init__(self, model_path, device=None):
        self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')

        print(f"📂 Loading model from: {model_path}")
        checkpoint = torch.load(model_path, map_location=self.device)

        self.model = AdvancedLLM(
            vocab_size=checkpoint['vocab_size'],
            emb_size=checkpoint['emb_size'],
            n_layers=checkpoint['n_layers'],
            n_heads=checkpoint['n_heads'],
            max_len=checkpoint['max_len'],
            use_gradient_checkpoint=checkpoint.get('use_gradient_checkpoint', False)
        )
        self.model.load_state_dict(checkpoint['state_dict'])
        self.model.to(self.device)

        tokenizer_name = checkpoint.get('tokenizer_name', 'gpt2')
        self.inference_engine = ConversationalInference(
            self.model,
            tokenizer_name=tokenizer_name,
            device=self.device
        )

        print("✅ Offline inference ready!")

    def chat(self, message, session_id=None, **kwargs):
        return self.inference_engine.chat(message, session_id=session_id, **kwargs)

    def create_session(self):
        return self.inference_engine.create_session()

    def clear_session(self, session_id):
        self.inference_engine.clear_session(session_id)

    def get_history(self, session_id):
        return self.inference_engine.get_conversation_history(session_id)

print("✅ Inference engine defined")

✅ Inference engine defined


## 🚀 Quick Training (1M Samples)

Optimized for Kaggle's 50GB SSD - Quick training with 1 million samples

In [8]:
# Quick training configuration - 1M samples
print("🚀 Quick Training Configuration")
print("="*60)

# Monitor resources before training
disk_info, mem_info = monitor_resources()

# Select configuration
config = TRAINING_CONFIGS['quick']
print(f"\n📋 Using configuration: {config['description']}")
print(f"   Max samples: {config['max_samples']:,}")
print(f"   Batch size: {config['batch_size']}")
print(f"   Gradient accumulation: {config['gradient_accumulation_steps']}")
print(f"   Effective batch size: {config['batch_size'] * config['gradient_accumulation_steps']}")

# Initialize trainer with memory optimization
quick_trainer = LLMTrainer(
    model_name='cyberdyne-quick-1M',
    emb_size=512,
    n_layers=6,
    n_heads=8,
    learning_rate=5e-5,
    warmup_steps=500,
    use_gradient_checkpoint=True,  # Enable gradient checkpointing
    mixed_precision=True  # Enable mixed precision
)

# Estimate memory usage
# Handle DataParallel wrapper if present
model_for_estimate = quick_trainer.model.module if hasattr(quick_trainer.model, 'module') else quick_trainer.model
memory_estimate = model_for_estimate.estimate_memory_usage(
    batch_size=config['batch_size'],
    seq_length=512
)
print(f"\n💾 Estimated memory usage:")
print(f"   Model parameters: {memory_estimate['params_gb']:.2f}GB")
print(f"   Activations: {memory_estimate['activations_gb']:.2f}GB")
print(f"   Total: {memory_estimate['total_gb']:.2f}GB")

# Train on wikitext dataset with 1M samples
print("\n🎓 Starting quick training on wikitext dataset...")
history = quick_trainer.train(
    dataset_key='wikitext',
    num_epochs=2,
    batch_size=config['batch_size'],
    gradient_accumulation_steps=config['gradient_accumulation_steps'],
    max_samples=config['max_samples'],
    checkpoint_interval=config['checkpoint_interval'],
    save_dir=KAGGLE_MODEL_DIR,
    checkpoint_dir=KAGGLE_CHECKPOINT_DIR,
    cache_dir=KAGGLE_CACHE_DIR,
    log_interval=50,
    eval_interval=1  # Evaluate after each epoch
)

if history:
    print("\n" + "="*60)
    print("📊 Quick Training Summary")
    print("="*60)
    print(f"Model: {history['model_name']}")
    print(f"Dataset: {history['dataset']}")
    print(f"Epochs: {history['epochs']}")
    print(f"Samples processed: {config['max_samples']:,}")
    print(f"Final Loss: {history['epoch_losses'][-1]:.4f}")
    print(f"Duration: {history['duration']/60:.2f} minutes")
    print(f"Training speed: {config['max_samples'] / (history['duration']/60):.0f} samples/minute")
    print("="*60)
    
    # Check disk usage after training
    print("\n📊 Post-training resource check:")
    monitor_resources()

🚀 Quick Training Configuration

📊 Resource Monitor
💾 Disk: 19.5GB free of 19.5GB (0.0% used)
🧠 RAM: 29.7GB free of 31.4GB (5.4% used)
🎮 GPU: 0.0GB used, 0.0GB cached of 15.8GB


📋 Using configuration: Quick training with 1M samples
   Max samples: 237,560
   Batch size: 8
   Gradient accumulation: 4
   Effective batch size: 32


tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

🚀 Using 2 GPUs for training
🤖 Model initialized with 45,009,408 parameters
🖥️  Using device: cuda
💾 Gradient checkpointing enabled for memory efficiency
⚡ Mixed precision training enabled

💾 Estimated memory usage:
   Model parameters: 0.18GB
   Activations: 0.05GB
   Total: 0.41GB

🎓 Starting quick training on wikitext dataset...

🎓 Starting training on dataset: wikitext
   Epochs: 2, Batch size: 8
   Gradient accumulation: 4
   Effective batch size: 64
   Max samples: 237,560
💾 Using cache directory: /kaggle/working/cache


README.md: 0.00B [00:00, ?B/s]

✅ Loaded dataset: wikitext


  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 1/2:   7%|▋         | 2000/29695 [15:55<5:53:28,  1.31it/s, loss=7.5793, lr=5.00e-05, step=500] 

💾 Checkpoint saved: /kaggle/working/checkpoints/checkpoint_epoch_1_step_500.pt


Epoch 1/2:  13%|█▎        | 3999/29695 [31:44<3:22:41,  2.11it/s, loss=4.2031, lr=4.83e-05, step=1000]

💾 Checkpoint saved: /kaggle/working/checkpoints/checkpoint_epoch_1_step_1000.pt

📊 Resource Monitor
💾 Disk: 18.5GB free of 19.5GB (5.2% used)
🧠 RAM: 26.8GB free of 31.4GB (14.4% used)
🎮 GPU: 1.0GB used, 4.0GB cached of 15.8GB



Epoch 1/2:  20%|██        | 6000/29695 [47:33<5:00:38,  1.31it/s, loss=1.4510, lr=4.65e-05, step=1500]

💾 Checkpoint saved: /kaggle/working/checkpoints/checkpoint_epoch_1_step_1500.pt


Epoch 1/2:  27%|██▋       | 7999/29695 [1:03:21<2:50:23,  2.12it/s, loss=3.3327, lr=4.48e-05, step=2000] 

💾 Checkpoint saved: /kaggle/working/checkpoints/checkpoint_epoch_1_step_2000.pt
🗑️ Removed old checkpoint: checkpoint_epoch_1_step_500.pt

📊 Resource Monitor
💾 Disk: 18.0GB free of 19.5GB (7.7% used)
🧠 RAM: 26.8GB free of 31.4GB (14.5% used)
🎮 GPU: 1.0GB used, 4.4GB cached of 15.8GB



Epoch 1/2:  34%|███▎      | 10000/29695 [1:19:11<4:23:33,  1.25it/s, loss=6.3364, lr=4.30e-05, step=2500]

💾 Checkpoint saved: /kaggle/working/checkpoints/checkpoint_epoch_1_step_2500.pt
🗑️ Removed old checkpoint: checkpoint_epoch_1_step_1000.pt


Epoch 1/2:  40%|████      | 11999/29695 [1:34:58<2:19:40,  2.11it/s, loss=4.9071, lr=4.13e-05, step=3000] 

💾 Checkpoint saved: /kaggle/working/checkpoints/checkpoint_epoch_1_step_3000.pt
🗑️ Removed old checkpoint: checkpoint_epoch_1_step_1500.pt

📊 Resource Monitor
💾 Disk: 18.0GB free of 19.5GB (7.7% used)
🧠 RAM: 26.7GB free of 31.4GB (14.9% used)
🎮 GPU: 1.0GB used, 4.4GB cached of 15.8GB



Epoch 1/2:  47%|████▋     | 14000/29695 [1:50:47<3:29:35,  1.25it/s, loss=5.0418, lr=3.95e-05, step=3500] 

💾 Checkpoint saved: /kaggle/working/checkpoints/checkpoint_epoch_1_step_3500.pt
🗑️ Removed old checkpoint: checkpoint_epoch_1_step_2000.pt


Epoch 1/2:  54%|█████▍    | 15999/29695 [2:06:34<1:48:00,  2.11it/s, loss=4.0013, lr=3.78e-05, step=4000]

💾 Checkpoint saved: /kaggle/working/checkpoints/checkpoint_epoch_1_step_4000.pt
🗑️ Removed old checkpoint: checkpoint_epoch_1_step_2500.pt

📊 Resource Monitor
💾 Disk: 18.0GB free of 19.5GB (7.7% used)
🧠 RAM: 26.7GB free of 31.4GB (14.9% used)
🎮 GPU: 1.0GB used, 4.4GB cached of 15.8GB



Epoch 1/2:  61%|██████    | 18000/29695 [2:22:24<2:34:42,  1.26it/s, loss=4.7974, lr=3.61e-05, step=4500]

💾 Checkpoint saved: /kaggle/working/checkpoints/checkpoint_epoch_1_step_4500.pt
🗑️ Removed old checkpoint: checkpoint_epoch_1_step_3000.pt


Epoch 1/2:  65%|██████▍   | 19197/29695 [2:31:50<1:23:02,  2.11it/s, loss=3.5425, lr=3.50e-05, step=4799]


✅ Epoch 1 completed. Average loss: 5.4589
✅ Model saved to /kaggle/working/models/cyberdyne-quick-1M_epoch_1.pt (0.18GB)

📊 Evaluating on 1000 samples...
💾 Using cache directory: /kaggle/working/cache
✅ Loaded dataset: wikitext


Evaluating:  65%|██████▍   | 81/125 [00:13<00:07,  6.12it/s]


📈 Evaluation loss: 1.0564


Epoch 2/2:   3%|▎         | 803/29695 [06:21<3:46:59,  2.12it/s, loss=2.7456, lr=3.43e-05, step=5000] 

💾 Checkpoint saved: /kaggle/working/checkpoints/checkpoint_epoch_2_step_5000.pt
🗑️ Removed old checkpoint: checkpoint_epoch_1_step_3500.pt

📊 Resource Monitor
💾 Disk: 17.8GB free of 19.5GB (8.6% used)
🧠 RAM: 26.4GB free of 31.4GB (15.8% used)
🎮 GPU: 1.0GB used, 4.4GB cached of 15.8GB



Epoch 2/2:   9%|▉         | 2804/29695 [22:10<5:59:35,  1.25it/s, loss=4.3997, lr=3.26e-05, step=5500]

💾 Checkpoint saved: /kaggle/working/checkpoints/checkpoint_epoch_2_step_5500.pt
🗑️ Removed old checkpoint: checkpoint_epoch_1_step_4000.pt


Epoch 2/2:  16%|█▌        | 4803/29695 [37:57<3:17:47,  2.10it/s, loss=7.2639, lr=3.08e-05, step=6000]

💾 Checkpoint saved: /kaggle/working/checkpoints/checkpoint_epoch_2_step_6000.pt
🗑️ Removed old checkpoint: checkpoint_epoch_1_step_4500.pt

📊 Resource Monitor
💾 Disk: 17.8GB free of 19.5GB (8.6% used)
🧠 RAM: 26.4GB free of 31.4GB (15.8% used)
🎮 GPU: 1.0GB used, 4.4GB cached of 15.8GB



Epoch 2/2:  23%|██▎       | 6804/29695 [53:46<5:12:09,  1.22it/s, loss=4.3983, lr=2.91e-05, step=6500]

💾 Checkpoint saved: /kaggle/working/checkpoints/checkpoint_epoch_2_step_6500.pt
🗑️ Removed old checkpoint: checkpoint_epoch_2_step_5000.pt


Epoch 2/2:  30%|██▉       | 8803/29695 [1:09:33<2:44:42,  2.11it/s, loss=3.9933, lr=2.73e-05, step=7000]

💾 Checkpoint saved: /kaggle/working/checkpoints/checkpoint_epoch_2_step_7000.pt
🗑️ Removed old checkpoint: checkpoint_epoch_2_step_5500.pt

📊 Resource Monitor
💾 Disk: 17.8GB free of 19.5GB (8.6% used)
🧠 RAM: 26.4GB free of 31.4GB (15.8% used)
🎮 GPU: 1.0GB used, 4.4GB cached of 15.8GB



Epoch 2/2:  36%|███▋      | 10804/29695 [1:25:22<4:13:12,  1.24it/s, loss=2.5392, lr=2.56e-05, step=7500] 

💾 Checkpoint saved: /kaggle/working/checkpoints/checkpoint_epoch_2_step_7500.pt
🗑️ Removed old checkpoint: checkpoint_epoch_2_step_6000.pt


Epoch 2/2:  43%|████▎     | 12803/29695 [1:41:09<2:12:34,  2.12it/s, loss=2.3715, lr=2.39e-05, step=8000]

💾 Checkpoint saved: /kaggle/working/checkpoints/checkpoint_epoch_2_step_8000.pt
🗑️ Removed old checkpoint: checkpoint_epoch_2_step_6500.pt

📊 Resource Monitor
💾 Disk: 17.8GB free of 19.5GB (8.6% used)
🧠 RAM: 26.4GB free of 31.4GB (15.9% used)
🎮 GPU: 1.0GB used, 4.4GB cached of 15.8GB



Epoch 2/2:  50%|████▉     | 14804/29695 [1:56:58<3:24:30,  1.21it/s, loss=2.7847, lr=2.21e-05, step=8500]

💾 Checkpoint saved: /kaggle/working/checkpoints/checkpoint_epoch_2_step_8500.pt
🗑️ Removed old checkpoint: checkpoint_epoch_2_step_7000.pt


Epoch 2/2:  57%|█████▋    | 16803/29695 [2:12:45<1:43:02,  2.09it/s, loss=6.4116, lr=2.04e-05, step=9000]

💾 Checkpoint saved: /kaggle/working/checkpoints/checkpoint_epoch_2_step_9000.pt
🗑️ Removed old checkpoint: checkpoint_epoch_2_step_7500.pt

📊 Resource Monitor
💾 Disk: 17.8GB free of 19.5GB (8.6% used)
🧠 RAM: 26.4GB free of 31.4GB (15.8% used)
🎮 GPU: 1.0GB used, 4.4GB cached of 15.8GB



Epoch 2/2:  63%|██████▎   | 18804/29695 [2:28:34<2:26:02,  1.24it/s, loss=2.1818, lr=1.86e-05, step=9500]

💾 Checkpoint saved: /kaggle/working/checkpoints/checkpoint_epoch_2_step_9500.pt
🗑️ Removed old checkpoint: checkpoint_epoch_2_step_8000.pt


Epoch 2/2:  65%|██████▍   | 19197/29695 [2:31:40<1:22:56,  2.11it/s, loss=3.3948, lr=1.83e-05, step=9598]


✅ Epoch 2 completed. Average loss: 4.3524
✅ Model saved to /kaggle/working/models/cyberdyne-quick-1M_epoch_2.pt (0.18GB)

📊 Evaluating on 1000 samples...
💾 Using cache directory: /kaggle/working/cache
✅ Loaded dataset: wikitext


Evaluating:  65%|██████▍   | 81/125 [00:13<00:07,  6.22it/s]


📈 Evaluation loss: 1.0057
✅ Model saved to /kaggle/working/models/cyberdyne-quick-1M_final.pt (0.18GB)

🎉 Training completed in 304.00 minutes
💾 Final model saved to: /kaggle/working/models/cyberdyne-quick-1M_final.pt
📊 Training history saved to: /kaggle/working/models/cyberdyne-quick-1M_history.json

📊 Resource Monitor
💾 Disk: 17.5GB free of 19.5GB (10.3% used)
🧠 RAM: 26.3GB free of 31.4GB (16.0% used)
🎮 GPU: 0.9GB used, 4.4GB cached of 15.8GB


📊 Quick Training Summary
Model: cyberdyne-quick-1M
Dataset: wikitext
Epochs: 2
Samples processed: 237,560
Final Loss: 4.3524
Duration: 304.00 minutes
Training speed: 781 samples/minute

📊 Post-training resource check:

📊 Resource Monitor
💾 Disk: 17.5GB free of 19.5GB (10.3% used)
🧠 RAM: 26.3GB free of 31.4GB (16.0% used)
🎮 GPU: 0.7GB used, 4.4GB cached of 15.8GB



## 💪 Standard Training (5M Samples)

Medium-scale training with 5 million samples

In [9]:
# Standard training configuration - 5M samples
print("💪 Standard Training Configuration")
print("="*60)

# Monitor resources before training
disk_info, mem_info = monitor_resources()

# Select configuration
config = TRAINING_CONFIGS['standard']
print(f"\n📋 Using configuration: {config['description']}")
print(f"   Max samples: {config['max_samples']:,}")
print(f"   Batch size: {config['batch_size']}")
print(f"   Gradient accumulation: {config['gradient_accumulation_steps']}")

# Initialize trainer
standard_trainer = LLMTrainer(
    model_name='cyberdyne-standard-5M',
    emb_size=768,
    n_layers=8,
    n_heads=12,
    ff_size=3072,
    learning_rate=3e-5,
    warmup_steps=1000,
    use_gradient_checkpoint=True,
    mixed_precision=True
)

# Train on dolly instruction dataset
print("\n🎓 Starting standard training on dolly dataset...")
history = standard_trainer.train(
    dataset_key='dolly',
    num_epochs=3,
    batch_size=config['batch_size'],
    gradient_accumulation_steps=config['gradient_accumulation_steps'],
    max_samples=config['max_samples'],
    checkpoint_interval=config['checkpoint_interval'],
    save_dir=KAGGLE_MODEL_DIR,
    checkpoint_dir=KAGGLE_CHECKPOINT_DIR,
    cache_dir=KAGGLE_CACHE_DIR,
    log_interval=100,
    resume_from_checkpoint=None  # Set to checkpoint path to resume
)

if history:
    print("\n" + "="*60)
    print("📊 Standard Training Summary")
    print("="*60)
    print(f"Model: {history['model_name']}")
    print(f"Dataset: {history['dataset']}")
    print(f"Duration: {history['duration']/3600:.2f} hours")
    print("="*60)

💪 Standard Training Configuration

📊 Resource Monitor
💾 Disk: 17.5GB free of 19.5GB (10.3% used)
🧠 RAM: 26.3GB free of 31.4GB (16.0% used)
🎮 GPU: 0.7GB used, 4.4GB cached of 15.8GB


📋 Using configuration: Standard training with 5M samples
   Max samples: 1,187,804
   Batch size: 8
   Gradient accumulation: 8
🚀 Using 2 GPUs for training
🤖 Model initialized with 95,240,448 parameters
🖥️  Using device: cuda
💾 Gradient checkpointing enabled for memory efficiency
⚡ Mixed precision training enabled

🎓 Starting standard training on dolly dataset...

🎓 Starting training on dataset: dolly
   Epochs: 3, Batch size: 8
   Gradient accumulation: 8
   Effective batch size: 128
   Max samples: 1,187,804
💾 Using cache directory: /kaggle/working/cache


README.md: 0.00B [00:00, ?B/s]

✅ Loaded dataset: databricks/databricks-dolly-15k


  with torch.cuda.amp.autocast():
Epoch 1/3:   1%|▏         | 1877/148475 [33:17<43:19:46,  1.06s/it, loss=14.4490, lr=7.02e-06, step=234]


✅ Epoch 1 completed. Average loss: 31.4455
✅ Model saved to /kaggle/working/models/cyberdyne-standard-5M_epoch_1.pt (0.38GB)


Epoch 2/3:   1%|▏         | 1877/148475 [33:14<43:16:30,  1.06s/it, loss=11.4922, lr=1.40e-05, step=468]


✅ Epoch 2 completed. Average loss: 12.5462
✅ Model saved to /kaggle/working/models/cyberdyne-standard-5M_epoch_2.pt (0.38GB)


Epoch 3/3:   1%|▏         | 1877/148475 [33:14<43:15:44,  1.06s/it, loss=10.3081, lr=2.11e-05, step=702]


✅ Epoch 3 completed. Average loss: 10.4863
✅ Model saved to /kaggle/working/models/cyberdyne-standard-5M_epoch_3.pt (0.38GB)
✅ Model saved to /kaggle/working/models/cyberdyne-standard-5M_final.pt (0.38GB)

🎉 Training completed in 99.80 minutes
💾 Final model saved to: /kaggle/working/models/cyberdyne-standard-5M_final.pt
📊 Training history saved to: /kaggle/working/models/cyberdyne-standard-5M_history.json

📊 Resource Monitor
💾 Disk: 16.1GB free of 19.5GB (17.6% used)
🧠 RAM: 26.2GB free of 31.4GB (16.4% used)
🎮 GPU: 2.4GB used, 6.0GB cached of 15.8GB


📊 Standard Training Summary
Model: cyberdyne-standard-5M
Dataset: dolly
Duration: 1.66 hours


## 🔥 Large-Scale Pile Training (20M+ Samples)

Maximum scale training using the Pile dataset

In [10]:
# Large-scale Pile training - Maximum samples that fit in 50GB
print("🔥 Large-Scale Pile Training Configuration")
print("="*60)

# Monitor resources
disk_info, mem_info = monitor_resources()

# Select configuration
config = TRAINING_CONFIGS['pile']
print(f"\n📋 Using configuration: {config['description']}")
print(f"   Max samples: {config['max_samples']:,}")
print(f"   This will use most of the available {DATA_CACHE_SIZE_GB:.1f}GB cache space")

# Initialize large model with maximum memory optimization
pile_trainer = LLMTrainer(
    model_name='cyberdyne-pile-large',
    emb_size=768,
    n_layers=12,
    n_heads=12,
    ff_size=3072,
    learning_rate=2e-5,
    weight_decay=0.01,
    warmup_steps=2000,
    use_gradient_checkpoint=True,  # Essential for large models
    mixed_precision=True  # Essential for speed and memory
)

# Pile dataset configuration
PILE_SUBSETS = [
    'pile_cc',           # Common Crawl
    'pile_pubmed',       # PubMed Central
    'pile_arxiv',        # ArXiv
    'pile_openwebtext2', # OpenWebText2
]

print(f"\n📚 Will train on Pile subsets: {', '.join(PILE_SUBSETS)}")
print(f"   Samples per subset: {config['max_samples'] // len(PILE_SUBSETS):,}")

# Create tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.add_special_tokens({'pad_token': '[PAD]'})

# Option 1: Train on a single Pile subset
print("\n🎓 Starting large-scale training on Pile dataset...")
history = pile_trainer.train(
    dataset_key='pile_cc',  # Start with Common Crawl subset
    num_epochs=2,
    batch_size=config['batch_size'],
    gradient_accumulation_steps=config['gradient_accumulation_steps'],
    max_samples=config['max_samples'],
    checkpoint_interval=config['checkpoint_interval'],
    save_dir=KAGGLE_MODEL_DIR,
    checkpoint_dir=KAGGLE_CHECKPOINT_DIR,
    cache_dir=KAGGLE_CACHE_DIR,
    log_interval=200,
    resume_from_checkpoint=None  # Can resume from checkpoint if interrupted
)

# Option 2: Train on multiple Pile subsets (uncomment to use)
# datasets = MultiDatasetLoader.create_pile_dataset(
#     tokenizer=tokenizer,
#     subsets=PILE_SUBSETS,
#     max_samples_per_subset=config['max_samples'] // len(PILE_SUBSETS),
#     max_len=512,
#     cache_dir=KAGGLE_CACHE_DIR,
#     use_disk_cache=True
# )

if history:
    print("\n" + "="*60)
    print("📊 Large-Scale Pile Training Summary")
    print("="*60)
    print(f"Model: {history['model_name']}")
    print(f"Dataset: Pile")
    print(f"Samples processed: {config['max_samples']:,}")
    print(f"Duration: {history['duration']/3600:.2f} hours")
    print(f"Final Loss: {history['epoch_losses'][-1]:.4f}")
    print(f"Checkpoints saved: {len(history['checkpoints'])}")
    print("="*60)
    
    # Final resource check
    print("\n📊 Final resource status:")
    disk_info, mem_info = monitor_resources()
    
    # Cache usage
    cache_size = sum(
        os.path.getsize(os.path.join(dirpath, filename))
        for dirpath, dirnames, filenames in os.walk(KAGGLE_CACHE_DIR)
        for filename in filenames
    ) / (1024**3)
    print(f"📦 Cache size: {cache_size:.2f}GB")

🔥 Large-Scale Pile Training Configuration

📊 Resource Monitor
💾 Disk: 16.1GB free of 19.5GB (17.6% used)
🧠 RAM: 26.2GB free of 31.4GB (16.4% used)
🎮 GPU: 2.3GB used, 6.0GB cached of 15.8GB


📋 Using configuration: Pile dataset training with 4,751,216 samples
   Max samples: 4,751,216
   This will use most of the available 9.5GB cache space
🚀 Using 2 GPUs for training
🤖 Model initialized with 123,561,216 parameters
🖥️  Using device: cuda
💾 Gradient checkpointing enabled for memory efficiency
⚡ Mixed precision training enabled

📚 Will train on Pile subsets: pile_cc, pile_pubmed, pile_arxiv, pile_openwebtext2
   Samples per subset: 1,187,804

🎓 Starting large-scale training on Pile dataset...

🎓 Starting training on dataset: pile_cc
   Epochs: 2, Batch size: 4
   Gradient accumulation: 32
   Effective batch size: 256
   Max samples: 4,751,216
💾 Using cache directory: /kaggle/working/cache


README.md: 0.00B [00:00, ?B/s]

pile.py: 0.00B [00:00, ?B/s]

❌ Error loading dataset EleutherAI/pile: Dataset scripts are no longer supported, but found pile.py
❌ Error loading dataset: Dataset scripts are no longer supported, but found pile.py


## 🔄 Resume Training from Checkpoint

Continue interrupted training sessions

In [11]:
# Resume training from checkpoint
print("🔄 Resume Training from Checkpoint")
print("="*60)

# List available checkpoints
if os.path.exists(KAGGLE_CHECKPOINT_DIR):
    checkpoints = sorted([
        f for f in os.listdir(KAGGLE_CHECKPOINT_DIR) 
        if f.endswith('.pt')
    ])
    
    if checkpoints:
        print("\n📁 Available checkpoints:")
        for i, checkpoint in enumerate(checkpoints):
            checkpoint_path = os.path.join(KAGGLE_CHECKPOINT_DIR, checkpoint)
            size_gb = os.path.getsize(checkpoint_path) / (1024**3)
            print(f"  {i+1}. {checkpoint} ({size_gb:.2f}GB)")
        
        # Use the latest checkpoint
        latest_checkpoint = os.path.join(KAGGLE_CHECKPOINT_DIR, checkpoints[-1])
        print(f"\n✅ Will resume from: {latest_checkpoint}")
        
        # Resume training
        resume_trainer = LLMTrainer(
            model_name='cyberdyne-resumed',
            emb_size=768,
            n_layers=12,
            n_heads=12,
            learning_rate=2e-5,
            use_gradient_checkpoint=True,
            mixed_precision=True
        )
        
        # Continue training
        config = TRAINING_CONFIGS['standard']
        history = resume_trainer.train(
            dataset_key='pile_cc',
            num_epochs=3,  # Total epochs (will continue from checkpoint)
            batch_size=config['batch_size'],
            gradient_accumulation_steps=config['gradient_accumulation_steps'],
            max_samples=config['max_samples'],
            checkpoint_interval=config['checkpoint_interval'],
            save_dir=KAGGLE_MODEL_DIR,
            checkpoint_dir=KAGGLE_CHECKPOINT_DIR,
            cache_dir=KAGGLE_CACHE_DIR,
            resume_from_checkpoint=latest_checkpoint
        )
    else:
        print("❌ No checkpoints found. Train a model first.")
else:
    print("❌ Checkpoint directory doesn't exist. Train a model first.")

🔄 Resume Training from Checkpoint

📁 Available checkpoints:
  1. checkpoint_epoch_2_step_8500.pt (0.50GB)
  2. checkpoint_epoch_2_step_9000.pt (0.50GB)
  3. checkpoint_epoch_2_step_9500.pt (0.50GB)

✅ Will resume from: /kaggle/working/checkpoints/checkpoint_epoch_2_step_9500.pt
🚀 Using 2 GPUs for training
🤖 Model initialized with 123,561,216 parameters
🖥️  Using device: cuda
💾 Gradient checkpointing enabled for memory efficiency
⚡ Mixed precision training enabled

🎓 Starting training on dataset: pile_cc
   Epochs: 3, Batch size: 8
   Gradient accumulation: 8
   Effective batch size: 128
   Max samples: 1,187,804


RuntimeError: Error(s) in loading state_dict for AdvancedLLM:
	Missing key(s) in state_dict: "blocks.6.attn.qkv.weight", "blocks.6.attn.out.weight", "blocks.6.attn.out.bias", "blocks.6.ff.w1.weight", "blocks.6.ff.w2.weight", "blocks.6.ff.w3.weight", "blocks.6.ln1.weight", "blocks.6.ln2.weight", "blocks.7.attn.qkv.weight", "blocks.7.attn.out.weight", "blocks.7.attn.out.bias", "blocks.7.ff.w1.weight", "blocks.7.ff.w2.weight", "blocks.7.ff.w3.weight", "blocks.7.ln1.weight", "blocks.7.ln2.weight", "blocks.8.attn.qkv.weight", "blocks.8.attn.out.weight", "blocks.8.attn.out.bias", "blocks.8.ff.w1.weight", "blocks.8.ff.w2.weight", "blocks.8.ff.w3.weight", "blocks.8.ln1.weight", "blocks.8.ln2.weight", "blocks.9.attn.qkv.weight", "blocks.9.attn.out.weight", "blocks.9.attn.out.bias", "blocks.9.ff.w1.weight", "blocks.9.ff.w2.weight", "blocks.9.ff.w3.weight", "blocks.9.ln1.weight", "blocks.9.ln2.weight", "blocks.10.attn.qkv.weight", "blocks.10.attn.out.weight", "blocks.10.attn.out.bias", "blocks.10.ff.w1.weight", "blocks.10.ff.w2.weight", "blocks.10.ff.w3.weight", "blocks.10.ln1.weight", "blocks.10.ln2.weight", "blocks.11.attn.qkv.weight", "blocks.11.attn.out.weight", "blocks.11.attn.out.bias", "blocks.11.ff.w1.weight", "blocks.11.ff.w2.weight", "blocks.11.ff.w3.weight", "blocks.11.ln1.weight", "blocks.11.ln2.weight". 
	size mismatch for token_embed.weight: copying a param with shape torch.Size([50258, 512]) from checkpoint, the shape in current model is torch.Size([50258, 768]).
	size mismatch for blocks.0.attn.qkv.weight: copying a param with shape torch.Size([1536, 512]) from checkpoint, the shape in current model is torch.Size([2304, 768]).
	size mismatch for blocks.0.attn.out.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([768, 768]).
	size mismatch for blocks.0.attn.out.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for blocks.0.ff.w1.weight: copying a param with shape torch.Size([1408, 512]) from checkpoint, the shape in current model is torch.Size([2048, 768]).
	size mismatch for blocks.0.ff.w2.weight: copying a param with shape torch.Size([512, 1408]) from checkpoint, the shape in current model is torch.Size([768, 2048]).
	size mismatch for blocks.0.ff.w3.weight: copying a param with shape torch.Size([1408, 512]) from checkpoint, the shape in current model is torch.Size([2048, 768]).
	size mismatch for blocks.0.ln1.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for blocks.0.ln2.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for blocks.1.attn.qkv.weight: copying a param with shape torch.Size([1536, 512]) from checkpoint, the shape in current model is torch.Size([2304, 768]).
	size mismatch for blocks.1.attn.out.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([768, 768]).
	size mismatch for blocks.1.attn.out.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for blocks.1.ff.w1.weight: copying a param with shape torch.Size([1408, 512]) from checkpoint, the shape in current model is torch.Size([2048, 768]).
	size mismatch for blocks.1.ff.w2.weight: copying a param with shape torch.Size([512, 1408]) from checkpoint, the shape in current model is torch.Size([768, 2048]).
	size mismatch for blocks.1.ff.w3.weight: copying a param with shape torch.Size([1408, 512]) from checkpoint, the shape in current model is torch.Size([2048, 768]).
	size mismatch for blocks.1.ln1.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for blocks.1.ln2.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for blocks.2.attn.qkv.weight: copying a param with shape torch.Size([1536, 512]) from checkpoint, the shape in current model is torch.Size([2304, 768]).
	size mismatch for blocks.2.attn.out.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([768, 768]).
	size mismatch for blocks.2.attn.out.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for blocks.2.ff.w1.weight: copying a param with shape torch.Size([1408, 512]) from checkpoint, the shape in current model is torch.Size([2048, 768]).
	size mismatch for blocks.2.ff.w2.weight: copying a param with shape torch.Size([512, 1408]) from checkpoint, the shape in current model is torch.Size([768, 2048]).
	size mismatch for blocks.2.ff.w3.weight: copying a param with shape torch.Size([1408, 512]) from checkpoint, the shape in current model is torch.Size([2048, 768]).
	size mismatch for blocks.2.ln1.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for blocks.2.ln2.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for blocks.3.attn.qkv.weight: copying a param with shape torch.Size([1536, 512]) from checkpoint, the shape in current model is torch.Size([2304, 768]).
	size mismatch for blocks.3.attn.out.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([768, 768]).
	size mismatch for blocks.3.attn.out.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for blocks.3.ff.w1.weight: copying a param with shape torch.Size([1408, 512]) from checkpoint, the shape in current model is torch.Size([2048, 768]).
	size mismatch for blocks.3.ff.w2.weight: copying a param with shape torch.Size([512, 1408]) from checkpoint, the shape in current model is torch.Size([768, 2048]).
	size mismatch for blocks.3.ff.w3.weight: copying a param with shape torch.Size([1408, 512]) from checkpoint, the shape in current model is torch.Size([2048, 768]).
	size mismatch for blocks.3.ln1.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for blocks.3.ln2.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for blocks.4.attn.qkv.weight: copying a param with shape torch.Size([1536, 512]) from checkpoint, the shape in current model is torch.Size([2304, 768]).
	size mismatch for blocks.4.attn.out.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([768, 768]).
	size mismatch for blocks.4.attn.out.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for blocks.4.ff.w1.weight: copying a param with shape torch.Size([1408, 512]) from checkpoint, the shape in current model is torch.Size([2048, 768]).
	size mismatch for blocks.4.ff.w2.weight: copying a param with shape torch.Size([512, 1408]) from checkpoint, the shape in current model is torch.Size([768, 2048]).
	size mismatch for blocks.4.ff.w3.weight: copying a param with shape torch.Size([1408, 512]) from checkpoint, the shape in current model is torch.Size([2048, 768]).
	size mismatch for blocks.4.ln1.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for blocks.4.ln2.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for blocks.5.attn.qkv.weight: copying a param with shape torch.Size([1536, 512]) from checkpoint, the shape in current model is torch.Size([2304, 768]).
	size mismatch for blocks.5.attn.out.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([768, 768]).
	size mismatch for blocks.5.attn.out.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for blocks.5.ff.w1.weight: copying a param with shape torch.Size([1408, 512]) from checkpoint, the shape in current model is torch.Size([2048, 768]).
	size mismatch for blocks.5.ff.w2.weight: copying a param with shape torch.Size([512, 1408]) from checkpoint, the shape in current model is torch.Size([768, 2048]).
	size mismatch for blocks.5.ff.w3.weight: copying a param with shape torch.Size([1408, 512]) from checkpoint, the shape in current model is torch.Size([2048, 768]).
	size mismatch for blocks.5.ln1.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for blocks.5.ln2.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for ln_final.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for head.weight: copying a param with shape torch.Size([50258, 512]) from checkpoint, the shape in current model is torch.Size([50258, 768]).

## 💬 Test Your Model

Interactive testing with your trained models

In [12]:
# Load and test trained model
print("💬 Model Testing Interface")
print("="*60)

# List available models
if os.path.exists(KAGGLE_MODEL_DIR):
    models = sorted([f for f in os.listdir(KAGGLE_MODEL_DIR) if f.endswith('_final.pt')])
    
    if models:
        print("\n📁 Available models:")
        for i, model in enumerate(models):
            model_path = os.path.join(KAGGLE_MODEL_DIR, model)
            size_gb = os.path.getsize(model_path) / (1024**3)
            print(f"  {i+1}. {model} ({size_gb:.2f}GB)")
        
        # Load the first model (or modify to choose)
        model_path = os.path.join(KAGGLE_MODEL_DIR, models[0])
        print(f"\n📂 Loading: {model_path}")
        
        inference = OfflineInference(model_path)
        session_id = inference.create_session()
        
        # Test with sample prompts
        test_prompts = [
            "What is artificial intelligence?",
            "Explain machine learning in simple terms.",
            "How does a neural network work?",
            "What are the benefits of deep learning?"
        ]
        
        print("\n" + "="*60)
        print("🧪 Testing Model Responses")
        print("="*60 + "\n")
        
        for prompt in test_prompts:
            print(f"👤 User: {prompt}")
            response, _ = inference.chat(
                prompt, 
                session_id=session_id,
                max_new_tokens=100,
                temperature=0.7
            )
            print(f"🤖 Model: {response}")
            print("-" * 40)
            
        # Interactive chat (uncomment to enable)
        # print("\n💬 Interactive Chat (type 'quit' to exit)")
        # print("="*60)
        # while True:
        #     user_input = input("\n👤 You: ")
        #     if user_input.lower() in ['quit', 'exit', 'q']:
        #         break
        #     response, _ = inference.chat(user_input, session_id=session_id)
        #     print(f"🤖 Model: {response}")
        
    else:
        print("❌ No trained models found. Please train a model first.")
else:
    print("❌ Model directory doesn't exist. Please train a model first.")

💬 Model Testing Interface

📁 Available models:
  1. cyberdyne-quick-1M_final.pt (0.17GB)
  2. cyberdyne-standard-5M_final.pt (0.35GB)

📂 Loading: /kaggle/working/models/cyberdyne-quick-1M_final.pt
📂 Loading model from: /kaggle/working/models/cyberdyne-quick-1M_final.pt
💬 Inference engine initialized on cuda
✅ Offline inference ready!

🧪 Testing Model Responses

👤 User: What is artificial intelligence?
🤖 Model: , and is then thought to be a high @-@ wing , in which the body is a few @-@ year @-@ old black @-@ down body . The entire <unk> is a major @-@ class body , with a single time that is a red @-@ of @-@ a @-@ hour and a @-@ point <unk> @-@ old @-@ old . The two @-@ based <unk> @-@
----------------------------------------
👤 User: Explain machine learning in simple terms.
🤖 Model: : the <unk> @-@ <unk> .
----------------------------------------
👤 User: How does a neural network work?
🤖 Model: for their own .
----------------------------------------
👤 User: What are the benefits of de

## 🧹 Cleanup and Optimization

Manage disk space and optimize cache

In [None]:
# Cleanup utilities
def cleanup_cache(cache_dir=KAGGLE_CACHE_DIR, keep_gb=10):
    """Clean up cache directory to free disk space."""
    print(f"🧹 Cleaning cache directory: {cache_dir}")
    
    if not os.path.exists(cache_dir):
        print("❌ Cache directory doesn't exist")
        return
    
    # Get current size
    total_size = 0
    file_list = []
    
    for dirpath, dirnames, filenames in os.walk(cache_dir):
        for filename in filenames:
            filepath = os.path.join(dirpath, filename)
            size = os.path.getsize(filepath)
            total_size += size
            file_list.append((filepath, size, os.path.getmtime(filepath)))
    
    current_gb = total_size / (1024**3)
    print(f"📊 Current cache size: {current_gb:.2f}GB")
    
    if current_gb <= keep_gb:
        print(f"✅ Cache is within limit ({keep_gb}GB)")
        return
    
    # Sort by modification time (oldest first)
    file_list.sort(key=lambda x: x[2])
    
    # Remove oldest files until we're under the limit
    removed_size = 0
    removed_count = 0
    target_size = keep_gb * (1024**3)
    
    for filepath, size, _ in file_list:
        if total_size - removed_size <= target_size:
            break
        
        try:
            os.remove(filepath)
            removed_size += size
            removed_count += 1
        except Exception as e:
            print(f"⚠️ Failed to remove {filepath}: {e}")
    
    freed_gb = removed_size / (1024**3)
    print(f"✅ Removed {removed_count} files, freed {freed_gb:.2f}GB")
    
    # Final status
    monitor_resources()

def optimize_models(model_dir=KAGGLE_MODEL_DIR, keep_best=3):
    """Keep only the best models to save space."""
    print(f"🗂️ Optimizing model directory: {model_dir}")
    
    if not os.path.exists(model_dir):
        print("❌ Model directory doesn't exist")
        return
    
    models = [f for f in os.listdir(model_dir) if f.endswith('.pt')]
    
    if len(models) <= keep_best:
        print(f"✅ Only {len(models)} models, no cleanup needed")
        return
    
    # Sort by modification time (keep newest)
    models.sort(
        key=lambda x: os.path.getmtime(os.path.join(model_dir, x)),
        reverse=True
    )
    
    # Remove older models
    for model in models[keep_best:]:
        model_path = os.path.join(model_dir, model)
        size_gb = os.path.getsize(model_path) / (1024**3)
        os.remove(model_path)
        print(f"🗑️ Removed {model} ({size_gb:.2f}GB)")
    
    print(f"✅ Kept {keep_best} best models")

# Run cleanup
print("🧹 Cleanup and Optimization")
print("="*60)

# Check current status
monitor_resources()

# Clean cache if needed (keep 10GB)
# cleanup_cache(keep_gb=10)

# Optimize models (keep best 3)
# optimize_models(keep_best=3)

# Clear GPU cache
if device == 'cuda':
    torch.cuda.empty_cache()
    print("✅ GPU cache cleared")

# Garbage collection
gc.collect()
print("✅ Python garbage collected")

# Final status
print("\n📊 Final resource status:")
monitor_resources()

## 📊 Training Analysis & Visualization

Analyze training history and performance

In [None]:
import json
import matplotlib.pyplot as plt

def analyze_training_history(model_name):
    """Analyze and visualize training history."""
    history_path = os.path.join(KAGGLE_MODEL_DIR, f"{model_name}_history.json")
    
    if not os.path.exists(history_path):
        print(f"❌ No history found for {model_name}")
        return
    
    with open(history_path, 'r') as f:
        history = json.load(f)
    
    print(f"\n📊 Training Analysis: {model_name}")
    print("="*60)
    print(f"Dataset: {history['dataset']}")
    print(f"Epochs: {history['epochs']}")
    print(f"Batch size: {history['batch_size']} × {history['gradient_accumulation_steps']} accumulation")
    print(f"Duration: {history['duration']/3600:.2f} hours")
    print(f"Final loss: {history['epoch_losses'][-1]:.4f}")
    
    # Plot training curves
    if history['losses']:
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
        
        # Loss over steps
        steps = [l['step'] for l in history['losses']]
        losses = [l['loss'] for l in history['losses']]
        ax1.plot(steps, losses)
        ax1.set_xlabel('Steps')
        ax1.set_ylabel('Loss')
        ax1.set_title('Training Loss')
        ax1.grid(True, alpha=0.3)
        
        # Learning rate over steps
        lrs = [l['lr'] for l in history['losses']]
        ax2.plot(steps, lrs)
        ax2.set_xlabel('Steps')
        ax2.set_ylabel('Learning Rate')
        ax2.set_title('Learning Rate Schedule')
        ax2.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()
        
        # Epoch summary
        if history['epoch_losses']:
            print("\n📈 Epoch Losses:")
            for i, loss in enumerate(history['epoch_losses'], 1):
                print(f"  Epoch {i}: {loss:.4f}")

# Analyze available models
if os.path.exists(KAGGLE_MODEL_DIR):
    history_files = [f for f in os.listdir(KAGGLE_MODEL_DIR) if f.endswith('_history.json')]
    
    if history_files:
        print("📊 Available training histories:")
        for history_file in history_files:
            model_name = history_file.replace('_history.json', '')
            print(f"  - {model_name}")
            # Uncomment to analyze
            # analyze_training_history(model_name)
    else:
        print("❌ No training histories found")

## 📝 Summary & Next Steps

### ✅ What You've Accomplished
- Configured optimized dataset caching for Kaggle's 50GB SSD
- Implemented gradient accumulation for effective large batch training
- Added checkpoint resuming for interrupted training sessions
- Created configurations for different training scales (1M, 5M, 20M+ samples)
- Enabled memory optimizations (gradient checkpointing, mixed precision)
- Built resource monitoring and cleanup utilities

### 🎯 Next Steps
1. **Scale Up**: Try the large-scale Pile configuration with 20M+ samples
2. **Fine-tune**: Load a checkpoint and fine-tune on specific datasets
3. **Experiment**: Try different model architectures (layers, heads, embedding size)
4. **Deploy**: Export models for inference in production
5. **Benchmark**: Compare performance across different configurations

### 💡 Tips for Kaggle
- Always monitor disk usage with `monitor_resources()`
- Use gradient accumulation for larger effective batch sizes
- Enable gradient checkpointing for larger models
- Clean cache periodically with `cleanup_cache()`
- Save checkpoints frequently for long training runs
- Use mixed precision training for 2x speedup

### 🔗 Resources
- [Hugging Face Datasets](https://huggingface.co/datasets)
- [PyTorch Documentation](https://pytorch.org/docs/)
- [The Pile Dataset](https://pile.eleuther.ai/)

Happy Training! 🚀