In [None]:
"""
BatteryMind - Inference Speed Test Notebook

Comprehensive inference speed testing and performance benchmarking for all
BatteryMind AI models including transformers, federated learning, reinforcement
learning, and ensemble models.

This notebook provides:
- Batch inference speed testing
- Real-time inference latency measurement
- Throughput analysis under different loads
- Memory usage during inference
- Model comparison across different architectures
- Edge deployment performance evaluation

Author: BatteryMind Development Team
Version: 1.0.0
"""

import numpy as np
import pandas as pd
import time
import psutil
import threading
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, List, Tuple, Any
import warnings
warnings.filterwarnings('ignore')

# Memory profiling
import tracemalloc
from memory_profiler import profile
import gc

# Model imports
import sys
sys.path.append('../../')
from transformers.battery_health_predictor.predictor import BatteryHealthPredictor
from transformers.degradation_forecaster.forecaster import DegradationForecaster
from transformers.optimization_recommender.recommender import OptimizationRecommender
from transformers.ensemble_model.ensemble import EnsembleModel
from reinforcement_learning.agents.charging_agent import ChargingAgent
from federated_learning.client_models.local_trainer import LocalTrainer

# Utility imports
from utils.data_utils import generate_test_data
from utils.model_utils import load_model_artifacts
from utils.visualization import plot_performance_metrics

print("BatteryMind Inference Speed Test Notebook")
print("="*50)

# Configuration
TEST_CONFIG = {
    'batch_sizes': [1, 8, 16, 32, 64, 128],
    'sequence_lengths': [100, 500, 1000, 2000],
    'num_iterations': 100,
    'warmup_iterations': 10,
    'models_to_test': ['transformer', 'federated', 'rl', 'ensemble'],
    'test_types': ['single_inference', 'batch_inference', 'concurrent_inference'],
    'device_types': ['cpu', 'gpu'] if torch.cuda.is_available() else ['cpu']
}

class InferenceSpeedTester:
    """
    Comprehensive inference speed testing framework for BatteryMind models.
    """
    
    def __init__(self, config: Dict[str, Any]):
        self.config = config
        self.results = {}
        self.models = {}
        self.test_data = {}
        
    def load_models(self):
        """Load all BatteryMind models for testing."""
        print("Loading BatteryMind models...")
        
        # Load transformer models
        self.models['battery_health'] = BatteryHealthPredictor.load_model(
            '../../model-artifacts/trained_models/transformer_v1.0/model.pkl'
        )
        
        self.models['degradation_forecaster'] = DegradationForecaster.load_model(
            '../../model-artifacts/trained_models/transformer_v1.0/model.pkl'
        )
        
        self.models['optimization_recommender'] = OptimizationRecommender.load_model(
            '../../model-artifacts/trained_models/transformer_v1.0/model.pkl'
        )
        
        # Load ensemble model
        self.models['ensemble'] = EnsembleModel.load_model(
            '../../model-artifacts/trained_models/ensemble_v1.0/ensemble_model.pkl'
        )
        
        # Load RL agent
        self.models['rl_agent'] = ChargingAgent.load_model(
            '../../model-artifacts/trained_models/rl_agent_v1.0/policy_network.pt'
        )
        
        # Load federated model
        self.models['federated'] = LocalTrainer.load_model(
            '../../model-artifacts/trained_models/federated_v1.0/global_model.pkl'
        )
        
        print(f"Loaded {len(self.models)} models successfully")
    
    def generate_test_data(self):
        """Generate test data for different batch sizes and sequence lengths."""
        print("Generating test data...")
        
        for batch_size in self.config['batch_sizes']:
            for seq_len in self.config['sequence_lengths']:
                # Battery telemetry data
                battery_data = np.random.randn(batch_size, seq_len, 10)  # 10 features
                
                # Environmental data
                env_data = np.random.randn(batch_size, seq_len, 5)  # 5 features
                
                # State data for RL
                state_data = np.random.randn(batch_size, 20)  # 20 state features
                
                key = f"batch_{batch_size}_seq_{seq_len}"
                self.test_data[key] = {
                    'battery_data': battery_data,
                    'env_data': env_data,
                    'state_data': state_data
                }
        
        print(f"Generated test data for {len(self.test_data)} configurations")
    
    def measure_inference_time(self, model, data, num_iterations=100, warmup=10):
        """
        Measure inference time for a model with given data.
        
        Returns:
            Dict with timing statistics
        """
        # Warmup
        for _ in range(warmup):
            try:
                _ = model.predict(data)
            except:
                pass
        
        # Clear cache
        if hasattr(torch, 'cuda') and torch.cuda.is_available():
            torch.cuda.empty_cache()
        
        # Measure inference time
        times = []
        
        for i in range(num_iterations):
            start_time = time.perf_counter()
            
            try:
                output = model.predict(data)
            except Exception as e:
                print(f"Error during inference: {e}")
                continue
            
            end_time = time.perf_counter()
            times.append(end_time - start_time)
        
        if not times:
            return {'mean': 0, 'std': 0, 'min': 0, 'max': 0, 'p95': 0, 'p99': 0}
        
        times = np.array(times)
        
        return {
            'mean': np.mean(times),
            'std': np.std(times),
            'min': np.min(times),
            'max': np.max(times),
            'p95': np.percentile(times, 95),
            'p99': np.percentile(times, 99),
            'throughput': len(times) / np.sum(times)
        }
    
    def test_single_inference(self):
        """Test single inference speed for all models."""
        print("Testing single inference speed...")
        
        results = {}
        
        for model_name, model in self.models.items():
            print(f"Testing {model_name}...")
            model_results = {}
            
            for batch_size in self.config['batch_sizes']:
                for seq_len in self.config['sequence_lengths']:
                    key = f"batch_{batch_size}_seq_{seq_len}"
                    
                    if key not in self.test_data:
                        continue
                    
                    # Select appropriate data based on model type
                    if model_name == 'rl_agent':
                        data = self.test_data[key]['state_data']
                    else:
                        data = self.test_data[key]['battery_data']
                    
                    # Measure inference time
                    timing_stats = self.measure_inference_time(
                        model, data, 
                        num_iterations=self.config['num_iterations'],
                        warmup=self.config['warmup_iterations']
                    )
                    
                    model_results[key] = timing_stats
            
            results[model_name] = model_results
        
        self.results['single_inference'] = results
        return results
    
    def test_batch_inference(self):
        """Test batch inference performance."""
        print("Testing batch inference performance...")
        
        results = {}
        
        for model_name, model in self.models.items():
            print(f"Testing batch inference for {model_name}...")
            model_results = {}
            
            for batch_size in self.config['batch_sizes']:
                # Use a fixed sequence length for batch testing
                seq_len = 1000
                key = f"batch_{batch_size}_seq_{seq_len}"
                
                if key not in self.test_data:
                    continue
                
                # Select appropriate data
                if model_name == 'rl_agent':
                    data = self.test_data[key]['state_data']
                else:
                    data = self.test_data[key]['battery_data']
                
                # Measure batch inference time
                start_time = time.perf_counter()
                
                try:
                    output = model.predict(data)
                    end_time = time.perf_counter()
                    
                    inference_time = end_time - start_time
                    throughput = batch_size / inference_time
                    
                    model_results[f'batch_{batch_size}'] = {
                        'inference_time': inference_time,
                        'throughput': throughput,
                        'time_per_sample': inference_time / batch_size
                    }
                
                except Exception as e:
                    print(f"Error in batch inference for {model_name}: {e}")
                    model_results[f'batch_{batch_size}'] = {
                        'inference_time': 0,
                        'throughput': 0,
                        'time_per_sample': 0
                    }
            
            results[model_name] = model_results
        
        self.results['batch_inference'] = results
        return results
    
    def test_concurrent_inference(self):
        """Test concurrent inference performance."""
        print("Testing concurrent inference performance...")
        
        results = {}
        
        def run_inference(model, data, num_requests=100):
            """Run multiple inference requests concurrently."""
            times = []
            
            def single_request():
                start_time = time.perf_counter()
                try:
                    _ = model.predict(data)
                    end_time = time.perf_counter()
                    return end_time - start_time
                except:
                    return 0
            
            with ThreadPoolExecutor(max_workers=10) as executor:
                futures = [executor.submit(single_request) for _ in range(num_requests)]
                times = [future.result() for future in futures]
            
            return times
        
        for model_name, model in self.models.items():
            print(f"Testing concurrent inference for {model_name}...")
            
            # Use medium batch size and sequence length
            batch_size = 32
            seq_len = 1000
            key = f"batch_{batch_size}_seq_{seq_len}"
            
            if key not in self.test_data:
                continue
            
            if model_name == 'rl_agent':
                data = self.test_data[key]['state_data']
            else:
                data = self.test_data[key]['battery_data']
            
            # Run concurrent requests
            concurrent_times = run_inference(model, data, num_requests=50)
            
            if concurrent_times:
                results[model_name] = {
                    'mean_time': np.mean(concurrent_times),
                    'std_time': np.std(concurrent_times),
                    'min_time': np.min(concurrent_times),
                    'max_time': np.max(concurrent_times),
                    'p95_time': np.percentile(concurrent_times, 95),
                    'requests_per_second': len(concurrent_times) / np.sum(concurrent_times)
                }
        
        self.results['concurrent_inference'] = results
        return results
    
    def measure_memory_usage(self):
        """Measure memory usage during inference."""
        print("Measuring memory usage...")
        
        results = {}
        
        for model_name, model in self.models.items():
            print(f"Measuring memory usage for {model_name}...")
            
            # Use large batch for memory testing
            batch_size = 64
            seq_len = 2000
            key = f"batch_{batch_size}_seq_{seq_len}"
            
            if key not in self.test_data:
                continue
            
            if model_name == 'rl_agent':
                data = self.test_data[key]['state_data']
            else:
                data = self.test_data[key]['battery_data']
            
            # Measure memory before inference
            gc.collect()
            memory_before = psutil.Process().memory_info().rss / 1024 / 1024  # MB
            
            # Start memory tracing
            tracemalloc.start()
            
            try:
                # Run inference
                output = model.predict(data)
                
                # Measure memory after inference
                current, peak = tracemalloc.get_traced_memory()
                memory_after = psutil.Process().memory_info().rss / 1024 / 1024  # MB
                
                results[model_name] = {
                    'memory_before_mb': memory_before,
                    'memory_after_mb': memory_after,
                    'memory_delta_mb': memory_after - memory_before,
                    'peak_memory_mb': peak / 1024 / 1024,
                    'current_memory_mb': current / 1024 / 1024
                }
            
            except Exception as e:
                print(f"Error measuring memory for {model_name}: {e}")
                results[model_name] = {
                    'memory_before_mb': 0,
                    'memory_after_mb': 0,
                    'memory_delta_mb': 0,
                    'peak_memory_mb': 0,
                    'current_memory_mb': 0
                }
            
            finally:
                tracemalloc.stop()
        
        self.results['memory_usage'] = results
        return results
    
    def generate_performance_report(self):
        """Generate comprehensive performance report."""
        print("Generating performance report...")
        
        report = {
            'test_configuration': self.config,
            'models_tested': list(self.models.keys()),
            'test_timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
            'results': self.results
        }
        
        # Calculate summary statistics
        summary = {}
        
        if 'single_inference' in self.results:
            summary['fastest_model'] = self._find_fastest_model()
            summary['most_efficient_model'] = self._find_most_efficient_model()
        
        if 'memory_usage' in self.results:
            summary['lowest_memory_model'] = self._find_lowest_memory_model()
        
        report['summary'] = summary
        
        return report
    
    def _find_fastest_model(self):
        """Find the fastest model across all configurations."""
        fastest_times = {}
        
        for model_name, model_results in self.results['single_inference'].items():
            avg_times = []
            for config, timing_stats in model_results.items():
                avg_times.append(timing_stats['mean'])
            
            if avg_times:
                fastest_times[model_name] = np.mean(avg_times)
        
        if fastest_times:
            return min(fastest_times, key=fastest_times.get)
        return None
    
    def _find_most_efficient_model(self):
        """Find the most efficient model (highest throughput)."""
        throughputs = {}
        
        for model_name, model_results in self.results['single_inference'].items():
            avg_throughputs = []
            for config, timing_stats in model_results.items():
                avg_throughputs.append(timing_stats['throughput'])
            
            if avg_throughputs:
                throughputs[model_name] = np.mean(avg_throughputs)
        
        if throughputs:
            return max(throughputs, key=throughputs.get)
        return None
    
    def _find_lowest_memory_model(self):
        """Find the model with lowest memory usage."""
        memory_usage = {}
        
        for model_name, memory_stats in self.results['memory_usage'].items():
            memory_usage[model_name] = memory_stats['memory_delta_mb']
        
        if memory_usage:
            return min(memory_usage, key=memory_usage.get)
        return None
    
    def visualize_results(self):
        """Create visualizations of the inference speed test results."""
        print("Creating visualizations...")
        
        # Set up plotting style
        plt.style.use('seaborn-v0_8')
        fig, axes = plt.subplots(2, 2, figsize=(15, 12))
        
        # 1. Single inference latency comparison
        if 'single_inference' in self.results:
            ax1 = axes[0, 0]
            self._plot_single_inference_latency(ax1)
            ax1.set_title('Single Inference Latency by Model')
            ax1.set_xlabel('Model')
            ax1.set_ylabel('Latency (seconds)')
        
        # 2. Batch inference throughput
        if 'batch_inference' in self.results:
            ax2 = axes[0, 1]
            self._plot_batch_throughput(ax2)
            ax2.set_title('Batch Inference Throughput')
            ax2.set_xlabel('Batch Size')
            ax2.set_ylabel('Throughput (samples/second)')
        
        # 3. Memory usage comparison
        if 'memory_usage' in self.results:
            ax3 = axes[1, 0]
            self._plot_memory_usage(ax3)
            ax3.set_title('Memory Usage by Model')
            ax3.set_xlabel('Model')
            ax3.set_ylabel('Memory Usage (MB)')
        
        # 4. Concurrent inference performance
        if 'concurrent_inference' in self.results:
            ax4 = axes[1, 1]
            self._plot_concurrent_performance(ax4)
            ax4.set_title('Concurrent Inference Performance')
            ax4.set_xlabel('Model')
            ax4.set_ylabel('Requests per Second')
        
        plt.tight_layout()
        plt.savefig('inference_speed_test_results.png', dpi=300, bbox_inches='tight')
        plt.show()
    
    def _plot_single_inference_latency(self, ax):
        """Plot single inference latency comparison."""
        model_names = []
        latencies = []
        
        for model_name, model_results in self.results['single_inference'].items():
            avg_latency = np.mean([stats['mean'] for stats in model_results.values()])
            model_names.append(model_name)
            latencies.append(avg_latency)
        
        ax.bar(model_names, latencies)
        ax.set_xticklabels(model_names, rotation=45, ha='right')
    
    def _plot_batch_throughput(self, ax):
        """Plot batch inference throughput."""
        for model_name, model_results in self.results['batch_inference'].items():
            batch_sizes = []
            throughputs = []
            
            for batch_config, stats in model_results.items():
                batch_size = int(batch_config.split('_')[1])
                batch_sizes.append(batch_size)
                throughputs.append(stats['throughput'])
            
            ax.plot(batch_sizes, throughputs, marker='o', label=model_name)
        
        ax.legend()
        ax.set_xscale('log')
    
    def _plot_memory_usage(self, ax):
        """Plot memory usage comparison."""
        model_names = []
        memory_usage = []
        
        for model_name, memory_stats in self.results['memory_usage'].items():
            model_names.append(model_name)
            memory_usage.append(memory_stats['memory_delta_mb'])
        
        ax.bar(model_names, memory_usage)
        ax.set_xticklabels(model_names, rotation=45, ha='right')
    
    def _plot_concurrent_performance(self, ax):
        """Plot concurrent inference performance."""
        model_names = []
        rps = []
        
        for model_name, stats in self.results['concurrent_inference'].items():
            model_names.append(model_name)
            rps.append(stats['requests_per_second'])
        
        ax.bar(model_names, rps)
        ax.set_xticklabels(model_names, rotation=45, ha='right')

# Run the inference speed tests
print("Initializing Inference Speed Tester...")
tester = InferenceSpeedTester(TEST_CONFIG)

print("Loading models...")
tester.load_models()

print("Generating test data...")
tester.generate_test_data()

print("Running single inference tests...")
single_results = tester.test_single_inference()

print("Running batch inference tests...")
batch_results = tester.test_batch_inference()

print("Running concurrent inference tests...")
concurrent_results = tester.test_concurrent_inference()

print("Measuring memory usage...")
memory_results = tester.measure_memory_usage()

print("Generating performance report...")
performance_report = tester.generate_performance_report()

print("Creating visualizations...")
tester.visualize_results()

# Save results
import json
with open('inference_speed_test_results.json', 'w') as f:
    json.dump(performance_report, f, indent=2, default=str)

print("Inference speed test completed!")
print("Results saved to 'inference_speed_test_results.json'")
print("Visualizations saved to 'inference_speed_test_results.png'")

# Display summary results
print("\nSUMMARY RESULTS:")
print("="*50)
if 'summary' in performance_report:
    summary = performance_report['summary']
    print(f"Fastest Model: {summary.get('fastest_model', 'N/A')}")
    print(f"Most Efficient Model: {summary.get('most_efficient_model', 'N/A')}")
    print(f"Lowest Memory Model: {summary.get('lowest_memory_model', 'N/A')}")

# Display detailed results table
print("\nDETAILED RESULTS:")
print("="*50)

if 'single_inference' in performance_report['results']:
    print("\nSingle Inference Latency (seconds):")
    for model_name, model_results in performance_report['results']['single_inference'].items():
        avg_latency = np.mean([stats['mean'] for stats in model_results.values()])
        print(f"{model_name}: {avg_latency:.4f}")

if 'memory_usage' in performance_report['results']:
    print("\nMemory Usage (MB):")
    for model_name, memory_stats in performance_report['results']['memory_usage'].items():
        print(f"{model_name}: {memory_stats['memory_delta_mb']:.2f}")

print("\nInference Speed Test Completed Successfully!")
