# Quick script to search directories of models and plots (with metrics) to find best current model configs.

In [1]:
#!/usr/bin/env python
import os
import re
import json
import glob
import numpy as np
import pandas as pd
import pickle
# add validation loss.
from pprint import pprint

sim_set = "LH"
snap_str = "044"

model_dir = os.path.join(f"/disk/xray15/aem2/data/6pams/{sim_set}/IllustrisTNG/models/colours_lfs/{snap_str}")
plots_dir = os.path.join(f"/disk/xray15/aem2/plots/6pams/{sim_set}/IllustrisTNG/test/sbi_plots/colours_lfs/{snap_str}")


print(model_dir)
print(plots_dir)

# Set display options for full DataFrame visibility
pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)
pd.set_option('display.max_colwidth', None)
pd.set_option('display.width', 1000)


def find_best_model_with_metrics(model_dir, plots_dir):
    """
    Find the best model by examining all pickle files in the model directory,
    comparing their validation losses, and incorporating metrics data from plots_dir.
    
    Args:
        model_dir: Directory containing model files
        plots_dir: Directory containing metrics files
        
    Returns:
        DataFrame with model info sorted by validation loss
    """
    print(f"Searching for models in: {model_dir}")
    print(f"Looking for metrics in: {plots_dir}")
    
    # List to store model information
    model_info = []
    
    # Find all complete.pkl files in subdirectories
    for root, dirs, files in os.walk(model_dir):
        for file in files:
            if file.endswith('.pkl') and 'complete' in file and not file.endswith('_posterior.pkl'):
                file_path = os.path.join(root, file)
                model_name = os.path.basename(root)
                
                try:
                    # Load the pickle file
                    with open(file_path, 'rb') as f:
                        data = pickle.load(f)
                    
                    # Extract validation loss from the summaries
                    val_loss = None
                    if isinstance(data, dict) and 'summaries' in data:
                        summaries = data['summaries']
                        
                        # Check format of summaries - could be list or dict
                        if isinstance(summaries, list) and len(summaries) > 0:
                            # If it's a list, look at each item
                            for summary in summaries:
                                if 'validation_log_probs' in summary:
                                    # Get the maximum validation log probability (best performance)
                                    val_probs = summary['validation_log_probs']
                                    if val_probs:
                                        val_loss = max(val_probs)
                        elif isinstance(summaries, dict):
                            if 'validation_log_probs' in summaries:
                                val_probs = summaries['validation_log_probs']
                                if val_probs:
                                    val_loss = max(val_probs)
                    
                    # Parse model parameters from directory name
                    params = {}
                    parts = model_name.split('_')
                    for part in parts:
                        if 'batch' in part:
                            params['batch_size'] = part.replace('batch', '')
                        elif 'lr' in part:
                            params['learning_rate'] = part.replace('lr', '')
                        elif 'epochs' in part and 'max' not in part:
                            params['epochs'] = part.replace('epochs', '')
                        elif 'max_num_epochs' in part:
                            params['max_epochs'] = part.replace('max_num_epochs', '')
                        elif 'h' in part and len(part) < 5:  # Avoid matching things like "hidden"
                            params['hidden_size'] = part.replace('h', '')
                        elif 't' in part and len(part) < 5:  # Avoid matching words with 't'
                            params['transforms'] = part.replace('t', '')
                        elif 'nn' in part:
                            params['num_nets'] = part.replace('nn', '')
                    
                    # Add to our list
                    model_info.append({
                        'model_dir': root,
                        'model_file': file,
                        'model_name': model_name,
                        'val_loss': val_loss,
                        **params
                    })
                    
                except Exception as e:
                    print(f"Error processing {file_path}: {e}")
    
    # Convert to DataFrame
    if model_info:
        df_models = pd.DataFrame(model_info)
        
        # Sort by validation loss (highest is best for log probabilities)
        if 'val_loss' in df_models.columns:
            df_models = df_models.sort_values('val_loss', ascending=False)
        
        # Now find and parse metrics files in the plots directory
        metrics_data = []
        metrics_files = glob.glob(os.path.join(plots_dir, "**", "metrics_*.txt"), recursive=True)
        
        print(f"Found {len(metrics_files)} metrics files in plots directory.")
        
        for metrics_file in metrics_files:
            metrics_dict, config = parse_metrics_file(metrics_file)
            if metrics_dict:
                # Create a row for each parameter
                for i, param in enumerate(metrics_dict['parameters']):
                    metrics_data.append({
                        'model_name': config,
                        'parameter': param,
                        'R²': metrics_dict['R²'][i] if i < len(metrics_dict['R²']) else None,
                        'RMSE': metrics_dict['RMSE'][i] if i < len(metrics_dict['RMSE']) else None
                    })
        
        # Convert metrics data to DataFrame
        if metrics_data:
            df_metrics = pd.DataFrame(metrics_data)
            
            # Pivot the metrics DataFrame to have one row per model with columns for each parameter's metrics
            df_metrics_pivoted = pd.pivot_table(
                df_metrics, 
                index='model_name', 
                columns='parameter', 
                values=['R²', 'RMSE']
            )
            
            # Flatten the column multi-index
            df_metrics_pivoted.columns = [f"{metric}_{param}" for metric, param in df_metrics_pivoted.columns]
            df_metrics_pivoted = df_metrics_pivoted.reset_index()
            
            # Merge model info with metrics
            df_combined = pd.merge(
                df_models, 
                df_metrics_pivoted, 
                left_on='model_name', 
                right_on='model_name', 
                how='left'
            )
            
            return df_combined
        
        return df_models
    else:
        return pd.DataFrame()

def parse_metrics_file(file_path):
    """Parse a metrics file and extract R² and RMSE for each parameter."""
    # Extract configuration from file path
    config_match = re.search(r'/([^/]+)/metrics_', file_path)
    config = config_match.group(1) if config_match else os.path.basename(os.path.dirname(file_path))
    
    # Parse the metrics file
    metrics_dict = {
        'config': config,
        'parameters': [],
        'R²': [],
        'RMSE': []
    }
    
    try:
        with open(file_path, 'r') as f:
            lines = f.readlines()
            
        param_pattern = re.compile(r'Metrics for (.*):')
        r2_pattern = re.compile(r'R²: ([-]?\d+\.\d+)')
        rmse_pattern = re.compile(r'RMSE: (\d+\.\d+)')
        
        current_param = None
        
        for line in lines:
            param_match = param_pattern.search(line)
            if param_match:
                current_param = param_match.group(1)
                metrics_dict['parameters'].append(current_param)
                continue
                
            if current_param:
                r2_match = r2_pattern.search(line)
                if r2_match:
                    metrics_dict['R²'].append(float(r2_match.group(1)))
                    
                rmse_match = rmse_pattern.search(line)
                if rmse_match:
                    metrics_dict['RMSE'].append(float(rmse_match.group(1)))
        
        return metrics_dict, config
    
    except Exception as e:
        print(f"Error parsing {file_path}: {e}")
        return None, config

# Execute the function
df_combined = find_best_model_with_metrics(model_dir, plots_dir)

# Display results
if not df_combined.empty:
    pd.set_option('display.max_columns', None)  # Show all columns
    pd.set_option('display.max_rows', None)     # Show all rows
    print("\nModel Comparison (sorted by validation loss):")
    print(df_combined)
    
    # Find the best model
    if 'val_loss' in df_combined.columns and not df_combined['val_loss'].isna().all():
        best_model = df_combined.iloc[0]
        print(f"\nBest model found:")
        print(f"Directory: {best_model['model_dir']}")
        print(f"File: {best_model['model_file']}")
        print(f"Validation loss: {best_model['val_loss']}")
        
        # Get model parameters
        params_str = []
        for param in ['batch_size', 'learning_rate', 'hidden_size', 'transforms', 'num_nets', 'epochs', 'max_epochs']:
            if param in best_model and not pd.isna(best_model[param]):
                params_str.append(f"{param}={best_model[param]}")
        
        print(f"Parameters: {', '.join(params_str)}")
        
        # Print metrics for best model if available
        metrics_cols = [col for col in best_model.index if 'R²_' in col or 'RMSE_' in col]
        if metrics_cols:
            print("\nMetrics for best model:")
            for col in metrics_cols:
                if not pd.isna(best_model[col]):
                    print(f"{col}: {best_model[col]}")
    else:
        print("\nNo validation loss information found in any models.")
else:
    print("\nNo model files found.")

# If you want to save the results
results_path = os.path.join(plots_dir, "model_comparison_with_metrics.csv")
if not df_combined.empty:
    df_combined.to_csv(results_path, index=False)
    print(f"\nResults saved to: {results_path}")

/disk/xray15/aem2/data/6pams/LH/IllustrisTNG/models/colours_lfs/044
/disk/xray15/aem2/plots/6pams/LH/IllustrisTNG/test/sbi_plots/colours_lfs/044
Searching for models in: /disk/xray15/aem2/data/6pams/LH/IllustrisTNG/models/colours_lfs/044
Looking for metrics in: /disk/xray15/aem2/plots/6pams/LH/IllustrisTNG/test/sbi_plots/colours_lfs/044
Found 5 metrics files in plots directory.

Model Comparison (sorted by validation loss):
                                                                                                                                                          model_dir                                           model_file                                                                                    model_name  val_loss batch_size learning_rate epochs hidden_size transforms num_nets  RMSE_$A_{\mathrm{AGN1}}$  RMSE_$A_{\mathrm{AGN2}}$  RMSE_$A_{\mathrm{SN1}}$  RMSE_$A_{\mathrm{SN2}}$  RMSE_$\Omega_m$  RMSE_$\sigma_8$  R²_$A_{\mathrm{AGN1}}$  R²_$A_{\mathrm{AGN2}}$  R²