# W2 Distance Analysis

This notebook calculates **Energy W2** and **Torsion W2** distances between generated samples and a reference dataset.

## Features:
- **Energy W2**: Wasserstein-2 distance between energy distributions
- **Torsion W2**: Wasserstein-2 distance between torsion angle distributions  
- **Modular design**: Easy to configure and extend
- **Lightweight**: Focused on W2 calculations only

Based on functions from `infer_ad2.py`.

## Setup and Imports

In [1]:
import numpy as np
import torch
import mdtraj as md
import ot  # optimal transport
import matplotlib.pyplot as plt
from pathlib import Path
import json
import os

# Import required functions from local modules
from dataset.ad2_dataset import get_alanine_atom_types, get_alanine_implicit_dataset
from bgflow.utils import remove_mean
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   
os.environ["CUDA_VISIBLE_DEVICES"]="7" #Swap this every time we run

print("Imports completed successfully!")


****** PyMBAR will use 64-bit JAX! *******
* JAX is currently set to 32-bit bitsize *
* which is its default.                  *
*                                        *
* PyMBAR requires 64-bit mode and WILL   *
* enable JAX's 64-bit mode when called.  *
*                                        *
* This MAY cause problems with other     *
* Uses of JAX in the same code.          *
******************************************



Imports completed successfully!


## Configuration

In [2]:
# =============================================================================
# CONFIGURATION
# =============================================================================

CONFIG = {
    # Data file paths - now supporting 6 different generated sample files
    'generated_data_paths': [
        "/net/galaxy/home/koes/rishal/nce/BoltzNCE/BoltzNCE/generated/ebm_endpoint_MCMC0_0_numpy_dict.npz",  # File 1
        "/net/galaxy/home/koes/rishal/nce/BoltzNCE/BoltzNCE/generated/ebm_ot_vector_MCMC0_0_numpy_dict.npz",  # File 2
        "/net/galaxy/home/koes/rishal/nce/BoltzNCE/BoltzNCE/generated/ebm_vector_MCMC0_0_numpy_dict.npz",  # File 3
        "/net/galaxy/home/koes/rishal/nce/BoltzNCE/BoltzNCE/generated/ecnf_biased_generated.npz",  # File 4
        "/net/galaxy/home/koes/rishal/nce/BoltzNCE/BoltzNCE/generated/endpoint_ot_ema_0_numpy_dict.npz",  # File 5
        "/net/galaxy/home/koes/rishal/nce/BoltzNCE/BoltzNCE/generated/endpoint_ot_0_numpy_dict.npz",  # File 6
        "/net/galaxy/home/koes/rishal/nce/BoltzNCE/BoltzNCE/generated/vector_ot_0_numpy_dict.npz",  # File 6
        "/net/galaxy/home/koes/rishal/nce/BoltzNCE/BoltzNCE/generated/vector_ot_ema_0_numpy_dict.npz",  # File 6
    ],
    
    # Labels for each file (for identification in results)
    'file_labels': [
        'ebm_endpoint_MCMC0_0_', 'Model_ebm_ot_vector_MCMC0_0_1', 'ebm_vector_MCMC0_0_', 'ecnf_biased_generated', 'endpoint_ot_ema_0_', 'vector_ot_ema_5_','vector_ot_0_','vector_ot_ema_0_'
    ],
    
    'reference_data_path': None,  # Reference dataset path (to be provided)
    
    # Data keys in .npz files
    'samples_key': 'samples',    # Key for samples in generated data
    'energies_key': 'energies',  # Key for energies in generated data
    
    # Analysis settings
    'n_samples_analysis': 10000,  # Number of samples to use for W2 calculation (for computational efficiency)
    'n_runs': 5,                 # Number of repeated sampling runs for statistical analysis
    'random_seed': 42,           # Random seed for reproducible sampling
    
    # Output settings
    'output_dir': './w2_analysis_results/',
    'save_results': True,
    'create_plots': True
}

print("Configuration:")
print(f"  Number of generated data files: {len(CONFIG['generated_data_paths'])}")
for i, (path, label) in enumerate(zip(CONFIG['generated_data_paths'], CONFIG['file_labels'])):
    print(f"    {label}: {path}")
print(f"  reference_data_path: {CONFIG['reference_data_path']}")
print(f"  samples_key: {CONFIG['samples_key']}")
print(f"  energies_key: {CONFIG['energies_key']}")
print(f"  n_samples_analysis: {CONFIG['n_samples_analysis']}")
print(f"  n_runs: {CONFIG['n_runs']}")
print(f"  random_seed: {CONFIG['random_seed']}")
print(f"  output_dir: {CONFIG['output_dir']}")
print(f"  save_results: {CONFIG['save_results']}")
print(f"  create_plots: {CONFIG['create_plots']}")

Configuration:
  Number of generated data files: 8
    ebm_endpoint_MCMC0_0_: /net/galaxy/home/koes/rishal/nce/BoltzNCE/BoltzNCE/generated/ebm_endpoint_MCMC0_0_numpy_dict.npz
    Model_ebm_ot_vector_MCMC0_0_1: /net/galaxy/home/koes/rishal/nce/BoltzNCE/BoltzNCE/generated/ebm_ot_vector_MCMC0_0_numpy_dict.npz
    ebm_vector_MCMC0_0_: /net/galaxy/home/koes/rishal/nce/BoltzNCE/BoltzNCE/generated/ebm_vector_MCMC0_0_numpy_dict.npz
    ecnf_biased_generated: /net/galaxy/home/koes/rishal/nce/BoltzNCE/BoltzNCE/generated/ecnf_biased_generated.npz
    endpoint_ot_ema_0_: /net/galaxy/home/koes/rishal/nce/BoltzNCE/BoltzNCE/generated/endpoint_ot_ema_0_numpy_dict.npz
    vector_ot_ema_5_: /net/galaxy/home/koes/rishal/nce/BoltzNCE/BoltzNCE/generated/endpoint_ot_0_numpy_dict.npz
    vector_ot_0_: /net/galaxy/home/koes/rishal/nce/BoltzNCE/BoltzNCE/generated/vector_ot_0_numpy_dict.npz
    vector_ot_ema_0_: /net/galaxy/home/koes/rishal/nce/BoltzNCE/BoltzNCE/generated/vector_ot_ema_0_numpy_dict.npz
  refere

## W2 Distance Functions

These functions are adapted from `infer_ad2.py` for calculating Wasserstein-2 distances.

In [3]:
def calc_energy_w2(gen_energies, ref_energies):
    """Calculate energy W2 distance between generated and reference samples.
    
    Args:
        gen_energies: Generated sample energies (numpy array)
        ref_energies: Reference sample energies (numpy array)
        
    Returns:
        float: W2 distance between energy distributions
    """
    # Flatten both to 1-D
    gen_energies = gen_energies.ravel()
    ref_energies = ref_energies.ravel()
    
    # Convert to numpy if torch tensors
    if hasattr(ref_energies, 'numpy'):
        ref_energies = ref_energies.numpy(force=True).ravel()
        
    # Calculate 1D Earth Mover's Distance (EMD) squared
    loss = ot.emd2_1d(gen_energies, ref_energies, metric="sqeuclidean")
    
    # Take sqrt to get W2
    w2_distance = np.sqrt(loss)
    
    print(f"Energy W2 distance: {w2_distance:.6f}")
    return w2_distance


def calc_torsion_w2(gen_angles, ref_angles):
    """Calculate torsion angle W2 distance between generated and reference samples.
    
    Args:
        gen_angles: Generated torsion angles (numpy array, shape: [n_samples, n_angles])
        ref_angles: Reference torsion angles (numpy array, shape: [n_samples, n_angles])
        
    Returns:
        float: W2 distance between torsion angle distributions
    """
    # Calculate distance matrix accounting for circular nature of angles
    dist = np.expand_dims(gen_angles, 0) - np.expand_dims(ref_angles, 1)
    dist = np.sum((dist % np.pi)**2, axis=-1)
    
    # Uniform weights for both distributions
    a, b = ot.unif(gen_angles.shape[0]), ot.unif(ref_angles.shape[0])
    
    # Calculate EMD with detailed log
    W, log = ot.emd2(a, b, dist, log=True, numItermax=int(1e9))
    
    # Take sqrt to get W2
    w2_distance = np.sqrt(W)
    
    print(f"Torsion W2 distance: {w2_distance:.6f}")
    return w2_distance


def reweight_samples_with_log_w(samples, energies, log_w, n_samples, label):
    """Reweight and resample data using log_w importance weights.
    
    Args:
        samples: Original samples (numpy array)
        energies: Original energies (numpy array) 
        log_w: Log importance weights (numpy array)
        n_samples: Number of samples to draw from reweighted distribution
        label: Label for logging
        
    Returns:
        tuple: (reweighted_samples, reweighted_energies) or (None, None) if failed
    """
    if samples is None or log_w is None:
        print(f"‚ö†Ô∏è  Cannot reweight {label} - missing samples or log_w")
        return None, None
        
    try:
        # Convert log_w to probabilities
        log_w_flat = log_w.ravel()
        
        # Handle numerical stability - subtract max before exp
        log_w_stable = log_w_flat - np.max(log_w_flat)
        weights = np.exp(log_w_stable)
        
        # Normalize to get probabilities
        probs = weights / np.sum(weights)
        
        # Check for valid probabilities
        if np.any(np.isnan(probs)) or np.any(np.isinf(probs)):
            print(f"‚ö†Ô∏è  Invalid probabilities for {label} - using uniform sampling")
            return None, None
            
        # Resample using importance weights
        n_available = len(samples)
        n_use = min(n_samples, n_available)
        
        indices = np.random.choice(n_available, n_use, replace=False, p=probs)
        
        reweighted_samples = samples[indices]
        reweighted_energies = energies[indices] if energies is not None else None
        
        print(f"‚úÖ {label} reweighted successfully:")
        print(f"   Original samples: {n_available}")
        print(f"   Reweighted samples: {n_use}")
        print(f"   Weight range: [{weights.min():.6f}, {weights.max():.6f}]")
        print(f"   Effective sample size: {1/np.sum(probs**2):.1f}")
        
        return reweighted_samples, reweighted_energies
        
    except Exception as e:
        print(f"‚ùå Error reweighting {label}: {e}")
        return None, None


def get_torsion_angles(samples_np):
    """Extract torsion angles from molecular samples.
    
    Args:
        samples: Molecular samples (numpy array, shape: [n_samples, 66])
        
    Returns:
        numpy array: Torsion angles (shape: [n_samples, 2]) for phi and psi angles
    """
    def determine_chirality_batch(cartesian_coords_batch):
        """Determine chirality for batch of coordinates."""
        coords_batch = np.array(cartesian_coords_batch)
        
        if coords_batch.shape[-2:] != (4, 3):
            raise ValueError("Input should be a batch of four 3D Cartesian coordinates")
        
        # Calculate vectors from chirality centers to connected atoms
        vectors_batch = coords_batch - coords_batch[:, 0:1, :]
        
        # Calculate normal vectors of planes
        normal_vectors_batch = np.cross(vectors_batch[:, 1, :], vectors_batch[:, 2, :])
        
        # Calculate dot products
        dot_products_batch = np.einsum('...i,...i->...', normal_vectors_batch, vectors_batch[:, 3, :])
        
        # Determine chirality labels
        chirality_labels_batch = np.where(dot_products_batch > 0.000, 'L', 'D')
        
        return chirality_labels_batch
    
    # Get atom types and identify carbon positions
    atom_types = get_alanine_atom_types()
    atom_types[[4, 6, 8, 14, 16]] = np.arange(4, 9)
    carbon_pos = np.where(atom_types == 1)[0]
    
    # Reshape samples to [n_samples, 22, 3]
    carbon_samples_np = samples_np.reshape(-1, 22, 3)[:, carbon_pos]
    carbon_distances = np.linalg.norm(samples_np.reshape(-1, 22, 3)[:, [8]] - carbon_samples_np, axis=-1)
    
    # Find C-beta atom index
    cb_idx = np.where(carbon_distances == carbon_distances.min(1, keepdims=True))
    
    # Get backbone and C-beta samples
    back_bone_samples = samples_np.reshape(-1, 22, 3)[:, np.array([8, 6, 14])]
    cb_samples = samples_np.reshape(-1, 22, 3)[cb_idx[0], carbon_pos[cb_idx[1]]][:, None, :]
    
    # Determine chirality and apply mapping
    chirality = determine_chirality_batch(np.concatenate([back_bone_samples, cb_samples], axis=1))
    samples_np_mapped = samples_np.copy()
    samples_np_mapped[chirality == "D"] *= -1
    
    # Create trajectory and compute dihedral angles
    dataset = get_alanine_implicit_dataset()
    traj_samples = md.Trajectory(samples_np_mapped.reshape(-1, 22, 3), topology=dataset.system.mdtraj_topology)
    
    # Define phi and psi dihedral indices
    phi_indices, psi_indices = [4, 6, 8, 14], [6, 8, 14, 16]
    angles = md.compute_dihedrals(traj_samples, [phi_indices, psi_indices])
    
    return angles

print("W2 distance and reweighting functions loaded successfully!")

W2 distance and reweighting functions loaded successfully!


## Load Data

In [4]:
# Set random seed for reproducibility
np.random.seed(CONFIG['random_seed'])
torch.manual_seed(CONFIG['random_seed'])

print("Loading generated data from multiple files...")

# Initialize storage for all generated data
all_gen_data = {}
gen_samples_list = []
gen_energies_list = []

# Load each generated data file
for i, (path, label) in enumerate(zip(CONFIG['generated_data_paths'], CONFIG['file_labels'])):
    print(f"\nLoading {label} from: {path}")
    try:
        gen_data = np.load(path)
        gen_samples = gen_data[CONFIG['samples_key']]
        gen_energies = gen_data[CONFIG['energies_key']]
        
        # Try to load log_w values if available
        log_w = None
        if 'log_w' in gen_data.files:
            log_w = gen_data['log_w']
            print(f"   log_w found with shape: {log_w.shape}")
        else:
            print(f"   ‚ö†Ô∏è  No log_w values found")
        
        # Store in dictionary for individual access
        all_gen_data[label] = {
            'samples': gen_samples,
            'energies': gen_energies,
            'log_w': log_w,
            'keys': list(gen_data.keys())
        }
        
        # Add to lists for combined analysis if needed
        gen_samples_list.append(gen_samples)
        gen_energies_list.append(gen_energies)
        
        print(f"‚úÖ {label} loaded successfully:")
        print(f"   Available keys: {list(gen_data.keys())}")
        print(f"   Samples shape: {gen_samples.shape}")
        print(f"   Energies shape: {gen_energies.shape}")
        print(f"   Energy range: [{gen_energies.min():.3f}, {gen_energies.max():.3f}]")
        if log_w is not None:
            print(f"   log_w shape: {log_w.shape}")
            print(f"   log_w range: [{log_w.min():.3f}, {log_w.max():.3f}]")
        
    except Exception as e:
        print(f"‚ùå Error loading {label}: {e}")
        all_gen_data[label] = None

print(f"\n‚úÖ Successfully loaded {len([x for x in all_gen_data.values() if x is not None])}/{len(CONFIG['generated_data_paths'])} files")

# Count files with log_w values
files_with_log_w = len([x for x in all_gen_data.values() if x is not None and x['log_w'] is not None])
print(f"üìä Files with log_w values: {files_with_log_w}/{len([x for x in all_gen_data.values() if x is not None])}")

# Load reference data (same as before)
print("\nLoading reference data...")
ref_samples = np.load("/net/galaxy/home/koes/rishal/nce/BoltzNCE/data/AD2_relaxed_holdout.npy")
ref_energies = np.load("/net/galaxy/home/koes/rishal/nce/BoltzNCE/BoltzNCE/generated/energies_data_holdout.npy")
energy_offset = 34600
ref_energies += energy_offset

print(f"‚úÖ Reference data loaded:")
print(f"   Samples shape: {ref_samples.shape}")
print(f"   Energies shape: {ref_energies.shape}")
print(f"   Energy range: [{ref_energies.min():.3f}, {ref_energies.max():.3f}]")

Loading generated data from multiple files...

Loading ebm_endpoint_MCMC0_0_ from: /net/galaxy/home/koes/rishal/nce/BoltzNCE/BoltzNCE/generated/ebm_endpoint_MCMC0_0_numpy_dict.npz
   log_w found with shape: (100000, 1)
‚úÖ ebm_endpoint_MCMC0_0_ loaded successfully:
   Available keys: ['samples', 'log_w', 'dlogf', 'energies']
   Samples shape: (100000, 66)
   Energies shape: (100000, 1)
   Energy range: [-112.575, 221.725]
   log_w shape: (100000, 1)
   log_w range: [-289.545, -1.701]

Loading Model_ebm_ot_vector_MCMC0_0_1 from: /net/galaxy/home/koes/rishal/nce/BoltzNCE/BoltzNCE/generated/ebm_ot_vector_MCMC0_0_numpy_dict.npz
   log_w found with shape: (100000, 1)
‚úÖ Model_ebm_ot_vector_MCMC0_0_1 loaded successfully:
   Available keys: ['samples', 'log_w', 'dlogf', 'energies']
   Samples shape: (100000, 66)
   Energies shape: (100000, 1)
   Energy range: [-109.144, 347.882]
   log_w shape: (100000, 1)
   log_w range: [-398.839, -4.295]

Loading ebm_vector_MCMC0_0_ from: /net/galaxy/home

In [5]:
## Sample Data for Analysis

def sample_data_for_analysis(samples, energies, n_samples, label):
    """Sample data for W2 analysis to manage computational cost."""
    if samples is None:
        return None, None
     
    print("energies", energies.shape if energies is not None else "None")    
    total_samples = len(samples)
    total_energies = len(energies) if energies is not None else 0
    n_use_energies = min(n_samples, total_energies) if energies is not None else 0
    n_use = min(n_samples, total_samples)
    
    print(f"\nüìä Sampling {label} data:")
    print(f"   Total available samples: {total_samples}")
    print(f"   Total available energies: {total_energies}")
    print(f"   Using for analysis: {n_use}")
    
    if n_use < total_samples:
        # Random sampling for samples
        indices = np.random.choice(total_samples, n_use, replace=False)
        sampled_samples = samples[indices]
        print(f"   Randomly sampled {n_use} from {total_samples} samples")
        
        # Random sampling for energies (handle different lengths)
        if energies is not None and total_energies > 0:
            if n_use_energies < total_energies:
                energies_indices = np.random.choice(total_energies, n_use_energies, replace=False)
                sampled_energies = energies[energies_indices]
                print(f"   Randomly sampled {n_use_energies} from {total_energies} energies")
            else:
                sampled_energies = energies
                print(f"   Using all {total_energies} energies")
        else:
            sampled_energies = None
            print("   No energies available")
    else:
        sampled_samples = samples
        sampled_energies = energies if energies is not None else None
        print(f"   Using all {n_use} samples")
        if energies is not None:
            print(f"   Using all {total_energies} energies")
        
    return sampled_samples, sampled_energies

In [6]:
# Sample reference data once for all comparisons
ref_samples_analysis, ref_energies_analysis = sample_data_for_analysis(
    ref_samples, ref_energies, CONFIG['n_samples_analysis'], "Reference"
)

# Sample data for each generated file multiple times (UNIFORM SAMPLING)
all_gen_data_analysis_runs = {}

print(f"\n{'='*70}")
print("UNIFORM SAMPLING DATA FOR W2 ANALYSIS - MULTIPLE RUNS")
print(f"{'='*70}")

for label in CONFIG['file_labels']:
    if all_gen_data[label] is not None:
        print(f"\nüìä Processing {label} - {CONFIG['n_runs']} runs...")
        
        gen_samples = all_gen_data[label]['samples']
        gen_energies = all_gen_data[label]['energies']
        
        runs_data = []
        for run_idx in range(CONFIG['n_runs']):
            print(f"   üîÑ Run {run_idx + 1}/{CONFIG['n_runs']}")
            
            # Set different seed for each run to sample different subsets
            np.random.seed(CONFIG['random_seed'] + run_idx * 1000)  # Use larger offset to ensure different samples
            
            sampled_samples, sampled_energies = sample_data_for_analysis(
                gen_samples, gen_energies, CONFIG['n_samples_analysis'], f"{label}_run_{run_idx+1}"
            )
            
            runs_data.append({
                'samples': sampled_samples,
                'energies': sampled_energies,
                'run_id': run_idx
            })
        
        all_gen_data_analysis_runs[label] = runs_data
        print(f"   ‚úÖ Completed {CONFIG['n_runs']} uniform sampling runs for {label}")
    else:
        print(f"\n‚ö†Ô∏è  Skipping {label} - data not loaded")
        all_gen_data_analysis_runs[label] = None

print(f"\n‚úÖ Uniform data sampling completed for all files!")

# REWEIGHTED SAMPLING using log_w values - MULTIPLE RUNS
all_gen_data_reweighted_runs = {}

print(f"\n{'='*70}")
print("REWEIGHTED SAMPLING DATA FOR W2 ANALYSIS - MULTIPLE RUNS") 
print(f"{'='*70}")

for label in CONFIG['file_labels']:
    if all_gen_data[label] is not None:
        print(f"\nüéØ Reweighting {label} - {CONFIG['n_runs']} runs...")
        
        gen_samples = all_gen_data[label]['samples']
        gen_energies = all_gen_data[label]['energies']  
        log_w = all_gen_data[label]['log_w']
        
        if log_w is not None:
            runs_data = []
            for run_idx in range(CONFIG['n_runs']):
                print(f"   üîÑ Run {run_idx + 1}/{CONFIG['n_runs']}")
                
                # Set different seed for each run to sample different weighted subsets
                np.random.seed(CONFIG['random_seed'] + run_idx * 1000)  # Use larger offset to ensure different samples
                
                reweighted_samples, reweighted_energies = reweight_samples_with_log_w(
                    gen_samples, gen_energies, log_w, CONFIG['n_samples_analysis'], f"{label}_run_{run_idx+1}"
                )
                
                runs_data.append({
                    'samples': reweighted_samples,
                    'energies': reweighted_energies,
                    'run_id': run_idx
                })
            
            all_gen_data_reweighted_runs[label] = runs_data
            print(f"   ‚úÖ Completed {CONFIG['n_runs']} reweighted sampling runs for {label}")
        else:
            print(f"‚ö†Ô∏è  No log_w values for {label} - skipping reweighting")
            all_gen_data_reweighted_runs[label] = None
    else:
        print(f"\n‚ö†Ô∏è  Skipping {label} - data not loaded")
        all_gen_data_reweighted_runs[label] = None

reweighted_files_count = len([x for x in all_gen_data_reweighted_runs.values() if x is not None])
print(f"\n‚úÖ Reweighted data sampling completed!")
print(f"üìä Successfully prepared {reweighted_files_count} files for reweighted analysis")

# Reset random seed for consistent analysis
np.random.seed(CONFIG['random_seed'])
print(f"üé≤ Note: Each run samples different subsets from the same file populations to capture sampling variance")

energies (11112, 1)

üìä Sampling Reference data:
   Total available samples: 100000
   Total available energies: 11112
   Using for analysis: 10000
   Randomly sampled 10000 from 100000 samples
   Randomly sampled 10000 from 11112 energies

UNIFORM SAMPLING DATA FOR W2 ANALYSIS - MULTIPLE RUNS

üìä Processing ebm_endpoint_MCMC0_0_ - 5 runs...
   üîÑ Run 1/5
energies (100000, 1)

üìä Sampling ebm_endpoint_MCMC0_0__run_1 data:
   Total available samples: 100000
   Total available energies: 100000
   Using for analysis: 10000
   Randomly sampled 10000 from 100000 samples
   Randomly sampled 10000 from 100000 energies
   üîÑ Run 2/5
energies (100000, 1)

üìä Sampling ebm_endpoint_MCMC0_0__run_2 data:
   Total available samples: 100000
   Total available energies: 100000
   Using for analysis: 10000
   Randomly sampled 10000 from 100000 samples
   Randomly sampled 10000 from 100000 energies
   üîÑ Run 3/5
energies (100000, 1)

üìä Sampling ebm_endpoint_MCMC0_0__run_3 data:
   Total

   Randomly sampled 10000 from 100000 samples
   Randomly sampled 10000 from 100000 energies
   üîÑ Run 4/5
energies (100000, 1)

üìä Sampling ebm_endpoint_MCMC0_0__run_4 data:
   Total available samples: 100000
   Total available energies: 100000
   Using for analysis: 10000
   Randomly sampled 10000 from 100000 samples
   Randomly sampled 10000 from 100000 energies
   üîÑ Run 5/5
energies (100000, 1)

üìä Sampling ebm_endpoint_MCMC0_0__run_5 data:
   Total available samples: 100000
   Total available energies: 100000
   Using for analysis: 10000
   Randomly sampled 10000 from 100000 samples
   Randomly sampled 10000 from 100000 energies
   ‚úÖ Completed 5 uniform sampling runs for ebm_endpoint_MCMC0_0_

üìä Processing Model_ebm_ot_vector_MCMC0_0_1 - 5 runs...
   üîÑ Run 1/5
energies (100000, 1)

üìä Sampling Model_ebm_ot_vector_MCMC0_0_1_run_1 data:
   Total available samples: 100000
   Total available energies: 100000
   Using for analysis: 10000
   Randomly sampled 10000 fro

## W2 Distance Calculations

In [7]:
# Initialize results dictionary for all files (UNIFORM ANALYSIS - MULTIPLE RUNS)
all_results_uniform_runs = {}

print("\n" + "="*80)
print("                W2 DISTANCE ANALYSIS - UNIFORM SAMPLING (MULTIPLE RUNS)")
print("="*80)

# Process each generated file with multiple runs
for label in CONFIG['file_labels']:
    if all_gen_data_analysis_runs[label] is not None:
        print(f"\nüîç ANALYZING {label.upper()} - {CONFIG['n_runs']} RUNS")
        print("="*60)
        
        runs_results = []
        
        for run_idx, run_data in enumerate(all_gen_data_analysis_runs[label]):
            print(f"\n   üìä Run {run_idx + 1}/{CONFIG['n_runs']}")
            
            gen_samples_analysis = run_data['samples']
            gen_energies_analysis = run_data['energies']
            
            run_result = {
                'run_id': run_idx,
                'energy_w2': None,
                'torsion_w2': None,
                'sample_count': len(gen_samples_analysis) if gen_samples_analysis is not None else 0
            }
            
            if gen_samples_analysis is not None and ref_samples_analysis is not None:
                
                # 1. Energy W2 Distance
                if gen_energies_analysis is not None and ref_energies_analysis is not None:
                    try:
                        energy_w2 = calc_energy_w2(gen_energies_analysis, ref_energies_analysis)
                        run_result['energy_w2'] = float(energy_w2)
                        print(f"      üîã Energy W2: {energy_w2:.6f}")
                    except Exception as e:
                        print(f"      ‚ùå Error calculating energy W2: {e}")
                
                # 2. Torsion W2 Distance  
                try:
                    gen_angles = get_torsion_angles(gen_samples_analysis)
                    
                    # Compute reference angles only once (if not already computed)
                    if 'ref_angles' not in locals():
                        print("      Computing torsion angles for reference samples...")
                        ref_angles = get_torsion_angles(ref_samples_analysis)
                        print(f"      Reference angles shape: {ref_angles.shape}")
                    
                    torsion_w2 = calc_torsion_w2(gen_angles, ref_angles)
                    run_result['torsion_w2'] = float(torsion_w2)
                    print(f"      üîÑ Torsion W2: {torsion_w2:.6f}")
                    
                except Exception as e:
                    print(f"      ‚ùå Error calculating torsion W2: {e}")
            
            runs_results.append(run_result)
        
        # Calculate statistics across runs
        energy_w2_values = [r['energy_w2'] for r in runs_results if r['energy_w2'] is not None]
        torsion_w2_values = [r['torsion_w2'] for r in runs_results if r['torsion_w2'] is not None]
        
        summary_result = {
            'config': CONFIG.copy(),
            'file_label': label,
            'analysis_type': 'uniform_multiple_runs',
            'n_runs': CONFIG['n_runs'],
            'individual_runs': runs_results,
            'statistics': {
                'energy_w2': {
                    'mean': float(np.mean(energy_w2_values)) if energy_w2_values else None,
                    'std': float(np.std(energy_w2_values)) if len(energy_w2_values) > 1 else 0.0,
                    'values': energy_w2_values
                },
                'torsion_w2': {
                    'mean': float(np.mean(torsion_w2_values)) if torsion_w2_values else None,
                    'std': float(np.std(torsion_w2_values)) if len(torsion_w2_values) > 1 else 0.0,
                    'values': torsion_w2_values
                }
            }
        }
        
        if energy_w2_values and torsion_w2_values:
            print(f"\n   üìà {label} Statistics:")
            print(f"      Energy W2:  {summary_result['statistics']['energy_w2']['mean']:.6f} ¬± {summary_result['statistics']['energy_w2']['std']:.6f}")
            print(f"      Torsion W2: {summary_result['statistics']['torsion_w2']['mean']:.6f} ¬± {summary_result['statistics']['torsion_w2']['std']:.6f}")
        
        all_results_uniform_runs[label] = summary_result
        
    else:
        print(f"\n‚ö†Ô∏è  Skipping {label} - data not available")
        all_results_uniform_runs[label] = None

print("\n" + "="*80)
print("UNIFORM W2 ANALYSIS COMPLETED FOR ALL FILES")
print("="*80)

# Initialize results dictionary for reweighted analysis (MULTIPLE RUNS)
all_results_reweighted_runs = {}

print("\n" + "="*80)
print("                W2 DISTANCE ANALYSIS - REWEIGHTED SAMPLING (MULTIPLE RUNS)")
print("="*80)

# Process each reweighted generated file with multiple runs
for label in CONFIG['file_labels']:
    if all_gen_data_reweighted_runs[label] is not None:
        print(f"\nüéØ ANALYZING REWEIGHTED {label.upper()} - {CONFIG['n_runs']} RUNS")
        print("="*60)
        
        runs_results = []
        
        for run_idx, run_data in enumerate(all_gen_data_reweighted_runs[label]):
            print(f"\n   üìä Run {run_idx + 1}/{CONFIG['n_runs']}")
            
            gen_samples_reweighted = run_data['samples']
            gen_energies_reweighted = run_data['energies']
            
            run_result = {
                'run_id': run_idx,
                'energy_w2': None,
                'torsion_w2': None,
                'sample_count': len(gen_samples_reweighted) if gen_samples_reweighted is not None else 0
            }
            
            if gen_samples_reweighted is not None and ref_samples_analysis is not None:
                
                # 1. Energy W2 Distance (Reweighted)
                if gen_energies_reweighted is not None and ref_energies_analysis is not None:
                    try:
                        energy_w2_reweighted = calc_energy_w2(gen_energies_reweighted, ref_energies_analysis)
                        run_result['energy_w2'] = float(energy_w2_reweighted)
                        print(f"      üîã Reweighted Energy W2: {energy_w2_reweighted:.6f}")
                    except Exception as e:
                        print(f"      ‚ùå Error calculating reweighted energy W2: {e}")
                
                # 2. Torsion W2 Distance (Reweighted)
                try:
                    gen_angles_reweighted = get_torsion_angles(gen_samples_reweighted)
                    torsion_w2_reweighted = calc_torsion_w2(gen_angles_reweighted, ref_angles)
                    run_result['torsion_w2'] = float(torsion_w2_reweighted)
                    print(f"      üîÑ Reweighted Torsion W2: {torsion_w2_reweighted:.6f}")
                    
                except Exception as e:
                    print(f"      ‚ùå Error calculating reweighted torsion W2: {e}")
            
            runs_results.append(run_result)
        
        # Calculate statistics across runs
        energy_w2_values = [r['energy_w2'] for r in runs_results if r['energy_w2'] is not None]
        torsion_w2_values = [r['torsion_w2'] for r in runs_results if r['torsion_w2'] is not None]
        
        summary_result = {
            'config': CONFIG.copy(),
            'file_label': label,
            'analysis_type': 'reweighted_multiple_runs',
            'n_runs': CONFIG['n_runs'],
            'individual_runs': runs_results,
            'statistics': {
                'energy_w2': {
                    'mean': float(np.mean(energy_w2_values)) if energy_w2_values else None,
                    'std': float(np.std(energy_w2_values)) if len(energy_w2_values) > 1 else 0.0,
                    'values': energy_w2_values
                },
                'torsion_w2': {
                    'mean': float(np.mean(torsion_w2_values)) if torsion_w2_values else None,
                    'std': float(np.std(torsion_w2_values)) if len(torsion_w2_values) > 1 else 0.0,
                    'values': torsion_w2_values
                }
            }
        }
        
        if energy_w2_values and torsion_w2_values:
            print(f"\n   üìà {label} Reweighted Statistics:")
            print(f"      Energy W2:  {summary_result['statistics']['energy_w2']['mean']:.6f} ¬± {summary_result['statistics']['energy_w2']['std']:.6f}")
            print(f"      Torsion W2: {summary_result['statistics']['torsion_w2']['mean']:.6f} ¬± {summary_result['statistics']['torsion_w2']['std']:.6f}")
        
        all_results_reweighted_runs[label] = summary_result
        
    else:
        print(f"\n‚ö†Ô∏è  Skipping {label} - reweighted data not available")
        all_results_reweighted_runs[label] = None

print("\n" + "="*80)
print("REWEIGHTED W2 ANALYSIS COMPLETED FOR ALL FILES")
print("="*80)


                W2 DISTANCE ANALYSIS - UNIFORM SAMPLING (MULTIPLE RUNS)

üîç ANALYZING EBM_ENDPOINT_MCMC0_0_ - 5 RUNS

   üìä Run 1/5
Energy W2 distance: 6.091695
      üîã Energy W2: 6.091695
Using downloaded and verified file: /tmp/A.pdb
Using downloaded and verified file: /tmp/A.pdb
      Computing torsion angles for reference samples...
Using downloaded and verified file: /tmp/A.pdb
Using downloaded and verified file: /tmp/A.pdb
      Reference angles shape: (10000, 2)
Torsion W2 distance: 1.112463
      üîÑ Torsion W2: 1.112463

   üìä Run 2/5
Energy W2 distance: 5.683302
      üîã Energy W2: 5.683302
Using downloaded and verified file: /tmp/A.pdb
Using downloaded and verified file: /tmp/A.pdb
Torsion W2 distance: 1.128744
      üîÑ Torsion W2: 1.128744

   üìä Run 3/5
Energy W2 distance: 6.936323
      üîã Energy W2: 6.936323
Using downloaded and verified file: /tmp/A.pdb
Using downloaded and verified file: /tmp/A.pdb
Torsion W2 distance: 1.121568
      üîÑ Torsion W2:

## Results Summary

In [8]:
print("\nüìä W2 DISTANCE RESULTS SUMMARY - UNIFORM vs REWEIGHTED (MULTIPLE RUNS):")
print("=" * 100)

# Create a comprehensive comparison table with statistics
print(f"\n{'File':<20} {'Uniform Energy W2 (Mean¬±Std)':<30} {'Reweighted Energy W2 (Mean¬±Std)':<32} {'Uniform Torsion W2 (Mean¬±Std)':<31} {'Reweighted Torsion W2 (Mean¬±Std)':<33} {'Status'}")
print("-" * 150)

valid_uniform_results = []
valid_reweighted_results = []
comparison_results = []

for label in CONFIG['file_labels']:
    # Get uniform results
    uniform_energy_mean = uniform_energy_std = uniform_torsion_mean = uniform_torsion_std = None
    if all_results_uniform_runs[label] is not None:
        uniform_energy_mean = all_results_uniform_runs[label]['statistics']['energy_w2']['mean']
        uniform_energy_std = all_results_uniform_runs[label]['statistics']['energy_w2']['std']
        uniform_torsion_mean = all_results_uniform_runs[label]['statistics']['torsion_w2']['mean']
        uniform_torsion_std = all_results_uniform_runs[label]['statistics']['torsion_w2']['std']
        
        if uniform_energy_mean is not None and uniform_torsion_mean is not None:
            valid_uniform_results.append((label, uniform_energy_mean, uniform_energy_std, uniform_torsion_mean, uniform_torsion_std))
    
    # Get reweighted results
    reweighted_energy_mean = reweighted_energy_std = reweighted_torsion_mean = reweighted_torsion_std = None
    if all_results_reweighted_runs[label] is not None:
        reweighted_energy_mean = all_results_reweighted_runs[label]['statistics']['energy_w2']['mean']
        reweighted_energy_std = all_results_reweighted_runs[label]['statistics']['energy_w2']['std']
        reweighted_torsion_mean = all_results_reweighted_runs[label]['statistics']['torsion_w2']['mean']
        reweighted_torsion_std = all_results_reweighted_runs[label]['statistics']['torsion_w2']['std']
        
        if reweighted_energy_mean is not None and reweighted_torsion_mean is not None:
            valid_reweighted_results.append((label, reweighted_energy_mean, reweighted_energy_std, reweighted_torsion_mean, reweighted_torsion_std))
    
    # Format values for display
    uniform_energy_str = f"{uniform_energy_mean:.4f}¬±{uniform_energy_std:.4f}" if uniform_energy_mean is not None else "N/A"
    reweighted_energy_str = f"{reweighted_energy_mean:.4f}¬±{reweighted_energy_std:.4f}" if reweighted_energy_mean is not None else "N/A"
    uniform_torsion_str = f"{uniform_torsion_mean:.4f}¬±{uniform_torsion_std:.4f}" if uniform_torsion_mean is not None else "N/A"
    reweighted_torsion_str = f"{reweighted_torsion_mean:.4f}¬±{reweighted_torsion_std:.4f}" if reweighted_torsion_mean is not None else "N/A"
    
    # Determine status
    has_uniform = uniform_energy_mean is not None and uniform_torsion_mean is not None
    has_reweighted = reweighted_energy_mean is not None and reweighted_torsion_mean is not None
    
    if has_uniform and has_reweighted:
        status = "‚úÖ Both"
        comparison_results.append({
            'label': label,
            'uniform_energy_mean': uniform_energy_mean,
            'uniform_energy_std': uniform_energy_std,
            'reweighted_energy_mean': reweighted_energy_mean,
            'reweighted_energy_std': reweighted_energy_std,
            'uniform_torsion_mean': uniform_torsion_mean,
            'uniform_torsion_std': uniform_torsion_std,
            'reweighted_torsion_mean': reweighted_torsion_mean,
            'reweighted_torsion_std': reweighted_torsion_std
        })
    elif has_uniform:
        status = "üîµ Uniform only"
    elif has_reweighted:
        status = "üü° Reweighted only"
    else:
        status = "‚ùå Failed"
    
    print(f"{label:<20} {uniform_energy_str:<30} {reweighted_energy_str:<32} {uniform_torsion_str:<31} {reweighted_torsion_str:<33} {status}")

print(f"\nüìã Analysis Configuration:")
print(f"   Reference samples: {len(ref_samples_analysis) if ref_samples_analysis is not None else 'N/A'}")
print(f"   Generated samples per file per run: {CONFIG['n_samples_analysis']}")
print(f"   Number of runs per file: {CONFIG['n_runs']}")
print(f"   Random seed base: {CONFIG['random_seed']} (runs use seeds: {CONFIG['random_seed']}, {CONFIG['random_seed']+1000}, {CONFIG['random_seed']+2000})")

# Statistics and comparisons for uniform results
if valid_uniform_results:
    print(f"\nüî¢ UNIFORM SAMPLING STATISTICS (ACROSS {CONFIG['n_runs']} RUNS):")
    uniform_energy_means = [x[1] for x in valid_uniform_results]
    uniform_energy_stds = [x[2] for x in valid_uniform_results]
    uniform_torsion_means = [x[3] for x in valid_uniform_results]
    uniform_torsion_stds = [x[4] for x in valid_uniform_results]
    
    print(f"   Energy W2 - Mean across models: {np.mean(uniform_energy_means):.6f}")
    print(f"   Energy W2 - Average std within models: {np.mean(uniform_energy_stds):.6f}")
    print(f"   Torsion W2 - Mean across models: {np.mean(uniform_torsion_means):.6f}")
    print(f"   Torsion W2 - Average std within models: {np.mean(uniform_torsion_stds):.6f}")
    
    best_uniform_energy_idx = np.argmin(uniform_energy_means)
    best_uniform_torsion_idx = np.argmin(uniform_torsion_means)
    print(f"   üèÜ Best Energy W2: {valid_uniform_results[best_uniform_energy_idx][0]} ({uniform_energy_means[best_uniform_energy_idx]:.6f}¬±{uniform_energy_stds[best_uniform_energy_idx]:.6f})")
    print(f"   üèÜ Best Torsion W2: {valid_uniform_results[best_uniform_torsion_idx][0]} ({uniform_torsion_means[best_uniform_torsion_idx]:.6f}¬±{uniform_torsion_stds[best_uniform_torsion_idx]:.6f})")

# Statistics and comparisons for reweighted results
if valid_reweighted_results:
    print(f"\nüéØ REWEIGHTED SAMPLING STATISTICS (ACROSS {CONFIG['n_runs']} RUNS):")
    reweighted_energy_means = [x[1] for x in valid_reweighted_results]
    reweighted_energy_stds = [x[2] for x in valid_reweighted_results]
    reweighted_torsion_means = [x[3] for x in valid_reweighted_results]
    reweighted_torsion_stds = [x[4] for x in valid_reweighted_results]
    
    print(f"   Energy W2 - Mean across models: {np.mean(reweighted_energy_means):.6f}")
    print(f"   Energy W2 - Average std within models: {np.mean(reweighted_energy_stds):.6f}")
    print(f"   Torsion W2 - Mean across models: {np.mean(reweighted_torsion_means):.6f}")
    print(f"   Torsion W2 - Average std within models: {np.mean(reweighted_torsion_stds):.6f}")
    
    best_reweighted_energy_idx = np.argmin(reweighted_energy_means)
    best_reweighted_torsion_idx = np.argmin(reweighted_torsion_means)
    print(f"   üèÜ Best Energy W2: {valid_reweighted_results[best_reweighted_energy_idx][0]} ({reweighted_energy_means[best_reweighted_energy_idx]:.6f}¬±{reweighted_energy_stds[best_reweighted_energy_idx]:.6f})")
    print(f"   üèÜ Best Torsion W2: {valid_reweighted_results[best_reweighted_torsion_idx][0]} ({reweighted_torsion_means[best_reweighted_torsion_idx]:.6f}¬±{reweighted_torsion_stds[best_reweighted_torsion_idx]:.6f})")

# Direct comparison for models with both uniform and reweighted results
if comparison_results:
    print(f"\n‚öñÔ∏è  UNIFORM vs REWEIGHTED COMPARISON:")
    print(f"   Models with both analyses: {len(comparison_results)}")
    
    energy_improvements = []
    torsion_improvements = []
    
    print(f"\n   {'Model':<20} {'Energy Œî (Mean)':<15} {'Torsion Œî (Mean)':<16} {'Energy Status':<15} {'Torsion Status'}")
    print("   " + "-" * 85)
    
    for comp in comparison_results:
        energy_delta = comp['uniform_energy_mean'] - comp['reweighted_energy_mean']  # Positive = improvement
        torsion_delta = comp['uniform_torsion_mean'] - comp['reweighted_torsion_mean']  # Positive = improvement
        
        energy_improvements.append(energy_delta)
        torsion_improvements.append(torsion_delta)
        
        energy_status = "üü¢ Better" if energy_delta > 0 else "üî¥ Worse" if energy_delta < 0 else "üü° Same"
        torsion_status = "üü¢ Better" if torsion_delta > 0 else "üî¥ Worse" if torsion_delta < 0 else "üü° Same"
        
        print(f"   {comp['label']:<20} {energy_delta:+.6f}       {torsion_delta:+.6f}        {energy_status:<15} {torsion_status}")
    
    print(f"\n   üìà Overall Reweighting Impact:")
    avg_energy_improvement = np.mean(energy_improvements)
    avg_torsion_improvement = np.mean(torsion_improvements)
    
    print(f"      Energy W2 - Average change: {avg_energy_improvement:+.6f} ({'improvement' if avg_energy_improvement > 0 else 'degradation'})")
    print(f"      Torsion W2 - Average change: {avg_torsion_improvement:+.6f} ({'improvement' if avg_torsion_improvement > 0 else 'degradation'})")
    
    energy_better_count = sum(1 for x in energy_improvements if x > 0)
    torsion_better_count = sum(1 for x in torsion_improvements if x > 0)
    
    print(f"      Models improved by reweighting - Energy: {energy_better_count}/{len(comparison_results)}, Torsion: {torsion_better_count}/{len(comparison_results)}")
    
    # Statistical significance of improvements
    if len(comparison_results) > 1:
        print(f"\n   üìä Statistical Analysis:")
        energy_improvement_std = np.std(energy_improvements)
        torsion_improvement_std = np.std(torsion_improvements)
        print(f"      Energy improvement std: {energy_improvement_std:.6f}")
        print(f"      Torsion improvement std: {torsion_improvement_std:.6f}")

print(f"\nüí° Interpretation:")
print(f"   üìà Lower W2 distances indicate better similarity to reference")
print(f"   üîã Energy W2 measures thermodynamic similarity")
print(f"   üîÑ Torsion W2 measures conformational similarity")
print(f"   üéØ Reweighting uses log_w importance weights to emphasize higher-probability samples")
print(f"   ‚öñÔ∏è  Positive Œî values indicate reweighting improved the W2 distance")
print(f"   üìä Standard deviations show sampling variance across {CONFIG['n_runs']} runs")
print(f"   üé≤ Each run samples different subsets from the same file populations (not just different seeds)")
print(f"   üî¨ Small std values indicate robust, repeatable results independent of sample selection")


üìä W2 DISTANCE RESULTS SUMMARY - UNIFORM vs REWEIGHTED (MULTIPLE RUNS):

File                 Uniform Energy W2 (Mean¬±Std)   Reweighted Energy W2 (Mean¬±Std)  Uniform Torsion W2 (Mean¬±Std)   Reweighted Torsion W2 (Mean¬±Std)  Status
------------------------------------------------------------------------------------------------------------------------------------------------------
ebm_endpoint_MCMC0_0_ 6.2825¬±0.4228                  2.7786¬±0.0283                    1.1210¬±0.0067                   0.5865¬±0.0059                     ‚úÖ Both
Model_ebm_ot_vector_MCMC0_0_1 7.3496¬±0.1936                  0.2249¬±0.0306                    1.1122¬±0.0152                   0.5562¬±0.0049                     ‚úÖ Both
ebm_vector_MCMC0_0_  7.2504¬±0.2034                  0.2674¬±0.0149                    1.1177¬±0.0115                   0.5631¬±0.0115                     ‚úÖ Both
ecnf_biased_generated 8.3433¬±0.5447                  0.3533¬±0.0263                    1.0963¬±0.0062       

## Optional: Create Comparison Plots

In [9]:
if CONFIG['create_plots']:
    
    print("\nüìà Creating comparison plots for all files...")
    
    # Count valid results for plot arrangement
    valid_files = [label for label in CONFIG['file_labels'] if all_results[label] is not None and all_results[label]['energy_w2'] is not None]
    n_files = len(valid_files)
    
    if n_files > 0:
        # Create figure with subplots for energy and torsion comparisons
        fig, axes = plt.subplots(2, 3, figsize=(20, 12))
        fig.suptitle('W2 Distance Analysis - Multiple Files Comparison', fontsize=16, fontweight='bold')
        
        # Define colors for each file
        colors = ['blue', 'red', 'green', 'orange', 'purple', 'brown']
        
        # Plot 1: Energy distributions comparison
        ax_energy = axes[0, 0]
        if ref_energies_analysis is not None:
            ax_energy.hist(ref_energies_analysis.flatten(), bins=50, alpha=0.7, 
                          label='Reference', density=True, color='black', linewidth=2, histtype='step')
            
            for i, label in enumerate(valid_files):
                if all_gen_data_analysis[label]['energies'] is not None:
                    energies = all_gen_data_analysis[label]['energies']
                    energy_w2 = all_results[label]['energy_w2']
                    ax_energy.hist(energies.flatten(), bins=50, alpha=0.6, 
                                  label=f'{label} (W2={energy_w2:.3f})', 
                                  density=True, color=colors[i % len(colors)])
        
        ax_energy.set_xlabel('Energy')
        ax_energy.set_ylabel('Density')
        ax_energy.set_title('Energy Distributions Comparison')
        ax_energy.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        ax_energy.grid(True, alpha=0.3)
        
        # Plot 2: Torsion angles comparison (Ramachandran plots)
        ax_torsion = axes[0, 1]
        if 'reference_angles' in all_results:
            ref_angles = all_results['reference_angles']
            ax_torsion.scatter(ref_angles[:, 0], ref_angles[:, 1], 
                              alpha=0.3, s=0.5, label='Reference', color='black')
            
            for i, label in enumerate(valid_files):
                if all_results[label]['gen_angles'] is not None:
                    gen_angles = all_results[label]['gen_angles']
                    torsion_w2 = all_results[label]['torsion_w2']
                    ax_torsion.scatter(gen_angles[:, 0], gen_angles[:, 1], 
                                      alpha=0.4, s=0.5, 
                                      label=f'{label} (W2={torsion_w2:.3f})', 
                                      color=colors[i % len(colors)])
        
        ax_torsion.set_xlabel('œÜ (phi) angle')
        ax_torsion.set_ylabel('œà (psi) angle')
        ax_torsion.set_title('Ramachandran Plot Comparison')
        ax_torsion.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        ax_torsion.grid(True, alpha=0.3)
        ax_torsion.set_xlim([-np.pi, np.pi])
        ax_torsion.set_ylim([-np.pi, np.pi])
        
        # Plot 3: W2 Distance Comparison Bar Chart
        ax_bar = axes[0, 2]
        if valid_results:
            file_names = [x[0] for x in valid_results]
            energy_w2_vals = [x[1] for x in valid_results]
            torsion_w2_vals = [x[2] for x in valid_results]
            
            x = np.arange(len(file_names))
            width = 0.35
            
            bars1 = ax_bar.bar(x - width/2, energy_w2_vals, width, label='Energy W2', alpha=0.8)
            bars2 = ax_bar.bar(x + width/2, torsion_w2_vals, width, label='Torsion W2', alpha=0.8)
            
            ax_bar.set_xlabel('Files')
            ax_bar.set_ylabel('W2 Distance')
            ax_bar.set_title('W2 Distances Comparison')
            ax_bar.set_xticks(x)
            ax_bar.set_xticklabels(file_names, rotation=45)
            ax_bar.legend()
            ax_bar.grid(True, alpha=0.3)
            
            # Add value labels on bars
            for bar in bars1:
                height = bar.get_height()
                ax_bar.annotate(f'{height:.3f}', xy=(bar.get_x() + bar.get_width()/2, height),
                               xytext=(0, 3), textcoords="offset points", ha='center', va='bottom', fontsize=8)
            for bar in bars2:
                height = bar.get_height()
                ax_bar.annotate(f'{height:.3f}', xy=(bar.get_x() + bar.get_width()/2, height),
                               xytext=(0, 3), textcoords="offset points", ha='center', va='bottom', fontsize=8)
        
        # Individual Ramachandran plots for first 3 models
        for i, label in enumerate(valid_files[:3]):
            ax = axes[1, i]
            
            if 'reference_angles' in all_results and all_results[label]['gen_angles'] is not None:
                ref_angles = all_results['reference_angles']
                gen_angles = all_results[label]['gen_angles']
                torsion_w2 = all_results[label]['torsion_w2']
                
                ax.scatter(ref_angles[:, 0], ref_angles[:, 1], 
                          alpha=0.5, s=1, label='Reference', color='red')
                ax.scatter(gen_angles[:, 0], gen_angles[:, 1], 
                          alpha=0.5, s=1, label=label, color=colors[i % len(colors)])
                
                ax.set_xlabel('œÜ (phi) angle')
                ax.set_ylabel('œà (psi) angle')
                ax.set_title(f'{label} vs Reference\\n(Torsion W2 = {torsion_w2:.4f})')
                ax.legend()
                ax.grid(True, alpha=0.3)
                ax.set_xlim([-np.pi, np.pi])
                ax.set_ylim([-np.pi, np.pi])
            else:
                ax.text(0.5, 0.5, f'{label}\\nData not available', 
                       ha='center', va='center', transform=ax.transAxes, fontsize=12)
                ax.set_title(f'{label}')
        
        # Hide unused subplots
        for i in range(len(valid_files[:3]), 3):
            axes[1, i].set_visible(False)
        
        plt.tight_layout()
        
        # Save plot if requested
        if CONFIG['save_results']:
            output_dir = Path(CONFIG['output_dir'])
            output_dir.mkdir(parents=True, exist_ok=True)
            plot_path = output_dir / 'w2_comparison_plots_all_files.png'
            plt.savefig(plot_path, dpi=300, bbox_inches='tight')
            print(f"üìä Plot saved to: {plot_path}")
        
        plt.show()
        
    else:
        print("‚ö†Ô∏è  No valid data available for plotting")
        
else:
    print("‚ö†Ô∏è  Plotting disabled in configuration")


üìà Creating comparison plots for all files...


NameError: name 'all_results' is not defined

## Save Results

In [None]:
if CONFIG['save_results']:
    
    print("\nüíæ Saving results for all files (uniform and reweighted with multiple runs)...")
    
    # Create output directory
    output_dir = Path(CONFIG['output_dir'])
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Save individual UNIFORM results for each file (with multiple runs)
    print("\nüìÅ Saving uniform analysis results (multiple runs)...")
    for label in CONFIG['file_labels']:
        if all_results_uniform_runs[label] is not None:
            # Add metadata to results
            result = all_results_uniform_runs[label].copy()
            result['metadata'] = {
                'analysis_type': 'W2_distance_comparison_uniform_multiple_runs',
                'file_label': label,
                'generated_data_path': CONFIG['generated_data_paths'][CONFIG['file_labels'].index(label)],
                'reference_data_path': CONFIG['reference_data_path'],
                'n_runs': CONFIG['n_runs'],
                'completed_successfully': result['statistics']['energy_w2']['mean'] is not None or result['statistics']['torsion_w2']['mean'] is not None
            }
            
            # Save as JSON
            results_path = output_dir / f'w2_distances_uniform_multirun_{label}.json'
            with open(results_path, 'w') as f:
                json.dump(result, f, indent=2)
            print(f"üìÑ {label} uniform multi-run results saved to: {results_path}")
    
    # Save individual REWEIGHTED results for each file (with multiple runs)
    print("\nüìÅ Saving reweighted analysis results (multiple runs)...")
    for label in CONFIG['file_labels']:
        if all_results_reweighted_runs[label] is not None:
            # Add metadata to results
            result = all_results_reweighted_runs[label].copy()
            result['metadata'] = {
                'analysis_type': 'W2_distance_comparison_reweighted_multiple_runs',
                'file_label': label,
                'generated_data_path': CONFIG['generated_data_paths'][CONFIG['file_labels'].index(label)],
                'reference_data_path': CONFIG['reference_data_path'],
                'n_runs': CONFIG['n_runs'],
                'completed_successfully': result['statistics']['energy_w2']['mean'] is not None or result['statistics']['torsion_w2']['mean'] is not None
            }
            
            # Save as JSON
            results_path = output_dir / f'w2_distances_reweighted_multirun_{label}.json'
            with open(results_path, 'w') as f:
                json.dump(result, f, indent=2)
            print(f"üìÑ {label} reweighted multi-run results saved to: {results_path}")
    
    # Save COMPREHENSIVE SUMMARY results
    print("\nüìÅ Saving comprehensive summary results...")
    summary_results = {
        'config': CONFIG.copy(),
        'analysis_type': 'W2_distance_comparison_comprehensive_multiple_runs_summary',
        'analysis_metadata': {
            'n_runs_per_file': CONFIG['n_runs'],
            'n_samples_per_run': CONFIG['n_samples_analysis'],
            'total_comparisons_per_file': CONFIG['n_runs'],
            'random_seed_base': CONFIG['random_seed']
        },
        'uniform_results_summary': {},
        'reweighted_results_summary': {},
        'comparison_summary': {}
    }
    
    # Add uniform results summary with statistics
    for label in CONFIG['file_labels']:
        if all_results_uniform_runs[label] is not None:
            stats = all_results_uniform_runs[label]['statistics']
            summary_results['uniform_results_summary'][label] = {
                'energy_w2_mean': stats['energy_w2']['mean'],
                'energy_w2_std': stats['energy_w2']['std'],
                'energy_w2_values': stats['energy_w2']['values'],
                'torsion_w2_mean': stats['torsion_w2']['mean'],
                'torsion_w2_std': stats['torsion_w2']['std'],
                'torsion_w2_values': stats['torsion_w2']['values'],
                'n_successful_runs': len(stats['energy_w2']['values']) if stats['energy_w2']['values'] else 0
            }
        else:
            summary_results['uniform_results_summary'][label] = None
    
    # Add reweighted results summary with statistics
    for label in CONFIG['file_labels']:
        if all_results_reweighted_runs[label] is not None:
            stats = all_results_reweighted_runs[label]['statistics']
            summary_results['reweighted_results_summary'][label] = {
                'energy_w2_mean': stats['energy_w2']['mean'],
                'energy_w2_std': stats['energy_w2']['std'],
                'energy_w2_values': stats['energy_w2']['values'],
                'torsion_w2_mean': stats['torsion_w2']['mean'],
                'torsion_w2_std': stats['torsion_w2']['std'],
                'torsion_w2_values': stats['torsion_w2']['values'],
                'n_successful_runs': len(stats['energy_w2']['values']) if stats['energy_w2']['values'] else 0
            }
        else:
            summary_results['reweighted_results_summary'][label] = None
    
    # Calculate comprehensive statistics
    valid_uniform_summary = [(k, v) for k, v in summary_results['uniform_results_summary'].items() if v is not None and v['energy_w2_mean'] is not None]
    valid_reweighted_summary = [(k, v) for k, v in summary_results['reweighted_results_summary'].items() if v is not None and v['energy_w2_mean'] is not None]
    
    # Uniform statistics across all models
    if valid_uniform_summary:
        uniform_energy_means = [v['energy_w2_mean'] for k, v in valid_uniform_summary]
        uniform_energy_stds = [v['energy_w2_std'] for k, v in valid_uniform_summary]
        uniform_torsion_means = [v['torsion_w2_mean'] for k, v in valid_uniform_summary]
        uniform_torsion_stds = [v['torsion_w2_std'] for k, v in valid_uniform_summary]
        
        summary_results['uniform_aggregate_stats'] = {
            'energy_w2': {
                'mean_of_means': float(np.mean(uniform_energy_means)),
                'std_of_means': float(np.std(uniform_energy_means)),
                'mean_of_stds': float(np.mean(uniform_energy_stds)),
                'best_model': min(valid_uniform_summary, key=lambda x: x[1]['energy_w2_mean'])[0],
                'best_value': float(min(uniform_energy_means))
            },
            'torsion_w2': {
                'mean_of_means': float(np.mean(uniform_torsion_means)),
                'std_of_means': float(np.std(uniform_torsion_means)),
                'mean_of_stds': float(np.mean(uniform_torsion_stds)),
                'best_model': min(valid_uniform_summary, key=lambda x: x[1]['torsion_w2_mean'])[0],
                'best_value': float(min(uniform_torsion_means))
            }
        }
    
    # Reweighted statistics across all models
    if valid_reweighted_summary:
        reweighted_energy_means = [v['energy_w2_mean'] for k, v in valid_reweighted_summary]
        reweighted_energy_stds = [v['energy_w2_std'] for k, v in valid_reweighted_summary]
        reweighted_torsion_means = [v['torsion_w2_mean'] for k, v in valid_reweighted_summary]
        reweighted_torsion_stds = [v['torsion_w2_std'] for k, v in valid_reweighted_summary]
        
        summary_results['reweighted_aggregate_stats'] = {
            'energy_w2': {
                'mean_of_means': float(np.mean(reweighted_energy_means)),
                'std_of_means': float(np.std(reweighted_energy_means)),
                'mean_of_stds': float(np.mean(reweighted_energy_stds)),
                'best_model': min(valid_reweighted_summary, key=lambda x: x[1]['energy_w2_mean'])[0],
                'best_value': float(min(reweighted_energy_means))
            },
            'torsion_w2': {
                'mean_of_means': float(np.mean(reweighted_torsion_means)),
                'std_of_means': float(np.std(reweighted_torsion_means)),
                'mean_of_stds': float(np.mean(reweighted_torsion_stds)),
                'best_model': min(valid_reweighted_summary, key=lambda x: x[1]['torsion_w2_mean'])[0],
                'best_value': float(min(reweighted_torsion_means))
            }
        }
    
    # Comparison statistics for models with both uniform and reweighted results
    comparison_data = []
    for label in CONFIG['file_labels']:
        uniform_result = summary_results['uniform_results_summary'].get(label)
        reweighted_result = summary_results['reweighted_results_summary'].get(label)
        
        if (uniform_result is not None and reweighted_result is not None and 
            uniform_result['energy_w2_mean'] is not None and reweighted_result['energy_w2_mean'] is not None and
            uniform_result['torsion_w2_mean'] is not None and reweighted_result['torsion_w2_mean'] is not None):
            
            energy_improvement = uniform_result['energy_w2_mean'] - reweighted_result['energy_w2_mean']
            torsion_improvement = uniform_result['torsion_w2_mean'] - reweighted_result['torsion_w2_mean']
            
            comparison_data.append({
                'model': label,
                'uniform_energy_w2_mean': uniform_result['energy_w2_mean'],
                'uniform_energy_w2_std': uniform_result['energy_w2_std'],
                'reweighted_energy_w2_mean': reweighted_result['energy_w2_mean'],
                'reweighted_energy_w2_std': reweighted_result['energy_w2_std'],
                'uniform_torsion_w2_mean': uniform_result['torsion_w2_mean'],
                'uniform_torsion_w2_std': uniform_result['torsion_w2_std'],
                'reweighted_torsion_w2_mean': reweighted_result['torsion_w2_mean'],
                'reweighted_torsion_w2_std': reweighted_result['torsion_w2_std'],
                'energy_improvement': energy_improvement,
                'torsion_improvement': torsion_improvement
            })
    
    if comparison_data:
        energy_improvements = [x['energy_improvement'] for x in comparison_data]
        torsion_improvements = [x['torsion_improvement'] for x in comparison_data]
        
        summary_results['comparison_summary'] = {
            'models_with_both_analyses': len(comparison_data),
            'energy_improvement_stats': {
                'mean': float(np.mean(energy_improvements)),
                'std': float(np.std(energy_improvements)),
                'models_improved': int(sum(1 for x in energy_improvements if x > 0)),
                'models_degraded': int(sum(1 for x in energy_improvements if x < 0)),
                'improvement_values': energy_improvements
            },
            'torsion_improvement_stats': {
                'mean': float(np.mean(torsion_improvements)),
                'std': float(np.std(torsion_improvements)),
                'models_improved': int(sum(1 for x in torsion_improvements if x > 0)),
                'models_degraded': int(sum(1 for x in torsion_improvements if x < 0)),
                'improvement_values': torsion_improvements
            },
            'per_model_comparison': comparison_data
        }
    
    # Save comprehensive summary
    summary_path = output_dir / 'w2_distances_comprehensive_multirun_summary.json'
    with open(summary_path, 'w') as f:
        json.dump(summary_results, f, indent=2)
    print(f"üìÑ Comprehensive multi-run summary results saved to: {summary_path}")
    
    # Save reference data once
    if ref_samples_analysis is not None:
        np.save(output_dir / 'reference_samples_analysis.npy', ref_samples_analysis)
        print(f"üíæ Reference samples saved to: {output_dir / 'reference_samples_analysis.npy'}")
        
    if 'ref_angles' in locals():
        np.save(output_dir / 'reference_angles_analysis.npy', ref_angles)
        print(f"üíæ Reference angles saved to: {output_dir / 'reference_angles_analysis.npy'}")
        
    print(f"\n‚úÖ All results saved to: {CONFIG['output_dir']}")
    
else:
    print("\n‚ö†Ô∏è  Result saving disabled in configuration")

print("\nüéØ Multi-File W2 Distance Analysis with Reweighting and Statistical Analysis Complete!")
uniform_count = len([x for x in all_results_uniform_runs.values() if x is not None])
reweighted_count = len([x for x in all_results_reweighted_runs.values() if x is not None])
print(f"üìä Analyzed {uniform_count} files with uniform sampling ({CONFIG['n_runs']} runs each)")
print(f"üéØ Analyzed {reweighted_count} files with reweighted sampling ({CONFIG['n_runs']} runs each)")
both_count = len([label for label in CONFIG['file_labels'] if all_results_uniform_runs.get(label) is not None and all_results_reweighted_runs.get(label) is not None])
print(f"‚öñÔ∏è  {both_count} files have both uniform and reweighted analysis for comparison")
print(f"üìà Total W2 calculations performed: {(uniform_count + reweighted_count) * CONFIG['n_runs'] * 2}")  # *2 for energy and torsion
print(f"üé≤ Statistical analysis provides mean ¬± std across {CONFIG['n_runs']} independent sampling runs")