# File Location: notebooks/07_evaluation_export_predict/17_onnx_torchscript_export.ipynb

# ONNX and TorchScript Export for Production Deployment

This notebook covers model export strategies using ONNX and TorchScript for production deployment, including optimization techniques, cross-platform compatibility, and performance benchmarking.

## Learning Objectives
- Export PyTorch Lightning models to ONNX and TorchScript formats
- Optimize models for different deployment scenarios
- Implement cross-platform compatibility testing
- Benchmark performance across different export formats
- Handle dynamic shapes and batch processing

```python
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
import torchvision
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt
import time
import json
import os
from typing import Dict, List, Tuple, Any, Optional
import warnings
warnings.filterwarnings('ignore')

# ONNX and optimization imports
try:
    import onnx
    import onnxruntime as ort
    ONNX_AVAILABLE = True
    print(f"ONNX version: {onnx.__version__}")
    print(f"ONNX Runtime version: {ort.__version__}")
except ImportError:
    ONNX_AVAILABLE = False
    print("ONNX not available. Install with: pip install onnx onnxruntime")

# Set seeds
torch.manual_seed(42)
np.random.seed(42)

print(f"PyTorch version: {torch.__version__}")
print(f"Lightning version: {pl.__version__}")
```

## 1. Export-Ready Lightning Module

```python
class ExportableModel(pl.LightningModule):
    """Lightning module optimized for export to different formats"""
    
    def __init__(self, num_classes=10, input_shape=(1, 28, 28), learning_rate=0.001):
        super().__init__()
        self.save_hyperparameters()
        self.input_shape = input_shape
        
        # Model architecture - designed for export compatibility
        self.features = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d((7, 7))
        )
        
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128 * 7 * 7, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )
        
        # Metrics
        self.train_acc = pl.metrics.Accuracy(task="multiclass", num_classes=num_classes)
        self.val_acc = pl.metrics.Accuracy(task="multiclass", num_classes=num_classes)
        
    def forward(self, x):
        """Forward pass optimized for export"""
        # Ensure input is the right type and shape
        if not isinstance(x, torch.Tensor):
            x = torch.tensor(x, dtype=torch.float32)
        
        # Handle different input shapes
        if len(x.shape) == 3:  # Add batch dimension
            x = x.unsqueeze(0)
        
        features = self.features(x)
        output = self.classifier(features)
        return output
    
    def predict_with_softmax(self, x):
        """Forward pass with softmax - useful for ONNX export"""
        logits = self.forward(x)
        return F.softmax(logits, dim=1)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        self.train_acc(logits, y)
        
        self.log('train_loss', loss, on_epoch=True, prog_bar=True)
        self.log('train_acc', self.train_acc, on_epoch=True, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        self.val_acc(logits, y)
        
        self.log('val_loss', loss, on_epoch=True, prog_bar=True)
        self.log('val_acc', self.val_acc, on_epoch=True, prog_bar=True)
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
        return optimizer
    
    def export_summary(self):
        """Print model summary for export validation"""
        total_params = sum(p.numel() for p in self.parameters())
        trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
        
        print(f"Model Export Summary:")
        print(f"  Input shape: {self.input_shape}")
        print(f"  Total parameters: {total_params:,}")
        print(f"  Trainable parameters: {trainable_params:,}")
        print(f"  Model size (MB): {total_params * 4 / 1024 / 1024:.2f}")

# Initialize model
model = ExportableModel(num_classes=10, input_shape=(1, 28, 28))
model.export_summary()
```

## 2. Data Module for Export Testing

```python
class ExportDataModule(pl.LightningDataModule):
    """Data module for export testing"""
    
    def __init__(self, batch_size=64, num_workers=4):
        super().__init__()
        self.batch_size = batch_size
        self.num_workers = num_workers
        
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])
        
    def setup(self, stage=None):
        if stage == 'fit' or stage is None:
            self.train_dataset = torchvision.datasets.MNIST('./data', train=True, transform=self.transform, download=True)
            self.val_dataset = torchvision.datasets.MNIST('./data', train=False, transform=self.transform, download=True)
        
        if stage == 'test' or stage is None:
            self.test_dataset = torchvision.datasets.MNIST('./data', train=False, transform=self.transform, download=True)
    
    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)
    
    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)
    
    def test_dataloader(self):
        return torch.utils.data.DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)

# Train the model first
data_module = ExportDataModule(batch_size=128)

trainer = pl.Trainer(
    max_epochs=3,
    accelerator='auto',
    devices=1,
    enable_checkpointing=False,
    logger=False
)

print("Training model for export...")
trainer.fit(model, data_module)
print(f"Training completed. Final validation accuracy: {trainer.callback_metrics.get('val_acc', 'N/A')}")
```

## 3. TorchScript Export Implementation

```python
class TorchScriptExporter:
    """Handle TorchScript export with various optimization options"""
    
    def __init__(self, model: pl.LightningModule):
        self.model = model
        self.model.eval()  # Set to evaluation mode
    
    def export_scripted(self, example_input: torch.Tensor, save_path: str = "model_scripted.pt") -> torch.jit.ScriptModule:
        """Export model using torch.jit.script"""
        try:
            print("Exporting with torch.jit.script...")
            scripted_model = torch.jit.script(self.model)
            
            # Test the scripted model
            with torch.no_grad():
                scripted_output = scripted_model(example_input)
                original_output = self.model(example_input)
                
                # Check outputs match
                if torch.allclose(scripted_output, original_output, atol=1e-6):
                    print("✓ Script export successful - outputs match")
                else:
                    print("⚠ Warning: Script export outputs differ slightly")
            
            # Save the model
            torch.jit.save(scripted_model, save_path)
            print(f"Model saved to {save_path}")
            
            return scripted_model
            
        except Exception as e:
            print(f"Script export failed: {str(e)}")
            return None
    
    def export_traced(self, example_input: torch.Tensor, save_path: str = "model_traced.pt") -> torch.jit.TracedModule:
        """Export model using torch.jit.trace"""
        try:
            print("Exporting with torch.jit.trace...")
            
            with torch.no_grad():
                traced_model = torch.jit.trace(self.model, example_input)
                
                # Test the traced model
                traced_output = traced_model(example_input)
                original_output = self.model(example_input)
                
                # Check outputs match
                if torch.allclose(traced_output, original_output, atol=1e-6):
                    print("✓ Trace export successful - outputs match")
                else:
                    print("⚠ Warning: Trace export outputs differ slightly")
            
            # Save the model
            torch.jit.save(traced_model, save_path)
            print(f"Model saved to {save_path}")
            
            return traced_model
            
        except Exception as e:
            print(f"Trace export failed: {str(e)}")
            return None
    
    def optimize_torchscript(self, scripted_model: torch.jit.ScriptModule, example_input: torch.Tensor) -> torch.jit.ScriptModule:
        """Apply TorchScript optimizations"""
        print("Applying TorchScript optimizations...")
        
        try:
            # Freeze the model for inference optimization
            optimized_model = torch.jit.freeze(scripted_model)
            
            # Apply graph optimizations
            optimized_model = torch.jit.optimize_for_inference(optimized_model)
            
            # Warmup
            with torch.no_grad():
                for _ in range(10):
                    _ = optimized_model(example_input)
            
            print("✓ TorchScript optimization completed")
            return optimized_model
            
        except Exception as e:
            print(f"TorchScript optimization failed: {str(e)}")
            return scripted_model
    
    def compare_torchscript_methods(self, example_input: torch.Tensor) -> Dict[str, Any]:
        """Compare different TorchScript export methods"""
        results = {}
        
        # Original model
        print("\n=== TorchScript Export Comparison ===")
        
        with torch.no_grad():
            start_time = time.time()
            for _ in range(100):
                original_output = self.model(example_input)
            original_time = time.time() - start_time
        
        results['original'] = {
            'time': original_time,
            'output_shape': original_output.shape
        }
        
        # Script export
        scripted_model = self.export_scripted(example_input, "model_scripted.pt")
        if scripted_model:
            with torch.no_grad():
                start_time = time.time()
                for _ in range(100):
                    scripted_output = scripted_model(example_input)
                scripted_time = time.time() - start_time
            
            results['scripted'] = {
                'time': scripted_time,
                'speedup': original_time / scripted_time,
                'output_shape': scripted_output.shape
            }
        
        # Trace export
        traced_model = self.export_traced(example_input, "model_traced.pt")
        if traced_model:
            with torch.no_grad():
                start_time = time.time()
                for _ in range(100):
                    traced_output = traced_model(example_input)
                traced_time = time.time() - start_time
            
            results['traced'] = {
                'time': traced_time,
                'speedup': original_time / traced_time,
                'output_shape': traced_output.shape
            }
            
            # Optimized traced model
            optimized_model = self.optimize_torchscript(traced_model, example_input)
            with torch.no_grad():
                start_time = time.time()
                for _ in range(100):
                    optimized_output = optimized_model(example_input)
                optimized_time = time.time() - start_time
            
            results['optimized'] = {
                'time': optimized_time,
                'speedup': original_time / optimized_time,
                'output_shape': optimized_output.shape
            }
        
        # Print comparison
        print("\nTorchScript Performance Comparison:")
        for method, metrics in results.items():
            if method == 'original':
                print(f"{method:12}: {metrics['time']:.4f}s (baseline)")
            else:
                print(f"{method:12}: {metrics['time']:.4f}s ({metrics['speedup']:.2f}x speedup)")
        
        return results

# Test TorchScript export
example_input = torch.randn(1, 1, 28, 28)
torchscript_exporter = TorchScriptExporter(model)
torchscript_results = torchscript_exporter.compare_torchscript_methods(example_input)
```

## 4. ONNX Export Implementation

```python
class ONNXExporter:
    """Handle ONNX export with optimization options"""
    
    def __init__(self, model: pl.LightningModule):
        self.model = model
        self.model.eval()
    
    def export_onnx(self, example_input: torch.Tensor, save_path: str = "model.onnx", 
                   dynamic_axes: Optional[Dict] = None, opset_version: int = 11) -> bool:
        """Export model to ONNX format"""
        if not ONNX_AVAILABLE:
            print("ONNX not available. Please install onnx and onnxruntime")
            return False
        
        try:
            print(f"Exporting to ONNX (opset version {opset_version})...")
            
            # Default dynamic axes for batch size flexibility
            if dynamic_axes is None:
                dynamic_axes = {
                    'input': {0: 'batch_size'},
                    'output': {0: 'batch_size'}
                }
            
            # Export to ONNX
            torch.onnx.export(
                self.model,
                example_input,
                save_path,
                export_params=True,
                opset_version=opset_version,
                do_constant_folding=True,  # Optimize constant folding
                input_names=['input'],
                output_names=['output'],
                dynamic_axes=dynamic_axes,
                verbose=False
            )
            
            # Verify the exported model
            if self._verify_onnx_model(save_path, example_input):
                print(f"✓ ONNX export successful - model saved to {save_path}")
                return True
            else:
                print("✗ ONNX export verification failed")
                return False
                
        except Exception as e:
            print(f"ONNX export failed: {str(e)}")
            return False
    
    def _verify_onnx_model(self, model_path: str, example_input: torch.Tensor) -> bool:
        """Verify ONNX model correctness"""
        try:
            # Load and check ONNX model
            onnx_model = onnx.load(model_path)
            onnx.checker.check_model(onnx_model)
            
            # Create ONNX Runtime session
            ort_session = ort.InferenceSession(model_path)
            
            # Run inference
            ort_inputs = {ort_session.get_inputs()[0].name: example_input.numpy()}
            ort_outputs = ort_session.run(None, ort_inputs)
            
            # Compare with original PyTorch output
            with torch.no_grad():
                torch_output = self.model(example_input)
            
            # Check if outputs are close
            if np.allclose(torch_output.numpy(), ort_outputs[0], atol=1e-5):
                return True
            else:
                print("Warning: ONNX outputs differ from PyTorch outputs")
                return False
                
        except Exception as e:
            print(f"ONNX verification failed: {str(e)}")
            return False
    
    def optimize_onnx(self, model_path: str, optimized_path: str = "model_optimized.onnx") -> bool:
        """Apply ONNX optimization"""
        if not ONNX_AVAILABLE:
            return False
        
        try:
            print("Applying ONNX optimizations...")
            
            # Available optimization levels: 'basic', 'extended', 'layout', 'all'
            sess_options = ort.SessionOptions()
            sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
            
            # Load and optimize
            providers = ['CPUExecutionProvider']
            if torch.cuda.is_available():
                providers.insert(0, 'CUDAExecutionProvider')
            
            session = ort.InferenceSession(model_path, sess_options, providers=providers)
            
            print(f"✓ ONNX model optimized for providers: {session.get_providers()}")
            return True
            
        except Exception as e:
            print(f"ONNX optimization failed: {str(e)}")
            return False
    
    def benchmark_onnx_vs_pytorch(self, model_path: str, example_input: torch.Tensor, num_runs: int = 100) -> Dict[str, float]:
        """Benchmark ONNX vs PyTorch performance"""
        if not ONNX_AVAILABLE:
            return {}
        
        results = {}
        
        # PyTorch benchmark
        print("Benchmarking PyTorch model...")
        with torch.no_grad():
            # Warmup
            for _ in range(10):
                _ = self.model(example_input)
            
            # Timing
            start_time = time.time()
            for _ in range(num_runs):
                pytorch_output = self.model(example_input)
            pytorch_time = time.time() - start_time
        
        results['pytorch'] = pytorch_time
        
        # ONNX benchmark
        try:
            print("Benchmarking ONNX model...")
            
            # Setup ONNX Runtime
            providers = ['CPUExecutionProvider']
            if torch.cuda.is_available():
                providers.insert(0, 'CUDAExecutionProvider')
            
            ort_session = ort.InferenceSession(model_path, providers=providers)
            ort_input_name = ort_session.get_inputs()[0].name
            ort_inputs = {ort_input_name: example_input.numpy()}
            
            # Warmup
            for _ in range(10):
                _ = ort_session.run(None, ort_inputs)
            
            # Timing
            start_time = time.time()
            for _ in range(num_runs):
                onnx_outputs = ort_session.run(None, ort_inputs)
            onnx_time = time.time() - start_time
            
            results['onnx'] = onnx_time
            results['speedup'] = pytorch_time / onnx_time
            
            # Verify outputs still match
            if np.allclose(pytorch_output.numpy(), onnx_outputs[0], atol=1e-5):
                print("✓ ONNX outputs match PyTorch outputs")
            else:
                print("⚠ Warning: ONNX outputs differ from PyTorch outputs")
            
        except Exception as e:
            print(f"ONNX benchmarking failed: {str(e)}")
            return results
        
        # Print results
        print(f"\nBenchmark Results ({num_runs} runs):")
        print(f"PyTorch: {pytorch_time:.4f}s")
        print(f"ONNX:    {onnx_time:.4f}s")
        print(f"Speedup: {results.get('speedup', 0):.2f}x")
        
        return results

# Test ONNX export if available
if ONNX_AVAILABLE:
    onnx_exporter = ONNXExporter(model)
    
    # Export with different configurations
    example_input = torch.randn(1, 1, 28, 28)
    
    # Basic export
    success = onnx_exporter.export_onnx(example_input, "model_basic.onnx")
    
    if success:
        # Benchmark performance
        benchmark_results = onnx_exporter.benchmark_onnx_vs_pytorch("model_basic.onnx", example_input)
        
        # Export with dynamic shapes
        dynamic_axes = {
            'input': {0: 'batch_size'},
            'output': {0: 'batch_size'}
        }
        onnx_exporter.export_onnx(example_input, "model_dynamic.onnx", dynamic_axes=dynamic_axes)
else:
    print("Skipping ONNX export - ONNX not available")
```

## 5. Cross-Platform Deployment Testing

```python
class CrossPlatformTester:
    """Test exported models across different platforms and configurations"""
    
    def __init__(self):
        self.test_results = {}
    
    def test_model_formats(self, original_model, test_input):
        """Test all exported model formats"""
        print("=== Cross-Platform Model Testing ===")
        
        results = {
            'pytorch': self._test_pytorch_model(original_model, test_input),
            'torchscript_traced': self._test_torchscript_model("model_traced.pt", test_input),
            'torchscript_scripted': self._test_torchscript_model("model_scripted.pt", test_input),
        }
        
        if ONNX_AVAILABLE:
            results['onnx'] = self._test_onnx_model("model_basic.onnx", test_input)
        
        return results
    
    def _test_pytorch_model(self, model, test_input):
        """Test original PyTorch model"""
        try:
            model.eval()
            with torch.no_grad():
                start_time = time.time()
                output = model(test_input)
                inference_time = time.time() - start_time
            
            return {
                'success': True,
                'inference_time': inference_time,
                'output_shape': output.shape,
                'model_size_mb': sum(p.numel() * 4 for p in model.parameters()) / 1024 / 1024
            }
        except Exception as e:
            return {'success': False, 'error': str(e)}
    
    def _test_torchscript_model(self, model_path, test_input):
        """Test TorchScript model"""
        try:
            if not os.path.exists(model_path):
                return {'success': False, 'error': 'Model file not found'}
            
            model = torch.jit.load(model_path)
            model.eval()
            
            with torch.no_grad():
                start_time = time.time()
                output = model(test_input)
                inference_time = time.time() - start_time
            
            file_size = os.path.getsize(model_path) / 1024 / 1024  # MB
            
            return {
                'success': True,
                'inference_time': inference_time,
                'output_shape': output.shape,
                'file_size_mb': file_size
            }
        except Exception as e:
            return {'success': False, 'error': str(e)}
    
    def _test_onnx_model(self, model_path, test_input):
        """Test ONNX model"""
        try:
            if not os.path.exists(model_path):
                return {'success': False, 'error': 'Model file not found'}
            
            ort_session = ort.InferenceSession(model_path)
            ort_inputs = {ort_session.get_inputs()[0].name: test_input.numpy()}
            
            start_time = time.time()
            ort_outputs = ort_session.run(None, ort_inputs)
            inference_time = time.time() - start_time
            
            file_size = os.path.getsize(model_path) / 1024 / 1024  # MB
            
            return {
                'success': True,
                'inference_time': inference_time,
                'output_shape': ort_outputs[0].shape,
                'file_size_mb': file_size
            }
        except Exception as e:
            return {'success': False, 'error': str(e)}
    
    def test_batch_sizes(self, model_paths, batch_sizes=[1, 4, 16, 32, 64]):
        """Test different batch sizes for exported models"""
        print("\n=== Batch Size Performance Testing ===")
        
        results = {}
        
        for batch_size in batch_sizes:
            test_input = torch.randn(batch_size, 1, 28, 28)
            batch_results = {}
            
            # Test TorchScript
            if os.path.exists("model_traced.pt"):
                try:
                    model = torch.jit.load("model_traced.pt")
                    with torch.no_grad():
                        start_time = time.time()
                        _ = model(test_input)
                        batch_results['torchscript'] = time.time() - start_time
                except:
                    batch_results['torchscript'] = None
            
            # Test ONNX
            if ONNX_AVAILABLE and os.path.exists("model_dynamic.onnx"):
                try:
                    ort_session = ort.InferenceSession("model_dynamic.onnx")
                    ort_inputs = {ort_session.get_inputs()[0].name: test_input.numpy()}
                    start_time = time.time()
                    _ = ort_session.run(None, ort_inputs)
                    batch_results['onnx'] = time.time() - start_time
                except:
                    batch_results['onnx'] = None
            
            results[batch_size] = batch_results
            print(f"Batch size {batch_size:2d}: TorchScript={batch_results.get('torchscript', 'N/A'):.4f}s, "
                  f"ONNX={batch_results.get('onnx', 'N/A'):.4f}s")
        
        return results
    
    def generate_deployment_report(self, test_results):
        """Generate comprehensive deployment report"""
        report = {
            'summary': {},
            'recommendations': [],
            'compatibility': {},
            'performance': test_results
        }
        
        # Analyze results
        successful_formats = [fmt for fmt, result in test_results.items() if result.get('success', False)]
        
        if successful_formats:
            # Find fastest format
            fastest_format = min(successful_formats, 
                                key=lambda x: test_results[x].get('inference_time', float('inf')))
            
            # Find smallest format
            size_key = 'file_size_mb' if 'file_size_mb' in test_results[fastest_format] else 'model_size_mb'
            smallest_format = min(successful_formats,
                                 key=lambda x: test_results[x].get(size_key, float('inf')))
            
            report['summary'] = {
                'successful_formats': successful_formats,
                'fastest_format': fastest_format,
                'smallest_format': smallest_format,
                'fastest_time': test_results[fastest_format].get('inference_time', 0),
                'smallest_size': test_results[smallest_format].get(size_key, 0)
            }
            
            # Recommendations
            if 'onnx' in successful_formats:
                report['recommendations'].append("ONNX: Best for cross-platform deployment")
            if 'torchscript_traced' in successful_formats:
                report['recommendations'].append("TorchScript Traced: Best for PyTorch ecosystem")
            if fastest_format == 'onnx':
                report['recommendations'].append("ONNX provides fastest inference")
        
        return report

# Run cross-platform testing
tester = CrossPlatformTester()
test_input = torch.randn(1, 1, 28, 28)
cross_platform_results = tester.test_model_formats(model, test_input)

# Display results
print("\n=== Cross-Platform Test Results ===")
for format_name, result in cross_platform_results.items():
    if result['success']:
        size_key = 'file_size_mb' if 'file_size_mb' in result else 'model_size_mb'
        print(f"{format_name:20}: ✓ {result['inference_time']:.4f}s, {result.get(size_key, 0):.2f}MB")
    else:
        print(f"{format_name:20}: ✗ {result.get('error', 'Unknown error')}")

# Test batch performance
batch_results = tester.test_batch_sizes({})

# Generate deployment report
deployment_report = tester.generate_deployment_report(cross_platform_results)
print(f"\n=== Deployment Recommendations ===")
for rec in deployment_report['recommendations']:
    print(f"• {rec}")
```

## 6. Production Deployment Pipeline

```python
class ProductionDeploymentPipeline:
    """Complete pipeline for production model deployment"""
    
    def __init__(self, model, model_name="mnist_classifier"):
        self.model = model
        self.model_name = model_name
        self.deployment_artifacts = {}
        
    def prepare_for_deployment(self, example_input, target_platforms=['torchscript', 'onnx']):
        """Prepare all deployment artifacts"""
        print(f"=== Preparing {self.model_name} for Production Deployment ===")
        
        artifacts = {}
        
        # Create deployment directory
        deploy_dir = f"deployment_{self.model_name}"
        os.makedirs(deploy_dir, exist_ok=True)
        
        # Export to different formats
        if 'torchscript' in target_platforms:
            artifacts.update(self._export_torchscript(example_input, deploy_dir))
        
        if 'onnx' in target_platforms and ONNX_AVAILABLE:
            artifacts.update(self._export_onnx(example_input, deploy_dir))
        
        # Generate metadata and documentation
        artifacts['metadata'] = self._generate_metadata(example_input, deploy_dir)
        artifacts['readme'] = self._generate_readme(deploy_dir)
        
        self.deployment_artifacts = artifacts
        return artifacts
    
    def _export_torchscript(self, example_input, deploy_dir):
        """Export TorchScript models"""
        artifacts = {}
        
        # Traced model
        try:
            traced_model = torch.jit.trace(self.model, example_input)
            traced_path = os.path.join(deploy_dir, f"{self.model_name}_traced.pt")
            torch.jit.save(traced_model, traced_path)
            artifacts['torchscript_traced'] = traced_path
            print(f"✓ TorchScript traced model saved to {traced_path}")
        except Exception as e:
            print(f"✗ TorchScript traced export failed: {e}")
        
        # Optimized model
        try:
            if 'torchscript_traced' in artifacts:
                optimized_model = torch.jit.optimize_for_inference(traced_model)
                optimized_path = os.path.join(deploy_dir, f"{self.model_name}_optimized.pt")
                torch.jit.save(optimized_model, optimized_path)
                artifacts['torchscript_optimized'] = optimized_path
                print(f"✓ Optimized TorchScript model saved to {optimized_path}")
        except Exception as e:
            print(f"✗ TorchScript optimization failed: {e}")
        
        return artifacts
    
    def _export_onnx(self, example_input, deploy_dir):
        """Export ONNX models"""
        artifacts = {}
        
        # Standard ONNX export
        try:
            onnx_path = os.path.join(deploy_dir, f"{self.model_name}.onnx")
            torch.onnx.export(
                self.model, example_input, onnx_path,
                export_params=True, opset_version=11, do_constant_folding=True,
                input_names=['input'], output_names=['output'],
                dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
            )
            artifacts['onnx'] = onnx_path
            print(f"✓ ONNX model saved to {onnx_path}")
        except Exception as e:
            print(f"✗ ONNX export failed: {e}")
        
        return artifacts
    
    def _generate_metadata(self, example_input, deploy_dir):
        """Generate model metadata"""
        metadata = {
            'model_name': self.model_name,
            'input_shape': list(example_input.shape),
            'input_dtype': str(example_input.dtype),
            'num_classes': self.model.hparams.num_classes,
            'pytorch_version': torch.__version__,
            'export_timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
            'model_parameters': sum(p.numel() for p in self.model.parameters()),
            'model_size_mb': sum(p.numel() * 4 for p in self.model.parameters()) / 1024 / 1024,
            'preprocessing': {
                'normalization': {
                    'mean': [0.1307],
                    'std': [0.3081]
                },
                'input_range': [0, 1]
            },
            'postprocessing': {
                'output_type': 'logits',
                'apply_softmax': True,
                'class_names': [f'class_{i}' for i in range(10)]
            }
        }
        
        metadata_path = os.path.join(deploy_dir, 'model_metadata.json')
        with open(metadata_path, 'w') as f:
            json.dump(metadata, f, indent=2)
        
        print(f"✓ Model metadata saved to {metadata_path}")
        return metadata_path
    
    def _generate_readme(self, deploy_dir):
        """Generate deployment README"""
        readme_content = f"""# {self.model_name.upper()} - Production Deployment
```

## Model Information
- **Model Name**: {self.model_name}
- **Task**: MNIST Digit Classification
- **Input Shape**: (batch_size, 1, 28, 28)
- **Output Shape**: (batch_size, 10)
- **Parameters**: {sum(p.numel() for p in self.model.parameters()):,}

## Available Formats
- **TorchScript**: Optimized for PyTorch ecosystem
- **ONNX**: Cross-platform deployment

## Usage Example

### TorchScript
```python
import torch

# Load model
model = torch.jit.load('{self.model_name}_traced.pt')
model.eval()

# Prepare input (1x1x28x28 tensor)
input_tensor = torch.randn(1, 1, 28, 28)

# Inference
with torch.no_grad():
    output = model(input_tensor)
    predictions = torch.softmax(output, dim=1)
```

### ONNX
```python
import onnxruntime as ort
import numpy as np

# Load model
session = ort.InferenceSession('{self.model_name}.onnx')

# Prepare input
input_data = np.random.randn(1, 1, 28, 28).astype(np.float32)

# Inference
outputs = session.run(None, {{'input': input_data}})
predictions = outputs[0]
```

## Preprocessing
1. Normalize input images: `(x - 0.1307) / 0.3081`
2. Ensure input shape is (batch_size, 1, 28, 28)
3. Input should be float32 tensors in range [0, 1]

## Postprocessing
1. Apply softmax to get probabilities
2. Use argmax to get predicted class
3. Classes are numbered 0-9 representing digits

## Performance Notes
- TorchScript: Optimized for CPU inference
- ONNX: Better cross-platform compatibility
- Batch processing recommended for throughput

## Files
- `{self.model_name}_traced.pt`: TorchScript traced model
- `{self.model_name}_optimized.pt`: Optimized TorchScript model  
- `{self.model_name}.onnx`: ONNX model
- `model_metadata.json`: Model configuration and metadata
"""

```python        
        readme_path = os.path.join(deploy_dir, 'README.md')
        with open(readme_path, 'w') as f:
            f.write(readme_content)
        
        print(f"✓ README generated at {readme_path}")
        return readme_path
    
    def validate_deployment(self, num_test_samples=100):
        """Validate all deployment artifacts"""
        print(f"\n=== Validating Deployment Artifacts ===")
        
        validation_results = {}
        test_input = torch.randn(num_test_samples, 1, 28, 28)
        
        # Get reference output from original model
        self.model.eval()
        with torch.no_grad():
            reference_output = self.model(test_input)
        
        # Validate TorchScript models
        for artifact_name, artifact_path in self.deployment_artifacts.items():
            if artifact_name.startswith('torchscript') and artifact_path.endswith('.pt'):
                validation_results[artifact_name] = self._validate_torchscript(
                    artifact_path, test_input, reference_output
                )
        
        # Validate ONNX models
        if 'onnx' in self.deployment_artifacts:
            validation_results['onnx'] = self._validate_onnx(
                self.deployment_artifacts['onnx'], test_input, reference_output
            )
        
        return validation_results
    
    def _validate_torchscript(self, model_path, test_input, reference_output):
        """Validate TorchScript model"""
        try:
            model = torch.jit.load(model_path)
            model.eval()
            
            with torch.no_grad():
                output = model(test_input)
            
            # Check output similarity
            max_diff = torch.max(torch.abs(output - reference_output)).item()
            avg_diff = torch.mean(torch.abs(output - reference_output)).item()
            
            return {
                'success': True,
                'max_difference': max_diff,
                'avg_difference': avg_diff,
                'outputs_match': max_diff < 1e-5
            }
        except Exception as e:
            return {'success': False, 'error': str(e)}
    
    def _validate_onnx(self, model_path, test_input, reference_output):
        """Validate ONNX model"""
        try:
            session = ort.InferenceSession(model_path)
            ort_inputs = {session.get_inputs()[0].name: test_input.numpy()}
            ort_outputs = session.run(None, ort_inputs)
            
            onnx_output = torch.from_numpy(ort_outputs[0])
            
            # Check output similarity
            max_diff = torch.max(torch.abs(onnx_output - reference_output)).item()
            avg_diff = torch.mean(torch.abs(onnx_output - reference_output)).item()
            
            return {
                'success': True,
                'max_difference': max_diff,
                'avg_difference': avg_diff,
                'outputs_match': max_diff < 1e-4
            }
        except Exception as e:
            return {'success': False, 'error': str(e)}

# Create deployment pipeline
pipeline = ProductionDeploymentPipeline(model, "mnist_classifier_v1")

# Prepare deployment artifacts
example_input = torch.randn(1, 1, 28, 28)
deployment_artifacts = pipeline.prepare_for_deployment(example_input, ['torchscript', 'onnx'])

# Validate deployment
validation_results = pipeline.validate_deployment(num_test_samples=50)

print(f"\n=== Deployment Validation Results ===")
for artifact_name, result in validation_results.items():
    if result['success']:
        status = "✓ PASS" if result['outputs_match'] else "⚠ DIFF"
        print(f"{artifact_name:20}: {status} (max_diff: {result['max_difference']:.2e})")
    else:
        print(f"{artifact_name:20}: ✗ FAIL ({result.get('error', 'Unknown error')})")
```

## 7. Performance Benchmarking Suite

```python
class ModelBenchmarkSuite:
    """Comprehensive benchmarking for exported models"""
    
    def __init__(self):
        self.results = {}
    
    def run_comprehensive_benchmark(self, model_paths, test_configs):
        """Run comprehensive benchmarks across all models and configurations"""
        print("=== Comprehensive Model Benchmarking ===")
        
        results = {}
        
        for config_name, config in test_configs.items():
            print(f"\nTesting configuration: {config_name}")
            config_results = {}
            
            batch_size = config.get('batch_size', 1)
            num_runs = config.get('num_runs', 100)
            warmup_runs = config.get('warmup_runs', 10)
            
            test_input = torch.randn(batch_size, 1, 28, 28)
            
            # Benchmark each model format
            for model_name, model_path in model_paths.items():
                if os.path.exists(model_path):
                    config_results[model_name] = self._benchmark_single_model(
                        model_path, test_input, num_runs, warmup_runs, model_name
                    )
                else:
                    config_results[model_name] = {'error': 'Model file not found'}
            
            results[config_name] = config_results
        
        self.results = results
        return results
    
    def _benchmark_single_model(self, model_path, test_input, num_runs, warmup_runs, model_type):
        """Benchmark a single model"""
        try:
            if model_type.startswith('torchscript'):
                return self._benchmark_torchscript(model_path, test_input, num_runs, warmup_runs)
            elif model_type == 'onnx' and ONNX_AVAILABLE:
                return self._benchmark_onnx(model_path, test_input, num_runs, warmup_runs)
            else:
                return {'error': f'Unsupported model type: {model_type}'}
        except Exception as e:
            return {'error': str(e)}
    
    def _benchmark_torchscript(self, model_path, test_input, num_runs, warmup_runs):
        """Benchmark TorchScript model"""
        model = torch.jit.load(model_path)
        model.eval()
        
        with torch.no_grad():
            # Warmup
            for _ in range(warmup_runs):
                _ = model(test_input)
            
            # Benchmark
            times = []
            for _ in range(num_runs):
                start_time = time.time()
                output = model(test_input)
                times.append(time.time() - start_time)
        
        return {
            'mean_time': np.mean(times),
            'std_time': np.std(times),
            'min_time': np.min(times),
            'max_time': np.max(times),
            'throughput': test_input.shape[0] / np.mean(times),  # samples/second
            'file_size_mb': os.path.getsize(model_path) / 1024 / 1024
        }
    
    def _benchmark_onnx(self, model_path, test_input, num_runs, warmup_runs):
        """Benchmark ONNX model"""
        session = ort.InferenceSession(model_path)
        ort_inputs = {session.get_inputs()[0].name: test_input.numpy()}
        
        # Warmup
        for _ in range(warmup_runs):
            _ = session.run(None, ort_inputs)
        
        # Benchmark
        times = []
        for _ in range(num_runs):
            start_time = time.time()
            outputs = session.run(None, ort_inputs)
            times.append(time.time() - start_time)
        
        return {
            'mean_time': np.mean(times),
            'std_time': np.std(times),
            'min_time': np.min(times),
            'max_time': np.max(times),
            'throughput': test_input.shape[0] / np.mean(times),
            'file_size_mb': os.path.getsize(model_path) / 1024 / 1024
        }
    
    def generate_benchmark_report(self, save_path="benchmark_report.json"):
        """Generate comprehensive benchmark report"""
        if not self.results:
            print("No benchmark results available")
            return
        
        # Calculate summary statistics
        summary = {
            'fastest_model': {},
            'most_efficient': {},
            'smallest_model': {},
            'detailed_results': self.results
        }
        
        # Find best performing models across configurations
        for config_name, config_results in self.results.items():
            fastest_time = float('inf')
            fastest_model = None
            
            smallest_size = float('inf')
            smallest_model = None
            
            highest_throughput = 0
            most_efficient = None
            
            for model_name, result in config_results.items():
                if 'error' not in result:
                    # Check fastest
                    if result['mean_time'] < fastest_time:
                        fastest_time = result['mean_time']
                        fastest_model = model_name
                    
                    # Check smallest
                    if result['file_size_mb'] < smallest_size:
                        smallest_size = result['file_size_mb']
                        smallest_model = model_name
                    
                    # Check most efficient (highest throughput)
                    if result['throughput'] > highest_throughput:
                        highest_throughput = result['throughput']
                        most_efficient = model_name
            
            summary['fastest_model'][config_name] = {
                'model': fastest_model,
                'time': fastest_time
            }
            summary['smallest_model'][config_name] = {
                'model': smallest_model,
                'size_mb': smallest_size
            }
            summary['most_efficient'][config_name] = {
                'model': most_efficient,
                'throughput': highest_throughput
            }
        
        # Save report
        with open(save_path, 'w') as f:
            json.dump(summary, f, indent=2, default=str)
        
        print(f"✓ Benchmark report saved to {save_path}")
        
        # Print summary
        print("\n=== Benchmark Summary ===")
        for config_name in self.results.keys():
            print(f"\nConfiguration: {config_name}")
            if config_name in summary['fastest_model']:
                fastest = summary['fastest_model'][config_name]
                print(f"  Fastest: {fastest['model']} ({fastest['time']:.4f}s)")
            if config_name in summary['most_efficient']:
                efficient = summary['most_efficient'][config_name]
                print(f"  Most Efficient: {efficient['model']} ({efficient['throughput']:.1f} samples/s)")
            if config_name in summary['smallest_model']:
                smallest = summary['smallest_model'][config_name]
                print(f"  Smallest: {smallest['model']} ({smallest['size_mb']:.2f} MB)")
        
        return summary

# Run comprehensive benchmarks
benchmark_suite = ModelBenchmarkSuite()

# Define test configurations
test_configs = {
    'single_inference': {'batch_size': 1, 'num_runs': 100, 'warmup_runs': 10},
    'small_batch': {'batch_size': 4, 'num_runs': 100, 'warmup_runs': 10},
    'medium_batch': {'batch_size': 16, 'num_runs': 50, 'warmup_runs': 5},
    'large_batch': {'batch_size': 64, 'num_runs': 20, 'warmup_runs': 5}
}

# Model paths (update based on actual deployment artifacts)
model_paths = {
    'torchscript_traced': f'deployment_mnist_classifier_v1/mnist_classifier_v1_traced.pt',
    'torchscript_optimized': f'deployment_mnist_classifier_v1/mnist_classifier_v1_optimized.pt',
    'onnx': f'deployment_mnist_classifier_v1/mnist_classifier_v1.onnx'
}

# Filter existing models
existing_models = {name: path for name, path in model_paths.items() if os.path.exists(path)}

if existing_models:
    benchmark_results = benchmark_suite.run_comprehensive_benchmark(existing_models, test_configs)
    benchmark_summary = benchmark_suite.generate_benchmark_report()
else:
    print("No deployment models found for benchmarking")
```

# Summary

This notebook provided comprehensive model export strategies for production deployment using ONNX and TorchScript. Key implementations and concepts covered:

## Export Format Implementations
- **TorchScript Tracing**: JIT compilation for production inference
- **TorchScript Scripting**: Full model serialization with control flow
- **ONNX Export**: Cross-platform deployment format
- **Model Optimization**: Performance tuning for each export format

## Production Deployment Pipeline  
- **Multi-Format Export**: Automated export to multiple formats
- **Validation Testing**: Ensuring output consistency across formats
- **Metadata Generation**: Complete model documentation and configuration
- **Deployment Artifacts**: Production-ready model packages

## Performance Analysis and Benchmarking
- **Cross-Platform Testing**: Compatibility verification across formats
- **Batch Size Optimization**: Performance scaling analysis
- **Memory Usage Monitoring**: Efficient resource utilization
- **Throughput Benchmarking**: Real-world performance metrics

## Key Benefits Achieved
- **Production Readiness**: Complete deployment pipeline with validation
- **Cross-Platform Compatibility**: Models deployable across different environments
- **Performance Optimization**: Format-specific optimizations for speed
- **Documentation**: Comprehensive usage guides and metadata

## Export Format Comparison
- **TorchScript**: Best for PyTorch ecosystem, preserves Python semantics
- **ONNX**: Superior cross-platform support, broader runtime options
- **Optimization Trade-offs**: Speed vs. compatibility considerations
- **File Size**: Efficient model serialization and compression

## Next Steps
- Integrate with cloud deployment platforms (AWS, GCP, Azure)
- Add support for quantization and pruning
- Implement model serving with FastAPI/TorchServe
- Develop continuous integration for model validation

The export framework provides a robust foundation for deploying PyTorch Lightning models in production environments with optimal performance and reliability.