In [None]:
"""
BatteryMind - Transformer Model Development Notebook

Interactive development environment for building and training transformer models
for battery health prediction and degradation forecasting. This notebook provides
comprehensive model development workflow from data preparation to model evaluation.

Features:
- Multi-head attention transformer architecture for battery health prediction
- Time series preprocessing and feature engineering
- Advanced training techniques with early stopping and learning rate scheduling
- Model interpretability with attention visualization
- Comprehensive evaluation metrics and performance analysis
- Integration with BatteryMind ecosystem components

Author: BatteryMind Development Team
Version: 1.0.0
"""

# Import required libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Configure plotting
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

print("BatteryMind Transformer Development Environment")
print("=" * 50)

# Import BatteryMind modules
import sys
sys.path.append('../../')

from transformers.battery_health_predictor.model import BatteryHealthTransformer
from transformers.battery_health_predictor.trainer import BatteryHealthTrainer
from transformers.battery_health_predictor.data_loader import BatteryDataLoader
from transformers.battery_health_predictor.preprocessing import BatteryDataPreprocessor
from transformers.common.attention_layers import MultiHeadAttention
from transformers.common.positional_encoding import PositionalEncoding
from training_data.generators.synthetic_generator import SyntheticDataGenerator
from training_data.preprocessing_scripts.feature_extractor import BatteryFeatureExtractor
from training_data.preprocessing_scripts.normalization import BatteryDataNormalizer

print("✓ BatteryMind modules imported successfully")

# Configuration
class TransformerConfig:
    """Configuration for transformer model development."""
    
    # Data parameters
    SEQUENCE_LENGTH = 144  # 24 hours at 10-minute intervals
    FEATURE_DIM = 12  # Number of input features
    BATCH_SIZE = 32
    NUM_SAMPLES = 10000
    
    # Model parameters
    D_MODEL = 256
    NUM_HEADS = 8
    NUM_LAYERS = 6
    D_FF = 1024
    DROPOUT = 0.1
    
    # Training parameters
    LEARNING_RATE = 1e-4
    NUM_EPOCHS = 100
    PATIENCE = 10
    WEIGHT_DECAY = 1e-5
    
    # Output parameters
    OUTPUT_DIM = 3  # SoH, RUL, Next_SoC
    
    # Device configuration
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

config = TransformerConfig()
print(f"✓ Configuration loaded - Device: {config.DEVICE}")

# Data Generation and Preprocessing
print("\n1. SYNTHETIC DATA GENERATION")
print("-" * 30)

# Generate synthetic battery data
data_generator = SyntheticDataGenerator()
print("Generating synthetic battery telemetry data...")

# Generate diverse battery scenarios
battery_data = []
for i in range(config.NUM_SAMPLES):
    # Generate battery parameters
    battery_params = {
        'capacity_ah': np.random.uniform(50, 100),
        'chemistry': np.random.choice(['lithium_ion', 'lifepo4']),
        'temperature_range': (-20, 60),
        'usage_pattern': np.random.choice(['urban', 'highway', 'mixed'])
    }
    
    # Generate time series data
    time_series = data_generator.generate_battery_timeseries(
        duration_hours=24,
        sampling_rate=10,  # 10-minute intervals
        battery_params=battery_params
    )
    
    battery_data.append(time_series)
    
    if (i + 1) % 1000 == 0:
        print(f"Generated {i + 1}/{config.NUM_SAMPLES} battery scenarios")

print("✓ Synthetic data generation completed")

# Convert to DataFrame for easier manipulation
df_list = []
for i, data in enumerate(battery_data):
    df = pd.DataFrame(data)
    df['battery_id'] = i
    df_list.append(df)

battery_df = pd.concat(df_list, ignore_index=True)
print(f"✓ Combined dataset shape: {battery_df.shape}")

# Display sample data
print("\nSample Battery Data:")
print(battery_df.head())

# Feature Engineering
print("\n2. FEATURE ENGINEERING")
print("-" * 25)

feature_extractor = BatteryFeatureExtractor()

# Extract features from raw sensor data
features = feature_extractor.extract_features(battery_df)
print(f"✓ Extracted {len(features.columns)} features")

# Feature importance analysis
feature_importance = feature_extractor.analyze_feature_importance(features)
print("\nTop 10 Most Important Features:")
for i, (feature, importance) in enumerate(feature_importance.head(10).items()):
    print(f"{i+1:2d}. {feature:<25} {importance:.4f}")

# Visualize feature distributions
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.ravel()

key_features = ['voltage', 'current', 'temperature', 'soc', 'soh', 'internal_resistance']
for i, feature in enumerate(key_features):
    if feature in features.columns:
        axes[i].hist(features[feature], bins=50, alpha=0.7)
        axes[i].set_title(f'{feature.title()} Distribution')
        axes[i].set_xlabel(feature)
        axes[i].set_ylabel('Frequency')

plt.tight_layout()
plt.show()

# Data Preprocessing
print("\n3. DATA PREPROCESSING")
print("-" * 22)

preprocessor = BatteryDataPreprocessor()
normalizer = BatteryDataNormalizer()

# Normalize features
normalized_features = normalizer.normalize(features)
print("✓ Feature normalization completed")

# Create sequences for transformer input
def create_sequences(data, sequence_length, target_columns):
    """Create sequences for transformer training."""
    sequences = []
    targets = []
    
    for battery_id in data['battery_id'].unique():
        battery_data = data[data['battery_id'] == battery_id].sort_values('timestamp')
        
        for i in range(len(battery_data) - sequence_length):
            # Input sequence
            seq = battery_data.iloc[i:i+sequence_length][
                ['voltage', 'current', 'temperature', 'soc', 'power', 
                 'energy', 'efficiency', 'thermal_gradient', 'resistance_change',
                 'capacity_fade', 'cycle_count', 'age_days']
            ].values
            
            # Target (next values)
            target = battery_data.iloc[i+sequence_length][target_columns].values
            
            sequences.append(seq)
            targets.append(target)
    
    return np.array(sequences), np.array(targets)

# Create sequences
target_columns = ['soh', 'remaining_useful_life', 'soc']
X, y = create_sequences(normalized_features, config.SEQUENCE_LENGTH, target_columns)

print(f"✓ Created {len(X)} sequences")
print(f"  Input shape: {X.shape}")
print(f"  Target shape: {y.shape}")

# Train-validation-test split
X_temp, X_test, y_temp, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, shuffle=True
)
X_train, X_val, y_train, y_val = train_test_split(
    X_temp, y_temp, test_size=0.25, random_state=42, shuffle=True
)

print(f"✓ Data split completed:")
print(f"  Training: {X_train.shape[0]} samples")
print(f"  Validation: {X_val.shape[0]} samples")
print(f"  Test: {X_test.shape[0]} samples")

# Convert to PyTorch tensors
X_train_tensor = torch.FloatTensor(X_train).to(config.DEVICE)
y_train_tensor = torch.FloatTensor(y_train).to(config.DEVICE)
X_val_tensor = torch.FloatTensor(X_val).to(config.DEVICE)
y_val_tensor = torch.FloatTensor(y_val).to(config.DEVICE)
X_test_tensor = torch.FloatTensor(X_test).to(config.DEVICE)
y_test_tensor = torch.FloatTensor(y_test).to(config.DEVICE)

# Create data loaders
train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
val_dataset = TensorDataset(X_val_tensor, y_val_tensor)
test_dataset = TensorDataset(X_test_tensor, y_test_tensor)

train_loader = DataLoader(train_dataset, batch_size=config.BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=config.BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=config.BATCH_SIZE, shuffle=False)

print("✓ Data loaders created")

# Model Architecture
print("\n4. TRANSFORMER MODEL ARCHITECTURE")
print("-" * 35)

class BatteryTransformer(nn.Module):
    """
    Transformer model for battery health prediction.
    """
    
    def __init__(self, config):
        super(BatteryTransformer, self).__init__()
        self.config = config
        
        # Input projection
        self.input_projection = nn.Linear(config.FEATURE_DIM, config.D_MODEL)
        
        # Positional encoding
        self.positional_encoding = PositionalEncoding(config.D_MODEL, config.DROPOUT)
        
        # Transformer encoder layers
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=config.D_MODEL,
            nhead=config.NUM_HEADS,
            dim_feedforward=config.D_FF,
            dropout=config.DROPOUT,
            activation='gelu',
            batch_first=True
        )
        
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer, num_layers=config.NUM_LAYERS
        )
        
        # Output layers
        self.layer_norm = nn.LayerNorm(config.D_MODEL)
        self.dropout = nn.Dropout(config.DROPOUT)
        
        # Multi-task output heads
        self.soh_head = nn.Sequential(
            nn.Linear(config.D_MODEL, config.D_MODEL // 2),
            nn.ReLU(),
            nn.Dropout(config.DROPOUT),
            nn.Linear(config.D_MODEL // 2, 1),
            nn.Sigmoid()
        )
        
        self.rul_head = nn.Sequential(
            nn.Linear(config.D_MODEL, config.D_MODEL // 2),
            nn.ReLU(),
            nn.Dropout(config.DROPOUT),
            nn.Linear(config.D_MODEL // 2, 1),
            nn.ReLU()
        )
        
        self.soc_head = nn.Sequential(
            nn.Linear(config.D_MODEL, config.D_MODEL // 2),
            nn.ReLU(),
            nn.Dropout(config.DROPOUT),
            nn.Linear(config.D_MODEL // 2, 1),
            nn.Sigmoid()
        )
        
        # Initialize weights
        self._init_weights()
    
    def _init_weights(self):
        """Initialize model weights."""
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)
    
    def forward(self, x, mask=None):
        """
        Forward pass through the transformer.
        
        Args:
            x: Input tensor of shape (batch_size, seq_len, feature_dim)
            mask: Optional attention mask
            
        Returns:
            Dictionary with predictions for each task
        """
        # Input projection
        x = self.input_projection(x)
        
        # Add positional encoding
        x = self.positional_encoding(x)
        
        # Transformer encoding
        encoded = self.transformer_encoder(x, src_key_padding_mask=mask)
        
        # Global average pooling
        pooled = encoded.mean(dim=1)
        
        # Layer normalization and dropout
        pooled = self.layer_norm(pooled)
        pooled = self.dropout(pooled)
        
        # Multi-task predictions
        soh_pred = self.soh_head(pooled)
        rul_pred = self.rul_head(pooled)
        soc_pred = self.soc_head(pooled)
        
        return {
            'soh': soh_pred,
            'rul': rul_pred,
            'soc': soc_pred
        }

# Initialize model
model = BatteryTransformer(config).to(config.DEVICE)
print(f"✓ Model initialized with {sum(p.numel() for p in model.parameters()):,} parameters")

# Model summary
def count_parameters(model):
    """Count trainable parameters."""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"✓ Trainable parameters: {count_parameters(model):,}")

# Training Setup
print("\n5. TRAINING SETUP")
print("-" * 17)

# Loss functions
class MultiTaskLoss(nn.Module):
    """Multi-task loss for battery prediction."""
    
    def __init__(self, weights=None):
        super(MultiTaskLoss, self).__init__()
        self.weights = weights or {'soh': 1.0, 'rul': 1.0, 'soc': 1.0}
        self.mse_loss = nn.MSELoss()
        self.mae_loss = nn.L1Loss()
    
    def forward(self, predictions, targets):
        """
        Compute multi-task loss.
        
        Args:
            predictions: Dictionary of predictions
            targets: Target tensor (batch_size, 3) for [soh, rul, soc]
        """
        soh_loss = self.mse_loss(predictions['soh'].squeeze(), targets[:, 0])
        rul_loss = self.mae_loss(predictions['rul'].squeeze(), targets[:, 1])
        soc_loss = self.mse_loss(predictions['soc'].squeeze(), targets[:, 2])
        
        total_loss = (
            self.weights['soh'] * soh_loss +
            self.weights['rul'] * rul_loss +
            self.weights['soc'] * soc_loss
        )
        
        return total_loss, {
            'soh_loss': soh_loss.item(),
            'rul_loss': rul_loss.item(),
            'soc_loss': soc_loss.item(),
            'total_loss': total_loss.item()
        }

# Initialize loss and optimizer
criterion = MultiTaskLoss()
optimizer = optim.AdamW(
    model.parameters(),
    lr=config.LEARNING_RATE,
    weight_decay=config.WEIGHT_DECAY
)

# Learning rate scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=5, verbose=True
)

print("✓ Loss function and optimizer initialized")

# Training Loop
print("\n6. MODEL TRAINING")
print("-" * 16)

def train_epoch(model, train_loader, criterion, optimizer, device):
    """Train for one epoch."""
    model.train()
    total_loss = 0
    loss_components = {'soh_loss': 0, 'rul_loss': 0, 'soc_loss': 0}
    
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        predictions = model(data)
        loss, components = criterion(predictions, target)
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        total_loss += loss.item()
        for key, value in components.items():
            if key in loss_components:
                loss_components[key] += value
    
    avg_loss = total_loss / len(train_loader)
    for key in loss_components:
        loss_components[key] /= len(train_loader)
    
    return avg_loss, loss_components

def validate_epoch(model, val_loader, criterion, device):
    """Validate for one epoch."""
    model.eval()
    total_loss = 0
    loss_components = {'soh_loss': 0, 'rul_loss': 0, 'soc_loss': 0}
    
    with torch.no_grad():
        for data, target in val_loader:
            data, target = data.to(device), target.to(device)
            predictions = model(data)
            loss, components = criterion(predictions, target)
            
            total_loss += loss.item()
            for key, value in components.items():
                if key in loss_components:
                    loss_components[key] += value
    
    avg_loss = total_loss / len(val_loader)
    for key in loss_components:
        loss_components[key] /= len(val_loader)
    
    return avg_loss, loss_components

# Training history
train_losses = []
val_losses = []
best_val_loss = float('inf')
patience_counter = 0

print("Starting training...")
for epoch in range(config.NUM_EPOCHS):
    # Training
    train_loss, train_components = train_epoch(
        model, train_loader, criterion, optimizer, config.DEVICE
    )
    
    # Validation
    val_loss, val_components = validate_epoch(
        model, val_loader, criterion, config.DEVICE
    )
    
    # Learning rate scheduling
    scheduler.step(val_loss)
    
    # Record losses
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    
    # Early stopping
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
        # Save best model
        torch.save(model.state_dict(), 'best_transformer_model.pth')
    else:
        patience_counter += 1
    
    # Print progress
    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1:3d}/{config.NUM_EPOCHS}")
        print(f"  Train Loss: {train_loss:.6f}")
        print(f"  Val Loss:   {val_loss:.6f}")
        print(f"  Components: SoH={val_components['soh_loss']:.6f}, "
              f"RUL={val_components['rul_loss']:.6f}, "
              f"SoC={val_components['soc_loss']:.6f}")
        print(f"  LR: {optimizer.param_groups[0]['lr']:.2e}")
    
    # Early stopping
    if patience_counter >= config.PATIENCE:
        print(f"Early stopping at epoch {epoch+1}")
        break

print("✓ Training completed")

# Load best model
model.load_state_dict(torch.load('best_transformer_model.pth'))
print("✓ Best model loaded")

# Visualize training progress
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Training Loss', alpha=0.7)
plt.plot(val_losses, label='Validation Loss', alpha=0.7)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Progress')
plt.legend()
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
plt.plot(train_losses[-50:], label='Training Loss (Last 50)', alpha=0.7)
plt.plot(val_losses[-50:], label='Validation Loss (Last 50)', alpha=0.7)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Progress (Recent)')
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Model Evaluation
print("\n7. MODEL EVALUATION")
print("-" * 18)

def evaluate_model(model, test_loader, device):
    """Comprehensive model evaluation."""
    model.eval()
    all_predictions = {'soh': [], 'rul': [], 'soc': []}
    all_targets = []
    
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            predictions = model(data)
            
            all_predictions['soh'].extend(predictions['soh'].cpu().numpy())
            all_predictions['rul'].extend(predictions['rul'].cpu().numpy())
            all_predictions['soc'].extend(predictions['soc'].cpu().numpy())
            all_targets.extend(target.cpu().numpy())
    
    # Convert to numpy arrays
    for key in all_predictions:
        all_predictions[key] = np.array(all_predictions[key]).flatten()
    all_targets = np.array(all_targets)
    
    return all_predictions, all_targets

# Evaluate on test set
predictions, targets = evaluate_model(model, test_loader, config.DEVICE)

# Calculate metrics
def calculate_metrics(y_true, y_pred, task_name):
    """Calculate evaluation metrics."""
    mse = mean_squared_error(y_true, y_pred)
    mae = mean_absolute_error(y_true, y_pred)
    r2 = r2_score(y_true, y_pred)
    
    return {
        f'{task_name}_mse': mse,
        f'{task_name}_mae': mae,
        f'{task_name}_r2': r2,
        f'{task_name}_rmse': np.sqrt(mse)
    }

# Calculate metrics for each task
metrics = {}
tasks = ['soh', 'rul', 'soc']
for i, task in enumerate(tasks):
    task_metrics = calculate_metrics(targets[:, i], predictions[task], task)
    metrics.update(task_metrics)

# Print evaluation results
print("Test Set Evaluation Results:")
print("=" * 40)
for task in tasks:
    print(f"\n{task.upper()} Prediction:")
    print(f"  RMSE: {metrics[f'{task}_rmse']:.6f}")
    print(f"  MAE:  {metrics[f'{task}_mae']:.6f}")
    print(f"  R²:   {metrics[f'{task}_r2']:.6f}")

# Visualize predictions vs targets
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

for i, task in enumerate(tasks):
    axes[i].scatter(targets[:, i], predictions[task], alpha=0.5, s=1)
    axes[i].plot([targets[:, i].min(), targets[:, i].max()], 
                 [targets[:, i].min(), targets[:, i].max()], 'r--', lw=2)
    axes[i].set_xlabel(f'True {task.upper()}')
    axes[i].set_ylabel(f'Predicted {task.upper()}')
    axes[i].set_title(f'{task.upper()} Predictions (R² = {metrics[f"{task}_r2"]:.3f})')
    axes[i].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Attention Visualization
print("\n8. ATTENTION VISUALIZATION")
print("-" * 27)

def visualize_attention(model, sample_input, layer_idx=0, head_idx=0):
    """Visualize attention weights."""
    model.eval()
    
    with torch.no_grad():
        # Get attention weights from transformer encoder
        sample_input = sample_input.unsqueeze(0)  # Add batch dimension
        
        # Forward pass through input projection and positional encoding
        x = model.input_projection(sample_input)
        x = model.positional_encoding(x)
        
        # Get attention weights from specific layer
        encoder_layer = model.transformer_encoder.layers[layer_idx]
        attn_weights = encoder_layer.self_attn(x, x, x, need_weights=True)[1]
        
        # Extract specific head
        attention = attn_weights[0, head_idx].cpu().numpy()
        
        return attention

# Visualize attention for a sample
sample_idx = 0
sample_input = X_test_tensor[sample_idx]
attention_weights = visualize_attention(model, sample_input)

plt.figure(figsize=(12, 8))
sns.heatmap(attention_weights, cmap='Blues', cbar=True)
plt.title('Attention Weights Visualization (Layer 0, Head 0)')
plt.xlabel('Key Position')
plt.ylabel('Query Position')
plt.show()

# Feature Importance Analysis
print("\n9. FEATURE IMPORTANCE ANALYSIS")
print("-" * 32)

def analyze_feature_importance(model, test_loader, device):
    """Analyze feature importance using gradient-based methods."""
    model.eval()
    
    feature_gradients = []
    
    for data, target in test_loader:
        data = data.to(device)
        data.requires_grad_(True)
        
        predictions = model(data)
        
        # Calculate gradients for SoH prediction
        soh_pred = predictions['soh']
        grad_outputs = torch.ones_like(soh_pred)
        
        gradients = torch.autograd.grad(
            outputs=soh_pred,
            inputs=data,
            grad_outputs=grad_outputs,
            create_graph=False,
            retain_graph=False
        )[0]
        
        # Average gradients across sequence length
        feature_grads = gradients.abs().mean(dim=1).mean(dim=0)
        feature_gradients.append(feature_grads.cpu().numpy())
    
    # Average across all samples
    avg_gradients = np.mean(feature_gradients, axis=0)
    
    return avg_gradients

# Calculate feature importance
feature_importance = analyze_feature_importance(model, test_loader, config.DEVICE)

# Visualize feature importance
feature_names = ['voltage', 'current', 'temperature', 'soc', 'power', 
                'energy', 'efficiency', 'thermal_gradient', 'resistance_change',
                'capacity_fade', 'cycle_count', 'age_days']

plt.figure(figsize=(12, 6))
importance_df = pd.DataFrame({
    'feature': feature_names,
    'importance': feature_importance
}).sort_values('importance', ascending=True)

plt.barh(importance_df['feature'], importance_df['importance'])
plt.xlabel('Feature Importance (Gradient Magnitude)')
plt.title('Feature Importance Analysis')
plt.tight_layout()
plt.show()

# Model Interpretability
print("\n10. MODEL INTERPRETABILITY")
print("-" * 26)

def explain_prediction(model, sample_input, feature_names):
    """Explain a single prediction."""
    model.eval()
    
    sample_input = sample_input.unsqueeze(0)
    sample_input.requires_grad_(True)
    
    predictions = model(sample_input)
    
    # Get gradients for each output
    explanations = {}
    
    for task in ['soh', 'rul', 'soc']:
        output = predictions[task]
        
        gradients = torch.autograd.grad(
            outputs=output,
            inputs=sample_input,
            grad_outputs=torch.ones_like(output),
            create_graph=False,
            retain_graph=True
        )[0]
        
        # Average across sequence length
        feature_importance = gradients.abs().mean(dim=1).squeeze().cpu().numpy()
        explanations[task] = feature_importance
    
    return explanations, predictions

# Explain a sample prediction
sample_explanations, sample_predictions = explain_prediction(
    model, X_test_tensor[0], feature_names
)

print("Sample Prediction Explanation:")
print(f"Predicted SoH: {sample_predictions['soh'].item():.4f}")
print(f"Predicted RUL: {sample_predictions['rul'].item():.4f}")
print(f"Predicted SoC: {sample_predictions['soc'].item():.4f}")

# Visualize explanation
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

for i, task in enumerate(['soh', 'rul', 'soc']):
    axes[i].bar(feature_names, sample_explanations[task])
    axes[i].set_title(f'{task.upper()} Prediction Explanation')
    axes[i].set_ylabel('Feature Importance')
    axes[i].tick_params(axis='x', rotation=45)

plt.tight_layout()
plt.show()

# Model Deployment Preparation
print("\n11. MODEL DEPLOYMENT PREPARATION")
print("-" * 35)

# Save final model
torch.save({
    'model_state_dict': model.state_dict(),
    'config': config,
    'feature_names': feature_names,
    'metrics': metrics,
    'model_architecture': str(model)
}, 'battery_transformer_final.pth')

print("✓ Model saved for deployment")

# Create ONNX export for production
dummy_input = torch.randn(1, config.SEQUENCE_LENGTH, config.FEATURE_DIM).to(config.DEVICE)

torch.onnx.export(
    model,
    dummy_input,
    'battery_transformer.onnx',
    export_params=True,
    opset_version=11,
    do_constant_folding=True,
    input_names=['input'],
    output_names=['soh', 'rul', 'soc'],
    dynamic_axes={
        'input': {0: 'batch_size'},
        'soh': {0: 'batch_size'},
        'rul': {0: 'batch_size'},
        'soc': {0: 'batch_size'}
    }
)

print("✓ ONNX model exported for production deployment")

# Summary Report
print("\n12. DEVELOPMENT SUMMARY")
print("-" * 23)

print("Transformer Model Development Completed Successfully!")
print("=" * 55)
print(f"Model Architecture: {config.NUM_LAYERS}-layer Transformer")
print(f"Model Parameters: {count_parameters(model):,}")
print(f"Training Samples: {len(X_train):,}")
print(f"Validation Samples: {len(X_val):,}")
print(f"Test Samples: {len(X_test):,}")
print(f"Best Validation Loss: {best_val_loss:.6f}")

print("\nFinal Test Performance:")
for task in tasks:
    print(f"  {task.upper()}: RMSE={metrics[f'{task}_rmse']:.4f}, "
          f"MAE={metrics[f'{task}_mae']:.4f}, R²={metrics[f'{task}_r2']:.4f}")

print("\nModel Artifacts Generated:")
print("  ✓ best_transformer_model.pth - Best model weights")
print("  ✓ battery_transformer_final.pth - Final model with metadata")
print("  ✓ battery_transformer.onnx - ONNX model for deployment")

print("\nNext Steps:")
print("  1. Deploy model to production environment")
print("  2. Set up real-time inference pipeline")
print("  3. Implement model monitoring and drift detection")
print("  4. Plan for continuous model updates")

print("\n" + "="*55)
print("BatteryMind Transformer Development Complete!")
