In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error
import json
from pathlib import Path

In [None]:
class ModelPostProcessor:
    def __init__(self, data_path='merged_steam_data.csv'):
        """Initialize the post processor with the data path."""
        self.data = pd.read_csv(data_path)
        self.models = {}
        self.predictions = {}
        
    def load_model_predictions(self, model_name, predictions, actual_values):
        """Load predictions from a model for comparison.
        
        Args:
            model_name (str): Name of the model (e.g., 'ridge', 'knn', 'ann')
            predictions (array-like): Model predictions
            actual_values (array-like): Actual target values
        """
        self.predictions[model_name] = {
            'predictions': predictions,
            'actual': actual_values
        }
        
    def calculate_metrics(self, model_name):
        """Calculate performance metrics for a specific model.
        
        Args:
            model_name (str): Name of the model to calculate metrics for
            
        Returns:
            dict: Dictionary containing various performance metrics
        """
        if model_name not in self.predictions:
            raise ValueError(f"No predictions found for model: {model_name}")
            
        pred = self.predictions[model_name]['predictions']
        actual = self.predictions[model_name]['actual']
        
        metrics = {
            'mse': mean_squared_error(actual, pred),
            'rmse': np.sqrt(mean_squared_error(actual, pred)),
            'mae': mean_absolute_error(actual, pred),
            'r2': r2_score(actual, pred),
            'mean_absolute_percentage_error': np.mean(np.abs((actual - pred) / actual)) * 100
        }
        
        return metrics
    
    def plot_prediction_vs_actual(self, model_name, save_path=None):
        """Create a scatter plot of predictions vs actual values.
        
        Args:
            model_name (str): Name of the model to plot
            save_path (str, optional): Path to save the plot
        """
        if model_name not in self.predictions:
            raise ValueError(f"No predictions found for model: {model_name}")
            
        pred = self.predictions[model_name]['predictions']
        actual = self.predictions[model_name]['actual']
        
        plt.figure(figsize=(10, 6))
        plt.scatter(actual, pred, alpha=0.5)
        plt.plot([min(actual), max(actual)], [min(actual), max(actual)], 'r--')
        plt.xlabel('Actual Values')
        plt.ylabel('Predicted Values')
        plt.title(f'{model_name.upper()} Predictions vs Actual Values')
        
        if save_path:
            plt.savefig(save_path)
        plt.close()
        
    def plot_error_distribution(self, model_name, save_path=None):
        """Plot the distribution of prediction errors.
        
        Args:
            model_name (str): Name of the model to plot
            save_path (str, optional): Path to save the plot
        """
        if model_name not in self.predictions:
            raise ValueError(f"No predictions found for model: {model_name}")
            
        pred = self.predictions[model_name]['predictions']
        actual = self.predictions[model_name]['actual']
        errors = actual - pred
        
        plt.figure(figsize=(10, 6))
        sns.histplot(errors, kde=True)
        plt.xlabel('Prediction Error')
        plt.ylabel('Frequency')
        plt.title(f'{model_name.upper()} Error Distribution')
        
        if save_path:
            plt.savefig(save_path)
        plt.close()
        
    def compare_models(self, save_path=None):
        """Compare performance metrics across all loaded models.
        
        Args:
            save_path (str, optional): Path to save the comparison plot
        """
        if not self.predictions:
            raise ValueError("No model predictions loaded")
            
        metrics_df = pd.DataFrame()
        for model_name in self.predictions:
            metrics = self.calculate_metrics(model_name)
            metrics_df[model_name] = pd.Series(metrics)

        plt.figure(figsize=(12, 6))
        metrics_df.T.plot(kind='bar')
        plt.title('Model Performance Comparison')
        plt.xlabel('Models')
        plt.ylabel('Metric Value')
        plt.xticks(rotation=45)
        plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path)
        plt.close()
        
        return metrics_df
    
    def generate_summary_report(self, output_path='model_summary.json'):
        """Generate a comprehensive summary report of all models.
        
        Args:
            output_path (str): Path to save the summary report
        """
        summary = {}
        for model_name in self.predictions:
            metrics = self.calculate_metrics(model_name)
            summary[model_name] = {
                'metrics': metrics,
                'sample_size': len(self.predictions[model_name]['actual'])
            }
            
        with open(output_path, 'w') as f:
            json.dump(summary, f, indent=4)
            
        return summary


if __name__ == "__main__":
    
    post_processor = ModelPostProcessor()