In [14]:
import pandas as pd
import numpy as np
import glob
import os
from scipy import signal
from sklearn.metrics import mutual_info_score
import seaborn as sns
from scipy import stats
import warnings
warnings.filterwarnings('ignore')
# Ensure seaborn is set to use a white background


# Your existing functions
def compute_coherence(x, y, fs):
    f, Cxy = signal.coherence(x, y, fs)
    return f, Cxy

def crosscorr(datax, datay, lag=0):
    """ Lag-N cross correlation. Shift datax by N elements. """
    return datax.corr(datay.shift(lag))

def compute_mutual_information(x: np.ndarray, y: np.ndarray) -> float:
    """
    Compute mutual information between two time series using KDE.
    """
    # Clean data
    mask = ~np.isnan(x) & ~np.isnan(y)
    x = x[mask]
    y = y[mask]
    
    if len(x) < 2 or len(y) < 2:
        return 0.0
    
    # Standardize the data
    x = (x - np.mean(x)) / np.std(x)
    y = (y - np.mean(y)) / np.std(y)
    
    # Create KDE estimators
    kde_joint = stats.gaussian_kde(np.vstack([x, y]))
    kde_x = stats.gaussian_kde(x)
    kde_y = stats.gaussian_kde(y)
    
    # Sample points for numerical integration
    n_samples = 50
    x_range = np.linspace(min(x) - 1, max(x) + 1, n_samples)
    y_range = np.linspace(min(y) - 1, max(y) + 1, n_samples)
    X, Y = np.meshgrid(x_range, y_range)
    positions = np.vstack([X.ravel(), Y.ravel()])
    
    # Evaluate densities
    joint_density = kde_joint(positions).reshape(X.shape)
    x_density = kde_x(X[0,:])
    y_density = kde_y(Y[:,0])
    X_density, Y_density = np.meshgrid(x_density, y_density)
    
    # Compute MI
    with np.errstate(divide='ignore', invalid='ignore'):
        mi_density = joint_density * np.log(joint_density / (X_density * Y_density))
    mi = np.nansum(mi_density) * (x_range[1] - x_range[0]) * (y_range[1] - y_range[0])
    
    return max(0, mi)  # Ensure non-negative MI

def compute_coupling_statistics(name, motion_ts, sound_ts, time):
    """Your existing coupling statistics function"""
    # Ensure inputs are numpy arrays
    motion_ts = np.array(motion_ts)
    sound_ts = np.array(sound_ts)
    time = np.array(time)

    # Check if inputs are scalar (single values)
    if motion_ts.ndim == 0 or sound_ts.ndim == 0 or time.ndim == 0:
        return pd.DataFrame({
            'scene': [name],
            'lags': [np.nan],
            'crosscorr': [np.nan],
        }), np.nan, np.nan, np.nan, np.nan, np.nan, np.nan

    # Check if inputs have the same length
    if len(motion_ts) != len(sound_ts) or len(motion_ts) != len(time):
        raise ValueError("motion_ts, sound_ts, and time must have the same length")

    # normalize and center the data
    motion_ts = (motion_ts - np.min(motion_ts)) / (np.max(motion_ts) - np.min(motion_ts))
    motion_ts = motion_ts - np.mean(motion_ts)
    sound_ts = (sound_ts - np.min(sound_ts)) / (np.max(sound_ts) - np.min(sound_ts))
    sound_ts = sound_ts - np.mean(sound_ts)

    # check if values are finite
    if not np.all(np.isfinite(motion_ts)) or not np.all(np.isfinite(sound_ts)):
        return pd.DataFrame({
            'scene': [name],
            'lags': [np.nan],
            'crosscorr': [np.nan],
        }), np.nan, np.nan, np.nan, np.nan, np.nan, np.nan

    # Compute sampling frequency
    fs = 1/np.mean(np.diff(time))

    # compute the average mutual information
    mi = []
    lags = [-0.3, -0.2, -0.1, 0, 0.1, 0.2, 0.3]
    fs = 1/np.mean(np.diff(time))
    lags_samples = [int(lag * fs) for lag in lags]
    
    for lag in lags_samples:
        if lag < 0:
            motsub = motion_ts[:lag]
            soundsub = sound_ts[-lag:]
        elif lag > 0:
            motsub = motion_ts[lag:]
            soundsub = sound_ts[:-lag]
        else:
            motsub = motion_ts
            soundsub = sound_ts
            
        # Add smoothing to capture temporal dependencies
        window = int(0.1 * fs)  # 100ms window
        if window > 1:
            motsub = np.convolve(motsub, np.ones(window)/window, mode='valid')
            soundsub = np.convolve(soundsub, np.ones(window)/window, mode='valid')
        
        mi_value = compute_mutual_information(motsub, soundsub)
        mi.append(mi_value)
    
    max_mi = np.max(mi)
    optimal_lag = lags[np.argmax(mi)]
    
    #################################################### Coherence
    # Compute coherence
    f, Cxy = compute_coherence(motion_ts, sound_ts, fs)

    # keep all values lower than 10 Hz
    mask = f < 10
    Cxy = Cxy[mask]
    f = f[mask]

    # maximum coherence
    coherence = np.max(Cxy)

    # Frequency of max coherence
    f_max = f[np.argmax(Cxy)]
        
    #################################################### Cross-correlation
    lag_seconds = 0.3
    lag_points = int(lag_seconds * fs)

    # Create a range of lags to examine
    lags = np.arange(-lag_points, lag_points + 1)

    # calc cross-cor
    cc = [crosscorr(pd.Series(motion_ts), pd.Series(sound_ts), lag) for lag in lags]

    # Calculate cross-correlation
    crosscorrdf = pd.DataFrame({
        'scene': name,
        'lags': lags*(1/fs),
        'crosscorr': cc,
    })

    ########## Phase locking value
    # Compute analytic signal (using Hilbert transform)
    motion_analytic = signal.hilbert(motion_ts)
    sound_analytic = signal.hilbert(sound_ts)

    # Extract instantaneous phase
    motion_phase = np.angle(motion_analytic)
    sound_phase = np.angle(sound_analytic)

    # Compute phase difference
    phase_diff = motion_phase - sound_phase

    # compute the plv
    plv = np.abs(np.mean(np.exp(1j * phase_diff)))
    
    # make phase_diff a regular list not a numpy array
    try:
        motion_phase = motion_phase.tolist()[0] if hasattr(motion_phase, 'tolist') else motion_phase[0]
        phase_diff = phase_diff.tolist()[0] if hasattr(phase_diff, 'tolist') else phase_diff[0]
    except:
        motion_phase = np.mean(motion_phase)
        phase_diff = np.mean(phase_diff)

    return crosscorrdf, coherence, f_max, max_mi, phase_diff, plv, optimal_lag

def calculate_p1_p2_coupling_stats(merged_folder='../merged_filteredtimeseries/', output_file='p1_p2_coupling_statistics.csv'):
    """
    Loop through merged time series files and calculate coupling statistics between P1 and P2
    """
    
    # Find all merged CSV files
    csv_files = glob.glob(os.path.join(merged_folder, "*.csv"))
    print(f"Found {len(csv_files)} files to process")
    
    # Define P1-P2 variable pairs to analyze
    p1_modalities = ['Amplitude_Envelope_P1', 'heart_rate_P1', 'Filtered_Respiration_P1', 'Filtered_EMG_Bicep_P1', 'Filtered_EMG_Tricep_P1', 'right_index_x_P1', 'right_index_y_P1', 'right_index_z_P1']
    p2_modalities = ['Amplitude_Envelope_P2', 'heart_rate_P2', 'Filtered_Respiration_P2', 'Filtered_EMG_Bicep_P2', 'Filtered_EMG_Tricep_P2', 'right_index_x_P2', 'right_index_y_P2', 'right_index_z_P2']

    variable_pairs = []
    # Add same-participant, same-modality (if needed for matrix)
    for i in range(len(p1_modalities)):
        p1_var = p1_modalities[i]
        p2_var = p2_modalities[i] # Assumes corresponding P1/P2 vars are at same index
        if p1_var.replace('_P1', '') == p2_var.replace('_P2', ''): # Only if same modality
            variable_pairs.append((p1_var, p2_var))

    # Add ALL cross-participant pairs (same and cross-modality)
    for p1_var in p1_modalities:
        for p2_var in p2_modalities:
            if (p1_var.replace('_P1', '') == p2_var.replace('_P2', '')) and (p1_var.replace('_P1', '') == 'Amplitude_Envelope'):
                # This is the P1-P2 envelope, which is handled separately in plotting
                # You might want to exclude it here if you only want to compute it once for the scatter plots
                pass
            else:
                variable_pairs.append((p1_var, p2_var))
    
    # Initialize results list
    results = []
    
    # Loop through each file
    for file_path in csv_files:
        print(f"\nProcessing: {os.path.basename(file_path)}")
        
        try:
            # Load the data
            df = pd.read_csv(file_path)
            
            # Extract metadata from filename and dataframe
            filename = os.path.basename(file_path)
            
            # Try to extract condition information
                # if filename has NoVision or NoMovement, set to 'NoVision' or 'NoMovement'
            condition_vision = 'Vision' if 'NoVision' not in filename else 'NoVision'
            condition_movement = 'Movement' if 'NoMovement' not in filename else 'NoMovement'
            trial = df['Trial'].iloc[0] if 'Trial' in df.columns else 'Unknown'
            
            # Get time vector
            time = df['Time'].values
            
            # Check if we have enough data (at least 3 seconds)
            if time.max() - time.min() < 5.0:
                print(f"  Skipping {filename}: insufficient data duration ({time.max() - time.min():.2f}s)")
                continue
            
            # Set up sliding window analysis
            window_duration = 5.0  # seconds
            step_size = 1.0  # seconds
            n_windows = int((time.max() - time.min() - window_duration) / step_size) + 1
            
            print(f"  Analyzing {n_windows} windows of {window_duration}s each")
            
            # Loop through each variable pair
            for var1, var2 in variable_pairs:
                if var1 not in df.columns or var2 not in df.columns:
                    continue
                    
                print(f"    Processing {var1} vs {var2}")
                
                # Loop through sliding windows
                for window_idx in range(n_windows):
                    window_start = time.min() + window_idx * step_size
                    window_end = window_start + window_duration
                    
                    # Extract window data
                    window_mask = (time >= window_start) & (time <= window_end)
                    time_window = time[window_mask]
                    var1_window = df[var1].values[window_mask]
                    var2_window = df[var2].values[window_mask]
                    
                    # Skip if not enough data in window or too many NaNs
                    if len(time_window) < 100 or np.sum(~np.isnan(var1_window)) < 50 or np.sum(~np.isnan(var2_window)) < 50:
                        continue
                    
                    try:
                        # Calculate coupling statistics
                        crosscorrdf, coherence, f_max, max_mi, phase_diff, plv, optimal_lag = \
                            compute_coupling_statistics(filename, var1_window, var2_window, time_window)
                        
                        # Extract max cross-correlation
                        max_crosscorr = crosscorrdf['crosscorr'].abs().max() if not crosscorrdf.empty else np.nan
                        lag_at_max_crosscorr = crosscorrdf.loc[crosscorrdf['crosscorr'].abs().idxmax(), 'lags'] if not crosscorrdf.empty else np.nan
                        
                        # Store results
                        result = {
                            'filename': filename,
                            'condition_vision': condition_vision,
                            'condition_movement': condition_movement,
                            'trial': trial,
                            'window_idx': window_idx,
                            'window_start': window_start,
                            'window_end': window_end,
                            'var1': var1,
                            'var2': var2,
                            'variable_pair_type': classify_variable_pair(var1, var2),
                            'max_crosscorr': max_crosscorr,
                            'lag_at_max_crosscorr': lag_at_max_crosscorr,
                            'max_coherence': coherence,
                            'freq_at_max_coherence': f_max,
                            'max_mutual_info': max_mi,
                            'optimal_lag_mi': optimal_lag,
                            'phase_locking_value': plv,
                            'mean_phase_diff': phase_diff,
                            'n_samples': len(time_window),
                            'sampling_rate': 1/np.mean(np.diff(time_window))
                        }
                        
                        results.append(result)
                        
                    except Exception as e:
                        print(f"      Error in window {window_idx}: {e}")
                        continue
                        
        except Exception as e:
            print(f"  Error processing {filename}: {e}")
            continue
    
    # Convert to DataFrame
    if results:
        results_df = pd.DataFrame(results)
        
        # Save to CSV
        results_df.to_csv(output_file, index=False)
        print(f"\n✓ Saved {len(results_df)} coupling statistics to {output_file}")
        
        # Print summary statistics
        print(f"\nSummary:")
        print(f"  Files processed: {results_df['filename'].nunique()}")
        print(f"  Variable pairs analyzed: {results_df['variable_pair_type'].nunique()}")
        print(f"  Total windows analyzed: {len(results_df)}")
        print(f"  Conditions: {results_df['condition_vision'].unique()} x {results_df['condition_movement'].unique()}")
        
        return results_df
    else:
        print("No results generated - check your data files and variable names")
        return pd.DataFrame()

def classify_variable_pair(var1, var2):
    """
    Classify the type of variable pair for easier analysis
    """
    # Extract modality information
    modalities = []
    participants = []
    
    for var in [var1, var2]:
        if 'Amplitude_Envelope' in var:
            modalities.append('Audio')
        elif 'heart_rate' in var:
            modalities.append('ECG')
        elif 'Respiration' in var:
            modalities.append('Respiration')
        elif 'EMG' in var:
            modalities.append('EMG')
        elif 'right_index' in var:
            modalities.append('Motion')
        else:
            modalities.append('Other')
            
        if '_P1' in var:
            participants.append('P1')
        elif '_P2' in var:
            participants.append('P2')
        else:
            participants.append('Unknown')
    
    # Classify pair type
    if participants[0] == participants[1]:
        if participants[0] == 'P1':
            pair_type = 'Within_P1'
        else:
            pair_type = 'Within_P2'
    else:
        pair_type = 'Between_Participants'
    
    if modalities[0] == modalities[1]:
        modality_type = f"Same_{modalities[0]}"
    else:
        modality_type = f"Cross_{modalities[0]}_{modalities[1]}"
    
    return f"{pair_type}_{modality_type}"

def create_plots(results_df, output_folder='coupling_plots/'):
    """
    Create P1-P2 cross-modality matrix and envelope correlations.
    This version properly handles all cross-modality pairs for the matrix
    and correctly generates scatter plots for envelope vs all other P1-P2 coupling.
    """
    
    os.makedirs(output_folder, exist_ok=True)
    plt.style.use('seaborn-white')
    
    # First, let's identify what variables we actually have in the results_df
    # and specifically filter for P1 and P2 variables.
    print("\nAnalyzing available data...")
    all_var1 = results_df['var1'].unique()
    all_var2 = results_df['var2'].unique()

    p1_vars_in_data = sorted([v for v in all_var1 if '_P1' in v])
    p2_vars_in_data = sorted([v for v in all_var2 if '_P2' in v])
    
    print(f"P1 variables found in data: {p1_vars_in_data}")
    print(f"P2 variables found in data: {p2_vars_in_data}")
    
    # Extract modality function (re-defined for clarity, assuming it's available)
    def get_modality_detailed(var_name):
        if 'Amplitude_Envelope' in var_name:
            return 'Audio'
        elif 'heart_rate' in var_name:
            return 'Heart Rate'
        elif 'Filtered_Respiration' in var_name:
            return 'Respiration'
        elif 'EMG_Bicep' in var_name:
            return 'EMG Bicep'
        elif 'EMG_Tricep' in var_name:
            return 'EMG Tricep'
        elif 'right_index_x' in var_name:
            return 'Motion X'
        elif 'right_index_y' in var_name:
            return 'Motion Y'
        elif 'right_index_z' in var_name:
            return 'Motion Z'
        else:
            return 'Unknown'
            
    # 1. Create P1-P2 Cross-Modality Matrix (using Mutual Information)
    print("\nCreating cross-modality matrix...")
    
    # Get all unique modalities present in the data for both P1 and P2
    all_modalities = sorted(list(set(results_df['var1'].apply(get_modality_detailed).unique()) |
                                 set(results_df['var2'].apply(get_modality_detailed).unique())))
    all_modalities = [m for m in all_modalities if m != 'Unknown'] # Remove 'Unknown' if it appears

    print(f"Modalities found: {all_modalities}")
    
    # Get conditions
    conditions = results_df.groupby(['condition_vision', 'condition_movement']).size().reset_index()[['condition_vision', 'condition_movement']]
    
    # Create figure for cross-modality matrix
    fig, axes = plt.subplots(2, 2, figsize=(24, 24))
    axes = axes.flatten()
    fig.suptitle('P1-P2 Cross-Modality Coupling Matrix - Mutual Information', fontsize=24)
    
    for idx, (_, cond_row) in enumerate(conditions.iterrows()):
        if idx >= 4: # Limit to 4 plots if more conditions exist
            break
            
        vision = cond_row['condition_vision']
        movement = cond_row['condition_movement']
        cond_name = f"{vision} x {movement}"
        
        ax = axes[idx]
        
        # Filter data for this specific condition
        cond_data = results_df[
            (results_df['condition_vision'] == vision) & 
            (results_df['condition_movement'] == movement)
        ].copy()
        
        # Add modality columns for grouping
        cond_data['modality1'] = cond_data['var1'].apply(get_modality_detailed)
        cond_data['modality2'] = cond_data['var2'].apply(get_modality_detailed)
        
        # Create an empty matrix initialized with NaNs
        matrix = np.full((len(all_modalities), len(all_modalities)), np.nan)
        
        # Populate the matrix by averaging 'max_mutual_info' for each P1-P2 modality pair
        for i, p2_mod_name in enumerate(all_modalities): # Rows represent P2 modalities
            for j, p1_mod_name in enumerate(all_modalities): # Columns represent P1 modalities
                # Find data for the specific P1 modality -> P2 modality coupling
                pair_data = cond_data[
                    (cond_data['modality1'] == p1_mod_name) & 
                    (cond_data['modality2'] == p2_mod_name)
                ]
                
                if len(pair_data) > 0 and not pair_data['max_mutual_info'].isnull().all():
                    value = pair_data['max_mutual_info'].mean()
                    matrix[i, j] = value
                    # print(f"  {cond_name}: P1 {p1_mod_name} -> P2 {p2_mod_name} = {value:.3f}")
        
        # Plot heatmap
        mask = np.isnan(matrix) # Mask NaN values for cleaner plot
        sns.heatmap(matrix, annot=True, fmt='.3f', cmap='viridis', 
                    mask=mask, cbar_kws={'label': 'Mutual Information', 'shrink': 0.8},
                    xticklabels=all_modalities, yticklabels=all_modalities,
                    ax=ax, square=True, vmin=0, vmax=1.5,
                    annot_kws={'fontsize': 25}, linewidths=0.5)
        
        ax.set_xlabel('P1 Modality', fontsize=32)
        ax.set_ylabel('P2 Modality', fontsize=32)
        ax.set_title(cond_name, fontsize=36, pad=10)
        ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right', fontsize=26)
        ax.set_yticklabels(ax.get_yticklabels(), rotation=0, fontsize=26)
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_folder, 'p1_p2_full_cross_modality_matrix_MI.png'), 
                dpi=300, bbox_inches='tight')
    plt.close()
    
    # 2. Create Envelope vs Other Modalities Correlation Plots (P1-P2 specific)
    print("\nCreating envelope correlations...")
    
    # Get P1-P2 Audio Envelope coupling data
    envelope_data = results_df[
        (results_df['var1'] == 'Amplitude_Envelope_P1') & 
        (results_df['var2'] == 'Amplitude_Envelope_P2')
    ].copy()
    
    if len(envelope_data) == 0:
        print("No P1-P2 Audio Envelope data found for correlation plots!")
        return
    
    # Average envelope coupling per trial/condition/window
    envelope_avg = envelope_data.groupby(
        ['trial', 'condition_vision', 'condition_movement', 'window_idx']
    )['max_mutual_info'].mean().reset_index()
    
    print(f"Envelope data shape for correlation: {envelope_avg.shape}")
    
    # Identify all *other* P1-P2 coupling pairs (excluding the envelope itself)
    # We need to consider all P1 variables coupled with all P2 variables.
    all_p1_p2_coupling_vars = results_df[
        (results_df['var1'].str.contains('_P1')) &
        (results_df['var2'].str.contains('_P2')) &
        ~((results_df['var1'] == 'Amplitude_Envelope_P1') & (results_df['var2'] == 'Amplitude_Envelope_P2')) # Exclude self
    ][['var1', 'var2']].drop_duplicates()

    # Consolidate EMG pairs for a single average EMG plot
    emg_pairs_to_process = []
    other_pairs_to_process = []

    for _, row in all_p1_p2_coupling_vars.iterrows():
        var1_name = row['var1']
        var2_name = row['var2']
        modality1 = get_modality_detailed(var1_name)
        modality2 = get_modality_detailed(var2_name)

        # For EMG, we want to group Bicep and Tricep if they are P1-P2 pairs
        if 'EMG' in modality1 and 'EMG' in modality2:
            emg_pairs_to_process.append((var1_name, var2_name, f"EMG ({modality1.replace('EMG ', '')} P1 - {modality2.replace('EMG ', '')} P2)"))
        else:
            other_pairs_to_process.append((var1_name, var2_name, f"{modality1} P1 - {modality2} P2"))

    # Remove duplicates from other_pairs_to_process (e.g., if 'Motion X P1 - Motion Y P2' and 'Motion Y P1 - Motion X P2' are both present, treat them distinctly)
    other_pairs_to_process = sorted(list(set(other_pairs_to_process)))
    
    print(f"\nFound {len(other_pairs_to_process)} non-EMG P1-P2 pairs for correlation.")
    if emg_pairs_to_process:
        print(f"Found {len(emg_pairs_to_process)} EMG P1-P2 pairs for combined correlation.")

    # Determine the number of plots needed
    n_plots = len(other_pairs_to_process) + (1 if len(emg_pairs_to_process) > 0 else 0)
    
    if n_plots == 0:
        print("No other P1-P2 coupling data found to correlate with envelope!")
        return

    n_cols = min(3, n_plots)
    n_rows = (n_plots + n_cols - 1) // n_cols
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(6*n_cols, 6*n_rows))
    if n_plots == 1:
        axes = [axes] # Ensure axes is iterable even for a single plot
    else:
        axes = axes.flatten()
    
    fig.suptitle('Envelope Coupling (P1-P2) vs Other Modality Coupling (P1-P2) - Mutual Information', fontsize=32)
    
    plot_idx = 0
    
    # Plot individual P1-P2 modality correlations
    for p1_var, p2_var, plot_title in other_pairs_to_process:
        if plot_idx >= len(axes):
            break
            
        ax = axes[plot_idx]
        
        # Get specific P1-P2 modality coupling data
        mod_data = results_df[
            (results_df['var1'] == p1_var) & 
            (results_df['var2'] == p2_var)
        ].copy()
        
        if len(mod_data) > 0:
            # Average per trial/condition/window
            mod_avg = mod_data.groupby(
                ['trial', 'condition_vision', 'condition_movement', 'window_idx']
            )['max_mutual_info'].mean().reset_index()
            
            # Merge with envelope data
            merged = pd.merge(
                envelope_avg,
                mod_avg,
                on=['trial', 'condition_vision', 'condition_movement', 'window_idx'],
                suffixes=('_env', '_mod')
            )
            
            if len(merged) > 0:
                x = merged['max_mutual_info_env'].values
                y = merged['max_mutual_info_mod'].values
                
                # Remove NaN values for linear regression and plotting
                mask = ~(np.isnan(x) | np.isnan(y))
                x_clean = x[mask]
                y_clean = y[mask]
                
                if len(x_clean) > 2: # Need at least 2 points for regression
                    # Plot with conditions as colors
                    conditions_for_plot = merged.loc[mask, 'condition_vision'].astype(str) + ' x ' + \
                                          merged.loc[mask, 'condition_movement'].astype(str)
                    unique_conds = sorted(conditions_for_plot.unique())
                    colors = plt.cm.tab10(np.linspace(0, 1, len(unique_conds)))
                    
                    for i, cond in enumerate(unique_conds):
                        cond_mask = conditions_for_plot == cond
                        ax.scatter(x_clean[cond_mask], y_clean[cond_mask], alpha=0.6, s=40,
                                   label=cond, color=colors[i])
                    
                    # Add regression line
                    slope, intercept, r_value, p_value, std_err = stats.linregress(x_clean, y_clean)
                    line_x = np.array([x_clean.min(), x_clean.max()])
                    line_y = slope * line_x + intercept
                    ax.plot(line_x, line_y, 'k--', linewidth=2)
                    
                    ax.text(0.05, 0.95, f'r = {r_value:.3f}\np = {p_value:.3f}',
                            transform=ax.transAxes, verticalalignment='top',
                            bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
                    
                    ax.set_xlabel('max_mutual_info (Audio Envelope P1-P2)')
                    ax.set_ylabel(f'max_mutual_info ({plot_title})')
                    ax.set_title(plot_title)
                    
                    if plot_idx == 0: # Only show legend on the first plot
                        ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=16)
                else:
                    ax.text(0.5, 0.5, f'Insufficient data (n={len(x_clean)})', 
                            ha='center', va='center', transform=ax.transAxes, fontsize=24)
            else:
                ax.text(0.5, 0.5, 'No matched data with envelope', 
                        ha='center', va='center', transform=ax.transAxes, fontsize=24)
        else:
            ax.text(0.5, 0.5, f'No {plot_title} data in results_df', 
                    ha='center', va='center', transform=ax.transAxes, fontsize=24)
        
        ax.grid(True, alpha=0.3)
        plot_idx += 1
    
    # EMG Average (if we have EMG data)
    if len(emg_pairs_to_process) > 0 and plot_idx < len(axes):
        ax = axes[plot_idx]
        
        # Combine all EMG data from the identified EMG pairs
        emg_dfs = []
        for p1_var, p2_var, _ in emg_pairs_to_process:
            emg_data = results_df[
                (results_df['var1'] == p1_var) & 
                (results_df['var2'] == p2_var)
            ].copy()
            if len(emg_data) > 0:
                emg_dfs.append(emg_data)
        
        if emg_dfs:
            # Concatenate all relevant EMG coupling data and average
            all_emg = pd.concat(emg_dfs)
            emg_avg = all_emg.groupby(
                ['trial', 'condition_vision', 'condition_movement', 'window_idx']
            )['max_mutual_info'].mean().reset_index()
            
            # Merge with envelope
            merged = pd.merge(
                envelope_avg,
                emg_avg,
                on=['trial', 'condition_vision', 'condition_movement', 'window_idx'],
                suffixes=('_env', '_emg_avg')
            )
            
            if len(merged) > 0:
                x = merged['max_mutual_info_env'].values
                y = merged['max_mutual_info_emg_avg'].values
                
                mask = ~(np.isnan(x) | np.isnan(y))
                x_clean = x[mask]
                y_clean = y[mask]
                
                if len(x_clean) > 2:
                    conditions_for_plot = merged.loc[mask, 'condition_vision'].astype(str) + ' x ' + \
                                          merged.loc[mask, 'condition_movement'].astype(str)
                    unique_conds = sorted(conditions_for_plot.unique())
                    colors = plt.cm.tab10(np.linspace(0, 1, len(unique_conds)))
                    
                    for i, cond in enumerate(unique_conds):
                        cond_mask = conditions_for_plot == cond
                        ax.scatter(x_clean[cond_mask], y_clean[cond_mask], alpha=0.6, s=40,
                                   label=cond, color=colors[i])
                    
                    slope, intercept, r_value, p_value, std_err = stats.linregress(x_clean, y_clean)
                    line_x = np.array([x_clean.min(), x_clean.max()])
                    line_y = slope * line_x + intercept
                    ax.plot(line_x, line_y, 'k--', linewidth=2)
                    
                    ax.text(0.05, 0.95, f'r = {r_value:.3f}\np = {p_value:.3f}',
                            transform=ax.transAxes, verticalalignment='top',
                            bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
            
                ax.set_xlabel('max_mutual_info (Audio Envelope P1-P2)')
                ax.set_ylabel('max_mutual_info (EMG P1-P2 Average)')
                ax.set_title('EMG P1-P2 Average')
                ax.grid(True, alpha=0.3)
            else:
                ax.text(0.5, 0.5, 'No matched EMG data with envelope', 
                        ha='center', va='center', transform=ax.transAxes, fontsize=12)
        else:
            ax.text(0.5, 0.5, 'No EMG P1-P2 data in results_df', 
                    ha='center', va='center', transform=ax.transAxes, fontsize=12)
        plot_idx += 1
    
    # Remove any unused subplots
    for i in range(plot_idx, len(axes)):
        if axes[i].has_children(): # Check if it has any plotted elements before trying to delete
            fig.delaxes(axes[i])
        else:
            axes[i].set_visible(False) # Make it invisible if it's empty

    plt.tight_layout()
    plt.savefig(os.path.join(output_folder, 'envelope_vs_modality_correlations_MI.png'), 
                dpi=300, bbox_inches='tight')
    plt.close()
    
    print("\n✓ Plots saved successfully!")


In [None]:
merged_folder = '../4a_PROCESSED/merged_filteredtimeseries/'  # Update this path
output_file = './p1_p2_coupling_statistics.csv'

# Calculate coupling statistics
results_df = calculate_p1_p2_coupling_stats(merged_folder, output_file)




Found 40 files to process

Processing: NoVisionMovement_Trial0.csv
  Analyzing 8 windows of 5.0s each
    Processing Amplitude_Envelope_P1 vs Amplitude_Envelope_P2
    Processing Filtered_Respiration_P1 vs Filtered_Respiration_P2
    Processing Filtered_EMG_Bicep_P1 vs Filtered_EMG_Bicep_P2
    Processing Filtered_EMG_Tricep_P1 vs Filtered_EMG_Tricep_P2
    Processing Amplitude_Envelope_P1 vs Filtered_Respiration_P2
    Processing Amplitude_Envelope_P1 vs Filtered_EMG_Bicep_P2
    Processing Amplitude_Envelope_P1 vs Filtered_EMG_Tricep_P2
    Processing Filtered_Respiration_P1 vs Amplitude_Envelope_P2
    Processing Filtered_Respiration_P1 vs Filtered_Respiration_P2
    Processing Filtered_Respiration_P1 vs Filtered_EMG_Bicep_P2
    Processing Filtered_Respiration_P1 vs Filtered_EMG_Tricep_P2
    Processing Filtered_EMG_Bicep_P1 vs Amplitude_Envelope_P2
    Processing Filtered_EMG_Bicep_P1 vs Filtered_Respiration_P2
    Processing Filtered_EMG_Bicep_P1 vs Filtered_EMG_Bicep_P2
    Proc

In [15]:
import matplotlib.pyplot as plt

# Create summary plots if results were generated
output_file = './p1_p2_coupling_statistics.csv'
results_df = pd.read_csv(output_file)
if not results_df.empty:
    create_plots(results_df, output_folder='coupling_plots/')
    
    # Display some example results
    print("\nExample results:")
    print(results_df.head())
    
    print("\nVariable pair types found:")
    print(results_df['variable_pair_type'].value_counts())


Analyzing available data...
P1 variables found in data: ['Amplitude_Envelope_P1', 'Filtered_EMG_Bicep_P1', 'Filtered_EMG_Tricep_P1', 'Filtered_Respiration_P1']
P2 variables found in data: ['Amplitude_Envelope_P2', 'Filtered_EMG_Bicep_P2', 'Filtered_EMG_Tricep_P2', 'Filtered_Respiration_P2']

Creating cross-modality matrix...
Modalities found: ['Audio', 'EMG Bicep', 'EMG Tricep', 'Respiration']

Creating envelope correlations...
Envelope data shape for correlation: (320, 5)

Found 1 non-EMG P1-P2 pairs for correlation.
Found 2 EMG P1-P2 pairs for combined correlation.

✓ Plots saved successfully!

Example results:
                      filename condition_vision condition_movement trial  \
0  NoVisionMovement_Trial0.csv         NoVision           Movement     0   
1  NoVisionMovement_Trial0.csv         NoVision           Movement     0   
2  NoVisionMovement_Trial0.csv         NoVision           Movement     0   
3  NoVisionMovement_Trial0.csv         NoVision           Movement     0  