# Nuclear Fusion Model Training and Evaluation

This notebook demonstrates the training and evaluation of multiple machine learning models for nuclear fusion prediction.

## Contents
1. Data Preparation
2. Model Training
3. Model Comparison
4. Hyperparameter Optimization
5. Feature Importance Analysis
6. Model Validation
7. Deployment Preparation

In [None]:
# Import required libraries
import sys
import os
sys.path.append('..')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score

# Import fusion analyzer modules
from src.data.generator import FusionDataGenerator
from src.data.processor import FusionDataProcessor
from src.models.fusion_predictor import FusionPredictor
from src.models.anomaly_detector import FusionAnomalyDetector
from src.utils.evaluator import FusionModelEvaluator
from src.visualization.plotter import FusionPlotter

# Set style
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

print("Libraries imported successfully!")

## 1. Data Preparation

In [None]:
# Generate training data
generator = FusionDataGenerator()
raw_data = generator.generate_dataset(n_samples=10000)

print(f"Generated {len(raw_data)} samples with {len(raw_data.columns)} features")
print(f"Target variable (Q Factor) range: {raw_data['q_factor'].min():.3f} to {raw_data['q_factor'].max():.3f}")

In [None]:
# Preprocess data
processor = FusionDataProcessor()
processed_data = processor.preprocess_pipeline(
    raw_data, 
    target_column='q_factor', 
    test_size=0.2, 
    validation_size=0.1
)

X_train = processed_data['X_train']
X_val = processed_data['X_val']
X_test = processed_data['X_test']
y_train = processed_data['y_train']
y_val = processed_data['y_val']
y_test = processed_data['y_test']

print(f"Training set: {X_train.shape}")
print(f"Validation set: {X_val.shape}")
print(f"Test set: {X_test.shape}")
print(f"Features: {len(processed_data['feature_names'])}")

## 2. Model Training

In [None]:
# Initialize predictor and train all models
predictor = FusionPredictor()

print("Training multiple ML models...")
print("Available models:", list(predictor.models.keys()))

In [None]:
# Train individual models
model_results = {}

for model_name in predictor.models.keys():
    print(f"\nTraining {model_name}...")
    try:
        results = predictor.train_model(
            model_name, X_train, y_train, X_val, y_val
        )
        model_results[model_name] = results
        print(f"  Training R²: {results['train_r2']:.4f}")
        print(f"  Validation R²: {results['val_r2']:.4f}")
        print(f"  Validation RMSE: {results['val_rmse']:.4f}")
    except Exception as e:
        print(f"  Failed: {e}")
        model_results[model_name] = {'error': str(e)}

In [None]:
# Train deep learning model if available
try:
    print("\nTraining deep learning model...")
    dl_results = predictor.train_deep_learning_model(
        X_train, y_train, X_val, y_val, epochs=50
    )
    model_results['deep_learning'] = dl_results
    print(f"  Training R²: {dl_results['train_r2']:.4f}")
    print(f"  Validation R²: {dl_results['val_r2']:.4f}")
except Exception as e:
    print(f"Deep learning training failed: {e}")

## 3. Model Comparison

In [None]:
# Create comparison DataFrame
comparison_data = []

for model_name, results in model_results.items():
    if isinstance(results, dict) and 'val_r2' in results:
        comparison_data.append({
            'Model': model_name,
            'Train R²': results.get('train_r2', 0),
            'Val R²': results.get('val_r2', 0),
            'Train RMSE': results.get('train_rmse', 0),
            'Val RMSE': results.get('val_rmse', 0),
            'Train MAE': results.get('train_mae', 0),
            'Val MAE': results.get('val_mae', 0)
        })

comparison_df = pd.DataFrame(comparison_data)
comparison_df = comparison_df.sort_values('Val R²', ascending=False)

print("Model Performance Comparison:")
print(comparison_df.round(4))

In [None]:
# Visualize model comparison
if not comparison_df.empty:
    fig, axes = plt.subplots(1, 2, figsize=(15, 6))
    
    # R² scores
    x_pos = np.arange(len(comparison_df))
    width = 0.35
    
    axes[0].bar(x_pos - width/2, comparison_df['Train R²'], width, 
               label='Training', alpha=0.8)
    axes[0].bar(x_pos + width/2, comparison_df['Val R²'], width, 
               label='Validation', alpha=0.8)
    
    axes[0].set_xlabel('Model')
    axes[0].set_ylabel('R² Score')
    axes[0].set_title('Model R² Comparison')
    axes[0].set_xticks(x_pos)
    axes[0].set_xticklabels(comparison_df['Model'], rotation=45)
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # RMSE
    axes[1].bar(x_pos - width/2, comparison_df['Train RMSE'], width, 
               label='Training', alpha=0.8)
    axes[1].bar(x_pos + width/2, comparison_df['Val RMSE'], width, 
               label='Validation', alpha=0.8)
    
    axes[1].set_xlabel('Model')
    axes[1].set_ylabel('RMSE')
    axes[1].set_title('Model RMSE Comparison')
    axes[1].set_xticks(x_pos)
    axes[1].set_xticklabels(comparison_df['Model'], rotation=45)
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

## 4. Best Model Evaluation

In [None]:
# Select best model based on validation R²
if not comparison_df.empty:
    best_model_name = comparison_df.iloc[0]['Model']
    print(f"Best model: {best_model_name}")
    
    # Get test predictions
    y_test_pred = predictor.predict(X_test, best_model_name)
    
    # Comprehensive evaluation
    evaluator = FusionModelEvaluator()
    test_evaluation = evaluator.comprehensive_evaluation(
        predictor.trained_models[best_model_name],
        X_test, y_test, y_test_pred, best_model_name
    )
    
    # Generate evaluation report
    report = evaluator.generate_evaluation_report(test_evaluation)
    print("\n" + report)

In [None]:
# Prediction vs actual plot
if 'y_test_pred' in locals():
    plt.figure(figsize=(10, 8))
    
    # Scatter plot
    plt.scatter(y_test, y_test_pred, alpha=0.6, s=20)
    
    # Perfect prediction line
    min_val = min(y_test.min(), y_test_pred.min())
    max_val = max(y_test.max(), y_test_pred.max())
    plt.plot([min_val, max_val], [min_val, max_val], 'r--', lw=2, label='Perfect Prediction')
    
    # Breakeven lines
    plt.axhline(y=1.0, color='orange', linestyle=':', alpha=0.7, label='Q=1 (Breakeven)')
    plt.axvline(x=1.0, color='orange', linestyle=':', alpha=0.7)
    
    plt.xlabel('Actual Q Factor')
    plt.ylabel('Predicted Q Factor')
    plt.title(f'{best_model_name} - Prediction vs Actual')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # Add R² score
    r2 = r2_score(y_test, y_test_pred)
    plt.text(0.05, 0.95, f'R² = {r2:.4f}', transform=plt.gca().transAxes, 
             bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    plt.tight_layout()
    plt.show()

## 5. Feature Importance Analysis

In [None]:
# Analyze feature importance
if best_model_name in predictor.feature_importance:
    importance = predictor.feature_importance[best_model_name]
    
    # Convert to DataFrame and sort
    importance_df = pd.DataFrame(list(importance.items()), 
                                columns=['Feature', 'Importance'])
    importance_df = importance_df.sort_values('Importance', ascending=True)
    
    # Plot top 15 features
    top_features = importance_df.tail(15)
    
    plt.figure(figsize=(10, 8))
    plt.barh(range(len(top_features)), top_features['Importance'])
    plt.yticks(range(len(top_features)), top_features['Feature'])
    plt.xlabel('Feature Importance')
    plt.title(f'Top 15 Feature Importance - {best_model_name}')
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()
    
    print("Top 10 Most Important Features:")
    for i, (feature, imp) in enumerate(importance_df.tail(10).values[::-1], 1):
        print(f"{i:2d}. {feature:<25} {imp:.4f}")

## 6. Error Analysis

In [None]:
# Analyze prediction errors
if 'y_test_pred' in locals():
    errors = y_test - y_test_pred
    
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    
    # Error distribution
    axes[0, 0].hist(errors, bins=50, alpha=0.7, edgecolor='black')
    axes[0, 0].axvline(errors.mean(), color='red', linestyle='--', 
                      label=f'Mean: {errors.mean():.4f}')
    axes[0, 0].set_xlabel('Prediction Error')
    axes[0, 0].set_ylabel('Frequency')
    axes[0, 0].set_title('Error Distribution')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    
    # Error vs actual values
    axes[0, 1].scatter(y_test, errors, alpha=0.6)
    axes[0, 1].axhline(y=0, color='red', linestyle='--')
    axes[0, 1].set_xlabel('Actual Q Factor')
    axes[0, 1].set_ylabel('Prediction Error')
    axes[0, 1].set_title('Error vs Actual Values')
    axes[0, 1].grid(True, alpha=0.3)
    
    # Absolute error vs actual
    abs_errors = np.abs(errors)
    axes[1, 0].scatter(y_test, abs_errors, alpha=0.6)
    axes[1, 0].set_xlabel('Actual Q Factor')
    axes[1, 0].set_ylabel('Absolute Error')
    axes[1, 0].set_title('Absolute Error vs Actual Values')
    axes[1, 0].grid(True, alpha=0.3)
    
    # Q-Q plot
    from scipy import stats
    stats.probplot(errors, dist="norm", plot=axes[1, 1])
    axes[1, 1].set_title('Q-Q Plot (Normal Distribution)')
    axes[1, 1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Error statistics
    print(f"Error Statistics:")
    print(f"  Mean Error: {errors.mean():.6f}")
    print(f"  Std Error: {errors.std():.6f}")
    print(f"  MAE: {abs_errors.mean():.6f}")
    print(f"  Max Absolute Error: {abs_errors.max():.6f}")
    print(f"  95th Percentile Error: {np.percentile(abs_errors, 95):.6f}")

## 7. Physics Validation

In [None]:
# Validate predictions against physics constraints
if 'y_test_pred' in locals():
    print("Physics Validation Results:")
    print("=" * 40)
    
    # Check for unphysical predictions
    negative_q = (y_test_pred < 0).sum()
    extreme_q = (y_test_pred > 50).sum()
    
    print(f"Negative Q factor predictions: {negative_q} ({negative_q/len(y_test_pred)*100:.2f}%)")
    print(f"Extremely high Q predictions (>50): {extreme_q} ({extreme_q/len(y_test_pred)*100:.2f}%)")
    
    # Breakeven and ignition prediction accuracy
    true_breakeven = (y_test >= 1.0)
    pred_breakeven = (y_test_pred >= 1.0)
    breakeven_accuracy = (true_breakeven == pred_breakeven).mean()
    
    true_ignition = (y_test >= 5.0)
    pred_ignition = (y_test_pred >= 5.0)
    ignition_accuracy = (true_ignition == pred_ignition).mean()
    
    print(f"\nBreakeven prediction accuracy: {breakeven_accuracy:.3f}")
    print(f"Ignition prediction accuracy: {ignition_accuracy:.3f}")
    
    # Physical realism score
    violations = negative_q + extreme_q
    realism_score = 1.0 - violations / len(y_test_pred)
    print(f"\nPhysical realism score: {realism_score:.3f}")

## 8. Model Saving

In [None]:
# Save trained models
import os

save_dir = '../saved_models'
os.makedirs(save_dir, exist_ok=True)

try:
    predictor.save_models(save_dir)
    print(f"Models saved to {save_dir}")
    
    # List saved files
    saved_files = os.listdir(save_dir)
    print("Saved model files:")
    for file in saved_files:
        print(f"  - {file}")
        
except Exception as e:
    print(f"Error saving models: {e}")

## 9. Summary

In [None]:
# Generate training summary
print("FUSION MODEL TRAINING SUMMARY")
print("=" * 50)

print(f"Dataset: {len(raw_data)} samples, {len(processed_data['feature_names'])} features")
print(f"Target: Q Factor (range: {raw_data['q_factor'].min():.3f} - {raw_data['q_factor'].max():.3f})")

if not comparison_df.empty:
    print(f"\nModels Trained: {len(comparison_df)}")
    print(f"Best Model: {best_model_name}")
    
    best_performance = comparison_df.iloc[0]
    print(f"Best Validation R²: {best_performance['Val R²']:.4f}")
    print(f"Best Validation RMSE: {best_performance['Val RMSE']:.4f}")

if 'y_test_pred' in locals():
    test_r2 = r2_score(y_test, y_test_pred)
    test_rmse = np.sqrt(mean_squared_error(y_test, y_test_pred))
    print(f"\nTest Performance:")
    print(f"  Test R²: {test_r2:.4f}")
    print(f"  Test RMSE: {test_rmse:.4f}")
    print(f"  Physical Realism Score: {realism_score:.3f}")

print("\nTraining completed successfully!")