In [2]:
%load_ext autoreload
%autoreload 2

In [1]:
import numpy as np
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import os
from stimrgs_v1.collector import *
from stimrgs_v1.new_fn import *

In [3]:
def plot_rgs_histogram(experiment:RGSExperiment, output_dir:str="plots") -> None:
    """
    Plot histogram for a single RGS experiment with error bars.
    
    Args:
        experiment: RGSExperiment object
        output_dir: Directory to save the plot
    """
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    
    # Get data for plotting
    edges, means, stds = experiment.get_histogram_data()
    
    # Skip if no data
    if not edges:
        print(f"No data available for {experiment.name}")
        return
    
    # Create the plot
    fig, ax = plt.subplots(figsize=(14, 8))
    
    # Create bar plot with error bars
    bars = ax.bar(edges, means, width=0.8, color='skyblue', alpha=0.7)
    ax.errorbar(edges, means, yerr=stds, fmt='none', ecolor='black', capsize=5)
    
    # Add data labels
    for i, (edge, mean) in enumerate(zip(edges, means)):
        ax.annotate(f'{mean:.2f}', 
                   xy=(edge, mean),
                   xytext=(0, 5),
                   textcoords='offset points',
                   ha='center', 
                   va='bottom',
                   fontsize=9)
    
    # Customize plot
    ax.set_xlabel('Number of Edges Between Rows', fontsize=12)
    ax.set_ylabel('Mean Number of Bell Pairs', fontsize=12)
    ax.set_title(f"RGS {experiment.num_cols} Arms - Bell Pairs Found", fontsize=14)
    ax.grid(True, linestyle='--', alpha=0.7)
    
    # Set x-ticks to be integers
    ax.set_xticks(edges)
    
    # Rotate x-tick labels for better readability if many edges
    if len(edges) > 15:
        plt.xticks(rotation=45)
    
    # Add horizontal line at peak value
    if means:
        max_mean = max(means)
        max_idx = means.index(max_mean)
        ax.axhline(y=max_mean, color='red', linestyle='--', alpha=0.5)
        ax.annotate(f'Max: {max_mean:.2f} at {edges[max_idx]} edges', 
                   xy=(edges[max_idx], max_mean),
                   xytext=(0, 10),
                   textcoords='offset points',
                   ha='center',
                   va='bottom',
                   color='red',
                   fontsize=10)
    
    plt.tight_layout()
    
    # Save the plot
    filename = f"rgs_{experiment.num_cols}_arms_histogram.png"
    plt.savefig(os.path.join(output_dir, filename), dpi=300)
    plt.close()
    
    print(f"Histogram saved as {filename}")

def plot_rgs_detailed(experiment: RGSExperiment, output_dir:str="plots") -> None:
    """
    Create a more detailed plot with both histogram and line graph for a single RGS experiment.
    
    Args:
        experiment: RGSExperiment object
        output_dir: Directory to save the plot
    """
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    
    # Get data for plotting
    edges, means, stds = experiment.get_histogram_data()
    
    # Skip if no data
    if not edges:
        print(f"No data available for {experiment.name}")
        return
    
    # Create the plot
    fig, ax = plt.subplots(figsize=(14, 8))
    
    # Create bar plot with error bars
    bars = ax.bar(edges, means, width=0.8, color='skyblue', alpha=0.6, label='Mean Bell Pairs')
    ax.errorbar(edges, means, yerr=stds, fmt='none', ecolor='black', capsize=5)
    
    # Add line plot to show trend
    ax.plot(edges, means, 'ro-', linewidth=2, markersize=6, label='Trend')
    
    # Shade area between mean±std
    ax.fill_between(edges, 
                   np.array(means) - np.array(stds), 
                   np.array(means) + np.array(stds), 
                   color='blue', alpha=0.2, label='Standard Deviation')
    
    # Highlight maximum point
    if means:
        max_idx = np.argmax(means)
        max_edge = edges[max_idx]
        max_mean = means[max_idx]
        ax.plot(max_edge, max_mean, 'D', color='green', markersize=10, 
                label=f'Maximum: {max_mean:.2f} at {max_edge} edges')
        
        # Add text annotation for peak value
        ax.annotate(f'Peak: {max_mean:.2f} ± {stds[max_idx]:.2f}', 
                   xy=(max_edge, max_mean),
                   xytext=(max_edge + 1, max_mean + 0.1),
                   arrowprops=dict(facecolor='black', shrink=0.05, width=1.5),
                   fontsize=10)
    
    # Customize plot
    ax.set_xlabel('Number of Edges Between Rows', fontsize=12)
    ax.set_ylabel('Mean Number of Bell Pairs', fontsize=12)
    ax.set_title(f"RGS {experiment.num_cols} Arms - Bell Pairs Distribution", fontsize=14)
    ax.grid(True, linestyle='--', alpha=0.7)
    
    # Set x-ticks to be integers
    ax.set_xticks(edges)
    
    # Add legend
    ax.legend(loc='upper right')
    
    plt.tight_layout()
    
    # Save the plot
    filename = f"rgs_{experiment.num_cols}_arms_detailed.png"
    plt.savefig(os.path.join(output_dir, filename), dpi=300)
    plt.close()
    
    print(f"Detailed plot saved as {filename}")

def run_rgs_experiments(key:jnp.ndarray, arms_list:list, num_rows:int=4, bsm_prob:float=1.0, num_trials:int=1000, generate_plots=True, plots_dir="rgs_plots") -> dict:
    """
    Run RGS experiments for multiple arm configurations and return a collection of results.
    Optionally generate plots for each experiment.
    
    Args:
        key: JAX random key
        arms_list: List of arm sizes to test
        num_rows: Number of rows in the RGS
        num_trials: Number of trials per configuration
        generate_plots: Whether to generate plots for each experiment
    
    Returns:
        Dictionary mapping arm sizes to RGSExperiment objects
    """
    experiments = {}    
    
    # Create output directory if it doesn't exist and plots are enabled
    if generate_plots:
        os.makedirs(plots_dir, exist_ok=True)
    
    for arm in arms_list:
        print(f'RGS {arm} arms')
        experiment = simulate_rgs_V2(key, num_rows=num_rows, num_cols=arm, num_trials=num_trials, bsm_prob=bsm_prob)
        experiments[arm] = experiment
        
        # Generate plots for this experiment if enabled
        if generate_plots:
            plot_rgs_histogram(experiment, output_dir=plots_dir)
            plot_rgs_detailed(experiment, output_dir=plots_dir)
        
        print('- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -')
    
    return experiments

def plot_best_configurations(experiments:RGSExperiment, output_dir="plots") -> None:
    """
    Plot the best configuration (most Bell pairs) for each arm size.
    
    Args:
        experiments: Dictionary mapping arm sizes to RGSExperiment objects
        output_dir: Directory to save the figure
    """
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    
    fig, ax = plt.subplots(figsize=(10, 6))
    
    arm_sizes = []
    best_arms = []
    best_means = []
    best_stds = []
    
    for arm_size, experiment in sorted(experiments.items()):
        summary = experiment.summary()
        best_arm = summary["best_arm"]
        
        if best_arm is not None:
            arm_sizes.append(arm_size)
            best_arms.append(best_arm)
            best_means.append(summary["mean_per_arm"][best_arm])
            best_stds.append(summary["std_per_arm"][best_arm])
    
    # Plot with error bars
    ax.bar(arm_sizes, best_means, color='skyblue', alpha=0.7)
    ax.errorbar(arm_sizes, best_means, yerr=best_stds, fmt='none', ecolor='black', capsize=5)
    
    # Add data labels
    for i, (arm_size, best_arm, mean) in enumerate(zip(arm_sizes, best_arms, best_means)):
        ax.annotate(f'Edges: {best_arm}', 
                   (arm_size, mean), 
                   xytext=(0, 5), 
                   textcoords='offset points',
                   ha='center', 
                   va='bottom')
    
    ax.set_xlabel('Number of RGS Arms')
    ax.set_ylabel('Maximum Mean Bell Pairs Found')
    ax.set_title('Best Configurations for Different RGS Arm Sizes')
    ax.grid(True, linestyle='--', alpha=0.7)
    
    # Set x-ticks to be integers
    ax.set_xticks(arm_sizes)
    
    plt.tight_layout()
    
    # Save the plot
    filename = "rgs_best_configurations.png"
    plt.savefig(os.path.join(output_dir, filename), dpi=300)
    plt.close()
    
    print(f"Best configurations plot saved as {filename}")

def save_experiment_data(experiments:RGSExperiment, output_dir:str="data") -> None:
    """
    Save the experiment data to CSV files for later analysis.
    
    Args:
        experiments: Dictionary mapping arm sizes to RGSExperiment objects
        output_dir: Directory to save the data files
    """
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    
    # Save summary data
    summary_file = os.path.join(output_dir, "rgs_summary.csv")
    with open(summary_file, 'w') as f:
        f.write("arm_size,edge_count,mean_bell_pairs,std_dev\n")
        
        for arm_size, experiment in sorted(experiments.items()):
            edges, means, stds = experiment.get_histogram_data()
            
            for edge, mean, std in zip(edges, means, stds):
                f.write(f"{arm_size},{edge},{mean},{std}\n")
    
    print(f"Summary data saved to {summary_file}")
    
    # Save best configurations
    best_file = os.path.join(output_dir, "rgs_best_configurations.csv")
    with open(best_file, 'w') as f:
        f.write("arm_size,best_edge_count,mean_bell_pairs,std_dev\n")
        
        for arm_size, experiment in sorted(experiments.items()):
            summary = experiment.summary()
            best_arm = summary["best_arm"]
            
            if best_arm is not None:
                mean = summary["mean_per_arm"][best_arm]
                std = summary["std_per_arm"][best_arm]
                f.write(f"{arm_size},{best_arm},{mean},{std}\n")
    
    print(f"Best configurations data saved to {best_file}")

In [4]:
directory_name = 'rgs_plots_bsm_0.5'

if __name__ == "__main__":
    key = jax.random.PRNGKey(42)  # Use a fixed seed for reproducibility
    arms_to_test = [3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
    
    # Run all experiments and generate plots automatically
    all_experiments = run_rgs_experiments(
        key, 
        arms_list=arms_to_test, 
        bsm_prob=0.5,
        generate_plots=True,
        plots_dir=directory_name)
    
    # Create a summary plot of the best configurations
    plot_best_configurations(all_experiments, output_dir=directory_name)
    
    # Save the experiment data for later analysis
    save_experiment_data(all_experiments, output_dir=f"data_{directory_name}")
    
    # Print summary of all experiments
    print("\nSUMMARY OF ALL EXPERIMENTS:")
    for arm, experiment in all_experiments.items():
        summary = experiment.summary()
        best_arm = summary["best_arm"]
        
        if best_arm is not None:
            mean = summary["mean_per_arm"][best_arm]
            std = summary["std_per_arm"][best_arm]
            print(f"RGS {arm} arms - Best configuration: {best_arm} edges with {mean:.4f} ± {std:.4f} bell pairs")

RGS 3 arms
Num edges connected 1, Bell found: 0.08000000566244125 ± 0.2712932229042053
Num edges connected 2, Bell found: 0.2151481956243515 ± 0.42665186524391174
Num edges connected 3, Bell found: 0.3103448152542114 ± 0.5073989033699036
Num edges connected 4, Bell found: 0.3988995850086212 ± 0.5404149889945984
Num edges connected 5, Bell found: 0.3814432919025421 ± 0.5208554863929749
Num edges connected 6, Bell found: 0.3912310302257538 ± 0.5373635292053223
Num edges connected 7, Bell found: 0.39529916644096375 ± 0.5696558952331543
Num edges connected 8, Bell found: 0.5310077667236328 ± 0.6417391300201416
Num edges connected 9, No bell pairs found
Histogram saved as rgs_3_arms_histogram.png
Detailed plot saved as rgs_3_arms_detailed.png
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
RGS 4 arms
Num edges connected 1, Bell found: 0.06700000166893005 ± 0.25002196431159973
Num edges connected 2, Bell found: 0.24126267433166504 ± 0.43826261162757874
Num edges conne