In [12]:
import os
import re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from glob import glob
import json
from ast import literal_eval


def extract_metadata_from_filename(filepath):
    """Extract metadata directly from filename when data extraction fails."""
    try:
        base_filename = os.path.basename(filepath)
        
        # Handle _loadfactor_layerwise pattern
        match = re.match(r'(\w+)_d(\d+)_h(\d+)_depth(\d+)_lr([\d\.]+)_b(\d+)_mode(\w+)_dist(\w+)_exp(\d+)(?:_loadfactor_layerwise)?', base_filename)
        
        if match:
            func_name, dim, hidden, depth, lr, batch, mode, dist, exp = match.groups()
            return {
                'function_name': func_name,
                'input_dim': int(dim),
                'hidden_size': int(hidden),
                'depth': int(depth),
                'learning_rate': float(lr),
                'batch_size': int(batch),
                'mode': mode,
                'input_distribution': dist,
                'experiment_num': int(exp)
            }
        return None
    except Exception as e:
        print(f"Error extracting metadata from filename {filepath}: {e}")
        return None


def extract_epoch_from_filename(filepath):
    """Extract epoch from the load factor filename, with improved pattern matching for layerwise variants."""
    try:
        base_filename = os.path.basename(filepath)
        
        # Check for "initial" marker - various patterns
        if "_initial.npz" in base_filename:
            return 0
            
        # Check for "final" marker - various patterns
        if "_final.npz" in base_filename:
            try:
                with np.load(filepath, allow_pickle=True) as data:
                    if 'iteration' in data:
                        return int(data['iteration'])
            except:
                pass
            # Default for final iterations if not found in file
            return 1000000
            
        # Try more specific patterns first, then fall back to more general ones
        
        # Pattern for layerwise termloadfactor files
        layerwise_term_match = re.search(r'_layerwise_termloadfactor_iter(\d+)', base_filename)
        if layerwise_term_match:
            return int(layerwise_term_match.group(1))
            
        # Pattern for layerwise loadfactor files
        layerwise_match = re.search(r'_layerwise_loadfactor_iter(\d+)', base_filename)
        if layerwise_match:
            return int(layerwise_match.group(1))
            
        # Pattern for regular loadfactor files
        loadfactor_match = re.search(r'_loadfactor_iter(\d+)', base_filename)
        if loadfactor_match:
            return int(loadfactor_match.group(1))
            
        # General pattern matching any iter number
        general_match = re.search(r'_iter(\d+)\.npz$', base_filename)
        if general_match:
            return int(general_match.group(1))
            
        # If we get here, we couldn't determine the epoch
        print(f"Could not determine epoch for {base_filename}, skipping")
        return None
        
    except Exception as e:
        print(f"Error extracting epoch from {filepath}: {e}")
        return None


def load_load_factors(file_path):
    """Load load factors from file with proper error handling."""
    try:
        with np.load(file_path, allow_pickle=True) as data:
            # Debug: print all keys in the file
            print(f"Keys in {os.path.basename(file_path)}: {list(data.keys())}")
            
            if 'load_factors' in data:
                load_factors = data['load_factors']
                
                # Check if load_factors is a dictionary
                if isinstance(load_factors, np.ndarray) and load_factors.dtype == np.dtype('O'):
                    # This is likely a dictionary stored as an object array
                    load_factors = load_factors.item()
                
                if isinstance(load_factors, dict):
                    return load_factors
                else:
                    print(f"load_factors in {file_path} is not a dictionary: {type(load_factors)}")
            else:
                print(f"No load_factors in {file_path}")
                
        return None
    except Exception as e:
        print(f"Error loading {file_path}: {e}")
        return None


def load_training_data(folder_path, model_base):
    """
    Load training data from either main results file or the loadfactor file.
    Tries multiple filename patterns to handle different naming conventions.
    Returns tuple of (epochs, losses) or None if not found.
    """
    # Try several possible filenames
    possible_filenames = [
        f"{model_base}.npz",
        f"{model_base}_loadfactor.npz",
        model_base.replace("_loadfactor_layerwise", "") + ".npz",
        model_base.replace("_loadfactor", "") + ".npz",
        model_base.split("_loadfactor")[0] + ".npz" if "_loadfactor" in model_base else None,
    ]
    
    # Filter out None values
    possible_filenames = [f for f in possible_filenames if f]
    
    # Try all possible files
    for filename in possible_filenames:
        results_file = os.path.join(folder_path, filename)
        
        if os.path.exists(results_file):
            try:
                print(f"Attempting to load training data from: {results_file}")
                with np.load(results_file, allow_pickle=True) as data:
                    if 'train_stats' in data:
                        train_stats = data['train_stats']
                        # Extract epochs and losses
                        if len(train_stats) > 0:
                            epochs = train_stats[:, 0]
                            losses = train_stats[:, 1]
                            print(f"Successfully loaded training data from {results_file}: {len(epochs)} epochs")
                            return (epochs, losses)
                        else:
                            print(f"Training data array is empty in {results_file}")
                    else:
                        print(f"No 'train_stats' key found in {results_file}")
            except Exception as e:
                print(f"Error loading training data from {results_file}: {e}")
    
    # If we get here, we've tried all possible files and found nothing
    print(f"Could not find training data for {model_base} after trying multiple filename patterns")
    
    # Last resort: try to find ALL .npz files and look for matches in their names
    npz_files = glob(os.path.join(folder_path, "*.npz"))
    for file in npz_files:
        base_name = os.path.basename(file)
        # Look for files that share the most significant part of the name
        if model_base.split("_loadfactor")[0] in base_name and "_termloadfactor_" not in base_name:
            try:
                print(f"Trying alternative file: {file}")
                with np.load(file, allow_pickle=True) as data:
                    if 'train_stats' in data:
                        train_stats = data['train_stats']
                        # Extract epochs and losses
                        if len(train_stats) > 0:
                            epochs = train_stats[:, 0]
                            losses = train_stats[:, 1]
                            print(f"Successfully loaded training data from {file}: {len(epochs)} epochs")
                            return (epochs, losses)
            except Exception as e:
                print(f"Error loading training data from alternative file {file}: {e}")
    
    return None


def analyze_layerwise_termloadfactor_file(filepath):
    """
    Extract data specifically from layerwise_termloadfactor files.
    Returns a list of dictionaries in the same format as analyze_load_factors.
    """
    try:
        with np.load(filepath, allow_pickle=True) as data:
            # Debug output
            print(f"Keys in {os.path.basename(filepath)}: {list(data.keys())}")
            
            # Check for required keys
            if not all(key in data for key in ['load_factors', 'term_descriptions']):
                print(f"Missing required keys in {filepath}")
                return []
                
            # Extract epoch from filename
            epoch = extract_epoch_from_filename(filepath)
            if epoch is None:
                # Try to get it from metadata
                if 'metadata' in data and len(data['metadata']) > 0:
                    try:
                        metadata_str = data['metadata'][0]
                        metadata_dict = eval(metadata_str)
                        epoch = metadata_dict.get('iteration', None)
                        print(f"Extracted epoch {epoch} from metadata")
                    except:
                        print(f"Failed to parse metadata in {filepath}")
                        
            if epoch is None:
                print(f"Could not determine epoch for {filepath}, skipping")
                return []
                
            # Extract base model info from filename
            model_base = os.path.basename(filepath)
            # Remove the termloadfactor part
            model_base = re.sub(r'_termloadfactor_(?:initial|iter\d+|final)\.npz$', '', model_base)
            
            # Extract metadata from filename
            metadata = extract_metadata_from_filename(filepath)
            if not metadata:
                print(f"Could not extract metadata from {filepath}, skipping")
                return []
                
            depth = metadata.get('depth', 0)
            
            # Get load factors
            load_factors = data['load_factors'].item() if data['load_factors'].dtype == np.dtype('O') else data['load_factors']
            term_descriptions = data['term_descriptions']
            
            # Process load factors
            all_data = []
            for term_key, layer_factors in load_factors.items():
                # Create layer names dynamically based on actual number of layers
                num_layers = len(layer_factors)
                layers = [f"Layer {i+1}" for i in range(num_layers)]
                
                for layer_idx, factor in enumerate(layer_factors):
                    if layer_idx < len(layers):  # Ensure we don't go out of bounds
                        all_data.append({
                            'Epoch': epoch,
                            'Layer': layers[layer_idx],
                            'Term': term_key,
                            'Load Factor': factor,
                            'Model': model_base,
                            'Depth': depth
                        })
            
            return all_data
            
    except Exception as e:
        print(f"Error processing layerwise termloadfactor file {filepath}: {e}")
        return []


def analyze_load_factors(folder_path):
    """Analyze load factor files in a given folder with improved error handling for layerwise variants."""
    # Find termloadfactor files specifically
    termloadfactor_files = glob(os.path.join(folder_path, "*termloadfactor*.npz"))
    print(f"Found {len(termloadfactor_files)} termloadfactor files")
    
    # Process all termloadfactor files directly
    all_data = []
    for file_path in termloadfactor_files:
        file_data = analyze_layerwise_termloadfactor_file(file_path)
        all_data.extend(file_data)
    
    # Convert to dataframe
    df = pd.DataFrame(all_data)
    
    if df.empty:
        print("No data was successfully processed from termloadfactor files. Trying standard load factor files...")
        
        # Find all standard load factor files that match the pattern
        load_factor_files = glob(os.path.join(folder_path, "*loadfactor*iter*.npz"))
        load_factor_files += glob(os.path.join(folder_path, "*loadfactor*initial.npz"))
        load_factor_files += glob(os.path.join(folder_path, "*loadfactor*final.npz"))
        load_factor_files = [f for f in load_factor_files if "termloadfactor" not in f]  # Exclude termloadfactor files
        print(f"Found {len(load_factor_files)} standard load factor files")
        
        # Group files by model configuration
        model_files = {}
        for file_path in load_factor_files:
            base_filename = os.path.basename(file_path)
            
            # Extract the base model name without the epoch/iteration info
            # Handle the double loadfactor case
            model_base = re.sub(r'_loadfactor_(?:layerwise_)?loadfactor_(?:initial|iter\d+|final)\.npz$', '', base_filename)
            model_base = re.sub(r'_loadfactor_(?:initial|iter\d+|final)\.npz$', '', model_base)
            
            if model_base not in model_files:
                model_files[model_base] = []
            
            # Extract the epoch
            epoch = extract_epoch_from_filename(file_path)
            if epoch is not None:
                model_files[model_base].append((epoch, file_path))
        
        # Sort files by epoch for each model
        for model_base in model_files:
            model_files[model_base].sort(key=lambda x: x[0])
        
        # Process each model's load factor data
        standard_data = []
        
        for model_base, file_list in model_files.items():
            if not file_list:
                print(f"No valid files for {model_base}, skipping")
                continue
                
            metadata = None
            
            # Try to get metadata from filename
            metadata = extract_metadata_from_filename(file_list[0][1])
            
            if not metadata:
                print(f"Could not extract metadata for {model_base}, skipping")
                continue
            
            depth = metadata.get('depth', 0)
            
            print(f"Processing model: {model_base}, depth: {depth}")
            
            for epoch, file_path in file_list:
                # Load load factors using the improved function
                load_factors = load_load_factors(file_path)
                
                if not load_factors:
                    print(f"Failed to load factors from {file_path}, skipping")
                    continue
                    
                # Get the number of layers from the first term
                first_term = next(iter(load_factors))
                factors_for_first_term = load_factors[first_term]
                
                if isinstance(factors_for_first_term, list):
                    num_layers = len(factors_for_first_term)
                    
                    # Create layer names dynamically based on actual number of layers
                    layers = [f"Layer {i+1}" for i in range(num_layers)]
                    
                    # Process each term's load factors
                    for term in load_factors:
                        factors = load_factors[term]
                        for layer_idx, factor in enumerate(factors):
                            if layer_idx < len(layers):  # Ensure we don't go out of bounds
                                standard_data.append({
                                    'Epoch': epoch,
                                    'Layer': layers[layer_idx],
                                    'Term': term,
                                    'Load Factor': factor,
                                    'Model': model_base,
                                    'Depth': depth
                                })
                else:
                    print(f"Unexpected format for factors: {factors_for_first_term}")
        
        # Combine with termloadfactor data
        if standard_data:
            df = pd.DataFrame(standard_data)
    
    if df.empty:
        print("No data was successfully processed. Dataframe is empty.")
        return df
    
    # Now create plots
    print(f"Successfully processed data for {len(df['Model'].unique())} models")
    plot_load_factors_by_depth(df, folder_path)
    
    return df


def plot_load_factors_by_depth(df, output_folder):
    """
    Create separate plots for each depth with training error included.
    
    Args:
        df: DataFrame with load factor data
        output_folder: Folder to save the plots
    """
    if df.empty:
        print("DataFrame is empty, cannot create plots.")
        return
    
    # Get unique depths
    depths = df['Depth'].unique()
    
    for depth in depths:
        depth_df = df[df['Depth'] == depth]
        
        # Get unique models for this depth
        models = depth_df['Model'].unique()
        
        for model in models:
            model_df = depth_df[depth_df['Model'] == model]
            
            # Get unique terms
            terms = model_df['Term'].unique()
            
            # Get unique layers
            layers = model_df['Layer'].unique()
            num_layers = len(layers)
            
            # Create figure with num_layers + 1 rows (layers + training error)
            fig, axes = plt.subplots(num_layers + 1, 1, figsize=(12, 3 * (num_layers + 1)), sharex=True)
            
            # Handle the case where there's only one subplot
            if num_layers == 0:
                print(f"No layers found for model {model}")
                continue
                
            # Make axes a list if it's a single subplot
            if num_layers == 1:
                axes = [axes]
            elif not isinstance(axes, list):
                axes = axes.flatten()
            
            # Plot load factors for each layer
            for i, layer in enumerate(sorted(layers, key=lambda x: int(x.split()[1]))):
                layer_df = model_df[model_df['Layer'] == layer]
                
                for j, term in enumerate(sorted(terms)):
                    term_df = layer_df[layer_df['Term'] == term]
                    if not term_df.empty:
                        # Use a different color for each term and add transparency
                        color = plt.cm.tab10(j % 10)
                        axes[i].plot(
                            term_df['Epoch'], 
                            term_df['Load Factor'], 
                            'o-', 
                            label=term,
                            color=color,
                            alpha=0.7,  # Add transparency
                            linewidth=2,
                            markersize=4
                        )
                
                axes[i].set_ylabel('Load Factor (RMSE)')
                axes[i].set_title(f"{layer}")
                axes[i].grid(True, alpha=0.3)
                axes[i].legend(loc='center left', bbox_to_anchor=(1, 0.5))
                axes[i].set_xscale('log')  # Use log scale for x-axis
            
            # Load and plot training error in the last subplot
            training_data = load_training_data(output_folder, model)
            
            if training_data:
                epochs, losses = training_data
                axes[-1].plot(epochs, losses, 'r-', label='Training Loss', linewidth=2)
                axes[-1].set_ylabel('Training Loss (MSE)')
                axes[-1].set_yscale('log')  # Log scale for training loss
                axes[-1].grid(True, alpha=0.3)
                axes[-1].legend(loc='center left', bbox_to_anchor=(1, 0.5))
            else:
                axes[-1].text(0.5, 0.5, 'No training data available', 
                             horizontalalignment='center', verticalalignment='center',
                             transform=axes[-1].transAxes)
            
            # Set common x-label
            axes[-1].set_xlabel('Epoch')
            axes[-1].set_xscale('log')  # Use log scale for x-axis
            
            # Add title
            fig_title = f"Load Factor Evolution by Layer and Term Combination\nModel: {model}, Depth: {depth}"
            plt.suptitle(fig_title, fontsize=16)
            plt.tight_layout(rect=[0, 0, 1, 0.96])  # Adjust for title
            
            # Save figure
            output_path = os.path.join(output_folder, f"load_factors_depth{depth}_{model}.png")
            plt.savefig(output_path, dpi=300, bbox_inches='tight')
            plt.close()
            
            print(f"Saved plot to {output_path}")


def plot_separate_terms(df, output_folder):
    """
    Create plots with each term in a separate panel for better visibility.
    
    Args:
        df: DataFrame with load factor data
        output_folder: Folder to save the plots
    """
    if df.empty:
        print("DataFrame is empty, cannot create plots.")
        return
    
    # Get unique depths
    depths = df['Depth'].unique()
    
    for depth in depths:
        depth_df = df[df['Depth'] == depth]
        
        # Get unique models for this depth
        models = depth_df['Model'].unique()
        
        for model in models:
            model_df = depth_df[depth_df['Model'] == model]
            
            # Get unique terms and layers
            terms = sorted(model_df['Term'].unique())
            layers = sorted(model_df['Layer'].unique(), key=lambda x: int(x.split()[1]))
            
            # Create a figure with terms as rows and layers as columns
            fig, axes = plt.subplots(len(terms), 1, figsize=(12, 4 * len(terms)), sharex=True)
            
            # Make axes a list if it's a single subplot
            if len(terms) == 1:
                axes = [axes]
            elif not isinstance(axes, list):
                axes = axes.flatten()
            
            # Load training data
            training_data = load_training_data(output_folder, model)
            
            # Plot each term in a separate row
            for i, term in enumerate(terms):
                term_df = model_df[model_df['Term'] == term]
                
                # Plot all layers for this term on the same subplot
                for j, layer in enumerate(layers):
                    layer_data = term_df[term_df['Layer'] == layer]
                    if not layer_data.empty:
                        # Use a different color for each layer
                        color = plt.cm.viridis(j / max(1, len(layers) - 1))
                        axes[i].plot(
                            layer_data['Epoch'],
                            layer_data['Load Factor'],
                            'o-', 
                            label=layer,
                            color=color,
                            linewidth=2,
                            markersize=4
                        )
                
                # Add training loss to each term subplot as a dotted line
                if training_data:
                    epochs, losses = training_data
                    # Use a twin axis for training loss
                    ax2 = axes[i].twinx()
                    ax2.plot(epochs, losses, 'r--', label='Training Loss', alpha=0.5, linewidth=1.5)
                    ax2.set_ylabel('Training Loss (MSE)', color='r')
                    ax2.set_yscale('log')
                    ax2.tick_params(axis='y', labelcolor='r')
                    
                axes[i].set_ylabel('Load Factor (RMSE)')
                axes[i].set_title(f"Term: {term}")
                axes[i].grid(True, alpha=0.3)
                axes[i].legend(loc='upper left')
                axes[i].set_xscale('log')
            
            # Set common x-label on the bottom subplot
            axes[-1].set_xlabel('Epoch')
            
            # Add title
            fig_title = f"Load Factor Evolution by Term\nModel: {model}, Depth: {depth}"
            plt.suptitle(fig_title, fontsize=16)
            plt.tight_layout(rect=[0, 0, 1, 0.96])  # Adjust for title
            
            # Save figure
            output_path = os.path.join(output_folder, f"load_factors_by_term_depth{depth}_{model}.png")
            plt.savefig(output_path, dpi=300, bbox_inches='tight')
            plt.close()
            
            print(f"Saved plot to {output_path}")


# Example usage
if __name__ == "__main__":
    # Modify this path to your results folder
    folder_path = "/home/goring/OnlineSGD/results_MSP/Phi_1304_mon_810_leap_load_layer2/complex_leap_exp_20250414_003722"
    
    # Analyze and create standard plots
    df = analyze_load_factors(folder_path)
    
    # Create additional plots with terms separated for better visibility
    if not df.empty:
        print(f"Successfully processed {len(df)} data points")
        plot_separate_terms(df, folder_path)
        
        # Print statistics
        print("\nLoad Factor Statistics:")
        print(df.groupby(['Depth', 'Layer', 'Term'])['Load Factor'].describe())
    else:
        print("No data was successfully processed.")

Found 19 termloadfactor files
Keys in 10_leap_d40_h4096_depth3_lr0.05_b512_modemup_distbinary_exp1_loadfactor_layerwise_termloadfactor_initial.npz: ['load_factors', 'term_descriptions', 'metadata']
Keys in 10_leap_d40_h4096_depth3_lr0.05_b512_modemup_distbinary_exp1_loadfactor_layerwise_termloadfactor_iter150000.npz: ['load_factors', 'term_descriptions', 'metadata']
Keys in 10_leap_d40_h4096_depth3_lr0.05_b512_modemup_distbinary_exp1_loadfactor_layerwise_termloadfactor_iter125000.npz: ['load_factors', 'term_descriptions', 'metadata']
Keys in 10_leap_d40_h4096_depth3_lr0.05_b512_modemup_distbinary_exp1_loadfactor_layerwise_termloadfactor_iter175000.npz: ['load_factors', 'term_descriptions', 'metadata']
Keys in 10_leap_d40_h4096_depth3_lr0.05_b512_modemup_distbinary_exp1_loadfactor_layerwise_termloadfactor_iter100000.npz: ['load_factors', 'term_descriptions', 'metadata']
Keys in 10_leap_d40_h4096_depth3_lr0.05_b512_modemup_distbinary_exp1_loadfactor_layerwise_termloadfactor_iter200000.np