# PantheraML TPU Inference Example

This notebook demonstrates how to perform high-performance model inference on TPUs using PantheraML's advanced TPU functions powered by **PantheraML-Zoo** (the TPU-enabled fork of unsloth_zoo).

## Features Covered:
- TPU device detection and setup
- Model loading with TPU optimizations using PantheraML-Zoo
- Phase 1: Basic TPU inference with error handling
- Phase 2: Performance-optimized inference with XLA
- Phase 3: Advanced multi-pod inference with JAX integration
- Memory management and optimization
- Batch inference and dynamic shapes

## Requirements:
- TPU runtime (Google Colab TPU or Google Cloud TPU)
- PantheraML with PantheraML-Zoo for full TPU support
- torch-xla for XLA integration

## PantheraML-Zoo:
This notebook leverages **PantheraML-Zoo**, our TPU-optimized fork of unsloth_zoo that provides:
- Enhanced TPU compatibility
- Advanced XLA compilation support
- Multi-pod TPU training capabilities
- Optimized memory management for TPU workloads

## 🚨 Troubleshooting Guide

### Common Issues and Solutions:

**Issue 1**: `NotImplementedError: PantheraML currently only works on NVIDIA GPUs, Intel GPUs, and TPUs (experimental).`
- **Solution**: Run the environment setup cells in order. The notebook sets proper environment variables to force TPU detection.

**Issue 2**: `ImportError: No module named 'torch_xla'`
- **Solution**: Make sure you're using a TPU runtime in Colab (Runtime > Change runtime type > Hardware accelerator > TPU)

**Issue 3**: TPU detection fails
- **Solution**: The notebook includes fallback mechanisms. Even if some PantheraML TPU features aren't available, basic inference will work.

**Issue 4**: Memory errors
- **Solution**: Use smaller models or reduce batch sizes. The notebook includes memory management examples.

### ✅ Quick Setup Checklist:
1. Enable TPU in Colab runtime settings
2. Run the "Quick Environment Test" cell first
3. Wait for all imports to complete before proceeding
4. If errors occur, try restarting runtime and running cells in order

---

## 1. Environment Setup and TPU Detection

In [None]:
# Quick Environment Test - Run this first!
print("🔧 Quick TPU Environment Test")
print("=" * 35)

# Test 1: Basic torch_xla
try:
    import torch_xla.core.xla_model as xm
    device = xm.xla_device()
    print(f"✅ Test 1: torch_xla working - Device: {device}")
except Exception as e:
    print(f"❌ Test 1: torch_xla failed: {e}")
    print("   Solution: Make sure you're running on a TPU runtime")

# Test 2: Basic tensor operation on TPU
try:
    x = torch.tensor([1.0, 2.0, 3.0]).to(device)
    y = x * 2
    xm.mark_step()  # Force execution
    print(f"✅ Test 2: TPU tensor operations working")
except Exception as e:
    print(f"❌ Test 2: TPU operations failed: {e}")

# Test 3: Device type detection
print(f"🔍 Current environment:")
print(f"   Runtime type: {'TPU' if 'TPU_NAME' in os.environ else 'Unknown'}")
print(f"   XLA available: {torch.cuda.is_available() == False}")  # TPUs typically don't show CUDA

print(f"\n{'✅ Environment ready!' if 'x' in locals() else '❌ Environment needs setup'}")

In [None]:
# Install required packages (run this in Colab TPU environment)
!pip install -q torch-xla
!pip install -q transformers datasets
!pip install -q accelerate bitsandbytes

# Set TPU environment variables BEFORE importing PantheraML
import os
import sys

# Force TPU detection by setting environment variables
os.environ["TPU_NAME"] = "local"  # For Colab TPU
os.environ["PANTHERAML_FORCE_TPU"] = "1"  # Force TPU mode
os.environ["DEVICE_TYPE"] = "tpu"  # Override device detection

# Check if we're on a TPU
try:
    import torch_xla.core.xla_model as xm
    print("🔍 TPU environment detected, setting up...")
    print(f"   XLA devices: {xm.get_xla_supported_devices()}")
    tpu_available = True
except ImportError:
    print("❌ torch-xla not available. Installing...")
    !pip install torch-xla
    import torch_xla.core.xla_model as xm
    tpu_available = True

# Now import PantheraML - it should detect TPU correctly
print("📦 Importing PantheraML with TPU support...")
import pantheraml
from pantheraml import FastLanguageModel

# Import TPU-specific modules
try:
    from pantheraml.kernels import tpu_kernels, tpu_performance, tpu_advanced
    print("✅ TPU kernels imported successfully")
except ImportError as e:
    print(f"⚠️ TPU kernels not available: {e}")
    print("   Using fallback imports...")
    # Create mock modules for compatibility
    class MockTPUModule:
        def __getattr__(self, name):
            return lambda *args, **kwargs: None
    tpu_kernels = MockTPUModule()
    tpu_performance = MockTPUModule()
    tpu_advanced = MockTPUModule()

from pantheraml.distributed import (
    is_tpu_available, 
    setup_multi_tpu, 
    MultiTPUConfig,
    get_tpu_rank,
    get_tpu_world_size
)

import torch
import numpy as np
import time
from typing import List, Dict, Any

print("✅ All imports successful!")
print(f"   PantheraML device type: {pantheraml.DEVICE_TYPE if hasattr(pantheraml, 'DEVICE_TYPE') else 'Unknown'}")

In [None]:
# Install PantheraML-Zoo (TPU-enabled fork of unsloth_zoo)
print("📦 Installing PantheraML-Zoo for optimal TPU support...")
try:
    !pip install -q git+https://github.com/PantheraAIML/PantheraML-Zoo.git
    print("✅ PantheraML-Zoo installed successfully (full TPU support)")
except Exception as e:
    print(f"⚠️ Failed to install PantheraML-Zoo: {e}")
    print("   Fallback: Installing original unsloth_zoo...")
    try:
        !pip install -q unsloth_zoo
        print("✅ unsloth_zoo installed (limited TPU support)")
    except Exception as e2:
        print(f"❌ Failed to install unsloth_zoo: {e2}")
        print("   Some features may be limited")

# Verify installation
try:
    import pantheraml_zoo
    print("✅ PantheraML-Zoo import successful")
except ImportError:
    try:
        import unsloth_zoo
        print("⚠️ Using unsloth_zoo fallback")
    except ImportError:
        print("❌ No zoo library available")

In [None]:
# TPU Environment Detection and Setup
print("🔍 Detecting TPU Environment...")
print("=" * 50)

# First check torch_xla availability (primary TPU indicator)
try:
    import torch_xla.core.xla_model as xm
    import torch_xla.distributed.xla_multiprocessing as xmp
    
    # Get TPU device information
    device = xm.xla_device()
    
    print(f"✅ TPU Available!")
    print(f"   Device: {device}")
    print(f"   XLA Devices: {xm.get_xla_supported_devices()}")
    
    tpu_detected = True
    
except Exception as e:
    print(f"❌ TPU not available: {e}")
    print(f"   Please enable TPU in Colab: Runtime > Change runtime type > Hardware accelerator > TPU")
    raise RuntimeError("TPU required for this example")

# Check PantheraML's TPU detection
try:
    pantheraml_tpu_available = is_tpu_available()
    print(f"   PantheraML TPU detection: {pantheraml_tpu_available}")
except Exception as e:
    print(f"   ⚠️ PantheraML TPU detection failed: {e}")
    print(f"   Continuing with manual TPU setup...")
    pantheraml_tpu_available = True  # Force it since we know TPU is available

if tpu_detected:
    # Get world size and rank
    try:
        world_size = get_tpu_world_size()
        rank = get_tpu_rank()
    except:
        # Fallback to XLA methods
        world_size = xm.xrt_world_size()
        rank = xm.get_ordinal()
    
    print(f"   World Size: {world_size}")
    print(f"   Rank: {rank}")
    
    # Initialize TPU kernels with Phase 1 enhancements
    print(f"\n🚀 Initializing TPU Kernels...")
    try:
        if hasattr(tpu_kernels, 'initialize_tpu_kernels') and callable(tpu_kernels.initialize_tpu_kernels):
            if tpu_kernels.initialize_tpu_kernels():
                print(f"✅ Phase 1 TPU kernels initialized")
                
                # Get TPU status
                if hasattr(tpu_kernels, 'get_tpu_status'):
                    status = tpu_kernels.get_tpu_status()
                    print(f"   Memory Available: {status.get('memory_info', {}).get('gb_limit', 'Unknown')} GB")
                    print(f"   Cores Available: {status.get('device_info', {}).get('cores', 'Unknown')}")
            else:
                print(f"⚠️ TPU kernels initialization failed")
        else:
            print(f"⚠️ TPU kernels not available, using basic TPU support")
            
    except Exception as e:
        print(f"⚠️ TPU kernel initialization error: {e}")
        print(f"   Continuing with basic TPU support...")

print(f"\n✅ TPU setup complete!")

## 2. Model Loading with TPU Optimizations

In [None]:
# Model configuration for TPU inference
MODEL_CONFIG = {
    "model_name": "Qwen/Qwen2.5-1.5B-Instruct",  # Good size for TPU inference
    "max_seq_length": 2048,
    "dtype": torch.bfloat16,  # Optimal for TPU
    "load_in_4bit": False,   # Not supported on TPU
    "device_map": None       # Manual TPU placement
}

print(f"📦 Loading Model for TPU Inference...")
print(f"   Model: {MODEL_CONFIG['model_name']}")
print(f"   Max Length: {MODEL_CONFIG['max_seq_length']}")
print(f"   Data Type: {MODEL_CONFIG['dtype']}")

# Load model and tokenizer
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=MODEL_CONFIG["model_name"],
    max_seq_length=MODEL_CONFIG["max_seq_length"],
    dtype=MODEL_CONFIG["dtype"],
    load_in_4bit=MODEL_CONFIG["load_in_4bit"],
    device_map=MODEL_CONFIG["device_map"]
)

print(f"✅ Model loaded successfully")
print(f"   Parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"   Memory Size: {sum(p.numel() * p.element_size() for p in model.parameters()) / 1024 / 1024:.1f} MB")

In [None]:
# Move model to TPU with optimizations
print(f"🔄 Moving model to TPU...")

# Enable inference mode for optimization
model = FastLanguageModel.for_inference(model)

# Move to TPU device
model = model.to(device)

# Apply TPU-specific optimizations
print(f"⚡ Applying TPU optimizations...")

# Phase 1: Basic error handling and memory management
tpu_memory_manager = tpu_kernels.tpu_memory_manager
if tpu_memory_manager:
    tpu_memory_manager.clear_cache()
    memory_stats = tpu_memory_manager.get_memory_stats()
    print(f"   Memory after model loading: {memory_stats.get('allocated_gb', 0):.1f} GB")

print(f"✅ Model ready for TPU inference")
print(f"   Model device: {next(model.parameters()).device}")
print(f"   Model dtype: {next(model.parameters()).dtype}")

## 3. Phase 1: Basic TPU Inference with Error Handling

In [None]:
# Phase 1: Basic TPU inference with comprehensive error handling
print(f"🧪 Phase 1: Basic TPU Inference")
print(f"=" * 40)

def basic_tpu_inference(prompt: str, max_new_tokens: int = 50) -> str:
    """Basic TPU inference with Phase 1 error handling."""
    
    try:
        # Tokenize input
        inputs = tokenizer(
            prompt,
            return_tensors="pt",
            truncation=True,
            max_length=MODEL_CONFIG["max_seq_length"] - max_new_tokens
        )
        
        # Move inputs to TPU
        inputs = {k: v.to(device) for k, v in inputs.items()}
        
        # Generate with error handling
        with torch.no_grad():
            # Try to use TPU error handler if available
            try:
                error_handler = getattr(tpu_kernels, 'tpu_error_handler', None)
                
                if error_handler and hasattr(error_handler, 'safe_execute'):
                    # Wrapped generation with error handling
                    def _generate():
                        return model.generate(
                            **inputs,
                            max_new_tokens=max_new_tokens,
                            do_sample=True,
                            temperature=0.7,
                            top_p=0.9,
                            pad_token_id=tokenizer.eos_token_id
                        )
                    
                    outputs = error_handler.safe_execute(
                        _generate,
                        "TPU inference generation",
                        max_retries=3
                    )
                    print(f"   ✅ Used TPU error handler")
                else:
                    # Direct generation fallback
                    outputs = model.generate(
                        **inputs,
                        max_new_tokens=max_new_tokens,
                        do_sample=True,
                        temperature=0.7,
                        top_p=0.9,
                        pad_token_id=tokenizer.eos_token_id
                    )
                    print(f"   ⚠️ Using direct generation (error handler not available)")
                    
            except Exception as handler_error:
                print(f"   ⚠️ Error handler failed: {handler_error}")
                # Fallback to direct generation
                outputs = model.generate(
                    **inputs,
                    max_new_tokens=max_new_tokens,
                    do_sample=True,
                    temperature=0.7,
                    top_p=0.9,
                    pad_token_id=tokenizer.eos_token_id
                )
        
        # Decode output
        generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # Extract only the new tokens
        new_text = generated_text[len(prompt):].strip()
        
        return new_text
        
    except Exception as e:
        print(f"❌ Basic inference error: {e}")
        return f"Error: {str(e)}"

# Test basic inference
test_prompts = [
    "What are the benefits of using TPUs for machine learning?",
    "Explain quantum computing in simple terms:",
    "Write a short poem about artificial intelligence:"
]

print(f"🔄 Testing basic TPU inference...")
for i, prompt in enumerate(test_prompts, 1):
    print(f"\n📝 Test {i}: {prompt[:50]}...")
    
    start_time = time.time()
    response = basic_tpu_inference(prompt)
    inference_time = time.time() - start_time
    
    print(f"   ⏱️ Time: {inference_time:.2f}s")
    print(f"   📄 Response: {response[:100]}...")
    
    # Force XLA synchronization
    xm.mark_step()

print(f"\n✅ Phase 1 basic inference complete!")

## 4. Phase 2: Performance-Optimized Inference with XLA

In [None]:
# Phase 2: Performance-optimized inference with XLA compilation
print(f"⚡ Phase 2: Performance-Optimized TPU Inference")
print(f"=" * 50)

# Initialize Phase 2 performance optimizers
try:
    # XLA Attention Optimizer
    xla_attention = tpu_performance.XLAAttentionOptimizer()
    
    # Dynamic Shape Manager for variable-length inputs
    shape_manager = tpu_performance.DynamicShapeManager(
        max_length=MODEL_CONFIG["max_seq_length"]
    )
    
    # Performance profiler
    profiler = tpu_performance.TPUPerformanceProfiler()
    
    print(f"✅ Phase 2 optimizers initialized")
    
except ImportError:
    print(f"⚠️ Phase 2 optimizers not available, using basic inference")
    xla_attention = None
    shape_manager = None
    profiler = None

def optimized_tpu_inference(prompt: str, max_new_tokens: int = 50) -> Dict[str, Any]:
    """Performance-optimized TPU inference with Phase 2 enhancements."""
    
    results = {
        "response": "",
        "metrics": {},
        "optimizations_applied": []
    }
    
    try:
        # Start profiling
        if profiler:
            profiler.start_profiling()
            results["optimizations_applied"].append("performance_profiling")
        
        # Tokenize with dynamic shape optimization
        inputs = tokenizer(
            prompt,
            return_tensors="pt",
            truncation=True,
            max_length=MODEL_CONFIG["max_seq_length"] - max_new_tokens,
            padding="max_length" if shape_manager else False  # Pad for XLA efficiency
        )
        
        # Apply dynamic shape optimization
        if shape_manager:
            inputs = shape_manager.optimize_input_shapes(inputs)
            results["optimizations_applied"].append("dynamic_shapes")
        
        # Move inputs to TPU
        inputs = {k: v.to(device) for k, v in inputs.items()}
        
        # Generate with XLA optimization
        generation_start = time.time()
        
        with torch.no_grad():
            # Apply XLA attention optimization if available
            if xla_attention:
                # Optimize attention computation
                model = xla_attention.optimize_model(model)
                results["optimizations_applied"].append("xla_attention")
            
            # Generate with optimized settings
            outputs = model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                do_sample=True,
                temperature=0.7,
                top_p=0.9,
                pad_token_id=tokenizer.eos_token_id,
                use_cache=True,  # Enable KV cache for efficiency
                num_beams=1      # Single beam for speed
            )
        
        generation_time = time.time() - generation_start
        
        # Force XLA synchronization and measure
        sync_start = time.time()
        xm.mark_step()
        sync_time = time.time() - sync_start
        
        # Decode output
        generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        new_text = generated_text[len(prompt):].strip()
        results["response"] = new_text
        
        # Collect performance metrics
        results["metrics"] = {
            "generation_time": generation_time,
            "sync_time": sync_time,
            "total_time": generation_time + sync_time,
            "tokens_generated": len(tokenizer.encode(new_text)),
            "tokens_per_second": len(tokenizer.encode(new_text)) / (generation_time + sync_time)
        }
        
        # Stop profiling and get detailed metrics
        if profiler:
            profile_data = profiler.stop_profiling()
            results["metrics"].update(profile_data)
        
        return results
        
    except Exception as e:
        results["response"] = f"Error: {str(e)}"
        results["metrics"]["error"] = str(e)
        return results

# Test optimized inference
print(f"🔄 Testing optimized TPU inference...")

optimization_prompts = [
    "Explain how TPU optimization works in machine learning:",
    "What makes XLA compilation effective for neural networks?"
]

for i, prompt in enumerate(optimization_prompts, 1):
    print(f"\n📝 Optimized Test {i}: {prompt[:50]}...")
    
    result = optimized_tpu_inference(prompt)
    
    print(f"   ⚡ Optimizations: {', '.join(result['optimizations_applied'])}")
    print(f"   ⏱️ Total Time: {result['metrics'].get('total_time', 0):.2f}s")
    print(f"   🚀 Tokens/sec: {result['metrics'].get('tokens_per_second', 0):.1f}")
    print(f"   📄 Response: {result['response'][:100]}...")

print(f"\n✅ Phase 2 optimized inference complete!")

## 5. Phase 3: Advanced Multi-Pod Inference with JAX Integration

In [None]:
# Phase 3: Advanced multi-pod inference with JAX integration
print(f"🚀 Phase 3: Advanced Multi-Pod TPU Inference")
print(f"=" * 50)

# Initialize Phase 3 advanced features
try:
    from pantheraml.kernels.tpu_advanced import (
        Phase3Manager,
        MultiPodConfig,
        JAXConfig,
        AutoScalingConfig
    )
    
    # Configure Phase 3 settings
    multi_pod_config = MultiPodConfig(
        num_pods=min(2, world_size),  # Use available pods
        cores_per_pod=8,
        enable_cross_pod_communication=True,
        pod_slice_shape=[2, 2, 1, 1]  # 2x2 slice
    )
    
    jax_config = JAXConfig(
        enable_jax_backend=True,
        precision="bfloat16",
        enable_jit=True,
        memory_fraction=0.9
    )
    
    auto_scaling_config = AutoScalingConfig(
        enable_auto_scaling=True,
        min_replicas=1,
        max_replicas=min(4, world_size),
        target_utilization=0.8
    )
    
    # Initialize Phase 3 manager
    phase3_manager = Phase3Manager(
        multi_pod_config=multi_pod_config,
        jax_config=jax_config,
        auto_scaling_config=auto_scaling_config
    )
    
    print(f"✅ Phase 3 manager initialized")
    print(f"   Multi-pod support: {multi_pod_config.num_pods} pods")
    print(f"   JAX backend: {jax_config.enable_jax_backend}")
    print(f"   Auto-scaling: {auto_scaling_config.enable_auto_scaling}")
    
    phase3_available = True
    
except ImportError as e:
    print(f"⚠️ Phase 3 features not available: {e}")
    phase3_manager = None
    phase3_available = False

def advanced_tpu_inference(prompts: List[str], max_new_tokens: int = 50) -> Dict[str, Any]:
    """Advanced multi-pod TPU inference with Phase 3 features."""
    
    results = {
        "responses": [],
        "batch_metrics": {},
        "advanced_features": []
    }
    
    try:
        if not phase3_available:
            # Fallback to batch inference without Phase 3
            print(f"   📦 Using batch inference fallback")
            
            for prompt in prompts:
                response = basic_tpu_inference(prompt, max_new_tokens)
                results["responses"].append(response)
            
            return results
        
        # Phase 3 advanced inference
        print(f"   🚀 Initializing Phase 3 inference pipeline...")
        
        # Initialize advanced pipeline
        if phase3_manager.initialize_advanced_pipeline():
            results["advanced_features"].append("phase3_pipeline")
        
        # Batch tokenization with padding for efficient processing
        batch_start = time.time()
        
        inputs = tokenizer(
            prompts,
            return_tensors="pt",
            truncation=True,
            padding=True,  # Pad to same length for batch processing
            max_length=MODEL_CONFIG["max_seq_length"] - max_new_tokens
        )
        
        # Move batch to TPU
        inputs = {k: v.to(device) for k, v in inputs.items()}
        
        # Advanced batch generation with multi-pod coordination
        with torch.no_grad():
            # Apply auto-scaling if needed
            if phase3_manager.should_scale_up(len(prompts)):
                phase3_manager.scale_up()
                results["advanced_features"].append("auto_scaling")
            
            # Generate with advanced features
            generation_start = time.time()
            
            outputs = model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                do_sample=True,
                temperature=0.7,
                top_p=0.9,
                pad_token_id=tokenizer.eos_token_id,
                use_cache=True,
                num_beams=1
            )
            
            generation_time = time.time() - generation_start
        
        # Multi-pod synchronization
        sync_start = time.time()
        if world_size > 1:
            xm.rendezvous("batch_inference_sync")
            results["advanced_features"].append("multi_pod_sync")
        xm.mark_step()
        sync_time = time.time() - sync_start
        
        # Decode batch outputs
        for i, (prompt, output) in enumerate(zip(prompts, outputs)):
            generated_text = tokenizer.decode(output, skip_special_tokens=True)
            new_text = generated_text[len(prompt):].strip()
            results["responses"].append(new_text)
        
        batch_time = time.time() - batch_start
        
        # Advanced metrics
        total_tokens = sum(len(tokenizer.encode(resp)) for resp in results["responses"])
        
        results["batch_metrics"] = {
            "batch_size": len(prompts),
            "total_time": batch_time,
            "generation_time": generation_time,
            "sync_time": sync_time,
            "total_tokens": total_tokens,
            "tokens_per_second": total_tokens / batch_time,
            "throughput_per_prompt": total_tokens / len(prompts) / batch_time,
            "world_size": world_size,
            "rank": rank
        }
        
        # Get resource utilization
        if phase3_manager:
            utilization = phase3_manager.get_resource_utilization()
            results["batch_metrics"].update(utilization)
        
        return results
        
    except Exception as e:
        results["responses"] = [f"Error: {str(e)}" for _ in prompts]
        results["batch_metrics"]["error"] = str(e)
        return results

# Test advanced batch inference
print(f"🔄 Testing advanced batch TPU inference...")

advanced_prompts = [
    "Describe the advantages of distributed computing:",
    "How does multi-pod TPU training improve efficiency?",
    "Explain JAX and its role in high-performance computing:",
    "What are the key benefits of auto-scaling in ML inference?"
]

result = advanced_tpu_inference(advanced_prompts)

print(f"\n📊 Advanced Inference Results:")
print(f"   🚀 Features Used: {', '.join(result['advanced_features'])}")
print(f"   📦 Batch Size: {result['batch_metrics'].get('batch_size', 0)}")
print(f"   ⏱️ Total Time: {result['batch_metrics'].get('total_time', 0):.2f}s")
print(f"   🎯 Tokens/sec: {result['batch_metrics'].get('tokens_per_second', 0):.1f}")
print(f"   📈 Throughput/prompt: {result['batch_metrics'].get('throughput_per_prompt', 0):.1f} tok/s")
print(f"   🌐 World Size: {result['batch_metrics'].get('world_size', 1)}")

print(f"\n📝 Sample Responses:")
for i, (prompt, response) in enumerate(zip(advanced_prompts[:2], result["responses"][:2]), 1):
    print(f"   {i}. {prompt[:40]}...")
    print(f"      {response[:80]}...")

print(f"\n✅ Phase 3 advanced inference complete!")

## 6. Performance Comparison and Benchmarking

In [None]:
# Performance comparison across all three phases
print(f"📊 Performance Comparison Across TPU Phases")
print(f"=" * 55)

benchmark_prompt = "Explain the importance of optimization in machine learning inference:"
benchmark_tokens = 75

performance_results = {}

# Phase 1 Benchmark
print(f"\n🧪 Benchmarking Phase 1 (Basic)...")
phase1_times = []
for i in range(3):  # Average of 3 runs
    start = time.time()
    response = basic_tpu_inference(benchmark_prompt, benchmark_tokens)
    end = time.time()
    phase1_times.append(end - start)
    xm.mark_step()

performance_results["Phase 1"] = {
    "avg_time": np.mean(phase1_times),
    "min_time": np.min(phase1_times),
    "max_time": np.max(phase1_times),
    "tokens_per_second": benchmark_tokens / np.mean(phase1_times)
}

# Phase 2 Benchmark
print(f"⚡ Benchmarking Phase 2 (Optimized)...")
phase2_times = []
for i in range(3):
    start = time.time()
    result = optimized_tpu_inference(benchmark_prompt, benchmark_tokens)
    end = time.time()
    phase2_times.append(end - start)
    xm.mark_step()

performance_results["Phase 2"] = {
    "avg_time": np.mean(phase2_times),
    "min_time": np.min(phase2_times),
    "max_time": np.max(phase2_times),
    "tokens_per_second": benchmark_tokens / np.mean(phase2_times)
}

# Phase 3 Benchmark (single prompt for fair comparison)
print(f"🚀 Benchmarking Phase 3 (Advanced)...")
phase3_times = []
for i in range(3):
    start = time.time()
    result = advanced_tpu_inference([benchmark_prompt], benchmark_tokens)
    end = time.time()
    phase3_times.append(end - start)
    xm.mark_step()

performance_results["Phase 3"] = {
    "avg_time": np.mean(phase3_times),
    "min_time": np.min(phase3_times),
    "max_time": np.max(phase3_times),
    "tokens_per_second": benchmark_tokens / np.mean(phase3_times)
}

# Display comparison
print(f"\n📈 Performance Comparison Results:")
print(f"{'Phase':<10} {'Avg Time':<12} {'Min Time':<12} {'Max Time':<12} {'Tokens/s':<12} {'Speedup':<10}")
print(f"{'-'*75}")

baseline_time = performance_results["Phase 1"]["avg_time"]

for phase, metrics in performance_results.items():
    speedup = baseline_time / metrics["avg_time"]
    print(f"{phase:<10} {metrics['avg_time']:<12.2f} {metrics['min_time']:<12.2f} {metrics['max_time']:<12.2f} {metrics['tokens_per_second']:<12.1f} {speedup:<10.2f}x")

# Memory usage comparison
print(f"\n💾 Memory Usage Analysis:")
if tpu_memory_manager:
    final_memory_stats = tpu_memory_manager.get_memory_stats()
    print(f"   Current Memory Usage: {final_memory_stats.get('allocated_gb', 0):.2f} GB")
    print(f"   Peak Memory Usage: {final_memory_stats.get('peak_gb', 0):.2f} GB")
    print(f"   Memory Efficiency: {final_memory_stats.get('efficiency_percent', 0):.1f}%")

print(f"\n🎯 Summary:")
best_phase = min(performance_results.keys(), key=lambda x: performance_results[x]["avg_time"])
best_speedup = baseline_time / performance_results[best_phase]["avg_time"]
print(f"   🏆 Best Performance: {best_phase} ({best_speedup:.2f}x speedup)")
print(f"   ⚡ Max Throughput: {max(perf['tokens_per_second'] for perf in performance_results.values()):.1f} tokens/second")
print(f"   🎮 Recommended: Phase 2 for balanced performance, Phase 3 for batch workloads")

## 7. Memory Management and Cleanup

In [None]:
# Memory management and cleanup
print(f"🧹 TPU Memory Management and Cleanup")
print(f"=" * 40)

# Get final memory statistics
if tpu_memory_manager:
    print(f"📊 Final Memory Statistics:")
    memory_stats = tpu_memory_manager.get_memory_stats()
    
    for key, value in memory_stats.items():
        if isinstance(value, (int, float)):
            print(f"   {key.replace('_', ' ').title()}: {value:.2f}")
        else:
            print(f"   {key.replace('_', ' ').title()}: {value}")

# Cleanup operations
print(f"\n🔄 Performing cleanup operations...")

# Clear model from memory
del model
print(f"   ✅ Model cleared from memory")

# Clear TPU cache
if tpu_memory_manager:
    tpu_memory_manager.clear_cache()
    print(f"   ✅ TPU cache cleared")

# Force garbage collection
import gc
gc.collect()
print(f"   ✅ Garbage collection completed")

# Final XLA synchronization
xm.mark_step()
print(f"   ✅ XLA synchronization completed")

# Cleanup Phase 3 resources if available
if phase3_manager:
    try:
        phase3_manager.cleanup_advanced_pipeline()
        print(f"   ✅ Phase 3 resources cleaned up")
    except:
        print(f"   ⚠️ Phase 3 cleanup not available")

print(f"\n✅ Cleanup completed successfully!")
print(f"\n🎉 TPU Inference Example Complete!")
print(f"   📚 You've successfully demonstrated all three phases of PantheraML TPU inference")
print(f"   🚀 Ready to scale up to production workloads")
print(f"   💡 Consider using Phase 2 for most applications, Phase 3 for large-scale batch inference")

## 8. Summary and Best Practices

### 🎯 Key Takeaways:

1. **Phase 1 (Basic)**: Provides robust error handling and basic TPU support
2. **Phase 2 (Optimized)**: Adds XLA compilation and performance optimizations
3. **Phase 3 (Advanced)**: Enables multi-pod inference with JAX integration

### ⚡ Performance Tips:

- Use `bfloat16` precision for optimal TPU performance
- Pad inputs to consistent lengths for XLA efficiency
- Batch multiple prompts together when possible
- Enable KV cache for autoregressive generation
- Use `xm.mark_step()` for proper synchronization

### 🏗️ Production Recommendations:

- **Phase 2** for most production inference workloads
- **Phase 3** for large-scale batch processing
- Monitor memory usage with PantheraML's memory manager
- Implement proper error handling and fallbacks
- Use auto-scaling for variable workloads

### 📚 Additional Resources:

- [PantheraML Documentation](https://github.com/PantheraML/docs)
- [TPU Performance Guide](https://cloud.google.com/tpu/docs/performance-guide)
- [XLA Compilation Best Practices](https://www.tensorflow.org/xla)

---

**Happy TPU inferencing with PantheraML! 🚀**