In [1]:
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
import seaborn as sns
import os
import json
from datetime import datetime

def create_results_directory():
    """
    Create a directory for storing correlation results.
    Returns the path to the created directory.
    """
    base_dir = "correlation_monkey_and_model_temp"
    # Add timestamp to avoid overwriting
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    # results_dir = f"{base_dir}_{timestamp}"
    results_dir = f"{base_dir}"
    
    
    # Create directory if it doesn't exist
    if not os.path.exists(results_dir):
        os.makedirs(results_dir)
        print(f"Created results directory: {results_dir}")
    
    return results_dir

def load_and_correlate_behavior(model_path, model_template, monkey_path, monkey_template, delay):
    """
    Load and correlate model and monkey behavioral data for a specific delay.
    
    Parameters:
    -----------
    model_path : str
        Path to model behavior data
    model_template : str
        Template string for model behavior filenames
    monkey_path : str
        Path to monkey behavior data
    monkey_template : str
        Template string for monkey behavior filenames
    delay : int
        Delay time in milliseconds
        
    Returns:
    --------
    tuple
        (model_data, monkey_data, correlation_coefficient, p_value)
    """
    # Load data
    model_file = model_path + model_template.format(delay)
    monkey_file = monkey_path + monkey_template.format(delay)
    
    model_data = np.load(model_file)
    monkey_data = np.load(monkey_file)

    print(f"{model_data.mean() = }, {monkey_data.mean() = }")
    
    # Flatten arrays if they're multidimensional
    model_data = model_data.flatten()
    monkey_data = monkey_data.flatten()
    
    # Calculate correlation
    r, p = stats.pearsonr(model_data, monkey_data)
    
    return model_data, monkey_data, r, p

def plot_correlation(model_data, monkey_data, delay, r, p):
    """
    Create a scatter plot of model vs monkey behavioral data.
    
    Parameters:
    -----------
    model_data : array-like
        Model behavioral data
    monkey_data : array-like
        Monkey behavioral data
    delay : int
        Delay time in milliseconds
    r : float
        Correlation coefficient
    p : float
        P-value of correlation
    """
    plt.figure(figsize=(8, 8))
    
    # Create scatter plot
    sns.scatterplot(x=monkey_data, y=model_data, alpha=0.5)
    
    # Add correlation line
    z = np.polyfit(monkey_data, model_data, 1)
    p_fit = np.poly1d(z)
    plt.plot(monkey_data, p_fit(monkey_data), "r--", alpha=0.8, label='Correlation Line')
    
    # Add labels and title
    plt.xlabel('Monkey Behavior')
    plt.ylabel('Model Behavior')
    plt.title(f'Model vs Monkey Behavior Correlation\nDelay: {delay}ms, r={r:.3f}, p={p:.3e}')
    
    # Add unity line
    min_val = min(plt.xlim()[0], plt.ylim()[0])
    max_val = max(plt.xlim()[1], plt.ylim()[1])
    plt.plot([min_val, max_val], [min_val, max_val], 'k--', alpha=0.3, label='Unity Line')
    
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)
    plt.gca().spines["top"].set_visible(False)
    plt.gca().spines["right"].set_visible(False)
    
    return plt.gcf()

def plot_correlation_summary(results, results_dir):
    """
    Create a bar plot comparing correlations across all delays.
    
    Parameters:
    -----------
    results : dict
        Dictionary containing correlation results for each delay
    results_dir : str
        Directory to save the plot
    """
    plt.figure(figsize=(10, 6))
    
    # Extract delays and correlations
    delays = [int(delay) for delay in results.keys()]
    correlations = [result['correlation'] for result in results.values()]
    p_values = [result['p_value'] for result in results.values()]
    
    # Create bar plot with thicker bars
    bars = plt.bar(delays, correlations, 
                  width=200,
                  alpha=1,
                  edgecolor='black',
                  linewidth=1.5)
    
    # Add correlation values slightly below the top of bars
    for bar in bars:
        height = bar.get_height()
        y_pos = height + 0.02 if height >= 0 else height - 0.04
        plt.text(bar.get_x() + bar.get_width()/2, y_pos,
                f'{height:.3f}',
                ha='center', va='bottom',
                fontsize=11)
    
    # Add significance stars with more spacing
    for i, (bar, p_val) in enumerate(zip(bars, p_values)):
        height = bar.get_height()
        stars = ''
        if p_val < 0.001:
            stars = '***'
        elif p_val < 0.01:
            stars = '**'
        elif p_val < 0.05:
            stars = '*'
        
        if stars:
            y_pos = height + 0.06 if height >= 0 else height - 0.08
            plt.text(bar.get_x() + bar.get_width()/2, y_pos,
                    stars,
                    ha='center', va='bottom',
                    fontsize=12)
    
    # Customize plot
    plt.xlabel('Delay (ms)', fontsize=12)
    plt.ylabel('Correlation Coefficient', fontsize=12)
    plt.title('Model-Monkey Behavior Correlation vs Delay', fontsize=14)
    
    # Set specific x-ticks
    plt.xticks(delays, fontsize=11)
    plt.yticks(fontsize=11)
    
    # Add gridlines
    plt.grid(False)
    
    # Set y-axis limits to prevent text cutoff
    ymin = min(correlations) - 0.2
    ymax = max(correlations) + 0.15
    plt.ylim(ymin, ymax)
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)
    plt.gca().spines["top"].set_visible(False)
    plt.gca().spines["right"].set_visible(False)
    
    # Adjust layout and save
    plt.tight_layout()
    
    # Save plot with higher DPI for better quality
    summary_plot_filename = os.path.join(results_dir, 'correlation_summary.png')
    plt.savefig(summary_plot_filename, dpi=300)
    plt.close()
    print(f"Saved summary plot: {summary_plot_filename}")

def process_all_delays(delays, model_path, model_template, monkey_path, monkey_template):
    """
    Process and save results for all specified delays.
    """
    # Create results directory
    results_dir = create_results_directory()
    
    # Dictionary to store all results
    all_results = {}
    
    # Process each delay
    for delay in delays:
        print(f"\nProcessing delay: {delay}ms")
        
        # Load and correlate data
        model_data, monkey_data, r, p = load_and_correlate_behavior(
            model_path,
            model_template,
            monkey_path,
            monkey_template,
            delay
        )
        
        # Store results
        all_results[str(delay)] = {
            'correlation': float(r),
            'p_value': float(p)
        }
        
        # Create and save plot
        fig = plot_correlation(model_data, monkey_data, delay, r, p)
        plot_filename = os.path.join(results_dir, f'correlation_plot_{delay}ms.png')
        fig.savefig(plot_filename)
        plt.close(fig)
        print(f"Saved plot: {plot_filename}")
        
        # Save raw data
        data_filename = os.path.join(results_dir, f'correlation_data_{delay}ms.npz')
        np.savez(data_filename, 
                 model_data=model_data, 
                 monkey_data=monkey_data)
        print(f"Saved data: {data_filename}")
    
    # Create and save summary bar plot
    plot_correlation_summary(all_results, results_dir)
    
    # Save summary results as JSON
    results_filename = os.path.join(results_dir, 'correlation_results.json')
    with open(results_filename, 'w') as f:
        json.dump(all_results, f, indent=4)
    print(f"\nSaved summary results: {results_filename}")
    
    return results_dir, all_results


model_behavior_path = "./"
model_behavior_template = "B_I1_hvm200_{}ms.npy"

monkey_behavior_path = "data/monkey_behavioral_data/"
monkey_behavior_template = "b_i1_delay_{}.npy"

delays = [100, 400, 800, 1200]

# Process all delays

results_dir, results = process_all_delays(
    delays,
    model_behavior_path,
    model_behavior_template,
    monkey_behavior_path,
    monkey_behavior_template
)

# Print summary of results
print("\nSummary of correlations:")
print("----------------------")
for delay, result in results.items():
    print(f"Delay {delay}ms:")
    print(f"  Correlation: {result['correlation']:.3f}")
    print(f"  P-value: {result['p_value']:.3e}")

Created results directory: correlation_monkey_and_model_temp

Processing delay: 100ms
model_data.mean() = 0.6497701163731606, monkey_data.mean() = 0.918082398023673
Saved plot: correlation_monkey_and_model_temp/correlation_plot_100ms.png
Saved data: correlation_monkey_and_model_temp/correlation_data_100ms.npz

Processing delay: 400ms
model_data.mean() = 0.6341095927948143, monkey_data.mean() = 0.922275953163999
Saved plot: correlation_monkey_and_model_temp/correlation_plot_400ms.png
Saved data: correlation_monkey_and_model_temp/correlation_data_400ms.npz

Processing delay: 800ms
model_data.mean() = 0.5716948009513231, monkey_data.mean() = 0.8921462201243031
Saved plot: correlation_monkey_and_model_temp/correlation_plot_800ms.png
Saved data: correlation_monkey_and_model_temp/correlation_data_800ms.npz

Processing delay: 1200ms
model_data.mean() = 0.5516149908845951, monkey_data.mean() = 0.8416626776251266
Saved plot: correlation_monkey_and_model_temp/correlation_plot_1200ms.png
Saved da