## final comparisons

In [4]:
%pip install seaborn
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import os
import re
from collections import defaultdict
import pickle

def parse_filename(filename):
    """Parse filename to extract parameters."""
    params = {}
    
    # Check if it's a batch approach file
    if 'W_final_' in filename:
        params['approach'] = 'batch'
        # Extract parameters using regex
        patterns = {
            'shell': r'shell(\d+)',
            'nbatch': r'nbatch(\d+)',
            'tstart': r'tstart(\d+)',
            'tend': r'tend(\d+)',
            'allint': r'allint(True|False)',
            'subtractD': r'subtractD(True|False)',
            'knorm': r'knorm(True|False)',
            'lassoCV': r'lassoCVTrue|lassoCV(True|False)',
            'lambda': r'lambda([\d.]+)',
            'randombatch': r'randombatch(True|False)',
            'sigthresh': r'sigthresh([\d.]+)',
            'minocc': r'minocc(\d+)'
        }
    else:
        params['approach'] = 'single'
        # Extract parameters for single approach
        patterns = {
            'shell': r'shell(\d+)',
            'nn': r'nn(\d+)',
            'tstart': r'tstart(\d+)',
            'lassoCV': r'lassoCVTrue|lassoCV(True|False)',
            'subtractD': r'subtractD(True|False)',
            'knorm': r'knorm(True|False)'
        }
    
    for param, pattern in patterns.items():
        match = re.search(pattern, filename)
        if match:
            if param in ['shell', 'nbatch', 'tstart', 'tend', 'nn', 'minocc']:
                params[param] = int(match.group(1))
            elif param in ['lambda', 'sigthresh']:
                params[param] = float(match.group(1))
            elif param == 'lassoCV':
                params[param] = 'True' in match.group(0)
            else:
                params[param] = match.group(1) == 'True'
    
    return params

def load_coefficient_data(filepath):
    """Load coefficient data from file."""
    try:
        with open(filepath, 'rb') as f:
            data = pickle.load(f)
        return data
    except:
        # Try numpy load if pickle fails
        try:
            data = np.load(filepath, allow_pickle=True)
            return data
        except:
            print(f"Could not load {filepath}")
            return None

def compare_sabra_approaches(data_directory, nn=20, max_shells=None, 
                           filter_params=None, save_plots=True):
    """
    Compare SABRA regression approaches and visualize results.
    
    Args:
        data_directory: Path to directory containing W files
        nn: Total number of shells
        max_shells: Maximum number of shells to analyze (None for all)
        filter_params: Dict of parameter filters (e.g., {'subtractD': False})
        save_plots: Whether to save plots to files
    """
    
    # Get all W files
    files = [f for f in os.listdir(data_directory) if f.startswith('W_')]
    
    # Parse files and group by parameters
    results = defaultdict(list)
    
    for filename in files:
        params = parse_filename(filename)
        
        # Apply filters if specified
        if filter_params:
            skip = False
            for key, value in filter_params.items():
                if key in params and params[key] != value:
                    skip = True
                    break
            if skip:
                continue
        
        # Load coefficient data
        filepath = os.path.join(data_directory, filename)
        W_data = load_coefficient_data(filepath)
        
        if W_data is None:
            continue
        
        # Handle both single W arrays and lists of W arrays
        if W_data.shape==(50,1,1600):
            # For single approach, use the last W (highest sample size)
            W = W_data[-1] if params['approach'] == 'single' else W_data[0]
            print(f"Using last W for single approach: {W.shape}")
        else:
            W = W_data
        
        # Check the structure of your data
        print(f"Type of W: {type(W)}")
        print(f"Shape: {W.shape if hasattr(W, 'shape') else 'No shape'}")
        print(f"First few elements: {W[:5] if hasattr(W, '__getitem__') else W}")
        # Analyze coefficients
        shell_idx = params['shell'] - 1  # Convert to 0-based indexing
        
        if max_shells and shell_idx >= max_shells:
            continue
            
        analysis = analyze_coefficients(W, nn, shell_idx, 
                                      params.get('subtractD', False),
                                      params.get('allint', True))
        
        # Store results
        result_key = tuple(sorted([(k, v) for k, v in params.items() 
                                 if k not in ['shell', 'tstart', 'tend']]))
        
        results[result_key].append({
            'shell': params['shell'],
            'approach': params['approach'],
            'analysis': analysis,
            'params': params
        })
    
    # Create visualizations for each parameter combination
    for param_combo, shell_results in results.items():
        create_comparison_plots(shell_results, param_combo, nn, save_plots)

    create_grouped_figures(results, nn, save_plots)

def analyze_coefficients(W, nn, shell_idx, subtractD, allint=True):
    """Analyze coefficients to separate expected vs unexpected interactions."""
    
    # Handle different W formats
    if isinstance(W, (list, tuple)):
        # If W is a list/tuple, take the first element or flatten
        if len(W) > 0:
            W_array = np.array(W[0]) if hasattr(W[0], '__len__') else np.array(W)
        else:
            return {'expected_sum': 0, 'unexpected_sum': 0, 'expected_dissipation_sum': 0, 'unexpected_dissipation_sum': 0,
                   'expected_count': 0, 'unexpected_count': 0, 'expected_dissipation_count': 0, 'unexpected_dissipation_count': 0,
                   'total_sum': 0, 'expected_percentage': 0, 'unexpected_percentage': 0,
                   'expected_dissipation_percentage': 0, 'unexpected_dissipation_percentage': 0}
    else:
        W_array = np.array(W)
    
    # Flatten if multidimensional
    if W_array.ndim > 1:
        W_array = W_array.flatten()
    
    # Get expected interactions for this shell
    expected_interactions = get_expected_interactions_single_shell(nn, shell_idx)
    
    # Build feature mapping
    include_dissipation = not subtractD
    feature_mapping = build_feature_mapping(nn, use_all_interactions=allint, 
                                          include_dissipation=include_dissipation)
    
    expected_coeff_sum = 0
    unexpected_coeff_sum = 0
    expected_count = 0
    unexpected_count = 0
    
    expected_dissipation_sum = 0
    unexpected_dissipation_sum = 0
    expected_dissipation_count = 0
    unexpected_dissipation_count = 0
    
    # Analyze each coefficient
    for idx, coeff in enumerate(W_array):
        # Handle scalar vs array coefficients
        if hasattr(coeff, '__len__') and len(coeff) > 1:
            # If coeff is an array, take its magnitude
            coeff_val = np.linalg.norm(coeff)
        else:
            coeff_val = float(coeff)
            
        if abs(coeff_val) < 1e-10:  # Skip essentially zero coefficients
            continue
            
        if idx in feature_mapping:
            interaction = feature_mapping[idx]
            
            if len(interaction) == 2 and interaction[1] == 'dissipation':
                # Dissipation term - check if it's from the current shell
                dissipation_shell = interaction[0]
                if dissipation_shell == shell_idx:
                    expected_dissipation_sum += abs(coeff_val)
                    expected_dissipation_count += 1
                else:
                    unexpected_dissipation_sum += abs(coeff_val)
                    unexpected_dissipation_count += 1
            else:
                # Interaction term
                i, j, int_type = interaction
                
                if (i, j, int_type) in expected_interactions:
                    expected_coeff_sum += abs(coeff_val)
                    expected_count += 1
                else:
                    unexpected_coeff_sum += abs(coeff_val)
                    unexpected_count += 1
    
    total_sum = expected_coeff_sum + unexpected_coeff_sum + expected_dissipation_sum + unexpected_dissipation_sum
    
    return {
        'expected_sum': expected_coeff_sum,
        'unexpected_sum': unexpected_coeff_sum,
        'expected_dissipation_sum': expected_dissipation_sum,
        'unexpected_dissipation_sum': unexpected_dissipation_sum,
        'expected_count': expected_count,
        'unexpected_count': unexpected_count,
        'expected_dissipation_count': expected_dissipation_count,
        'unexpected_dissipation_count': unexpected_dissipation_count,
        'total_sum': total_sum,
        'expected_percentage': (expected_coeff_sum / total_sum * 100) if total_sum > 0 else 0,
        'unexpected_percentage': (unexpected_coeff_sum / total_sum * 100) if total_sum > 0 else 0,
        'expected_dissipation_percentage': (expected_dissipation_sum / total_sum * 100) if total_sum > 0 else 0,
        'unexpected_dissipation_percentage': (unexpected_dissipation_sum / total_sum * 100) if total_sum > 0 else 0
    }

def plot_interaction_counts(ax, single_results, batch_results, subtractD=False):
    """Plot number of interactions by type with improved visualization and high contrast colors."""
    shells_single = [r['shell'] for r in single_results]
    shells_batch = [r['shell'] for r in batch_results]
    
    # Get all shells for x-axis
    all_shells = sorted(set(shells_single + shells_batch))
    
    expected_single = [r['analysis']['expected_count'] for r in single_results]
    unexpected_single = [r['analysis']['unexpected_count'] for r in single_results]
    exp_dissipation_single = [r['analysis']['expected_dissipation_count'] for r in single_results]
    unexp_dissipation_single = [r['analysis']['unexpected_dissipation_count'] for r in single_results]
    
    expected_batch = [r['analysis']['expected_count'] for r in batch_results]
    unexpected_batch = [r['analysis']['unexpected_count'] for r in batch_results]
    exp_dissipation_batch = [r['analysis']['expected_dissipation_count'] for r in batch_results]
    unexp_dissipation_batch = [r['analysis']['unexpected_dissipation_count'] for r in batch_results]
    
    width = 0.35
    
    x_single = np.array(shells_single) - width/2
    x_batch = np.array(shells_batch) + width/2
    
    # High contrast colors
    colors = {
        'expected_single': "#007191",      # Dark Green
        'exp_dissip_single': '#32CD32',    # Lime Green
        'unexpected_single': '#8B0000',    # Dark Red
        'unexp_dissip_single': '#FF4500',  # Orange Red
        'expected_batch': '#000080',       # Navy Blue
        'exp_dissip_batch': "#E1E141",     # Royal Blue
        'unexpected_batch': '#800080',     # Purple
        'unexp_dissip_batch': "#FF7B00"    # Orchid
    }
    
    if single_results:
        ax.bar(x_single, expected_single, width/2, label='Expected (Single)', 
               color=colors['expected_single'], alpha=0.9)
        ax.bar(x_single + width/2, exp_dissipation_single, width/2, 
               label='Expected Dissipation (Single)', color=colors['exp_dissip_single'], alpha=0.9)
        
        bottom_single_1 = np.array(expected_single)
        bottom_single_2 = np.array(exp_dissipation_single)
        ax.bar(x_single, unexpected_single, width/2, bottom=bottom_single_1,
               label='Unexpected (Single)', color=colors['unexpected_single'], alpha=0.9)
        ax.bar(x_single + width/2, unexp_dissipation_single, width/2, bottom=bottom_single_2,
               label='Unexpected Dissipation (Single)', color=colors['unexp_dissip_single'], alpha=0.9)
    
    if batch_results:
        ax.bar(x_batch, expected_batch, width/2, label='Expected (Batch)', 
               color=colors['expected_batch'], alpha=0.9)
        ax.bar(x_batch + width/2, exp_dissipation_batch, width/2,
               label='Expected Dissipation (Batch)', color=colors['exp_dissip_batch'], alpha=0.9)
        
        bottom_batch_1 = np.array(expected_batch)
        bottom_batch_2 = np.array(exp_dissipation_batch)
        ax.bar(x_batch, unexpected_batch, width/2, bottom=bottom_batch_1,
               label='Unexpected (Batch)', color=colors['unexpected_batch'], alpha=0.9)
        ax.bar(x_batch + width/2, unexp_dissipation_batch, width/2, bottom=bottom_batch_2,
               label='Unexpected Dissipation (Batch)', color=colors['unexp_dissip_batch'], alpha=0.9)
    
    # Add horizontal lines
    ax.axhline(y=6, color='red', linestyle='--', linewidth=2, alpha=0.7, label='Maximum Expected Interactions (6)')
    
    # Add line at y=1 only when subtractD is False (dissipation terms are included)
    if not subtractD:
        ax.axhline(y=1, color='orange', linestyle='--', linewidth=2, alpha=0.7, 
                  label='Maximum Expected Dissipation (1)')
    
    ax.set_xlabel('Shell Number')
    ax.set_ylabel('Number of Interactions')
    ax.set_xticks(all_shells)
    ax.set_xticklabels([str(int(s)) for s in all_shells])
    # Change legend location to overlap with plot
    ax.legend(loc='upper right')
    ax.grid(True, alpha=0.3)

def plot_total_magnitude(ax, single_results, batch_results):
    """Plot total coefficient magnitude."""
    shells_single = [r['shell'] for r in single_results]
    shells_batch = [r['shell'] for r in batch_results]
    
    magnitudes_single = [r['analysis']['total_sum'] for r in single_results]
    magnitudes_batch = [r['analysis']['total_sum'] for r in batch_results]
    
    # Get all shell numbers for x-axis
    all_shells = sorted(set(shells_single + shells_batch))
    
    if single_results:
        ax.semilogy(shells_single, magnitudes_single, 'o-', label='Single Approach', 
                   color='green', linewidth=2, markersize=8)
    
    if batch_results:
        ax.semilogy(shells_batch, magnitudes_batch, 's-', label='Batch Approach', 
                   color='blue', linewidth=2, markersize=8)
    
    ax.set_xlabel('Shell Number')
    ax.set_ylabel('Total Coefficient Magnitude (log scale)')
    # Set explicit x-axis ticks for all shell numbers as integers
    ax.set_xticks(all_shells)
    ax.set_xticklabels([str(int(shell)) for shell in all_shells])
    ax.legend()
    ax.grid(True, alpha=0.3)

def create_summary_table(single_results, batch_results, param_dict, save_plots):
    """Create and display a quantitative summary table."""
    
    # Combine all results
    all_results = single_results + batch_results
    shells = sorted(set([r['shell'] for r in all_results]))
    
    # Create summary data
    summary_data = []
    
    for shell in shells:
        # Single approach data
        single_data = [r for r in single_results if r['shell'] == shell]
        batch_data = [r for r in batch_results if r['shell'] == shell]
        
        row = {'Shell': int(shell)}
        
        if single_data:
            analysis = single_data[0]['analysis']
            row.update({
                'Single_Expected': analysis['expected_count'],
                'Single_Unexpected': analysis['unexpected_count'],
                'Single_ExpDissip': analysis['expected_dissipation_count'],
                'Single_UnexpDissip': analysis['unexpected_dissipation_count'],
                'Single_TotalMag': f"{analysis['total_sum']:.2e}",
                'Single_ExpectedRatio': f"{analysis['expected_count']/(analysis['expected_count']+analysis['unexpected_count'])*100:.1f}%" if (analysis['expected_count']+analysis['unexpected_count']) > 0 else "N/A"
            })
        else:
            row.update({
                'Single_Expected': 0, 'Single_Unexpected': 0, 'Single_ExpDissip': 0, 
                'Single_UnexpDissip': 0, 'Single_TotalMag': "0.00e+00", 'Single_ExpectedRatio': "N/A"
            })
        
        if batch_data:
            analysis = batch_data[0]['analysis']
            row.update({
                'Batch_Expected': analysis['expected_count'],
                'Batch_Unexpected': analysis['unexpected_count'],
                'Batch_ExpDissip': analysis['expected_dissipation_count'],
                'Batch_UnexpDissip': analysis['unexpected_dissipation_count'],
                'Batch_TotalMag': f"{analysis['total_sum']:.2e}",
                'Batch_ExpectedRatio': f"{analysis['expected_count']/(analysis['expected_count']+analysis['unexpected_count'])*100:.1f}%" if (analysis['expected_count']+analysis['unexpected_count']) > 0 else "N/A"
            })
        else:
            row.update({
                'Batch_Expected': 0, 'Batch_Unexpected': 0, 'Batch_ExpDissip': 0, 
                'Batch_UnexpDissip': 0, 'Batch_TotalMag': "0.00e+00", 'Batch_ExpectedRatio': "N/A"
            })
        
        summary_data.append(row)
    
    # Create DataFrame
    df = pd.DataFrame(summary_data)
    
    # Display table
    print(f"\n{'='*80}")
    print(f"QUANTITATIVE SUMMARY")
    print(f"Parameters: {', '.join([f'{k}={v}' for k, v in param_dict.items()])}")
    print(f"{'='*80}")
    print(df.to_string(index=False))
    print(f"{'='*80}")
    
    # Save table if requested
    if save_plots:
        param_str = '_'.join([f"{k}{v}" for k, v in param_dict.items()])
        filename = f"sabra_summary_table_{param_str}.csv"
        df.to_csv(filename, index=False)
        print(f"Saved summary table: {filename}")
    
    return df

def plot_expected_ratio(ax, single_results, batch_results):
    """Plot ratio of expected to total interactions with improved formatting."""
    shells_single = [r['shell'] for r in single_results]
    shells_batch = [r['shell'] for r in batch_results]
    
    ratios_single = []
    ratios_batch = []
    
    for r in single_results:
        total_interactions = r['analysis']['expected_count'] + r['analysis']['unexpected_count']
        ratio = r['analysis']['expected_count'] / total_interactions if total_interactions > 0 else 0
        ratios_single.append(ratio)
    
    for r in batch_results:
        total_interactions = r['analysis']['expected_count'] + r['analysis']['unexpected_count']
        ratio = r['analysis']['expected_count'] / total_interactions if total_interactions > 0 else 0
        ratios_batch.append(ratio)
    
    if single_results:
        ax.plot(shells_single, ratios_single, 'o-', label='Single Approach', 
               color='#2E8B57', linewidth=2, markersize=6)
    
    if batch_results:
        ax.plot(shells_batch, ratios_batch, 's-', label='Batch Approach', 
               color='#1E3A8A', linewidth=2, markersize=6)
    
    ax.set_xlabel('Shell Number')
    ax.set_ylabel('Expected Interactions / Total Interactions')
    ax.set_ylim(-0.05, 1.05)  # Add padding to avoid overlap with borders
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # Improve tick spacing if needed
    if shells_single or shells_batch:
        all_shells = sorted(set(shells_single + shells_batch))
        ax.set_xticks(all_shells)
        ax.set_xticklabels([str(int(s)) for s in all_shells])

def plot_coefficient_percentages(ax, single_results, batch_results):
    """Plot percentage of coefficient magnitude by type with improved formatting."""
    shells = sorted(set([r['shell'] for r in single_results + batch_results]))
    
    # Prepare data
    expected_perc_single = []
    unexpected_perc_single = []
    exp_dissipation_perc_single = []
    unexp_dissipation_perc_single = []
    expected_perc_batch = []
    unexpected_perc_batch = []
    exp_dissipation_perc_batch = []
    unexp_dissipation_perc_batch = []
    
    for shell in shells:
        # Single approach
        single_data = [r for r in single_results if r['shell'] == shell]
        if single_data:
            analysis = single_data[0]['analysis']
            expected_perc_single.append(analysis['expected_percentage'])
            unexpected_perc_single.append(analysis['unexpected_percentage'])
            exp_dissipation_perc_single.append(analysis['expected_dissipation_percentage'])
            unexp_dissipation_perc_single.append(analysis['unexpected_dissipation_percentage'])
        else:
            expected_perc_single.append(0)
            unexpected_perc_single.append(0)
            exp_dissipation_perc_single.append(0)
            unexp_dissipation_perc_single.append(0)
        
        # Batch approach
        batch_data = [r for r in batch_results if r['shell'] == shell]
        if batch_data:
            analysis = batch_data[0]['analysis']
            expected_perc_batch.append(analysis['expected_percentage'])
            unexpected_perc_batch.append(analysis['unexpected_percentage'])
            exp_dissipation_perc_batch.append(analysis['expected_dissipation_percentage'])
            unexp_dissipation_perc_batch.append(analysis['unexpected_dissipation_percentage'])
        else:
            expected_perc_batch.append(0)
            unexpected_perc_batch.append(0)
            exp_dissipation_perc_batch.append(0)
            unexp_dissipation_perc_batch.append(0)
    
    x = np.arange(len(shells))
    width = 0.35
    
    # High contrast colors
    colors = {
        'expected_single': '#006400',
        'exp_dissip_single': '#32CD32',
        'unexpected_single': '#8B0000',
        'unexp_dissip_single': "#FF4F0F",
        'expected_batch': '#000080',
        'exp_dissip_batch': '#4169E1',
        'unexpected_batch': '#800080',
        'unexp_dissip_batch': '#DA70D6'
    }
    
    # Stacked bars
    ax.bar(x - width/2, expected_perc_single, width, label='Expected (Single)', 
           color=colors['expected_single'], alpha=0.9)
    ax.bar(x - width/2, exp_dissipation_perc_single, width, bottom=expected_perc_single,
           label='Expected Dissipation (Single)', color=colors['exp_dissip_single'], alpha=0.9)
    
    bottom_single_1 = np.array(expected_perc_single) + np.array(exp_dissipation_perc_single)
    ax.bar(x - width/2, unexpected_perc_single, width, bottom=bottom_single_1,
           label='Unexpected (Single)', color=colors['unexpected_single'], alpha=0.9)
    
    bottom_single_2 = bottom_single_1 + np.array(unexpected_perc_single)
    ax.bar(x - width/2, unexp_dissipation_perc_single, width, bottom=bottom_single_2,
           label='Unexpected Dissipation (Single)', color=colors['unexp_dissip_single'], alpha=0.9)
    
    ax.bar(x + width/2, expected_perc_batch, width, label='Expected (Batch)', 
           color=colors['expected_batch'], alpha=0.9)
    ax.bar(x + width/2, exp_dissipation_perc_batch, width, bottom=expected_perc_batch,
           label='Expected Dissipation (Batch)', color=colors['exp_dissip_batch'], alpha=0.9)
    
    bottom_batch_1 = np.array(expected_perc_batch) + np.array(exp_dissipation_perc_batch)
    ax.bar(x + width/2, unexpected_perc_batch, width, bottom=bottom_batch_1,
           label='Unexpected (Batch)', color=colors['unexpected_batch'], alpha=0.9)
    
    bottom_batch_2 = bottom_batch_1 + np.array(unexpected_perc_batch)
    ax.bar(x + width/2, unexp_dissipation_perc_batch, width, bottom=bottom_batch_2,
           label='Unexpected Dissipation (Batch)', color=colors['unexp_dissip_batch'], alpha=0.9)
    
    ax.set_xlabel('Shell Number')
    ax.set_ylabel('Percentage of Total Magnitude')
    ax.set_ylim(-2, 102)  # Add padding to avoid overlap with borders
    ax.set_xticks(x)
    ax.set_xticklabels([str(int(shell)) for shell in shells])
    ax.legend(loc='lower left', borderaxespad=0.)
    ax.grid(True, alpha=0.3)

def create_comparison_plots(shell_results, param_combo, nn, save_plots):
    """Create comparison plots for a specific parameter combination without showing them."""
    
    # Group by approach
    single_results = [r for r in shell_results if r['approach'] == 'single']
    batch_results = [r for r in shell_results if r['approach'] == 'batch']
    
    # Sort by shell number
    single_results.sort(key=lambda x: x['shell'])
    batch_results.sort(key=lambda x: x['shell'])
    
    # Extract parameter info for title
    param_dict = dict(param_combo)
    title_parts = []
    for key, value in param_dict.items():
        if key in ['subtractD', 'knorm', 'lassoCV', 'allint']:
            title_parts.append(f"{key}={value}")
    
    # Create individual plots without showing them
    plt.ioff()  # Turn off interactive mode
    
    # Plot 1: Number of interactions by type
    fig1, ax1 = plt.subplots(figsize=(12, 8))
    plot_interaction_counts(ax1, single_results, batch_results, param_dict.get('subtractD', False))
    ax1.set_title(f'Number of Interactions Learned\n{", ".join(title_parts)}', fontsize=14, fontweight='bold')
    
    if save_plots:
        param_str = '_'.join([f"{k}{v}" for k, v in param_dict.items()])
        filename1 = f"sabra_interaction_counts_{param_str}.png"
        fig1.savefig(filename1, dpi=300, bbox_inches='tight')
        print(f"Saved plot: {filename1}")
    
    plt.close(fig1)  # Close instead of show
    
    # Plot 2: Percentage of coefficient magnitude
    fig2, ax2 = plt.subplots(figsize=(12, 8))
    plot_coefficient_percentages(ax2, single_results, batch_results)
    ax2.set_title(f'Coefficient Magnitude Distribution\n{", ".join(title_parts)}', fontsize=14, fontweight='bold')
    
    if save_plots:
        filename2 = f"sabra_coefficient_percentages_{param_str}.png"
        fig2.savefig(filename2, dpi=300, bbox_inches='tight')
        print(f"Saved plot: {filename2}")
    
    plt.close(fig2)
    
    # Plot 3: Expected vs Unexpected ratio
    fig3, ax3 = plt.subplots(figsize=(12, 8))
    plot_expected_ratio(ax3, single_results, batch_results)
    ax3.set_title(f'Expected vs Unexpected Interaction Ratio\n{", ".join(title_parts)}', fontsize=14, fontweight='bold')
    
    if save_plots:
        filename3 = f"sabra_expected_ratio_{param_str}.png"
        fig3.savefig(filename3, dpi=300, bbox_inches='tight')
        print(f"Saved plot: {filename3}")
    
    plt.close(fig3)
    
    # Plot 4: Total coefficient magnitude
    fig4, ax4 = plt.subplots(figsize=(12, 8))
    plot_total_magnitude(ax4, single_results, batch_results)
    ax4.set_title(f'Total Coefficient Magnitude\n{", ".join(title_parts)}', fontsize=14, fontweight='bold')
    
    if save_plots:
        filename4 = f"sabra_total_magnitude_{param_str}.png"
        fig4.savefig(filename4, dpi=300, bbox_inches='tight')
        print(f"Saved plot: {filename4}")
    
    plt.close(fig4)
    
    plt.ion()  # Turn interactive mode back on
    
    # Create and display quantitative summary table
    create_summary_table(single_results, batch_results, param_dict, save_plots)


def plot_interaction_counts_expected(ax, single_results, batch_results, subtractD=False):
    """Plot number of expected interactions/dissipations only."""
    shells_single = [r['shell'] for r in single_results]
    shells_batch = [r['shell'] for r in batch_results]
    
    # Get all shells for x-axis
    all_shells = sorted(set(shells_single + shells_batch))
    
    expected_single = [r['analysis']['expected_count'] for r in single_results]
    exp_dissipation_single = [r['analysis']['expected_dissipation_count'] for r in single_results]
    
    expected_batch = [r['analysis']['expected_count'] for r in batch_results]
    exp_dissipation_batch = [r['analysis']['expected_dissipation_count'] for r in batch_results]
    
    width = 0.35
    
    x_single = np.array(shells_single) - width/2
    x_batch = np.array(shells_batch) + width/2
    
    # High contrast colors
    colors = {
        'expected_single': "#007191",      # Dark Green
        'exp_dissip_single': '#32CD32',    # Lime Green
        'expected_batch': '#000080',       # Navy Blue
        'exp_dissip_batch': "#E1E141"      # Royal Blue
    }
    
    if single_results:
        ax.bar(x_single, expected_single, width/2, label='Expected (Single)', 
               color=colors['expected_single'], alpha=0.9)
        ax.bar(x_single + width/2, exp_dissipation_single, width/2, 
               label='Expected Dissipation (Single)', color=colors['exp_dissip_single'], alpha=0.9)
    
    if batch_results:
        ax.bar(x_batch, expected_batch, width/2, label='Expected (Batch)', 
               color=colors['expected_batch'], alpha=0.9)
        ax.bar(x_batch + width/2, exp_dissipation_batch, width/2,
               label='Expected Dissipation (Batch)', color=colors['exp_dissip_batch'], alpha=0.9)
    
    # Add horizontal lines
    ax.axhline(y=6, color='red', linestyle='--', linewidth=2, alpha=0.7, label='Maximum Expected Interactions (6)')
    
    # Add line at y=1 only when subtractD is False (dissipation terms are included)
    if not subtractD:
        ax.axhline(y=1, color='orange', linestyle='--', linewidth=2, alpha=0.7, 
                  label='Maximum Expected Dissipation (1)')
    
    ax.set_xlabel('Shell Number')
    ax.set_ylabel('Number of Expected Interactions')
    ax.set_xticks(all_shells)
    ax.set_xticklabels([str(int(s)) for s in all_shells])
    ax.legend(loc='upper right')
    ax.grid(True, alpha=0.3)

def create_grouped_figures(results, nn, save_plots):
    """Create separate grouped figures for each parameter combination, showing batch and single results separately."""
    
    plt.ioff()  # Turn off interactive mode
    
    for param_combo, shell_results in results.items():
        param_dict = dict(param_combo)
        
        single_results = [r for r in shell_results if r['approach'] == 'single']
        batch_results = [r for r in shell_results if r['approach'] == 'batch']
        
        single_results.sort(key=lambda x: x['shell'])
        batch_results.sort(key=lambda x: x['shell'])


        title_parts = [f"{k}={v}" for k, v in param_dict.items() 
                      if k in ['subtractD', 'knorm', 'lassoCV', 'allint']]
        param_str = '_'.join([f"{k}{v}" for k, v in param_dict.items()])
        
        # Create batch approach figure if batch results exist
        if batch_results:
            fig_batch, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(20, 10))
            
            # Interaction counts
            plot_interaction_counts(ax1, [], batch_results, param_dict.get('subtractD', False))
            ax1.set_title('Number of Interactions by Type', fontsize=12)
            
            # Coefficient percentages
            plot_coefficient_percentages(ax2, [], batch_results)
            ax2.set_title('Coefficient Magnitude Distribution', fontsize=12)
            
            # Expected ratio
            plot_interaction_counts_expected(ax3, [], batch_results)
            ax3.set_title('Expected ones only', fontsize=12)
            
            # Total magnitude
            plot_total_magnitude(ax4, [], batch_results)
            ax4.set_title('Total Coefficient Magnitude', fontsize=12)
            
            fig_batch.suptitle(f'Batch Approach: {", ".join(title_parts)}', fontsize=14, fontweight='bold')
            plt.tight_layout()
            
            if save_plots:
                filename = f'sabra_batch_grouped_{param_str}.png'
                fig_batch.savefig(filename, dpi=300, bbox_inches='tight')
                print(f"Saved grouped batch figure: {filename}")
            
            plt.close(fig_batch)
        
        # Create single approach figure if single results exist
        if single_results:
            fig_single, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(20, 10))
            
            # Interaction counts
            plot_interaction_counts(ax1, single_results, [], param_dict.get('subtractD', False))
            ax1.set_title('Number of Interactions by Type', fontsize=12)
            
            # Coefficient percentages
            plot_coefficient_percentages(ax2, single_results, [])
            ax2.set_title('Coefficient Magnitude Distribution', fontsize=12)
            
            # Expected ratio
            plot_interaction_counts_expected(ax3, [], single_results)
            ax3.set_title('Number of expected ones only', fontsize=12)
            
            # Total magnitude
            plot_total_magnitude(ax4, single_results, [])
            ax4.set_title('Total Coefficient Magnitude', fontsize=12)
            
            fig_single.suptitle(f'Single Approach: {", ".join(title_parts)}', fontsize=14, fontweight='bold')
            plt.tight_layout()
            
            if save_plots:
                filename = f'sabra_single_grouped_{param_str}.png'
                fig_single.savefig(filename, dpi=300, bbox_inches='tight')
                print(f"Saved grouped single figure: {filename}")
            
            plt.close(fig_single)
    
    plt.ion()  # Turn interactive mode back on

    



Note: you may need to restart the kernel to use updated packages.


In [None]:
compare_sabra_approaches('/home/vale/SABRA/params_bin2/Ws_CV', nn=20, 
                        filter_params={'randombatch': True, 'knorm': False, 'subtractD': True},)

## assembling all the interactions from each shell


In [17]:

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import os
import re
from collections import defaultdict
import pickle




def get_top_interactions_per_shell(W, n_top=50, threshold=1e-10):
    """
    Get top n interactions from a coefficient array W.
    
    Args:
        W: Coefficient array
        n_top: Number of top interactions to keep
        threshold: Minimum coefficient magnitude to consider
    
    Returns:
        dict: {feature_idx: coefficient_value} for top interactions
    """
    # Handle different W formats
    if isinstance(W, (list, tuple)):
        if len(W) > 0:
            W_array = np.array(W[0]) if hasattr(W[0], '__len__') else np.array(W)
        else:
            return {}
    else:
        W_array = np.array(W)
    
    # Flatten if multidimensional
    if W_array.ndim > 1:
        W_array = W_array.flatten()
    
    # Get non-zero coefficients with their indices
    significant_coeffs = []
    for idx, coeff in enumerate(W_array):
        # Handle scalar vs array coefficients
        if hasattr(coeff, '__len__') and len(coeff) > 1:
            coeff_val = np.linalg.norm(coeff)
        else:
            coeff_val = float(coeff)
            
        if abs(coeff_val) >= threshold:
            significant_coeffs.append((idx, coeff_val))
    
    # Sort by absolute coefficient value (descending)
    significant_coeffs.sort(key=lambda x: abs(x[1]), reverse=True)
    
    # Take top n
    top_coeffs = significant_coeffs[:n_top]
    
    return {idx: coeff for idx, coeff in top_coeffs}

def analyze_interactions(W, nn, use_all_interactions=False, threshold=1e-5, expected_interactions=None, 
                        shell_idx=None, include_dissipation=False):
    """Analyze significant interactions in the learned model.
    
    Args:
        W: array[n_shells, n_features] or array[1, n_features] - Learned coefficients
        nn: int - Number of shells
        use_all_interactions: bool
        threshold: float
        expected_interactions: dict
        shell_idx: int or None - The shell being analyzed (0-based) or None for all shells
        include_dissipation: bool - Whether dissipation terms are included in the dictionary
    """
    significant = []
    types = ['regular', 'j_conj', 'i_conj', 'both_conj'] if use_all_interactions else ['regular']
    
    # Convert to 2D array if needed
    Wmat = W if W.ndim == 2 else W.reshape(1, -1)
    n_shells_learned = Wmat.shape[0]
    
    # Determine which shells to analyze
    if shell_idx is not None:
        target_shells = [shell_idx]
        shell_indices = [0]  # Index in Wmat
    else:
        target_shells = list(range(min(n_shells_learned, nn)))
        shell_indices = list(range(n_shells_learned))
    
    for shell_w_idx, target_shell in zip(shell_indices, target_shells):
        # Analyze interaction terms
        for i in range(nn):
            for j in range(nn):
                for t, type_name in enumerate(types):
                    idx = (i * nn + j) * (4 if use_all_interactions else 1) + t
                    if idx >= Wmat.shape[1]:
                        continue
                        
                    weight = Wmat[shell_w_idx, idx]
                    if abs(weight) <= threshold:
                        continue
                    
                    # Check if interaction is expected
                    key = (i, j, type_name)
                    expected_set = expected_interactions.get(target_shell, set()) if expected_interactions else set()
                    status = 'expected' if key in expected_set else 'unexpected'
                    
                    significant.append({
                        'target': target_shell + 1,  # Convert to 1-based indexing
                        'i': i + 1,
                        'j': j + 1,
                        'type': type_name,
                        'weight': weight,
                        'highlight': status == 'expected',
                        'status': status,
                        'term_type': 'interaction'
                    })
        
        # Analyze dissipation terms if included
        if include_dissipation:
            dissipation_start_idx = nn * nn * (4 if use_all_interactions else 1)
            for i in range(nn):
                idx = dissipation_start_idx + i
                if idx >= Wmat.shape[1]:
                    continue
                    
                weight = Wmat[shell_w_idx, idx]
                if abs(weight) <= threshold:
                    continue
                
                significant.append({
                    'target': target_shell + 1,
                    'i': i + 1,
                    'j': i + 1,  # Self-interaction
                    'type': 'dissipation',
                    'weight': weight,
                    'highlight': (i == target_shell),  # Highlight if dissipation comes from the target shell
                    'status': 'expected' if (i == target_shell) else 'unexpected',
                    'term_type': 'dissipation'
                })
    
    return significant


def assemble_final_W_from_shells(data_directory, nn=20, n_top_per_shell=50, 
                                filter_params=None, combination_method='average'):
    """
    Assemble final W by collecting top interactions from each shell separately.
    
    Args:
        data_directory: Path to directory containing W files
        nn: Total number of shells
        n_top_per_shell: Number of top interactions to keep from each shell
        filter_params: Dict of parameter filters (excluding shell)
        combination_method: How to combine coefficients ('average', 'max', 'median')
    
    Returns:
        dict: {param_combo: (final_W, interaction_origins, feature_mapping)}
    """
    
    # Get all W files
    files = [f for f in os.listdir(data_directory) if f.startswith('W_final') or f.startswith('W_list')]
    
    # Group files by parameter combination (excluding shell)
    param_groups = defaultdict(list)
    
    for filename in files:
        params = parse_filename(filename)
        
        # Apply filters if specified
        if filter_params:
            skip = False
            for key, value in filter_params.items():
                if key in params and params[key] != value:
                    skip = True
                    break
            if skip:
                continue
        
        # Group by all parameters except shell
        param_key = tuple(sorted([(k, v) for k, v in params.items() 
                                if k not in ['shell', 'tstart', 'tend']]))
        
        param_groups[param_key].append((filename, params))
    
    # Process each parameter combination
    results = {}
    
    for param_combo, file_list in param_groups.items():
        print(f"\nProcessing parameter combination: {dict(param_combo)}")
        
        # Get representative parameters for building feature mapping
        rep_params = dict(param_combo)
        include_dissipation = not rep_params.get('subtractD', False)
        use_all_interactions = rep_params.get('allint', True)
        
        # Build feature mapping
        feature_mapping = build_feature_mapping(nn, 
                                              use_all_interactions=use_all_interactions,
                                              include_dissipation=include_dissipation)
        
        # Determine total features
        max_feature_idx = max(feature_mapping.keys()) if feature_mapping else 0
        total_features = max_feature_idx + 1
        
        # Collect top interactions from each shell separately
        all_interactions = {}  # {feature_idx: [coeff_values]}
        interaction_origins = {}  # {feature_idx: [(shell, approach)]}
        
        # Process each shell
        shells_found = set()
        for filename, params in file_list:
            shells_found.add(params['shell'])
        
        print(f"Found shells: {sorted(shells_found)}")
        
        for shell_num in sorted(shells_found):
            # Find file for this shell
            shell_file = None
            shell_params = None
            for filename, params in file_list:
                if params['shell'] == shell_num:
                    shell_file = filename
                    shell_params = params
                    break
            
            if shell_file is None:
                continue
            
            print(f"  Processing shell {shell_num} from {shell_file}")
            
            # Load coefficient data
            filepath = os.path.join(data_directory, shell_file)
            W_data = load_coefficient_data(filepath)
            
            if W_data is None:
                continue
            
            # Handle different W data formats
            if hasattr(W_data, 'shape') and len(W_data.shape) > 1 and W_data.shape[0] > 1:
                # For single approach, use the last W (highest sample size)
                W = W_data[-1] if shell_params['approach'] == 'single' else W_data[0]
            else:
                W = W_data
            
            # Get top interactions for this shell
            top_interactions = get_top_interactions_per_shell(W, n_top_per_shell)
            
            print(f"    Found {len(top_interactions)} significant interactions")
            
            # Store interactions with their shell origin
            for feature_idx, coeff_val in top_interactions.items():
                if feature_idx < total_features:  # Safety check
                    if feature_idx not in all_interactions:
                        all_interactions[feature_idx] = []
                        interaction_origins[feature_idx] = []
                    
                    all_interactions[feature_idx].append(coeff_val)
                    interaction_origins[feature_idx].append((shell_num, shell_params['approach']))
        
        # Combine coefficients using specified method
        final_W = np.zeros(total_features)
        
        for feature_idx, coeff_values in all_interactions.items():
            if combination_method == 'average':
                final_W[feature_idx] = np.mean(coeff_values)
            elif combination_method == 'max':
                # Take coefficient with maximum absolute value
                max_idx = np.argmax([abs(c) for c in coeff_values])
                final_W[feature_idx] = coeff_values[max_idx]
            elif combination_method == 'median':
                final_W[feature_idx] = np.median(coeff_values)
            else:
                final_W[feature_idx] = np.mean(coeff_values)  # Default to average
        
        print(f"  Final W has {np.sum(np.abs(final_W) > 1e-10)} non-zero coefficients")
        
        # Store results for this parameter combination
        results[param_combo] = (final_W, interaction_origins, feature_mapping)
    
    return results

def create_interaction_origin_summary_dict(interaction_origins, feature_mapping, nn):
    """
    Create a summary of interaction origins for analysis.
    
    Args:
        interaction_origins: Dict mapping feature indices to origin info
        feature_mapping: Dict mapping feature indices to interactions
        nn: Number of shells
    
    Returns:
        dict: Summary of origins by shell and interaction type
    """
    origin_summary = {
        'by_shell': defaultdict(int),
        'by_approach': defaultdict(int),
        'by_interaction_type': defaultdict(int),
        'shell_contributions': defaultdict(lambda: defaultdict(int)),
        'detailed_interactions': []
    }
    
    for feature_idx, origins in interaction_origins.items():
        if feature_idx in feature_mapping:
            interaction = feature_mapping[feature_idx]
            
            # Determine interaction type
            if len(interaction) == 2 and interaction[1] == 'dissipation':
                int_type = 'dissipation'
                interaction_str = f"D_{interaction[0]}"
            else:
                int_type = 'interaction'
                interaction_str = f"{interaction[0]}+{interaction[1]}→{interaction[0]+interaction[1]}"
            
            # Record detailed interaction info
            origin_summary['detailed_interactions'].append({
                'feature_idx': feature_idx,
                'interaction': interaction,
                'interaction_str': interaction_str,
                'type': int_type,
                'origins': origins
            })
            
            for shell, approach in origins:
                origin_summary['by_shell'][shell] += 1
                origin_summary['by_approach'][approach] += 1
                origin_summary['by_interaction_type'][int_type] += 1
                origin_summary['shell_contributions'][shell][int_type] += 1
    
    return origin_summary


def save_final_W_and_origins(final_W, interaction_origins, feature_mapping, 
                           output_directory, filename_prefix="final_W"):
    """
    Save the final W and origin information to files.
    
    Args:
        final_W: Final coefficient array
        interaction_origins: Origin information
        feature_mapping: Feature mapping
        output_directory: Where to save files
        filename_prefix: Prefix for output files
    """
    import os
    
    # Create output directory if it doesn't exist
    os.makedirs(output_directory, exist_ok=True)
    
    # Save final W
    np.save(os.path.join(output_directory, f"{filename_prefix}.npy"), final_W)
    
    # Save origins and mapping
    with open(os.path.join(output_directory, f"{filename_prefix}_origins.pkl"), 'wb') as f:
        pickle.dump(interaction_origins, f)
    
    with open(os.path.join(output_directory, f"{filename_prefix}_mapping.pkl"), 'wb') as f:
        pickle.dump(feature_mapping, f)
    
    # Create human-readable summary
    origin_summary = create_interaction_origin_summary(interaction_origins, feature_mapping, 20)
    
    with open(os.path.join(output_directory, f"{filename_prefix}_summary.txt"), 'w') as f:
        f.write("FINAL W INTERACTION ORIGINS SUMMARY\n")
        f.write("=" * 40 + "\n\n")
        
        f.write("Contributions by Shell:\n")
        for shell in sorted(origin_summary['by_shell'].keys()):
            f.write(f"  Shell {shell}: {origin_summary['by_shell'][shell]} interactions\n")
        
        f.write(f"\nContributions by Approach:\n")
        for approach, count in origin_summary['by_approach'].items():
            f.write(f"  {approach}: {count} interactions\n")
        
        f.write(f"\nContributions by Interaction Type:\n")
        for int_type, count in origin_summary['by_interaction_type'].items():
            f.write(f"  {int_type}: {count} interactions\n")
        
        f.write(f"\nDetailed Shell Contributions:\n")
        for shell in sorted(origin_summary['shell_contributions'].keys()):
            f.write(f"  Shell {shell}:\n")
            for int_type, count in origin_summary['shell_contributions'][shell].items():
                f.write(f"    {int_type}: {count} interactions\n")


def plot_coupled_results2(W, expected_shells, nn=20, n_strongest=30, allint=True, 
                         subtractD=False, filename=None, interaction_origins=None, 
                         feature_mapping=None, show_origin_info=True):
    """
    Plot SABRA regression results with optional interaction origin visualization.
    Maintains backward compatibility with original plot structure.
    
    Args:
        W: Coefficient array or final assembled W
        expected_shells: List of shells to consider for expected interactions
        nn: Number of shells
        n_strongest: Number of strongest coefficients to highlight
        allint: Whether all interactions were used
        subtractD: Whether dissipation was subtracted
        filename: Optional filename to save plot
        interaction_origins: Dict mapping feature indices to origin info (optional)
        feature_mapping: Dict mapping feature indices to interactions (optional)
        show_origin_info: Whether to show origin information in plots
    """
    
    # Handle W format
    if isinstance(W, (list, tuple)):
        if len(W) > 0:
            W_array = np.array(W[0]) if hasattr(W[0], '__len__') else np.array(W)
        else:
            print("Empty W provided")
            return
    else:
        W_array = np.array(W)
    
    if W_array.ndim > 1:
        W_array = W_array.flatten()
    
    # Build feature mapping if not provided
    if feature_mapping is None:
        include_dissipation = not subtractD
        feature_mapping = build_feature_mapping(nn, use_all_interactions=allint, 
                                              include_dissipation=include_dissipation)
    
    # Get expected interactions for all expected shells
    all_expected_interactions = set()
    for shell_idx in expected_shells:
        shell_expected = get_expected_interactions_single_shell(nn, shell_idx - 1)  # Convert to 0-based
        all_expected_interactions.update(shell_expected)
    
    # Analyze coefficients (original logic preserved)
    expected_coeff_sum = 0
    unexpected_coeff_sum = 0
    expected_dissipation_sum = 0
    unexpected_dissipation_sum = 0
    expected_count = 0
    unexpected_count = 0
    expected_dissipation_count = 0
    unexpected_dissipation_count = 0
    
    # Track coefficients for detailed plotting
    expected_coeffs = []
    unexpected_coeffs = []
    expected_dissipation_coeffs = []
    unexpected_dissipation_coeffs = []
    origin_mismatch_info = {}  # Track origin mismatches
    
    # Find significant coefficients
    significant_coeffs = []
    for idx, coeff in enumerate(W_array):
        if hasattr(coeff, '__len__') and len(coeff) > 1:
            coeff_val = np.linalg.norm(coeff)
        else:
            coeff_val = float(coeff)
            
        if abs(coeff_val) > 1e-10:
            significant_coeffs.append((idx, coeff_val))
    
    # Sort and get top coefficients
    significant_coeffs.sort(key=lambda x: abs(x[1]), reverse=True)
    top_coeffs = significant_coeffs[:n_strongest]
    
    # Categorize coefficients (original logic)
    for idx, coeff_val in significant_coeffs:
        if idx in feature_mapping:
            interaction = feature_mapping[idx]
            
            # Check for origin mismatch if we have origin info
            has_origin_mismatch = False
            mismatch_details = []
            if interaction_origins and idx in interaction_origins and show_origin_info:
                if len(interaction) == 3:  # Regular interaction (not dissipation)
                    i, j, int_type = interaction
                    is_globally_expected = (i, j, int_type) in all_expected_interactions
                    
                    if is_globally_expected:
                        for origin_shell,approach in interaction_origins[idx]:
                            shell_expected = get_expected_interactions_single_shell(nn, origin_shell - 1)
                            if (i, j, int_type) not in shell_expected:
                                has_origin_mismatch = True
                                mismatch_details.append((origin_shell, approach))
            
            if len(interaction) == 2 and interaction[1] == 'dissipation':
                # Dissipation term
                dissipation_shell = interaction[0]
                if dissipation_shell in [s-1 for s in expected_shells]:  # Convert to 0-based
                    expected_dissipation_sum += abs(coeff_val)
                    expected_dissipation_count += 1
                    expected_dissipation_coeffs.append((idx, coeff_val, interaction))
                else:
                    unexpected_dissipation_sum += abs(coeff_val)
                    unexpected_dissipation_count += 1
                    unexpected_dissipation_coeffs.append((idx, coeff_val, interaction))
            else:
                # Regular interaction
                i, j, int_type = interaction
                if (i, j, int_type) in all_expected_interactions:
                    expected_coeff_sum += abs(coeff_val)
                    expected_count += 1
                    expected_coeffs.append((idx, coeff_val, interaction))
                    
                    # Track origin mismatch for expected interactions
                    if has_origin_mismatch:
                        origin_mismatch_info[idx] = mismatch_details
                else:
                    unexpected_coeff_sum += abs(coeff_val)
                    unexpected_count += 1
                    unexpected_coeffs.append((idx, coeff_val, interaction))
    
    total_sum = expected_coeff_sum + unexpected_coeff_sum + expected_dissipation_sum + unexpected_dissipation_sum
    
    # Create the original plot structure
    fig = plt.figure(figsize=(20, 12))
    
    # Main plot: Coefficient values (original)
    ax1 = plt.subplot(2, 3, (1, 2))
    
    # Plot top coefficients with original color scheme
    indices = list(range(len(top_coeffs)))
    coeff_values = [coeff for _, coeff in top_coeffs]
    
    # Color by category (original logic)
    colors = []
    markers = []
    for idx, coeff_val in top_coeffs:
        if idx in feature_mapping:
            interaction = feature_mapping[idx]
            
            if len(interaction) == 2 and interaction[1] == 'dissipation':
                dissipation_shell = interaction[0]
                if dissipation_shell in [s-1 for s in expected_shells]:
                    colors.append('blue')  # Expected dissipation
                else:
                    colors.append('lightblue')  # Unexpected dissipation
                markers.append('s')  # Square for dissipation
            else:
                i, j, int_type = interaction
                if (i, j, int_type) in all_expected_interactions:
                    # Check for origin mismatch
                    if idx in origin_mismatch_info and show_origin_info:
                        colors.append('orange')  # Expected but wrong origin
                        markers.append('^')  # Triangle for mismatch
                    else:
                        colors.append('green')  # Expected
                        markers.append('o')  # Circle for expected
                else:
                    colors.append('red')  # Unexpected
                    markers.append('o')  # Circle for unexpected
        else:
            colors.append('gray')
            markers.append('o')
    
    # Create scatter plot (preserving original style)
    for i, (idx_val, coeff_val) in enumerate(top_coeffs):
        ax1.scatter(i, coeff_val, c=colors[i], marker=markers[i], s=60, alpha=0.7)
    
    ax1.axhline(y=0, color='black', linestyle='-', alpha=0.3)
    ax1.set_xlabel('Coefficient Index (ranked by magnitude)')
    ax1.set_ylabel('Coefficient Value')
    ax1.set_title(f'Top {n_strongest} Strongest Coefficients')
    ax1.grid(True, alpha=0.3)
    
    # Original legend with origin info addition
    legend_elements = [
        plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='green', 
                   markersize=8, label='Expected', alpha=0.7),
        plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='red', 
                   markersize=8, label='Unexpected', alpha=0.7),
        plt.Line2D([0], [0], marker='s', color='w', markerfacecolor='blue', 
                   markersize=8, label='Expected Dissipation', alpha=0.7),
        plt.Line2D([0], [0], marker='s', color='w', markerfacecolor='lightblue', 
                   markersize=8, label='Unexpected Dissipation', alpha=0.7)
    ]
    
    if origin_mismatch_info and show_origin_info:
        legend_elements.append(
            plt.Line2D([0], [0], marker='^', color='w', markerfacecolor='orange', 
                       markersize=8, label='Origin Mismatch', alpha=0.7)
        )
    
    ax1.legend(handles=legend_elements, loc='upper right')
    
    # Summary pie chart (original)
    ax2 = plt.subplot(2, 3, 3)
    
    pie_data = []
    pie_labels = []
    pie_colors = []
    
    if expected_coeff_sum > 0:
        pie_data.append(expected_coeff_sum)
        pie_labels.append(f'Expected\n({expected_count})')
        pie_colors.append('green')
    
    if unexpected_coeff_sum > 0:
        pie_data.append(unexpected_coeff_sum)
        pie_labels.append(f'Unexpected\n({unexpected_count})')
        pie_colors.append('red')
    
    if expected_dissipation_sum > 0:
        pie_data.append(expected_dissipation_sum)
        pie_labels.append(f'Exp. Dissip.\n({expected_dissipation_count})')
        pie_colors.append('blue')
    
    if unexpected_dissipation_sum > 0:
        pie_data.append(unexpected_dissipation_sum)
        pie_labels.append(f'Unexp. Dissip.\n({unexpected_dissipation_count})')
        pie_colors.append('lightblue')
    
    if pie_data:
        ax2.pie(pie_data, labels=pie_labels, colors=pie_colors, autopct='%1.1f%%', startangle=90)
        ax2.set_title('Coefficient Sum Distribution')
    
    # NEW: Origin analysis plots (only if origin info available)
    if interaction_origins and show_origin_info:
        # Origin shell distribution
        ax3 = plt.subplot(2, 3, 4)
        
        origin_shell_counts = defaultdict(int)
        origin_approach_counts = defaultdict(int)
        
        for idx, coeff_val in top_coeffs:
            if idx in interaction_origins:
                for origin_shell, approach in interaction_origins[idx]:
                    origin_shell_counts[origin_shell] += 1
                    origin_approach_counts[approach] += 1
        
        if origin_shell_counts:
            shells = sorted(origin_shell_counts.keys())
            counts = [origin_shell_counts[s] for s in shells]
            
            bars = ax3.bar(shells, counts, alpha=0.7, color='skyblue', edgecolor='navy')
            ax3.set_xlabel('Origin Shell')
            ax3.set_ylabel('Number of Top Interactions')
            ax3.set_title('Interaction Origins by Shell')
            ax3.set_xticks(shells)
            
            # Highlight shells with mismatches
            if origin_mismatch_info:
                mismatch_shells = set()
                for mismatch_details in origin_mismatch_info.values():
                    for shell, _ in mismatch_details:
                        mismatch_shells.add(shell)
                
                for i, shell in enumerate(shells):
                    if shell in mismatch_shells:
                        bars[i].set_color('orange')
                        bars[i].set_alpha(0.8)
        
        # Approach distribution
        ax4 = plt.subplot(2, 3, 5)
        
        if origin_approach_counts:
            approaches = list(origin_approach_counts.keys())
            counts = list(origin_approach_counts.values())
            
            ax4.pie(counts, labels=approaches, autopct='%1.1f%%', startangle=90)
            ax4.set_title('Interactions by Approach')
        
        # Origin mismatch details
        ax5 = plt.subplot(2, 3, 6)
        
        if origin_mismatch_info:
            mismatch_text = "Origin Mismatches:\n\n"
            for idx, mismatch_details in list(origin_mismatch_info.items())[:10]:  # Show first 10
                if idx in feature_mapping:
                    interaction = feature_mapping[idx]
                    mismatch_text += f"Interaction {interaction}:\n"
                    for shell, approach in mismatch_details:
                        mismatch_text += f"  From shell {shell} ({approach})\n"
                    mismatch_text += "\n"
            
            ax5.text(0.05, 0.95, mismatch_text, transform=ax5.transAxes, 
                    verticalalignment='top', fontsize=8, fontfamily='monospace')
            ax5.set_xlim(0, 1)
            ax5.set_ylim(0, 1)
            ax5.axis('off')
            ax5.set_title('Origin Mismatch Details')
        else:
            ax5.text(0.5, 0.5, 'No origin mismatches detected', 
                    transform=ax5.transAxes, ha='center', va='center')
            ax5.axis('off')
            ax5.set_title('Origin Analysis')
    
    plt.tight_layout()
    
    # Print original summary with origin info
    print(f"\nOriginal Analysis Summary:")
    print(f"Expected interactions: {expected_count} (sum: {expected_coeff_sum:.4f})")
    print(f"Unexpected interactions: {unexpected_count} (sum: {unexpected_coeff_sum:.4f})")
    print(f"Expected dissipation: {expected_dissipation_count} (sum: {expected_dissipation_sum:.4f})")
    print(f"Unexpected dissipation: {unexpected_dissipation_count} (sum: {unexpected_dissipation_sum:.4f})")
    print(f"Total sum: {total_sum:.4f}")
    
    if expected_coeff_sum + unexpected_coeff_sum > 0:
        print(f"Expected percentage: {expected_coeff_sum/(expected_coeff_sum + unexpected_coeff_sum)*100:.1f}%")
    
    if origin_mismatch_info and show_origin_info:
        print(f"\nOrigin Analysis:")
        print(f"Interactions with origin mismatches: {len(origin_mismatch_info)}")
        origin_shell_counts = defaultdict(int)
        for origins in interaction_origins.values():
            for shell, approach in origins:
                origin_shell_counts[shell] += 1
        print(f"Origin shell distribution: {dict(origin_shell_counts)}")
    
    if filename:
        plt.savefig(filename, dpi=300, bbox_inches='tight')
        print(f"\nPlot saved as: {filename}")
        plt.close(fig)
    else:
        plt.show()
    
    return {
        'expected_percentage': (expected_coeff_sum/(expected_coeff_sum + unexpected_coeff_sum)*100) if (expected_coeff_sum + unexpected_coeff_sum > 0) else 0,
        'expected_coeffs': expected_coeffs,
        'unexpected_coeffs': unexpected_coeffs,
        'expected_dissipation_coeffs': expected_dissipation_coeffs,
        'unexpected_dissipation_coeffs': unexpected_dissipation_coeffs,
        'origin_mismatch_info': origin_mismatch_info,
        'top_coeffs': top_coeffs
    }

def get_expected_interactions(nn):
    """
    Build a dict of expected (i, j, type) tuples for each shell index n (0-based).
    Only the three model terms:
      1) i k_{n+1} A u_{n+1}^* u_{n+2}    → type 'i_conj'
      2) i k_n     B u_{n-1}^* u_{n+1}    → type 'i_conj'
      3) -i k_{n-1} C u_{n-2}   u_{n-1}   → type 'regular'
    Also includes symmetry: (i, j, 'i_conj') is equivalent to (j, i, 'j_conj').
    Also (i, j, 'regular') is equivalent to (j, i, 'regular').
    """
    expected = {}
    for n in range(nn):
        terms = set()
        # term 1 (forward-forward)
        if n <= nn-3:
            # (n+1, n+2, 'i_conj') and its symmetric (n+2, n+1, 'j_conj')
            terms.add((n+1, n+2, 'i_conj'))
            terms.add((n+2, n+1, 'j_conj'))
        # term 2 (backward-forward)
        if 1 <= n <= nn-2:
            # (n-1, n+1, 'i_conj') and its symmetric (n+1, n-1, 'j_conj')
            terms.add((n-1, n+1, 'i_conj'))
            terms.add((n+1, n-1, 'j_conj'))
        # term 3 (backward-backward)
        if n >= 2:
            # (n-2, n-1, 'regular') and its symmetric (n-1, n-2, 'regular')
            terms.add((n-2, n-1, 'regular'))
            terms.add((n-1, n-2, 'regular'))
        expected[n] = terms
    return expected





In [None]:
nn=20
data_dir = '/home/vale/SABRA/params_bin/Ws'
expected_shells = list(range(nn)) # Example expected shells
subtract_dissipation= False


results = assemble_final_W_from_shells(
    data_directory=data_dir,
    nn=nn,
    n_top_per_shell=7,
    filter_params={ 'subtractD': subtract_dissipation} # example filters
)

# Then for each parameter combination:
for param_combo, (final_W, interaction_origins, feature_mapping) in results.items():
    filename='stats_Wfinal'
    filename+=f"_{'_'.join([f'{k}{v}' for k, v in param_combo])}.png"

    plot_coupled_results2(final_W, expected_shells=expected_shells, nn=nn,
                         interaction_origins=interaction_origins,
                         feature_mapping=feature_mapping,filename=filename,)
    
    plot_suffix= f"_{'_'.join([f'{k}{v}' for k, v in param_combo])}"
    #plot_suffix+= 'final_W'

    plot_coupled_results(final_W, expected_shells=expected_shells, nn=nn,top_n=90, 
                         include_dissipation=not subtract_dissipation,use_all_interactions=True,plot_suffix=plot_suffix,)

## Ranking interactions shell by shell

In [43]:
import os
import re
import numpy as np
import matplotlib.pyplot as plt

def plot_expected_interactions_by_shell(nn, refined, t_start, normalize_by_k, subtract_dissipation,
                                        use_lassoCV, lambda_reg, expected_interactions,
                                        include_dissipation, nbatch=None, minocc=None, 
                                        save_dir='Ws', output_dir='rankings'):
    """
    Load W_list files and plot expected interactions/dissipations shell-by-shell, ordered by magnitude.

    Args:
        nn (int): Number of shells.
        refined (str): Refinement label in filename (can be empty).
        t_start (float): Starting time index used in filenames.
        normalize_by_k (bool): Whether normalization by k was used.
        subtract_dissipation (bool): Whether dissipation terms were subtracted.
        use_lassoCV (bool): Whether LassoCV was used.
        lambda_reg (float): Regularization strength (ignored if LassoCV is True).
        expected_interactions (dict): Dict of expected interactions per shell.
        include_dissipation (bool): Whether dissipation terms are included in the model.
        nbatch (int, optional): Number of batches parameter to filter files.
        minocc (int, optional): Minimum occurrence parameter to filter files.
        save_dir (str): Folder where W_list_*.npy files are stored.
        output_dir (str): Folder where plot will be saved.
    """

    def compute_expected_ranking(shell_idx, nn):
        """Compute expected coefficient ranking for a given shell based on SABRA theory."""
        r = 2
        A = 1
        B = A/r - A  # -0.5
        C = -A/r     # -0.5
        nu = 1e-8
        
        # Compute wavenumbers (0-based indexing)
        k = r ** np.arange(nn)
        
        # Store expected coefficients with their identifiers
        expected_coeffs = []
        
        # Shell indexing: shell_idx is 0-based, but SABRA equations use 1-based (n)
        n = shell_idx + 1
        
        # Add interaction terms based on SABRA equations
        # Note: k array is 0-based, but SABRA equations use 1-based indexing
        if n == 1:  # First shell
            if n+1 < nn:  # Check bounds for u_3
                coeff_mag = k[n] * abs(A)  # k_2 * A for u_2^* u_3 (k[1] in 0-based)
                expected_coeffs.append((coeff_mag, ('interaction', n+1, n+2, 'i_conj')))
                
        elif n == 2:  # Second shell
            if n+1 < nn:  # k_3 A u_3^* u_4 term
                coeff_mag = k[n] * abs(A)  # k[2] for k_3
                expected_coeffs.append((coeff_mag, ('interaction', n+1, n+2, 'i_conj')))
            # k_2 B u_1^* u_3 term
            coeff_mag = k[n-1] * abs(B)  # k[1] for k_2
            expected_coeffs.append((coeff_mag, ('interaction', n-1, n+1, 'i_conj')))
            
        elif n == nn-1:  # Penultimate shell
            # k_{nn-1} B u_{nn-2}^* u_{nn} term
            coeff_mag = k[n-1] * abs(B)  # k[n-2] for k_{nn-1}
            expected_coeffs.append((coeff_mag, ('interaction', n-1, n+1, 'i_conj')))
            # k_{nn-2} C u_{nn-3} u_{nn-2} term
            if n-2 >= 1:
                coeff_mag = k[n-3] * abs(C)  # k[n-4] for k_{nn-2}
                expected_coeffs.append((coeff_mag, ('interaction', n-2, n-1, 'regular')))
                
        elif n == nn:  # Last shell
            # k_{nn-1} C u_{nn-2} u_{nn-1} term
            coeff_mag = k[n-2] * abs(C)  # k[n-3] for k_{nn-1}
            expected_coeffs.append((coeff_mag, ('interaction', n-1, n, 'regular')))
            
        else:  # Middle shells (3 to nn-2)
            if n+1 < nn:  # Forward term: k_{n+1} A u_{n+1}^* u_{n+2}
                coeff_mag = k[n] * abs(A)  # k[n-1] for k_{n+1}
                expected_coeffs.append((coeff_mag, ('interaction', n+1, n+2, 'i_conj')))
            # Middle term: k_n B u_{n-1}^* u_{n+1}
            coeff_mag = k[n-1] * abs(B)  # k[n-2] for k_n
            expected_coeffs.append((coeff_mag, ('interaction', n-1, n+1, 'i_conj')))
            # Backward term: k_{n-1} C u_{n-2} u_{n-1}
            if n-2 >= 1:
                coeff_mag = k[n-2] * abs(C)  # k[n-3] for k_{n-1}
                expected_coeffs.append((coeff_mag, ('interaction', n-1, n, 'regular')))
        
        # Add dissipation term if included
        if include_dissipation:
            coeff_mag = nu * k[n-1]**2  # k[n-2] for k_n
            expected_coeffs.append((coeff_mag, ('dissipation', n)))
        
        # Sort by magnitude (descending)
        expected_coeffs.sort(key=lambda x: x[0], reverse=True)
        
        # Return just the identifiers in ranked order
        return [coeff[1] for coeff in expected_coeffs]

    type_colors = {
        'forward-forward': 'green',
        'backward-backward': 'orange',
        'backward-forward': 'blue',
        'dissipation': 'red'
    }

    # Adjust regex to match both old and new filename formats
    # Old format: W_list__tstart...
    # New format: W_final_shell1_nbatch5_tstart1_tend50000_allintTrue_iTrue_subtractDFalse_knormFalse_lassoCVTrue_lambda500.0_randombatchFalse_sigthresh100.0_minocc5
    
    shell_data = {}
    max_shell = 0

    for fname in os.listdir(save_dir):
        if fname.startswith("W_") and fname.endswith(".npy"):
            shell_idx = None
            
            # Try new format first
            new_pattern = rf"W_final_shell(\d+)_nbatch(\d+)_tstart{t_start}_.*_subtractD{subtract_dissipation}_knorm{normalize_by_k}_lassoCVTrue.*_minocc(\d+)"
            if use_lassoCV:
                match = re.search(new_pattern, fname)
                if match:
                    file_shell_idx = int(match.group(1)) - 1  # 0-based shell index
                    file_nbatch = int(match.group(2))
                    file_minocc = int(match.group(3))
                    
                    # Check if parameters match (if specified)
                    if (nbatch is None or file_nbatch == nbatch) and (minocc is None or file_minocc == minocc):
                        shell_idx = file_shell_idx
            else:
                # For non-lassoCV, adjust pattern to include lambda value
                new_pattern_lambda = rf"W_final_shell(\d+)_nbatch(\d+)_tstart{t_start}_.*_subtractD{subtract_dissipation}_knorm{normalize_by_k}_lassoCVFalse_lambda{lambda_reg}_.*_minocc(\d+)"
                match = re.search(new_pattern_lambda, fname)
                if match:
                    file_shell_idx = int(match.group(1)) - 1  # 0-based shell index
                    file_nbatch = int(match.group(2))
                    file_minocc = int(match.group(3))
                    
                    # Check if parameters match (if specified)
                    if (nbatch is None or file_nbatch == nbatch) and (minocc is None or file_minocc == minocc):
                        shell_idx = file_shell_idx
            
            # If new format didn't match, try old format
            if shell_idx is None:
                old_pattern = (f"_tstart{t_start}_shell(\\d+)_nn{nn}_"
                              f"lambda{(lambda_reg if not use_lassoCV else '')}_"
                              f"lassoCV{use_lassoCV}_subtractD{subtract_dissipation}_knorm{normalize_by_k}")
                match = re.search(old_pattern, fname)
                if match:
                    shell_idx = int(match.group(1)) - 1  # 0-based shell index
            
            if shell_idx is not None:
                full_path = os.path.join(save_dir, fname)
                W = np.load(full_path, allow_pickle=True)
                W = W[()] if isinstance(W, np.ndarray) and W.shape == () else W

                significant = analyze_interactions(
                    W=W,
                    nn=nn,
                    use_all_interactions=True,
                    threshold=1e-5,
                    expected_interactions=expected_interactions,
                    shell_idx=shell_idx,
                    include_dissipation=include_dissipation
                )

                expected_terms = [entry for entry in significant if entry['status'] == 'expected']
                expected_terms.sort(key=lambda x: abs(x['weight']), reverse=True)
                
                # Compare ranking with expected
                learned_ranking = []
                for e in expected_terms:
                    if e['type'] == 'dissipation':
                        learned_ranking.append(('dissipation', e['i']))
                    else:
                        learned_ranking.append(('interaction', e['i'], e['j'], e['type']))
                
                expected_ranking = compute_expected_ranking(shell_idx, nn)
                ranking_match = learned_ranking == expected_ranking
                
                shell_data[shell_idx] = {
                    'terms': expected_terms,
                    'ranking_match': ranking_match
                }
                max_shell = max(max_shell, shell_idx)

    if not shell_data:
        print("❌ No matching files found for given parameters.")
        return

    # Create plot layout
    ncols = 5
    nrows = int(np.ceil((max_shell + 1) / ncols))
    fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols * 4.5, nrows * 3.5), constrained_layout=True)

    for idx in range(max_shell + 1):
        row, col = divmod(idx, ncols)
        ax = axs[row][col] if nrows > 1 else axs[col] if ncols > 1 else axs

        shell_info = shell_data.get(idx, {'terms': [], 'ranking_match': True})
        terms = shell_info['terms']
        ranking_match = shell_info['ranking_match']
        
        labels = [f"{e['i']},{e['j']} ({e['type']})" for e in terms]
        weights = [abs(e['weight']) for e in terms]

        # Classify interaction types
        colors = []
        for e in terms:
            i, j = e['i'] - 1, e['j'] - 1
            target = e['target'] - 1
            t = e['type']

            if t == 'dissipation':
                cat = 'dissipation'
            elif i > target and j > target:
                cat = 'forward-forward'
            elif i < target and j < target:
                cat = 'backward-backward'
            else:
                cat = 'backward-forward'

            colors.append(type_colors[cat])

        # Create bar plot with fixed width for single bars
        if len(labels) == 1:
            # Single bar case - use fixed width and center it
            bar_width = 0.3
            x_pos = [0]
            ax.bar(x_pos, weights, width=bar_width, color=colors, align='center')
            ax.set_xlim(-0.5, 0.5)
            ax.set_xticks([0])
            ax.set_xticklabels(labels, fontsize=8)
        else:
            # Multiple bars case - use default behavior
            ax.bar(labels, weights, color=colors)
            ax.set_xticks(range(len(labels)))
            ax.set_xticklabels(labels, rotation=45, ha='right', fontsize=8)
        
        # Add ranking match indicator to title
        ranking_indicator = "✓" if ranking_match else "✗"
        title_color = 'green' if ranking_match else 'red'
        ax.set_title(f"Shell {idx + 1} {ranking_indicator}", color=title_color, fontweight='bold')
        ax.set_ylabel("Magnitude")
        ax.grid(True)

    # Turn off empty subplots
    for empty_idx in range(max_shell + 1, nrows * ncols):
        row, col = divmod(empty_idx, ncols)
        axs[row][col].axis('off')

    # Global legend
    handles = [plt.Line2D([0], [0], marker='o', color='w', label=key,
                          markerfacecolor=color, markersize=10)
               for key, color in type_colors.items()]
    fig.legend(handles=handles, loc='upper center', ncol=len(type_colors), title="Interaction Types")

    # Save plot
    os.makedirs(output_dir, exist_ok=True)
    
    # Build filename with nbatch and minocc if specified
    nbatch_str = f"_nbatch{nbatch}" if nbatch is not None else ""
    minocc_str = f"_minocc{minocc}" if minocc is not None else ""
    
    plot_filename = (f"expected_terms_{refined}_tstart{t_start}{nbatch_str}{minocc_str}_nn{nn}_"
                     f"lambda{(lambda_reg if not use_lassoCV else '')}_"
                     f"lassoCV{use_lassoCV}_subtractD{subtract_dissipation}_knorm{normalize_by_k}.png")
    output_path = os.path.join(output_dir, plot_filename)

    plt.savefig(output_path, dpi=300)
    plt.close()
    print(f"✅ Plot saved to {output_path}")


In [52]:

include_dissipation = True  # Whether to include dissipation terms
nn=20
save_dir='/home/vale/SABRA/params_bin2/Ws_CV'
output_path='/home/vale/SABRA/params_bin2/rankings'
n_batch=5
min_occ=5

expected_interactions= get_expected_interactions(nn=nn)


plot_expected_interactions_by_shell(nn=nn,
    refined="",
    t_start=1,
    normalize_by_k=False,
    subtract_dissipation= not include_dissipation,
    use_lassoCV=True,
    lambda_reg=5e2,
    expected_interactions=expected_interactions,  # your dict of expected terms
    include_dissipation=include_dissipation,
    nbatch=n_batch,  # Number of batches to filter files
    minocc=min_occ,  # Minimum occurrence to filter files
    save_dir=save_dir,  # Directory to save plots
    output_dir=output_path,
)


✅ Plot saved to /home/vale/SABRA/params_bin2/rankings/expected_terms__tstart1_nbatch5_minocc5_nn20_lambda_lassoCVTrue_subtractDFalse_knormFalse.png
