# Training 200M LLaMA Model with Infini-Attention using Nanotron

This notebook demonstrates how to train a 200M parameter LLaMA model with Infini-Attention using nanotron's distributed trainer and dataloader with your preprocessed data.

Infini-Attention extends the LLaMA model with infinite context capabilities by using compressive memory.

In [None]:
# Install required packages if not available
import subprocess
import sys

def install_package(package_name):
    """Install a package if it's not already installed"""
    try:
        __import__(package_name)
        print(f"✅ {package_name} is already installed")
    except ImportError:
        print(f"📦 Installing {package_name}...")
        subprocess.check_call([sys.executable, "-m", "pip", "install", package_name])
        print(f"✅ {package_name} installed successfully")

# Install required packages
required_packages = [
    "pyyaml",  # For YAML configuration files
    "psutil",  # For system monitoring
    "tqdm",    # For progress bars
    "wandb",   # For experiment tracking (optional)
]

print("🔧 Checking and installing required packages...")
for package in required_packages:
    if package == "pyyaml":
        install_package("yaml")  # pyyaml imports as yaml
    else:
        install_package(package)

print("\n📚 Importing libraries...")
import os
import sys
import torch
import yaml
from pathlib import Path
from dataclasses import dataclass, asdict, field
from typing import Optional
import logging
import psutil
import warnings
from tqdm.auto import tqdm
import time
import json

print("✅ All libraries imported successfully!")

# GPU and Device Detection
def detect_training_hardware():
    """Detect GPU setup and configure for training"""
    print("\n🔍 Hardware Detection for Training:")
    print(f"   PyTorch version: {torch.__version__}")
    
    # Check CUDA availability
    if torch.cuda.is_available():
        gpu_count = torch.cuda.device_count()
        print(f"   ✅ CUDA available with {gpu_count} GPU(s)")
        
        flash_attention_gpus = []
        for i in range(gpu_count):
            gpu_props = torch.cuda.get_device_properties(i)
            memory_gb = gpu_props.total_memory / 1024**3
            print(f"      GPU {i}: {gpu_props.name} ({memory_gb:.1f} GB)")
            
            # Check for Ampere architecture (A100, RTX 30XX series, etc.) for Flash Attention
            is_ampere = gpu_props.major >= 8
            flash_attention_supported = is_ampere and memory_gb >= 8
            print(f"         Compute capability: {gpu_props.major}.{gpu_props.minor}")
            print(f"         Flash Attention supported: {'✅' if flash_attention_supported else '❌'}")
            
            if flash_attention_supported:
                flash_attention_gpus.append(i)
            
        current_device = torch.cuda.current_device()
        print(f"   Current device: cuda:{current_device}")
        print(f"   Flash Attention available on GPUs: {flash_attention_gpus}")
        return True, gpu_count, current_device, len(flash_attention_gpus) > 0
    else:
        print("   ⚠️  CUDA not available - training will use CPU (not recommended)")
        return False, 0, None, False

def check_flash_attention():
    """Check if Flash Attention is available"""
    try:
        import flash_attn
        print("   ✅ Flash Attention is installed")
        return True
    except ImportError:
        print("   ⚠️  Flash Attention not installed")
        print("      Install with: pip install flash-attn --no-build-isolation")
        return False

# Memory Management Class
class GPUMemoryManager:
    """Monitor and manage GPU memory during training"""
    
    def __init__(self, device_id=0):
        self.device_id = device_id
        self.has_gpu = torch.cuda.is_available()
        
    def get_memory_info(self):
        """Get current memory usage information"""
        if not self.has_gpu:
            ram = psutil.virtual_memory()
            return {
                'type': 'RAM',
                'used_gb': ram.used / 1024**3,
                'total_gb': ram.total / 1024**3,
                'percent': ram.percent
            }
        
        # GPU memory info
        torch.cuda.synchronize(self.device_id)
        allocated = torch.cuda.memory_allocated(self.device_id) / 1024**3
        reserved = torch.cuda.memory_reserved(self.device_id) / 1024**3
        total = torch.cuda.get_device_properties(self.device_id).total_memory / 1024**3
        
        return {
            'type': 'GPU',
            'allocated_gb': allocated,
            'reserved_gb': reserved,
            'total_gb': total,
            'percent': (allocated / total) * 100
        }
    
    def clear_cache(self):
        """Clear GPU cache"""
        if self.has_gpu:
            torch.cuda.empty_cache()
    
    def print_memory_summary(self, prefix=""):
        """Print memory usage summary"""
        info = self.get_memory_info()
        if info['type'] == 'GPU':
            print(f"{prefix}GPU Memory: {info['allocated_gb']:.2f}GB/{info['total_gb']:.2f}GB ({info['percent']:.1f}%)")
        else:
            print(f"{prefix}RAM: {info['used_gb']:.2f}GB/{info['total_gb']:.2f}GB ({info['percent']:.1f}%)")

# Training Device Configuration
has_gpu, gpu_count, current_gpu, flash_attention_hw_support = detect_training_hardware()
flash_attention_available = check_flash_attention() and flash_attention_hw_support

# Initialize memory manager
memory_manager = GPUMemoryManager(current_gpu if current_gpu is not None else 0)
memory_manager.print_memory_summary("🧠 Initial ")

# Device settings for training
TRAINING_DEVICE = f"cuda:{current_gpu}" if has_gpu else "cpu"
USE_FLASH_ATTENTION = has_gpu and gpu_count > 0  # Enable if GPU available
USE_DISTRIBUTED = gpu_count > 1  # Multi-GPU training
GPU_MEMORY_FRACTION = 0.95  # Use 95% of GPU memory for training

print(f"\n🎯 Training Configuration:")
print(f"   Training device: {TRAINING_DEVICE}")
print(f"   Flash Attention: {'Enabled' if USE_FLASH_ATTENTION else 'Disabled'}")
print(f"   Distributed training: {'Enabled' if USE_DISTRIBUTED else 'Disabled'}")
print(f"   GPU memory fraction: {GPU_MEMORY_FRACTION}")

# Set memory fraction if using GPU
if has_gpu:
    torch.cuda.set_per_process_memory_fraction(GPU_MEMORY_FRACTION, device=current_gpu)
    # Clear cache to start fresh
    torch.cuda.empty_cache()

# Flash Attention imports (only if supported)
flash_attention_available = False
if USE_FLASH_ATTENTION:
    try:
        # Try importing flash attention
        import flash_attn
        from flash_attn import flash_attn_func, flash_attn_with_kvcache
        flash_attention_available = True
        print(f"   ✅ Flash Attention v{flash_attn.__version__} loaded successfully")
    except ImportError as e:
        print(f"   ⚠️  Flash Attention not available: {e}")
        print(f"   💡 Install with: pip install flash-attn --no-build-isolation")
        USE_FLASH_ATTENTION = False

# Import nanotron components with robust error handling
print("\n📦 Importing nanotron components...")

try:
    # Import nanotron configuration classes
    from nanotron.config import (
        Config, 
        LlamaConfig, 
        ParallelismArgs, 
        RandomInit,
        DataArgs,
        DatasetStageArgs,
        LoggingArgs,
        CheckpointsArgs,
        GeneralArgs,
        TokensArgs,
        OptimizerArgs,
        LRSchedulerArgs,
        InfiniAttentionArgs
    )
    print("✅ Configuration classes imported successfully")
    
    # Import trainer
    from nanotron.trainer import DistributedTrainer
    print("✅ DistributedTrainer imported successfully")
    
    # Import dataloader utilities
    from nanotron.dataloader import get_datasets, clm_process, get_train_dataloader
    print("✅ Dataloader utilities imported successfully")
    
    # Import model classes
    from nanotron.models.llama import LlamaForTraining
    print("✅ LLaMA model imported successfully")
    
    # Import constants
    from nanotron import constants
    print("✅ Constants imported successfully")
    
    print("🎉 All nanotron components imported successfully!")
    
except ImportError as e:
    print(f"❌ Failed to import nanotron components: {e}")
    print("\n🛠️  Troubleshooting steps:")
    print("1. Install nanotron in development mode:")
    print("   cd /Users/zhang/Desktop/huawei/untitled folder 5/nanotron-infini")
    print("   pip install -e .")
    print("2. Check if all dependencies are installed:")
    print("   pip install torch pyyaml numpy packaging safetensors dacite tqdm")
    print("3. Restart the kernel and try again")
    raise ImportError("Cannot proceed without nanotron - please install it first")

# Setup Infini attention constants - this must be done before importing models
@dataclass
class InfiniAttentionConfig:
    segment_length: int = 64
    turn_on_memory: bool = True
    balance_factor_lr: float = 0.00015
    balance_act_type: str = "hard_sigmoid"
    balance_init_type: str = "zeros"
    logging: bool = True
    logging_interval: int = 1000
    log_grad: bool = False
    log_segment_acts: bool = False
    balance_factor_weight_decay: Optional[float] = None
    use_flash_attention: bool = USE_FLASH_ATTENTION  # Add Flash Attention flag

@dataclass
class ConfigWithInfini:
    infini_attention: InfiniAttentionConfig = field(default_factory=InfiniAttentionConfig)

# Set up the Infini attention configuration in constants
constants.CONFIG = ConfigWithInfini()

print("\n✅ All imports successful!")
print(f"🔬 Infini attention configured:")
print(f"   - Segment length: {constants.CONFIG.infini_attention.segment_length}")
print(f"   - Memory enabled: {constants.CONFIG.infini_attention.turn_on_memory}")
print(f"   - Balance activation: {constants.CONFIG.infini_attention.balance_act_type}")
print(f"   - Balance initialization: {constants.CONFIG.infini_attention.balance_init_type}")
print(f"   - Flash Attention: {constants.CONFIG.infini_attention.use_flash_attention}")
print(f"\n🚀 Ready for GPU training on {TRAINING_DEVICE}!")

In [None]:
# Install nanotron in development mode
print("\n🔧 Installing nanotron in development mode...")
nanotron_path = "/Users/zhang/Desktop/huawei/untitled folder 5/nanotron-infini"

try:
    # Change to nanotron directory and install in development mode
    import os
    original_cwd = os.getcwd()
    os.chdir(nanotron_path)
    
    # Install nanotron in development mode
    result = subprocess.run(
        [sys.executable, "-m", "pip", "install", "-e", "."],
        capture_output=True,
        text=True,
        check=False
    )
    
    if result.returncode == 0:
        print("✅ nanotron installed successfully in development mode")
    else:
        print(f"⚠️  Warning during nanotron installation: {result.stderr}")
        print("Continuing with path-based import...")
    
    # Return to original directory
    os.chdir(original_cwd)
    
except Exception as e:
    print(f"⚠️  Could not install nanotron in dev mode: {e}")
    print("Will try path-based import...")

# Add nanotron source to Python path as backup
nanotron_src_path = os.path.join(nanotron_path, "src")
if nanotron_src_path not in sys.path:
    sys.path.insert(0, nanotron_src_path)
    print(f"📁 Added to Python path: {nanotron_src_path}")

# Verify nanotron can be imported
try:
    import nanotron
    print(f"✅ nanotron imported successfully (version {nanotron.__version__})")
except ImportError as e:
    print(f"❌ Still cannot import nanotron: {e}")
    print("\n🛠️  Troubleshooting steps:")
    print("1. Check if you're in the correct directory")
    print("2. Try running: pip install -e . from the nanotron-infini directory")
    print("3. Check Python environment and dependencies")
    
    # List what's actually in the nanotron source path
    if os.path.exists(nanotron_src_path):
        print(f"\n📂 Contents of {nanotron_src_path}:")
        for item in os.listdir(nanotron_src_path):
            print(f"   - {item}")
            
        nanotron_module_path = os.path.join(nanotron_src_path, "nanotron")
        if os.path.exists(nanotron_module_path):
            print(f"\n📂 Contents of nanotron module:")
            for item in os.listdir(nanotron_module_path)[:10]:  # Show first 10 items
                print(f"   - {item}")
    
    raise ImportError("Cannot proceed without nanotron module")

## Configuration Setup

Define the training configuration for our 200M LLaMA model.

## GPU Memory Management and Configuration

Optimized memory management for training on different hardware configurations with Infini attention.

**Features:**
- 🔍 Automatic GPU detection and memory optimization
- 📊 Real-time memory monitoring during training
- ⚡ Flash Attention integration when supported
- 🔧 Dynamic batch size adjustment based on GPU memory
- 💾 Memory cleanup and cache management

In [None]:
# Model configuration for 200M parameters with Infini-Attention
# Based on the proven fineweb_local_200m_infini_config.yaml
# Optimized for GPU training with Flash Attention support

model_config = LlamaConfig(
    bos_token_id=1,
    eos_token_id=2,
    hidden_act="silu",
    hidden_size=1024,         # Proven 200M config
    initializer_range=0.02,
    intermediate_size=4096,   # 4 * hidden_size for 200M model
    max_position_embeddings=256,  # Start smaller, can be extended with Infini attention
    num_attention_heads=8,    # Proven 200M config
    num_hidden_layers=6,      # Proven 200M config
    num_key_value_heads=8,    # Same as attention heads
    pretraining_tp=1,
    rms_norm_eps=1e-5,
    rope_scaling=None,
    rope_theta=10000.0,
    tie_word_embeddings=False,  # Important: False for Infini attention
    use_cache=True,
    vocab_size=49152,         # Proven vocab size from the config
    pad_token_id=None,
    rope_interleaved=False,
    is_using_mup=False,
    # GPU and Flash Attention optimizations
    use_flash_attention_2=USE_FLASH_ATTENTION,  # Enable Flash Attention 2 if supported
    attention_dropout=0.0 if USE_FLASH_ATTENTION else 0.1,  # Disable dropout with Flash Attention
)

# GPU-optimized training configuration
GPU_TRAINING_CONFIG = {
    "device": TRAINING_DEVICE,
    "use_flash_attention": USE_FLASH_ATTENTION,
    "use_distributed": USE_DISTRIBUTED,
    "gpu_memory_fraction": GPU_MEMORY_FRACTION,
    "mixed_precision": has_gpu,  # Enable mixed precision on GPU
    "gradient_checkpointing": has_gpu,  # Enable gradient checkpointing for memory efficiency
    "compile_model": has_gpu and torch.__version__ >= "2.0",  # torch.compile for PyTorch 2.0+
}

print(f"🔧 GPU Training Configuration:")
for key, value in GPU_TRAINING_CONFIG.items():
    print(f"   {key}: {value}")

# GPU Memory Management Functions
class GPUMemoryManager:
    """Comprehensive GPU memory management for Infini attention training"""
    
    def __init__(self, device=None):
        self.device = device or TRAINING_DEVICE
        self.memory_history = []
        self.peak_memory = 0
        
    def get_memory_info(self):
        """Get current GPU memory usage information"""
        if not torch.cuda.is_available():
            return {"device": "cpu", "allocated": 0, "reserved": 0, "total": 0}
            
        allocated = torch.cuda.memory_allocated() / 1024**3  # GB
        reserved = torch.cuda.memory_reserved() / 1024**3   # GB
        total = torch.cuda.get_device_properties(0).total_memory / 1024**3  # GB
        
        return {
            "device": self.device,
            "allocated": allocated,
            "reserved": reserved,
            "total": total,
            "free": total - reserved,
            "usage_percent": (allocated / total) * 100
        }
    
    def monitor_memory(self, step_name="", log=True):
        """Monitor and log memory usage"""
        if not torch.cuda.is_available():
            return
            
        info = self.get_memory_info()
        self.memory_history.append({"step": step_name, **info})
        
        if info["allocated"] > self.peak_memory:
            self.peak_memory = info["allocated"]
            
        if log:
            print(f"📊 GPU Memory [{step_name}]:")
            print(f"   Allocated: {info['allocated']:.2f} GB ({info['usage_percent']:.1f}%)")
            print(f"   Reserved: {info['reserved']:.2f} GB")
            print(f"   Free: {info['free']:.2f} GB")
            
    def clear_cache(self, force=False):
        """Clear GPU cache and perform garbage collection"""
        if not torch.cuda.is_available():
            return
            
        if force:
            import gc
            gc.collect()
            
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
        
    def optimize_for_infini_attention(self):
        """Apply memory optimizations specific to Infini attention"""
        if not torch.cuda.is_available():
            return
            
        print("🔧 Applying Infini attention memory optimizations...")
        
        # Set memory management environment variables
        os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
        os.environ["TOKENIZERS_PARALLELISM"] = "false"
        
        # Configure cuDNN for memory efficiency
        if torch.backends.cudnn.is_available():
            torch.backends.cudnn.benchmark = True  # Optimize for consistent input sizes
            torch.backends.cudnn.deterministic = False  # Allow non-deterministic for speed
            
        # Set appropriate memory fraction
        torch.cuda.set_per_process_memory_fraction(GPU_MEMORY_FRACTION)
        
        print(f"   ✅ Memory fraction set to {GPU_MEMORY_FRACTION}")
        print(f"   ✅ Memory allocation optimized for segmented processing")
        
    def get_recommended_batch_size(self, model_size_mb=800):  # 200M model ≈ 800MB
        """Calculate recommended batch size based on available GPU memory"""
        if not torch.cuda.is_available():
            return 1
            
        info = self.get_memory_info()
        available_memory_gb = info["free"]
        
        # Conservative estimate: model + gradients + activations + buffer
        memory_per_sample_mb = model_size_mb * 3  # Model, gradients, activations
        memory_per_sample_gb = memory_per_sample_mb / 1024
        
        # Keep 2GB buffer for other operations
        usable_memory_gb = max(0.5, available_memory_gb - 2.0)
        
        recommended_batch_size = max(1, int(usable_memory_gb / memory_per_sample_gb))
        
        print(f"💡 Memory-based batch size recommendation:")
        print(f"   Available GPU memory: {available_memory_gb:.2f} GB")
        print(f"   Estimated memory per sample: {memory_per_sample_gb:.3f} GB")
        print(f"   Recommended batch size: {recommended_batch_size}")
        
        return recommended_batch_size
        
    def print_memory_summary(self):
        """Print comprehensive memory usage summary"""
        if not torch.cuda.is_available():
            print("📊 Memory Summary: CPU mode (no GPU available)")
            return
            
        current = self.get_memory_info()
        
        print("📊 GPU Memory Summary:")
        print(f"   Device: {current['device']}")
        print(f"   Total GPU Memory: {current['total']:.2f} GB")
        print(f"   Currently Allocated: {current['allocated']:.2f} GB ({current['usage_percent']:.1f}%)")
        print(f"   Currently Reserved: {current['reserved']:.2f} GB")
        print(f"   Peak Usage (session): {self.peak_memory:.2f} GB")
        print(f"   Free Memory: {current['free']:.2f} GB")
        
        if len(self.memory_history) > 1:
            print(f"   Memory tracking points: {len(self.memory_history)}")
            
# Initialize memory manager
memory_manager = GPUMemoryManager()
memory_manager.optimize_for_infini_attention()
memory_manager.monitor_memory("Initial setup", log=True)

# Calculate approximate parameter count
def estimate_parameters(config):
    """Rough estimate of model parameters"""
    vocab_size = config.vocab_size
    hidden_size = config.hidden_size
    intermediate_size = config.intermediate_size
    num_layers = config.num_hidden_layers
    
    # Embedding parameters
    embedding_params = vocab_size * hidden_size
    
    # Per layer parameters
    attention_params = 4 * hidden_size * hidden_size  # qkv + o projections
    mlp_params = 3 * hidden_size * intermediate_size   # gate, up, down
    layer_norm_params = 2 * hidden_size                # input and post attention layer norms
    
    # Infini attention adds balance factors per head
    infini_params = num_layers * config.num_attention_heads  # balance factors
    
    per_layer_params = attention_params + mlp_params + layer_norm_params + (infini_params / num_layers)
    total_layer_params = num_layers * per_layer_params
    
    # Final layer norm + lm_head (separate from embeddings when tie_word_embeddings=False)
    final_params = hidden_size + vocab_size * hidden_size
    
    total_params = embedding_params + total_layer_params + final_params
    return total_params

estimated_params = estimate_parameters(model_config)
print(f"📊 Estimated parameters: {estimated_params:,} ({estimated_params/1e6:.1f}M)")
print(f"🔬 Infini-Attention features:")
print(f"   - Segment length: {constants.CONFIG.infini_attention.segment_length}")
print(f"   - Compressive memory: {constants.CONFIG.infini_attention.turn_on_memory}")
print(f"   - Balance factor learning rate: {constants.CONFIG.infini_attention.balance_factor_lr}")

In [None]:
# Training configuration based on proven fineweb_local_200m_infini_config.yaml
config = Config(
    general=GeneralArgs(
        project="llama-200m-infini-training",
        run="llama-200m-infini-experiment",
        seed=42,
        step=None,
        consumed_train_samples=None,
        ignore_sanity_checks=False  # Keep sanity checks for safety
    ),
    
    checkpoints=CheckpointsArgs(
        checkpoints_path=Path("/Users/zhang/Desktop/huawei/untitled folder 5/nanotron-infini/checkpoints/llama_200m_infini"),
        checkpoint_interval=500,     # Save more frequently for testing
        save_initial_state=False,
        resume_checkpoint_path=None,
        checkpoints_path_is_shared_file_system=False
    ),
    
    parallelism=ParallelismArgs(
        dp=1,                    # Data parallel size
        pp=1,                    # Pipeline parallel size  
        tp=1,                    # Tensor parallel size
        pp_engine="1f1b",        # Pipeline engine
        tp_mode="ALL_REDUCE",
        tp_linear_async_communication=False,
        expert_parallel_size=1
    ),
    
    model=model_config,
    
    tokens=TokensArgs(
        sequence_length=256,     # Start with proven config length
        train_steps=5000,        # Reduced for initial testing
        micro_batch_size=4,      # Proven batch size from config
        batch_accumulation_per_replica=1,
        limit_val_batches=0
    ),
    
    optimizer=OptimizerArgs(
        zero_stage=0,
        weight_decay=0.1,        # Proven from config
        clip_grad=1.0,
        accumulate_grad_in_fp32=True,
        adam_eps=1e-8,
        adam_beta1=0.9,
        adam_beta2=0.95,
        torch_adam_is_fused=True,
        learning_rate_scheduler=LRSchedulerArgs(
            learning_rate=0.0000375,  # Proven learning rate from config
            lr_warmup_steps=500,      # Reduced proportionally
            lr_warmup_style="linear",
            lr_decay_steps=4500,      # Remaining steps after warmup
            lr_decay_style="cosine",
            min_decay_lr=0.00000375   # 10x smaller than main LR
        )
    ),
    
    logging=LoggingArgs(
        iteration_step_info_interval=50,  # More frequent logging for testing
        log_level="info",
        log_level_replica="info"
    ),
    
    data_stages=[
        DatasetStageArgs(
            name="infini_training_stage",
            start_training_step=1,
            data=DataArgs(
                dataset={
                    "parquet": {
                        "data_dir": "/Users/zhang/Desktop/huawei/untitled folder 5/nanotron-infini/data/train",
                        "data_files": "train_data.parquet"
                    }
                },
                seed=42,
                num_loading_workers=0,  # Single process for testing
                dataloader_type="single"
            )
        )
    ],
    
    # Add Infini attention configuration
    infini_attention=InfiniAttentionArgs(
        segment_length=constants.CONFIG.infini_attention.segment_length,
        turn_on_memory=constants.CONFIG.infini_attention.turn_on_memory,
        balance_factor_lr=constants.CONFIG.infini_attention.balance_factor_lr,
        balance_act_type=constants.CONFIG.infini_attention.balance_act_type,
        balance_init_type=constants.CONFIG.infini_attention.balance_init_type,
        logging=constants.CONFIG.infini_attention.logging,
        logging_interval=constants.CONFIG.infini_attention.logging_interval,
        log_grad=constants.CONFIG.infini_attention.log_grad,
        log_segment_acts=constants.CONFIG.infini_attention.log_segment_acts,
        balance_factor_weight_decay=constants.CONFIG.infini_attention.balance_factor_weight_decay
    )
)

print("✅ Configuration created successfully!")
print(f"📈 Global batch size: {config.tokens.micro_batch_size * config.tokens.batch_accumulation_per_replica * config.parallelism.dp}")
print(f"🔄 Total training steps: {config.tokens.train_steps}")
print(f"📏 Sequence length: {config.tokens.sequence_length}")
print(f"📁 Data directory: /Users/zhang/Desktop/huawei/untitled folder 5/nanotron-infini/data/train")
print(f"🔬 Infini attention segment length: {config.infini_attention.segment_length}")
print(f"🧠 Memory enabled: {config.infini_attention.turn_on_memory}")

## Data Loading Setup

Prepare the dataset and dataloader for training.

In [None]:
# Load the proven tokenizer from the config
def setup_tokenizer():
    """Setup the proven tokenizer for training"""
    try:
        from transformers import AutoTokenizer
        
        # Use the proven tokenizer from the config
        tokenizer_name = "lvwerra/the-tokenizer-v1"  # From fineweb_local_200m_infini_config.yaml
        
        print(f"🔤 Loading tokenizer: {tokenizer_name}")
        tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
        
        # Verify vocab size matches our model config
        if tokenizer.vocab_size != config.model.vocab_size:
            print(f"⚠️  Tokenizer vocab size ({tokenizer.vocab_size}) doesn't match model vocab size ({config.model.vocab_size})")
            print(f"   This is expected - the model vocab size includes padding for efficiency")
            
        # Set special tokens to match our model config
        if hasattr(tokenizer, 'pad_token_id'):
            tokenizer.pad_token_id = config.model.pad_token_id or 0
        if hasattr(tokenizer, 'bos_token_id'):
            tokenizer.bos_token_id = config.model.bos_token_id
        if hasattr(tokenizer, 'eos_token_id'):
            tokenizer.eos_token_id = config.model.eos_token_id
        
        print(f"✅ Using proven AutoTokenizer: {tokenizer.__class__.__name__}")
        return tokenizer
        
    except Exception as e:
        print(f"⚠️  Could not load proven tokenizer ({e}), using GPT-2 fallback")
        
        # Fallback to GPT-2 tokenizer
        try:
            tokenizer = AutoTokenizer.from_pretrained("gpt2")
            
            # Add padding token if missing
            if tokenizer.pad_token is None:
                tokenizer.pad_token = tokenizer.eos_token
                
            # Set special tokens
            tokenizer.pad_token_id = 0
            tokenizer.bos_token_id = config.model.bos_token_id
            tokenizer.eos_token_id = config.model.eos_token_id
            
            print(f"✅ Using GPT-2 fallback tokenizer")
            return tokenizer
            
        except Exception as e2:
            print(f"❌ Could not load any tokenizer ({e2}), using simple tokenizer")
            
            # Simple tokenizer fallback
            class SimpleTokenizer:
                """A simple tokenizer for demonstration purposes"""
                def __init__(self, vocab_size=49152):
                    self.vocab_size = vocab_size
                    self.pad_token_id = 0
                    self.bos_token_id = 1
                    self.eos_token_id = 2
                    
                def encode(self, text, add_special_tokens=True):
                    """Simple encoding that converts text to token IDs"""
                    import re
                    # Simple preprocessing: lowercase, split on whitespace and punctuation
                    words = re.findall(r'\w+|[^\w\s]', str(text).lower())
                    # Convert to token IDs (hash-based for simplicity)
                    tokens = [hash(word) % (self.vocab_size - 10) + 10 for word in words]
                    if add_special_tokens:
                        tokens = [self.bos_token_id] + tokens + [self.eos_token_id]
                    return tokens
                    
                def batch_encode_plus(self, texts, return_attention_mask=False, return_token_type_ids=False):
                    """Simple batch encoding"""
                    all_tokens = []
                    for text in texts:
                        tokens = self.encode(text, add_special_tokens=True)
                        all_tokens.append(tokens)
                    return {"input_ids": all_tokens}
            
            return SimpleTokenizer(vocab_size=config.model.vocab_size)

# Initialize tokenizer
tokenizer = setup_tokenizer()
print(f"✅ Tokenizer initialized: {tokenizer.__class__.__name__}")
print(f"   📊 Vocab size: {tokenizer.vocab_size}")
print(f"   🔤 Special tokens: BOS={tokenizer.bos_token_id}, EOS={tokenizer.eos_token_id}, PAD={getattr(tokenizer, 'pad_token_id', 0)}")

In [None]:
def load_preprocessed_data(config, tokenizer):
    """Load data from our preprocessed parquet files for Infini attention training"""
    import pandas as pd
    from datasets import Dataset, Features, Value
    
    try:
        # Load the preprocessed training data
        train_data_path = "/Users/zhang/Desktop/huawei/untitled folder 5/nanotron-infini/data/train/train_data.parquet"
        
        print(f"📚 Loading preprocessed training data from: {train_data_path}")
        
        # Check if file exists
        if not os.path.exists(train_data_path):
            raise FileNotFoundError(f"Training data not found at {train_data_path}")
        
        # Load the parquet file
        df = pd.read_parquet(train_data_path)
        print(f"✅ Loaded dataframe with shape: {df.shape}")
        print(f"📋 Columns: {list(df.columns)}")
        
        # Show sample data
        if len(df) > 0:
            print(f"👀 Sample data:")
            for col in df.columns:
                if df[col].dtype == 'object':  # Text columns
                    sample_text = str(df[col].iloc[0])[:100] + "..." if len(str(df[col].iloc[0])) > 100 else str(df[col].iloc[0])
                    print(f"   {col}: {sample_text}")
                else:
                    print(f"   {col}: {df[col].iloc[0]}")
        
        # Determine text column - look for 'text' first, then any string column
        text_columns = df.select_dtypes(include=['object']).columns.tolist()
        if 'text' in df.columns:
            text_column = 'text'
        elif len(text_columns) > 0:
            text_column = text_columns[0]
            print(f"⚠️  No 'text' column found, using '{text_column}' instead")
        else:
            raise ValueError("No text column found in the dataset")
        
        # Extract texts
        texts = df[text_column].dropna().astype(str).tolist()
        print(f"📝 Found {len(texts)} text samples")
        
        # Take a reasonable subset for training (adjust based on your needs)
        max_samples = 50000  # Start with a manageable size
        if len(texts) > max_samples:
            texts = texts[:max_samples]
            print(f"📊 Using subset of {len(texts)} samples for training")
        
        # Filter out very short texts
        min_text_length = 50  # Minimum characters
        texts = [text for text in texts if len(text.strip()) >= min_text_length]
        print(f"📝 After filtering short texts: {len(texts)} samples")
        
        if len(texts) == 0:
            raise ValueError("No valid text samples found after filtering")
        
        # Create HuggingFace dataset
        dataset_dict = {"text": texts}
        train_dataset = Dataset.from_dict(dataset_dict)
        
        print(f"✅ Created HuggingFace dataset with {len(train_dataset)} samples")
        
        # Process for causal language modeling with the proven tokenizer
        print("🔄 Processing dataset for CLM with Infini attention...")
        processed_dataset = clm_process(
            raw_dataset=train_dataset,
            tokenizer=tokenizer,
            text_column_name="text",
            dataset_processing_num_proc_per_process=1,
            dataset_overwrite_cache=True,
            sequence_length=config.tokens.sequence_length
        )
        
        print(f"✅ Dataset processed: {len(processed_dataset)} sequences")
        print(f"🔬 Ready for Infini attention training with segment length: {config.infini_attention.segment_length}")
        return processed_dataset
        
    except Exception as e:
        print(f"❌ Error loading preprocessed data: {e}")
        print("🔄 Creating dummy dataset for demonstration...")
        
        # Create dummy dataset as fallback
        import numpy as np
        from datasets import Features, Sequence, Value
        
        # Generate dummy sequences for Infini attention testing
        num_samples = 1000
        dummy_data = []
        
        for i in range(num_samples):
            # Create random sequence of token IDs
            seq_length = config.tokens.sequence_length + 1  # +1 for CLM processing
            tokens = np.random.randint(10, min(config.model.vocab_size-10, tokenizer.vocab_size), size=seq_length)
            # Ensure BOS and EOS tokens
            tokens[0] = tokenizer.bos_token_id
            tokens[-1] = tokenizer.eos_token_id
            dummy_data.append({"input_ids": tokens.tolist()})
        
        dummy_dataset = Dataset.from_list(
            dummy_data,
            features=Features({"input_ids": Sequence(feature=Value(dtype="int64"), length=config.tokens.sequence_length + 1)})
        )
        
        print(f"✅ Dummy dataset created: {len(dummy_dataset)} sequences")
        print(f"🔬 Ready for Infini attention testing")
        return dummy_dataset

# Load the preprocessed training dataset
train_dataset = load_preprocessed_data(config, tokenizer)

## Model Initialization and Training

Initialize the trainer and start training the model.

In [None]:
# Save config to file for trainer
config_path = Path("/Users/zhang/Desktop/huawei/untitled folder 5/nanotron-infini/configs/llama_200m_infini_config.yaml")
config_path.parent.mkdir(parents=True, exist_ok=True)

# Convert config to dict and save with proper YAML formatting
try:
    config_dict = config.as_dict()
except AttributeError:
    # Fallback manual serialization if as_dict doesn't work
    config_dict = {
        'general': asdict(config.general),
        'checkpoints': asdict(config.checkpoints),
        'parallelism': asdict(config.parallelism),
        'model': asdict(config.model),
        'tokens': asdict(config.tokens),
        'optimizer': asdict(config.optimizer),
        'logging': asdict(config.logging),
        'data_stages': [asdict(stage) for stage in config.data_stages],
        'infini_attention': asdict(config.infini_attention)
    }

# Convert Path objects to strings for YAML serialization
def convert_paths_to_strings(obj):
    if isinstance(obj, dict):
        return {k: convert_paths_to_strings(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [convert_paths_to_strings(item) for item in obj]
    elif isinstance(obj, Path):
        return str(obj)
    else:
        return obj

config_dict = convert_paths_to_strings(config_dict)

with open(config_path, 'w') as f:
    yaml.dump(config_dict, f, default_flow_style=False, allow_unicode=True, sort_keys=False)

print(f"✅ Infini attention config saved to: {config_path}")
print("🔬 Configuration includes:")
print(f"   - Segment length: {config.infini_attention.segment_length}")
print(f"   - Memory enabled: {config.infini_attention.turn_on_memory}")
print(f"   - Balance factor LR: {config.infini_attention.balance_factor_lr}")
print(f"   - Balance activation: {config.infini_attention.balance_act_type}")

In [None]:
# Initialize distributed environment for Infini attention training
import torch.distributed as dist

def init_distributed():
    """Initialize distributed training environment for Infini attention"""
    try:
        if not dist.is_initialized():
            # For single GPU training
            os.environ.setdefault('RANK', '0')
            os.environ.setdefault('WORLD_SIZE', '1') 
            os.environ.setdefault('MASTER_ADDR', 'localhost')
            os.environ.setdefault('MASTER_PORT', '12355')
            
            # Set CUDA device if available
            if torch.cuda.is_available():
                backend = 'nccl'
                torch.cuda.set_device(0)
                print(f"🔥 Using GPU for Infini attention training")
            else:
                backend = 'gloo'
                print("⚠️  CUDA not available, using CPU (not recommended for Infini attention)")
            
            dist.init_process_group(
                backend=backend,
                init_method='env://',
                world_size=1,
                rank=0
            )
            
            print(f"✅ Distributed environment initialized ({backend} backend)")
        else:
            print("✅ Distributed environment already initialized")
            
        # Set up for Infini attention training
        if torch.cuda.is_available():
            torch.cuda.empty_cache()  # Clear any existing cache
            print(f"🔥 GPU memory cleared for Infini attention training")
            print(f"   Available GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
            
        return True
        
    except Exception as e:
        print(f"❌ Error initializing distributed environment: {e}")
        print("⚠️  Continuing without distributed training (may affect performance)")
        return False

# Initialize distributed environment
dist_success = init_distributed()

if dist_success:
    print("🔬 Infini attention training environment ready!")
else:
    print("⚠️  Training may proceed but with limited functionality")

In [None]:
def create_dataloader_dict(train_dataset, config):
    """Create dataloader dictionary for Infini attention trainer"""
    from nanotron.parallel import ParallelContext
    
    try:
        # Create proper parallel context for Infini attention
        parallel_context = ParallelContext(
            tensor_parallel_size=config.parallelism.tp,
            pipeline_parallel_size=config.parallelism.pp,
            data_parallel_size=config.parallelism.dp,
            expert_parallel_size=config.parallelism.expert_parallel_size,
        )
        print("✅ Parallel context created successfully")
        
    except Exception as e:
        print(f"⚠️  Using mock parallel context for single GPU: {e}")
        
        # Mock parallel context for single GPU training
        class MockParallelContext:
            def __init__(self):
                if dist.is_initialized():
                    self.world_pg = dist.group.WORLD
                    self.dp_pg = dist.group.WORLD  
                    self.tp_pg = dist.group.WORLD
                    self.pp_pg = dist.group.WORLD
                    self.expert_pg = dist.group.WORLD
                else:
                    self.world_pg = None
                    self.dp_pg = None
                    self.tp_pg = None
                    self.pp_pg = None
                    self.expert_pg = None
                    
        parallel_context = MockParallelContext()
    
    # Create dataloader for Infini attention training
    try:
        dataloader = get_train_dataloader(
            train_dataset=train_dataset,
            sequence_length=config.tokens.sequence_length,
            parallel_context=parallel_context,
            input_pp_rank=0,
            output_pp_rank=0,
            micro_batch_size=config.tokens.micro_batch_size,
            consumed_train_samples=0,
            dataloader_num_workers=0,  # Keep simple for testing
            seed_worker=config.general.seed,
            dataloader_drop_last=True,
            dataloader_pin_memory=False,  # Disable for stability
            use_loop_to_round_batch_size=False,
        )
        
        print("✅ Infini attention dataloader created successfully")
        print(f"🔄 Batch size: {config.tokens.micro_batch_size}")
        print(f"📏 Sequence length: {config.tokens.sequence_length}")
        print(f"🔬 Segment length: {config.infini_attention.segment_length}")
        
        # Calculate how many segments per sequence
        segments_per_seq = config.tokens.sequence_length // config.infini_attention.segment_length
        if config.tokens.sequence_length % config.infini_attention.segment_length != 0:
            segments_per_seq += 1
        print(f"🧩 Segments per sequence: {segments_per_seq}")
        
        return {"infini_training_stage": dataloader}
        
    except Exception as e:
        print(f"❌ Error creating Infini attention dataloader: {e}")
        print("🔄 Creating simple dataloader for testing...")
        
        # Create a simple dummy dataloader for testing
        from torch.utils.data import DataLoader, Dataset
        
        class InfiniDummyDataset(Dataset):
            def __init__(self, size, seq_len, vocab_size):
                self.size = size
                self.seq_len = seq_len
                self.vocab_size = vocab_size
                
            def __len__(self):
                return self.size
                
            def __getitem__(self, idx):
                # Create data suitable for Infini attention
                input_ids = torch.randint(0, self.vocab_size, (self.seq_len,))
                input_mask = torch.ones(self.seq_len, dtype=torch.bool)
                label_ids = torch.randint(0, self.vocab_size, (self.seq_len,))
                label_mask = torch.ones(self.seq_len, dtype=torch.bool)
                
                return {
                    "input_ids": input_ids,
                    "input_mask": input_mask,
                    "label_ids": label_ids,
                    "label_mask": label_mask,
                }
        
        dummy_dataset = InfiniDummyDataset(
            size=1000, 
            seq_len=config.tokens.sequence_length, 
            vocab_size=min(config.model.vocab_size, 32000)  # Cap for testing
        )
        
        dummy_dataloader = DataLoader(
            dummy_dataset,
            batch_size=config.tokens.micro_batch_size,
            shuffle=True,
            drop_last=True
        )
        
        print("✅ Dummy Infini attention dataloader created")
        return {"infini_training_stage": dummy_dataloader}

# Create dataloader for Infini attention training
dataloader_dict = create_dataloader_dict(train_dataset, config)

In [None]:
# Initialize Infini attention trainer
print("🚀 Initializing Infini attention trainer...")

try:
    # Ensure the config path is properly set for the trainer
    trainer = DistributedTrainer(
        config_or_config_file=config,
        model_class=LlamaForTraining  # This will use Infini attention with our constants setup
    )
    print("✅ Infini attention trainer initialized successfully!")
    
    # Print model information
    print(f"📊 Model: {trainer.unwrapped_model.__class__.__name__}")
    
    # Count actual parameters
    total_params = sum(p.numel() for p in trainer.model.parameters())
    trainable_params = sum(p.numel() for p in trainer.model.parameters() if p.requires_grad)
    
    print(f"📈 Total parameters: {total_params:,} ({total_params/1e6:.1f}M)")
    print(f"🎯 Trainable parameters: {trainable_params:,} ({trainable_params/1e6:.1f}M)")
    print(f"💾 Model size (estimated): {total_params * 4 / 1024**2:.1f} MB")
    
    # Check for Infini attention specific components
    infini_attention_layers = 0
    balance_factors_count = 0
    
    for name, module in trainer.unwrapped_model.named_modules():
        if hasattr(module, 'segment_length'):
            infini_attention_layers += 1
        if hasattr(module, 'balance_factors'):
            balance_factors_count += 1
    
    print(f"🔬 Infini attention components:")
    print(f"   - Layers with segmentation: {infini_attention_layers}")
    print(f"   - Balance factors: {balance_factors_count}")
    print(f"   - Segment length: {constants.CONFIG.infini_attention.segment_length}")
    print(f"   - Memory enabled: {constants.CONFIG.infini_attention.turn_on_memory}")
    print(f"   - Balance activation: {constants.CONFIG.infini_attention.balance_act_type}")
    
    # Check GPU memory usage if available
    if torch.cuda.is_available():
        torch.cuda.synchronize()
        allocated = torch.cuda.memory_allocated() / 1024**2
        reserved = torch.cuda.memory_reserved() / 1024**2
        print(f"🔥 GPU Memory:")
        print(f"   - Allocated: {allocated:.1f} MB")
        print(f"   - Reserved: {reserved:.1f} MB")
    
except Exception as e:
    print(f"❌ Error initializing Infini attention trainer: {e}")
    import traceback
    traceback.print_exc()
    trainer = None
    
    print("⚠️  Troubleshooting tips:")
    print("1. Ensure all Infini attention constants are properly set")
    print("2. Check that the configuration is valid")
    print("3. Verify GPU memory is sufficient")
    print("4. Try reducing batch size if out of memory")

## Training Loop

Start the training process.

In [None]:
import time  # Add missing import

if trainer is not None:
    print("🏃‍♂️ Starting Infini attention training...")
    print(f"🔬 Training with {config.infini_attention.segment_length}-token segments")
    print(f"🧠 Memory state: {'enabled' if config.infini_attention.turn_on_memory else 'disabled'}")
    
    try:
        # Pre-training setup
        start_time = time.time()
        initial_step = getattr(trainer, 'iteration_step', 0)
        
        print(f"📅 Training started at step {initial_step}")
        print(f"🎯 Target steps: {config.tokens.train_steps}")
        
        # Start Infini attention training
        trainer.train(dataloader_or_dls=dataloader_dict)
        
        # Calculate training time
        end_time = time.time()
        training_time = end_time - start_time
        
        print("🎉 Infini attention training completed successfully!")
        print(f"⏱️  Training time: {training_time:.2f} seconds ({training_time/60:.1f} minutes)")
        
        # Final statistics
        final_step = getattr(trainer, 'iteration_step', 0)
        steps_completed = final_step - initial_step
        
        print(f"📈 Training summary:")
        print(f"   - Steps completed: {steps_completed}")
        print(f"   - Average time per step: {training_time/max(steps_completed, 1):.2f} seconds")
        
        if torch.cuda.is_available():
            print(f"   - Peak GPU memory: {torch.cuda.max_memory_allocated() / 1024**2:.1f} MB")
        
    except KeyboardInterrupt:
        print("⚠️  Training interrupted by user")
        
        # Try to save checkpoint if possible
        try:
            if hasattr(trainer, 'save_checkpoint'):
                checkpoint_path = trainer.save_checkpoint()
                print(f"💾 Emergency checkpoint saved to: {checkpoint_path}")
        except Exception as save_error:
            print(f"❌ Could not save emergency checkpoint: {save_error}")
            
    except torch.cuda.OutOfMemoryError as oom_error:
        print(f"❌ GPU out of memory during Infini attention training: {oom_error}")
        print("💡 Suggestions to fix:")
        print("   1. Reduce micro_batch_size in config")
        print("   2. Reduce sequence_length")
        print("   3. Reduce segment_length for Infini attention")
        print("   4. Enable gradient checkpointing")
        print("   5. Use mixed precision training")
        
        # Try to save what we can
        try:
            torch.cuda.empty_cache()
            if hasattr(trainer, 'save_checkpoint'):
                checkpoint_path = trainer.save_checkpoint()
                print(f"💾 Emergency checkpoint saved to: {checkpoint_path}")
        except Exception as save_error:
            print(f"❌ Could not save emergency checkpoint: {save_error}")
            
    except Exception as e:
        print(f"❌ Error during Infini attention training: {e}")
        import traceback
        traceback.print_exc()
        
        # Try to save checkpoint if possible
        try:
            if hasattr(trainer, 'save_checkpoint'):
                checkpoint_path = trainer.save_checkpoint()
                print(f"💾 Emergency checkpoint saved to: {checkpoint_path}")
        except Exception as save_error:
            print(f"❌ Could not save emergency checkpoint: {save_error}")
            
else:
    print("❌ Cannot start training - Infini attention trainer initialization failed")
    print("💡 Please check the error messages above and:")
    print("   1. Verify the configuration is correct")
    print("   2. Ensure sufficient GPU memory")
    print("   3. Check that all dependencies are installed")
    print("   4. Try running the data preprocessing notebook first")

## Training Monitoring and Results

Monitor training progress and analyze results.

In [None]:
# Infini attention training monitoring and results
if trainer is not None:
    print("📈 Infini Attention Training Summary:")
    print("=" * 50)
    
    # Basic training stats
    try:
        current_step = getattr(trainer, 'iteration_step', 0)
        consumed_samples = getattr(trainer, 'consumed_train_samples', 0)
        
        print(f"✅ Training steps completed: {current_step - 1}")
        print(f"📈 Consumed samples: {consumed_samples:,}")
        
        # Calculate effective context
        sequence_length = config.tokens.sequence_length
        segment_length = config.infini_attention.segment_length
        segments_per_sequence = sequence_length // segment_length
        
        print(f"📏 Sequence length: {sequence_length}")
        print(f"🔬 Segment length: {segment_length}")
        print(f"🧩 Segments per sequence: {segments_per_sequence}")
        print(f"🧠 Memory-augmented context: Infinite (via compressive memory)")
        
        # Batch information
        global_batch_size = config.tokens.micro_batch_size * config.tokens.batch_accumulation_per_replica * config.parallelism.dp
        print(f"🔢 Global batch size: {global_batch_size}")
        
    except Exception as e:
        print(f"⚠️  Could not retrieve training stats: {e}")
    
    # Check checkpoints
    checkpoints_dir = config.checkpoints.checkpoints_path
    if checkpoints_dir.exists():
        checkpoint_folders = list(checkpoints_dir.glob("*/"))
        if checkpoint_folders:
            print(f"💾 Infini attention checkpoints saved: {len(checkpoint_folders)}")
            for cp in sorted(checkpoint_folders)[-3:]:  # Show last 3 checkpoints
                size_mb = sum(f.stat().st_size for f in cp.rglob('*') if f.is_file()) / 1024**2
                print(f"   📁 {cp.name} ({size_mb:.1f} MB)")
        else:
            print("❌ No checkpoints found")
    else:
        print(f"❌ Checkpoint directory does not exist: {checkpoints_dir}")
    
    # Memory usage
    if torch.cuda.is_available():
        print(f"🔥 GPU Memory Usage:")
        print(f"   - Current allocated: {torch.cuda.memory_allocated() / 1024**2:.1f} MB")
        print(f"   - Current reserved: {torch.cuda.memory_reserved() / 1024**2:.1f} MB")
        print(f"   - Peak allocated: {torch.cuda.max_memory_allocated() / 1024**2:.1f} MB")
        
        # Calculate memory efficiency
        total_memory = torch.cuda.get_device_properties(0).total_memory / 1024**2
        memory_usage_pct = (torch.cuda.max_memory_allocated() / (total_memory * 1024**2)) * 100
        print(f"   - Peak usage: {memory_usage_pct:.1f}% of total GPU memory")
        
    # Infini attention specific metrics
    print(f"🔬 Infini Attention Configuration:")
    print(f"   - Segment processing: {config.infini_attention.segment_length} tokens per segment")
    print(f"   - Compressive memory: {'Enabled' if config.infini_attention.turn_on_memory else 'Disabled'}")
    print(f"   - Balance factor LR: {config.infini_attention.balance_factor_lr}")
    print(f"   - Balance activation: {config.infini_attention.balance_act_type}")
    print(f"   - Balance initialization: {config.infini_attention.balance_init_type}")
    
    # Model architecture summary
    if hasattr(trainer, 'unwrapped_model'):
        total_params = sum(p.numel() for p in trainer.model.parameters())
        print(f"🏠 Model Architecture:")
        print(f"   - Total parameters: {total_params:,} ({total_params/1e6:.1f}M)")
        print(f"   - Hidden size: {config.model.hidden_size}")
        print(f"   - Attention heads: {config.model.num_attention_heads}")
        print(f"   - Layers: {config.model.num_hidden_layers}")
        print(f"   - Vocab size: {config.model.vocab_size}")
        
else:
    print("❌ No training results available - trainer initialization failed")
    print("💡 Check the error messages above for troubleshooting steps")

In [None]:
# Infini attention model inference test
if trainer is not None:
    try:
        print("🧪 Testing Infini attention model inference...")
        
        # Put model in eval mode
        trainer.model.eval()
        
        # Create test input suitable for Infini attention
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        
        # Test with different sequence lengths to verify Infini attention
        test_lengths = [32, 128, config.tokens.sequence_length]
        
        for seq_len in test_lengths:
            print(f"\n📏 Testing sequence length: {seq_len}")
            
            # Create test input
            test_input = torch.randint(
                10, min(config.model.vocab_size - 10, 1000),  # Use smaller vocab for testing
                (1, seq_len),  # batch_size=1
                device=device
            )
            test_mask = torch.ones_like(test_input, dtype=torch.bool)
            
            # Ensure proper token boundaries
            test_input[0, 0] = tokenizer.bos_token_id  # Start with BOS
            test_input[0, -1] = tokenizer.eos_token_id  # End with EOS
            
            with torch.no_grad():
                # Test forward pass with Infini attention
                start_time = time.time()
                
                outputs = trainer.unwrapped_model.model.forward(
                    input_ids=test_input,
                    input_mask=test_mask
                )
                
                inference_time = time.time() - start_time
                
                # Extract logits from outputs
                if isinstance(outputs, dict):
                    logits = outputs.get('hidden_states', outputs.get('logits', None))
                elif hasattr(outputs, 'logits'):
                    logits = outputs.logits
                else:
                    logits = outputs
                
                if logits is not None:
                    print(f"   ✅ Inference successful!")
                    print(f"   📉 Input shape: {test_input.shape}")
                    print(f"   📉 Output shape: {logits.shape}")
                    print(f"   ⏱️  Inference time: {inference_time*1000:.2f} ms")
                    print(f"   🎯 Max output value: {logits.max().item():.3f}")
                    print(f"   🎯 Min output value: {logits.min().item():.3f}")
                    
                    # Calculate segments processed
                    segments = seq_len // config.infini_attention.segment_length
                    if seq_len % config.infini_attention.segment_length != 0:
                        segments += 1
                    print(f"   🧩 Segments processed: {segments}")
                    
                    # Memory usage for this inference
                    if torch.cuda.is_available():
                        memory_used = torch.cuda.memory_allocated() / 1024**2
                        print(f"   🔥 GPU memory: {memory_used:.1f} MB")
                else:
                    print(f"   ⚠️  Could not extract logits from output")
            
            # Clear cache between tests
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
        
        # Test segmented processing specifically
        print(f"\n🔬 Testing Infini attention segmentation:")
        
        # Create a longer sequence to test segmentation
        long_seq_len = config.infini_attention.segment_length * 3 + 10  # Multiple segments
        print(f"   Testing with {long_seq_len} tokens ({long_seq_len // config.infini_attention.segment_length + 1} segments)")
        
        long_test_input = torch.randint(
            10, min(config.model.vocab_size - 10, 1000),
            (1, long_seq_len),
            device=device
        )
        long_test_mask = torch.ones_like(long_test_input, dtype=torch.bool)
        
        with torch.no_grad():
            start_time = time.time()
            long_outputs = trainer.unwrapped_model.model.forward(
                input_ids=long_test_input,
                input_mask=long_test_mask
            )
            segmented_time = time.time() - start_time
            
            print(f"   ✅ Segmented processing successful!")
            print(f"   ⏱️  Processing time: {segmented_time*1000:.2f} ms")
            print(f"   🧠 Memory state: {'Active' if config.infini_attention.turn_on_memory else 'Inactive'}")
        
        print(f"\n🎉 All Infini attention inference tests passed!")
        
    except Exception as e:
        print(f"❌ Infini attention inference test failed: {e}")
        import traceback
        traceback.print_exc()
        
        print(f"\n💡 Troubleshooting:")
        print("   1. Check if model is properly initialized with Infini attention")
        print("   2. Verify input dimensions are correct")
        print("   3. Ensure sufficient GPU memory")
        print("   4. Try with smaller sequence lengths")
else:
    print("❌ No inference test possible - trainer not initialized")

## Cleanup and Next Steps

Clean up resources and provide guidance for next steps.

In [None]:
# Cleanup Infini attention training resources
print("🧹 Cleaning up Infini attention training session...")

# Clear CUDA cache
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    print("✅ CUDA cache cleared")
    
    # Final memory report
    final_memory = torch.cuda.memory_allocated() / 1024**2
    print(f"   Final GPU memory usage: {final_memory:.1f} MB")

# Reset model to free memory
if 'trainer' in locals() and trainer is not None:
    # Try to save final state if training was successful
    try:
        if hasattr(trainer, 'save_checkpoint'):
            final_checkpoint = trainer.save_checkpoint()
            print(f"💾 Final checkpoint saved: {final_checkpoint}")
    except Exception as e:
        print(f"⚠️  Could not save final checkpoint: {e}")
    
    del trainer
    print("✅ Infini attention trainer object deleted")

# Reset constants to clean state
try:
    if hasattr(constants, 'CONFIG'):
        # Don't delete, just note the state
        print("✅ Infini attention constants preserved for future use")
except:
    pass

print("🎉 Infini attention training session completed!")
print("\n" + "="*60)
print("📝 INFINI ATTENTION TRAINING SUMMARY:")
print("="*60)
print("🏠 Model: LLaMA 200M with Infini-Attention")
print("📚 Dataset: Preprocessed parquet data")
print(f"📁 Data source: /Users/zhang/Desktop/huawei/untitled folder 5/nanotron-infini/data/train/train_data.parquet")
print(f"🔬 Segment Length: {config.infini_attention.segment_length} tokens")
print(f"🧠 Compressive Memory: {'Enabled' if config.infini_attention.turn_on_memory else 'Disabled'}")
print(f"⚖️  Balance Activation: {config.infini_attention.balance_act_type}")
print(f"📏 Sequence Length: {config.tokens.sequence_length}")
print(f"🔢 Micro Batch Size: {config.tokens.micro_batch_size}")
print(f"🔄 Training Steps: {config.tokens.train_steps}")
print(f"📈 Learning Rate: {config.optimizer.learning_rate_scheduler.learning_rate}")
print(f"💾 Checkpoints: {config.checkpoints.checkpoints_path}")
print("="*60)

In [None]:
# Final Infini attention training status report
print("📋 INFINI ATTENTION TRAINING SESSION SUMMARY")
print("="*50)
print(f"🏠 Model: LLaMA 200M with Infini-Attention")
print(f"📚 Dataset: Preprocessed parquet data")
print(f"📁 Data source: /Users/zhang/Desktop/huawei/untitled folder 5/nanotron-infini/data/train/train_data.parquet")
print(f"⚙️  Configuration: {config_path}")
print()
print("🔬 INFINI ATTENTION FEATURES:")
print(f"   🧩 Segment Length: {config.infini_attention.segment_length} tokens")
print(f"   🧠 Compressive Memory: {'Enabled' if config.infini_attention.turn_on_memory else 'Disabled'}")
print(f"   ⚖️  Balance Factor LR: {config.infini_attention.balance_factor_lr}")
print(f"   🔄 Balance Activation: {config.infini_attention.balance_act_type}")
print(f"   🎨 Balance Initialization: {config.infini_attention.balance_init_type}")
print()
print("📊 TRAINING CONFIGURATION:")
print(f"   📏 Sequence Length: {config.tokens.sequence_length}")
print(f"   🔢 Micro Batch Size: {config.tokens.micro_batch_size}")
print(f"   🔄 Training Steps: {config.tokens.train_steps}")
print(f"   📈 Learning Rate: {config.optimizer.learning_rate_scheduler.learning_rate}")
print(f"   💾 Checkpoints: {config.checkpoints.checkpoints_path}")
print()
print("📁 FILES AND PATHS:")
print(f"   📄 Config file: {config_path}")
print(f"   📁 Project directory: /Users/zhang/Desktop/huawei/untitled folder 5/nanotron-infini")
print(f"   📚 Training data: /Users/zhang/Desktop/huawei/untitled folder 5/nanotron-infini/data/train/")
print(f"   💾 Checkpoints: {config.checkpoints.checkpoints_path}")

print("\n✅ Infini attention training setup complete!")

print("\n💡 TO USE YOUR OWN DATA:")
print("1. 🚀 Run the data preprocessing notebook (data.ipynb) first")
print("2. 📄 Make sure train_data.parquet exists in the data/train/ directory")
print("3. 🔤 Update the text column name if different from 'text'")
print("4. ⚙️  Adjust segment_length based on your data characteristics")
print("5. 📊 Monitor memory usage and adjust batch_size if needed")

print("\n🔬 INFINI ATTENTION BENEFITS:")
print("• 🧠 Infinite context length through compressive memory")
print("• 🚀 Efficient processing with fixed segment sizes")
print("• ⚖️  Learnable balance between local and global attention")
print("• 💾 Memory-efficient for long sequences")

print("\n📈 NEXT STEPS FOR PRODUCTION:")
print("1. 🔍 Scale up training with more data and longer sequences")
print("2. 📏 Experiment with different segment lengths")
print("3. ⚖️  Tune balance factor learning rates")
print("4. 📈 Monitor attention patterns and memory usage")
print("5. 🧪 Test on long-context evaluation tasks")
print("6. 💾 Save models in formats suitable for inference")

print("\n" + "="*50)
print("🎆 Happy training with Infini-Attention! 🎆")
print("="*50)