In [None]:
import os
import glob
import sys
import numpy as np
import pandas as pd
import pickle
from pathlib import Path
import seaborn as sns
import matplotlib.pyplot as plt

import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import butter, filtfilt
import pandas as pd

%matplotlib widget

sys.path.append('../datasets')
from load_intan_rhs_format.load_intan_rhs_format import read_data


In [None]:
experiments_path = "../PEDOT-DES"

base_path = "%s/%s/rhs_recordings"%(experiments_path, 'rat2-RightNerve' ) 

### Loads all folders at the beginning

In [None]:
def load_all_ports_rhs(path,fileType='rhs', downsample=1,verbose=0):
    """
    Load all .rhs files in a folder and return a dict with port-separated DataFrames and metadata.
    
    Returns:
        df_by_port (dict): Port-wise DataFrames with channels as columns.
        fs (float): Sampling frequency.
        port_info (list): List of dicts with channel metadata.
        full_df (DataFrame): Complete DataFrame with all channels.
    """
    
    # Gather all files
    files = sorted(glob.glob(os.path.join(path, f'*.{fileType}'), recursive=True))
    if verbose:
        print(f"Found {len(files)} {fileType} files in: {path}")
        print(files)
    
    amp_data = []
    time_data = []
    fs = None
    all_port_info = None

    # Loop through each file
    for count, file in enumerate(files):
        print(count)
        print(f"Loading file {count+1}/{len(files)}: {file}")
        data = read_data(file, verbose=verbose)
        
        # Sampling frequency (assume constant)
        if fs is None:
            fs = data['frequency_parameters']['amplifier_sample_rate']
        
        # Channel metadata (assume constant)
        if all_port_info is None:
            all_port_info = data['amplifier_channels']
        
        # Amplifier data
        amp = data['amplifier_data'].T  # shape: [samples, channels]
        amp = amp[::downsample]
        amp_data.append(amp)

        # Time vector
        t = data['t'] if 't' in data else data['t_amplifier']
        t = t[::downsample]
        time_data.append(t)

    # Concatenate across files
    amp_data = np.vstack(amp_data)  # shape: [total_samples, channels]
    time_data = np.concatenate(time_data)

    # Create column labels from amplifier_channels
    channel_names = [f"{ch['port_name']}_{ch['native_channel_name']}" for ch in all_port_info]
    port_names = [ch['port_name'] for ch in all_port_info]

    df_all = pd.DataFrame(amp_data, columns=channel_names)
    df_all['time'] = time_data

    # Split into port-wise DataFrames
    print('split')
    df_by_port = {}
    for port in set(port_names):
        cols = [f"{ch['port_name']}_{ch['native_channel_name']}" for ch in all_port_info if ch['port_name'] == port]
        df_port = df_all[cols + ['time']].copy()
        df_by_port[port] = df_port

    return df_by_port, fs, all_port_info, df_all

def is_valid_recording_folder(foldername):
    """Ignore calibration or malformed folders"""
    invalid_keywords = ['calibration', 'test']
    return not any(kw in foldername.lower() for kw in invalid_keywords)

def parse_condition_name(foldername):
    """Extract a clean condition name for indexing."""
    parts = foldername.split('_')
    try:
        idx = next(i for i, p in enumerate(parts) if 'pulses' in p)
        return foldername #'_'.join(parts[:idx+3])
    except StopIteration:
        return foldername  # fallback

def save_impedance_info(port_info, save_path, rat_id):
    """Save impedance magnitude and phase info from amplifier channels."""
    impedance_data = []
    for ch in port_info:
        impedance_data.append({
            'port_name': ch['port_name'],
            'native_channel_name': ch['native_channel_name'],
            'electrode_impedance_magnitude': ch['electrode_impedance_magnitude'],
            'electrode_impedance_phase': ch['electrode_impedance_phase']
        })

    impedance_df = pd.DataFrame(impedance_data)
    csv_path = f'{save_path}_impedance_summary.csv'
    print(csv_path)
    impedance_df.to_csv(csv_path, index=False)
    print(f"[✓] Saved impedance summary to {csv_path}")
    
def load_and_save_data(base_path, rat_id='rat', downsample=1, force_reload=False, verbose=0):
    """
    Load RHS recordings from all folders in the base path.
    Saves each cuff recording in a separate dictionary and groups the stimulation recordings in one.
    """
    os.makedirs(f'{base_path}/saved_pkls', exist_ok=True)
    os.makedirs(f'{base_path}/saved_stim_pkls', exist_ok=True)
    
    folders = [f for f in os.listdir(base_path) if os.path.isdir(os.path.join(base_path, f))]
    folders = [f for f in folders if is_valid_recording_folder(f)]  # filter only valid folders
    
    cuff_data = {}  # Dictionary for cuff recordings
    stim_data = {}  # Dictionary for grouped stimulation data
    
    # Separate the folders based on cuff and stimulation data
    cuff_folders = [f for f in folders if not any(kw in f.lower() for kw in ["saved", "calibration", "stim"])]
    print(cuff_folders)
    stim_folder = [f for f in folders if "stimulation" in f]  # Stimulation folder (with "stim" in the name)
    print(stim_folder)

    # Process cuff recordings
    for folder in cuff_folders:
        print(f"Loading cuff data from folder: {folder}")
        folder_path = os.path.join(base_path, folder)
        condition_name = parse_condition_name(folder)

        # Define save path for pickle
        save_path = f'{base_path}/saved_pkls/{condition_name}.pkl'
        
        # Load from pickle if exists and not forcing reload
        if os.path.exists(save_path) and not force_reload:
            print(f"[✓] Loading saved cuff data for {rat_id} from {save_path}")
            with open(save_path, 'rb') as f:
                data_dict, fs, port_info, full_df = pickle.load(f)
            cuff_data[condition_name] = data_dict
            continue  # Skip to the next folder if data is already saved

        # Otherwise, load fresh data
        data_dict = {}

        # Load data from all ports (adjust this function to your setup)
        df_by_port, fs, port_info, full_df = load_all_ports_rhs(path=folder_path, downsample=downsample)
        
        # Save impedance info for this cuff recording
        save_impedance_info(port_info, f'{base_path}/saved_pkls/{condition_name}', rat_id)

        # Add to dict
        data_dict['Port A'] = df_by_port.get('Port A')
        data_dict['Port B'] = df_by_port.get('Port B')
        data_dict['Port C'] = df_by_port.get('Port C')

        if verbose > 0:
            for port in ['Port A', 'Port B', 'Port C']:
                print(f"Data for {port} in condition {condition_name}:")
                print(df_by_port.get(port, f"No data for {port}"))
        
        # Save the cuff data
        with open(save_path, 'wb') as f:
            pickle.dump((data_dict, fs, port_info, full_df), f)
            print(f"[✓] Saved cuff data for {rat_id} to {save_path}")
        
        # Store the data in the cuff dictionary
        cuff_data[condition_name] = data_dict

    # Process stimulation folder
    if stim_folder:
        stim_folder_path = os.path.join(base_path, stim_folder[0])  # Assuming there's only one stim folder
        print(f"Loading stimulation data from folder: {stim_folder_path}")
        
        # Create a grouped dictionary for the stimulation subfolders
        grouped_stim_data = {}

        stim_subfolders = [f for f in os.listdir(stim_folder_path) if os.path.isdir(os.path.join(stim_folder_path, f))]
        
        for subfolder in stim_subfolders:
            subfolder_path = os.path.join(stim_folder_path, subfolder)
            condition_name = parse_condition_name(subfolder)

            # Define save path for pickle
            stim_save_path = f'{base_path}/saved_stim_pkls/{condition_name}.pkl'

            # Load from pickle if exists and not forcing reload
            if os.path.exists(stim_save_path) and not force_reload:
                print(f"[✓] Loading saved stimulation data for {rat_id} from {stim_save_path}")
                with open(stim_save_path, 'rb') as f:
                    stim_data_dict, fs, port_info, full_df = pickle.load(f)
                grouped_stim_data[condition_name] = stim_data_dict
                continue  # Skip if data is already saved

            # Otherwise, load fresh data for each subfolder
            stim_data_dict = {}

            # Load data from all ports (adjust this function to your setup)
            df_by_port, fs, port_info, full_df = load_all_ports_rhs(path=subfolder_path, downsample=downsample)

            # Save impedance info for this cuff recording
            save_impedance_info(port_info,f'{base_path}/saved_stim_pkls/{condition_name}', rat_id)

            # Add to dict
            stim_data_dict['Port A'] = df_by_port.get('Port A')
            stim_data_dict['Port B'] = df_by_port.get('Port B')
            stim_data_dict['Port C'] = df_by_port.get('Port C')

            if verbose > 0:
                for port in ['Port A', 'Port B', 'Port C']:
                    print(f"Data for {port} in stimulation condition {condition_name}:")
                    print(df_by_port.get(port, f"No data for {port}"))

            # Save the stimulation data
            with open(stim_save_path, 'wb') as f:
                pickle.dump((stim_data_dict, fs, port_info, full_df), f)
                print(f"[✓] Saved stimulation data for {rat_id} to {stim_save_path}")

            # Store the stimulation data in the grouped dictionary
            grouped_stim_data[condition_name] = stim_data_dict
            
        # Store grouped stimulation data
        stim_data[stim_folder[0]] = grouped_stim_data

    return cuff_data, stim_data, fs, port_info, full_df

In [None]:
#---------------------- Change rat_ID
raw_data, stim_data, fs, port_info, full_df  = load_and_save_data(base_path, rat_id='Rat_02-RN', downsample=1, verbose=0)

#### Print Z

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

def plot_impedance_boxplot(port_info):
    """
    Plot a boxplot of electrode impedance magnitudes grouped by port.
    """
    # Create a DataFrame from the list of channel dicts
    imp_df = pd.DataFrame([{
        'port': ch['port_name'],
        'channel': ch['native_channel_name'],
        'impedance_mag': ch['electrode_impedance_magnitude'],
        'impedance_phase': ch['electrode_impedance_phase']
    } for ch in port_info])

    # Drop missing or zero values (optional)
    imp_df = imp_df[imp_df['impedance_mag'] > 0]

    # Plot
    plt.figure(figsize=(8, 5))
    sns.boxplot(data=imp_df, x='port', y='impedance_mag', palette='Set2')
    plt.title("Electrode Impedance Magnitudes by Port")
    plt.ylabel("Impedance Magnitude (Ohms)")
    plt.xlabel("Port")
    plt.tight_layout()
    plt.grid(True, axis='y', linestyle='--', alpha=0.5)
    plt.show()

    return imp_df  # Return for inspection if needed

imp_df = plot_impedance_boxplot(port_info)
print(imp_df.describe())


## Recording analysis

Loads from saved_pkls folder the relevant pkl based on condtion, and then saved the outputs in a general folder for each condition in the PEDOT-DES folder, to be able then to just quickly extract per animal

### Load pre-saved data - non-stimulation

In [None]:
# Select rat
rat_id = 'Rat_02-RN'

base_path = "%s/%s/rhs_recordings"%(experiments_path, 'rat2-RightNerve' ) 

# Select file and condition - recording name per rat
condition_name = 'baseline_250721_164614' # 'Rat_02-RN'

condition = 'evoked' # 

#------------------------------------------------------------------------------------
save_path = f'{experiments_path}/{condition}' # Save all animals into same folder
os.makedirs(save_path, exist_ok=True)

# Load from pickle if exists 
if os.path.exists(f'{base_path}/saved_pkls'):
    print(f"[✓] Loading saved cuff data for {rat_id} from {f'{base_path}/saved_pkls'}")
    with open(f'{base_path}/saved_pkls/{condition_name}.pkl', 'rb') as f:
        raw_data, fs, port_info, full_df = pickle.load(f)

print(raw_data)

#### Plot pre-saved Impedances

In [None]:
# Load the CSV
csv_path = f'{base_path}/saved_pkls/{condition_name}_impedance_summary.csv'  # Update as needed
imp_df = pd.read_csv(csv_path)
print(imp_df)

# Remove Port C from plot
imp_df = imp_df[imp_df['port_name'] != 'Port C']

# Count A and B channels based on 'Port A_' and 'Port B_' in the 'Electrode' column
num_A_channels = imp_df['port_name'].str.contains('Port A').sum()
num_B_channels = imp_df['port_name'].str.contains('Port B').sum()

# Create a boxplot of impedance magnitudes grouped by port
plt.figure(figsize=(8, 6))
sns.boxplot(x='port_name', y='electrode_impedance_magnitude', data=imp_df, palette='Set2')

# Annotate number of channels on top of each violin
plt.text(x=0, y=imp_df['electrode_impedance_magnitude'].min() + 0.5, s=f"n={num_A_channels}", 
         ha='center', va='bottom', fontsize=20)
plt.text(x=1, y=imp_df['electrode_impedance_magnitude'].min() -0.5, s=f"n={num_B_channels}", 
         ha='center', va='bottom', fontsize=20)

# Plot aesthetics
plt.title('Electrode Impedance Magnitude by Port')
plt.xlabel('Port')
plt.ylabel('Impedance Magnitude (Ohms)')
plt.grid(True, linestyle='--', alpha=0.5)
plt.tight_layout()

plt.savefig(f'{save_path}/{rat_id}_impedances.svg', dpi=300)
# Show plot
plt.show()


#### Remove high-impedance electrodes

In [None]:
#remove hig-impedance electrodes
# Step 1: Get the list of high-impedance channels from df_imp
high_imp_rows = imp_df[imp_df['electrode_impedance_magnitude'] > 100000]
columns_to_drop = [f"{row['port_name']}_{row['native_channel_name']}" for _, row in high_imp_rows.iterrows()]

# Step 2: Drop those columns from each DataFrame in raw_data
for key in raw_data:
    raw_data[key] = raw_data[key].drop(columns=columns_to_drop, errors='ignore')

# Optional: print confirmation
print(f"Dropped {len(columns_to_drop)} channels from each DataFrame in raw_data.")
#print(raw_data)

#### Plot raw data

In [None]:

def plot_data_superimposed(df_port_a, df_port_b, title):
    # Check if the DataFrames are empty
    if df_port_a is None or df_port_a.empty or df_port_b is None or df_port_b.empty:
        print("No data found for the specified ports.")
        return

    # Prepare the figure with two subplots, sharing the X-axis
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8), sharex=True, sharey=True)

    # Plot Port A data
    if 'time' in df_port_a.columns:
        time_a = df_port_a['time']
        df_port_a = df_port_a.drop(columns=['time'])  # Remove 'time' column for plotting
    else:
        time_a = range(len(df_port_a))  # If no time column, use index as time

    for electrode in df_port_a.columns:
        ax1.plot(time_a, df_port_a[electrode], label=electrode)
    
    ax1.set_title("Port A - %s" %title)
    ax1.set_ylabel('Signal Amplitude')
    ax1.set_ylim([-100, 100])
    ax1.legend(loc='best', title='Electrodes')
    ax1.grid(True)

    # Plot Port B data
    if 'time' in df_port_b.columns:
        time_b = df_port_b['time']
        df_port_b = df_port_b.drop(columns=['time'])  # Remove 'time' column for plotting
    else:
        time_b = range(len(df_port_b))  # If no time column, use index as time

    for electrode in df_port_b.columns:
        ax2.plot(time_b, df_port_b[electrode], label=electrode)
    
    ax2.set_title("Port B - %s" %title)
    ax2.set_xlabel('Time (s)')
    ax2.set_ylabel('Signal Amplitude')
    ax2.legend(loc='best', title='Electrodes')
    ax2.grid(True)

    # Display the plot
    plt.tight_layout()
    plt.show()

In [None]:
#Plot raw
plot_data_superimposed(raw_data['Port A'].iloc[0:1*30000], raw_data['Port B'].iloc[0:1*30000], title='Raw data')

# Save the figure
plt.savefig(f'{save_path}/{rat_id}_1sec_raw.svg', dpi=300)

#### Filter and plot data

In [None]:
def bandpass_filter(signal, fs, lowcut=200.0, highcut=3000.0, order=4):
    """
    Band-pass filter the input signal between lowcut and highcut frequencies.
    """
    nyquist = 0.5 * fs
    low = lowcut / nyquist
    high = highcut / nyquist
    
    b, a = butter(order, [low, high], btype='band')
    return filtfilt(b, a, signal)

def apply_bandpass_to_rat_signal(raw_data, fs, lowcut=200.0, highcut=3000.0, order=4):
    """
    Apply bandpass filter to all electrodes in Port A and Port B of raw_data.
    Returns a new dictionary with filtered signals.
    """
    filtered_signal = {}

    for port in ['Port A', 'Port B']:
        df = raw_data[port].copy()
        if 'time' in df.columns:
            time = df['time']
            df = df.drop(columns=['time'])
        else:
            time = np.arange(len(df)) / fs

        df_filtered = pd.DataFrame(index=df.index)
        for col in df.columns:
            df_filtered[col] = bandpass_filter(df[col].values, fs, lowcut, highcut, order)
        
        df_filtered.insert(0, 'time', time)
        filtered_signal[port] = df_filtered

    return filtered_signal

#------------------------------------------------------
# Set your sampling frequency
fs = 30000  # e.g. 30 kHz

# Apply bandpass filter to all electrodes
filtered_signal = apply_bandpass_to_rat_signal(raw_data, fs)


In [None]:
# Select one electrode column from each port (e.g. the first non-time column)
ch_a = filtered_signal['Port A'].columns[1]  # Skip 'time'
ch_b = filtered_signal['Port B'].columns[1]

# Extract data
df_a = filtered_signal['Port A'][['time', ch_a]]#.iloc[0*60*fs:4*60*fs]
df_b = filtered_signal['Port B'][['time', ch_b]]#.iloc[0*60*fs:4*60*fs]

# Plot
plot_data_superimposed(df_a, df_b, title='Filtered data: Single channel')

plt.savefig(f'{save_path}/{rat_id}_singleCh_filtered.svg', dpi=300)

#### Spike detection and spike metrics

##### Functions

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

def detect_negative_spikes(signal, fs, threshold_factor=5, window_ms=2):
    """
    Detect negative spikes and extract waveforms.
    """
    window_size = int(window_ms * fs / 1000)  # ±window_size around spike
    #std = np.std(signal)
    threshold = -12 #threshold_factor * std

    spike_indices = np.where(signal < threshold)[0]
    valid_spikes = []

    for idx in spike_indices:
        if idx - window_size < 0 or idx + window_size >= len(signal):
            continue  # skip if window would go out of bounds
        valid_spikes.append(idx)

    waveforms = np.array([
        signal[idx - window_size:idx + window_size]
        for idx in valid_spikes
    ])

    return np.array(valid_spikes), waveforms
    
def extract_random_baseline_vpp(signal, num_samples, fs, window_ms=4):
    """
    Extract random VPP values from baseline signal.

    Parameters:
    - signal (1D np.array): baseline signal for one channel.
    - num_samples (int): number of 4 ms windows to extract (e.g., same as spike count).
    - fs (int): sampling frequency in Hz.
    - window_ms (float): window size in milliseconds (default: 4 ms).

    Returns:
    - vpp_list (list of floats): VPP values from baseline windows.
    """
    window_samples = int(window_ms * fs / 1000)
    max_start = len(signal) - window_samples

    if max_start <= 0:
        return []

    vpp_list = []
    for _ in range(num_samples):
        start = np.random.randint(0, max_start)
        window = signal[start:start + window_samples]
        vpp = np.max(window) - np.min(window)
        vpp_list.append(vpp)

    return vpp_list
    
def analyze_port(df, fs, baseline_sec=50, threshold_factor=5):
    """
    Analyze all electrodes in a given port and return waveform stats and metrics.
    """
    results = {}
    time = df['time']
    df = df.drop(columns=['time'])

    for ch in df.columns:
        signal = df[ch].values
        spike_indices, waveforms = detect_negative_spikes(signal[int((baseline_sec+5)*fs):-1], fs, threshold_factor)

        if waveforms.size == 0:
            print(f"No spikes detected on {ch}")
            continue

        mean_waveform = np.mean(waveforms, axis=0)
        std_waveform = np.std(waveforms, axis=0)

        spike_times = spike_indices / fs
        duration_sec = len(signal) / fs

        baseline_signal = signal[:int(fs * baseline_sec)]  # baseline
    
        # Now baseline VPPs:
        n_spikes = len(spike_times)
        baseline_vpps = extract_random_baseline_vpp(baseline_signal, n_spikes, fs)
        metrics = {
            'spike_count': len(spike_indices),
            'spike_rate_Hz': len(spike_indices) / duration_sec,
            'mean_spike_amplitude': np.mean(np.min(waveforms, axis=1)),
            'mean_spike_peak_to_peak': np.mean(np.ptp(waveforms, axis=1)),
            'baseline_peak_to_peak': np.mean(baseline_vpps),
            'power_signal': np.sum(np.square(signal)),
            'power_baseline': np.sum(np.square(baseline_signal)),
            'rms_signal': np.sqrt(np.mean(signal ** 2)),
            'rms_baseline': np.sqrt(np.mean(baseline_signal ** 2)),
            'waveforms': waveforms,
            'mean_waveform': mean_waveform,
            'std_waveform': std_waveform,
        }

        results[ch] = metrics

    return results

def plot_average_waveforms(metrics_dict, fs, window_ms=2, port_name='Port A'):
    """
    Plot average ± std of spike waveforms for each electrode.
    """
    time_vector = np.linspace(-window_ms, window_ms, 2 * int(window_ms * fs / 1000))

    plt.figure(figsize=(12, 6))
    for ch, metrics in metrics_dict.items():
        mean_wf = metrics['mean_waveform']
        std_wf = metrics['std_waveform']
        plt.plot(time_vector, mean_wf, label=ch)
        plt.fill_between(time_vector, mean_wf - std_wf, mean_wf + std_wf, alpha=0.3)
    
    plt.title(f'Spike Waveforms in {port_name}')
    plt.xlabel('Time (ms)')
    plt.ylabel('Amplitude')
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.show()

def print_metrics_summary(metrics_dict, port_name='Port A'):
    print(f"Summary for {port_name}")
    for ch, m in metrics_dict.items():
        print(f"  {ch}: {m['spike_count']} spikes | "
              f"{m['spike_rate_Hz']:.1f} Hz | "
              f"Mean amp: {m['mean_spike_amplitude']:.2f} | "
              f"Power: {m['power_signal']:.2f} | "
              f"PtP: {m['mean_spike_peak_to_peak']:.2f}")
        
def compute_average_metrics(metrics_dict):
    """
    Compute average metrics across all electrodes in the given port.
    Returns a dictionary of averaged metrics.
    """
    if not metrics_dict:
        return {}

    keys = ['spike_count', 'spike_rate_Hz', 'mean_spike_amplitude', 'mean_spike_peak_to_peak', 'baseline_peak_to_peak', 'power_signal', 'power_baseline', 'rms_signal','rms_baseline' ]
    aggregated = {k: [] for k in keys}

    for ch_metrics in metrics_dict.values():
        for k in keys:
            aggregated[k].append(ch_metrics[k])

    avg_metrics = {f'avg_{k}': np.mean(v) for k, v in aggregated.items()}
    std_metrics = {f'std_{k}': np.std(v) for k, v in aggregated.items()}

    return {**avg_metrics, **std_metrics}
def save_average_metrics_to_csv(avg_metrics_a, avg_metrics_b, filename='spike_metrics_summary.csv'):
    """
    Save average and std metrics for both ports to a CSV.
    """
    data = {
        'Metric': list(avg_metrics_a.keys()),
        'Port A': list(avg_metrics_a.values()),
        'Port B': [avg_metrics_b.get(k, np.nan) for k in avg_metrics_a.keys()]
    }
    df = pd.DataFrame(data)
    df.to_csv(filename, index=False)
    print(f"Saved summary metrics to {filename}")

from scipy.stats import ttest_ind

def compare_std_metrics(metrics_dict_a, metrics_dict_b, filename="std_comparison_report.txt"):
    """
    Perform statistical comparison on standard deviations of metrics between two ports,
    and save the results to a text file. The function also concludes whether the comparison 
    between ports is fair based on the similarity of the standard deviations.

    Args:
        metrics_dict_a (dict): Per-channel metrics for Port A.
        metrics_dict_b (dict): Per-channel metrics for Port B.
        filename (str): Path to the output report file.
    """
    keys = ['spike_count', 'spike_rate_Hz', 'mean_spike_amplitude', 'mean_spike_peak_to_peak', 'baseline_peak_to_peak', 'power_signal', 'power_baseline', 'rms_signal','rms_baseline' ]

    report_lines = ["Standard Deviation Comparison Report",
                    "Comparing standard deviations between Port A and Port B:\n"]

    # Initialize a flag to track fairness
    is_fair_comparison = True

    for k in keys:
        stds_a = [metrics[k] for metrics in metrics_dict_a.values()]
        print(stds_a)
        stds_b = [metrics[k] for metrics in metrics_dict_b.values()]
        print(stds_b)

        t_stat, p_val = ttest_ind(stds_a, stds_b, equal_var=False)  # Welch’s t-test
        result_line = f"  {k}: t = {t_stat:.2f}, p = {p_val:.4f} {'*' if p_val < 0.05 else ''}"
        print(result_line)
        report_lines.append(result_line)

        # If p-value is less than 0.05, the standard deviations are significantly different
        if p_val < 0.05:
            is_fair_comparison = False

    # Conclude if the comparison is fair based on p-value results
    fairness_conclusion = "Comparison is fair between Port A and Port B." if is_fair_comparison else \
                          "Comparison is NOT fair between Port A and Port B due to significant differences in standard deviations."

    report_lines.append(f"\nConclusion: {fairness_conclusion}")
    
    # Save to file
    with open(filename, "w") as f:
        f.write("\n".join(report_lines))

    print(f"\nReport saved to {os.path.abspath(filename)}")




##### Waveforms activity vs baseline

In [None]:
# Assuming filtered_signal['Port A'] is already created
metrics_a = analyze_port(filtered_signal['Port A'].iloc[int(0*60*fs):int(10*60*fs)], fs, baseline_sec=300) #0:240*fs 
metrics_b = analyze_port(filtered_signal['Port B'].iloc[int(0*60*fs):int(10*60*fs)], fs, baseline_sec=300)
print('--------------------')
plot_average_waveforms(metrics_a, fs, port_name='Port A')
plt.savefig(f'{save_path}/{rat_id}_waveforms_PortA.svg', dpi=300)
plot_average_waveforms(metrics_b, fs, port_name='Port B')
plt.savefig(f'{save_path}/{rat_id}_waveforms_PortB.svg', dpi=300)
print('--------------------')

print_metrics_summary(metrics_a,port_name='Port A')
print_metrics_summary(metrics_b,port_name='Port B')
print('--------------------')

avg_metrics_a = compute_average_metrics(metrics_a)
avg_metrics_b = compute_average_metrics(metrics_b)
print('--------------------')

save_average_metrics_to_csv(avg_metrics_a, avg_metrics_b, filename='%s/%s_spike_metrics_summary.csv'%(save_path, rat_id))

compare_std_metrics(metrics_a, metrics_b, filename='%s/%s_std_comparison.csv'%(save_path, rat_id))


#### RMS metrics and plots

In [None]:
def compute_rms_per_electrode(data, fs, rat_id, start=None, end=None, save_dir='rms_outputs_electrode'):
    import os
    import numpy as np
    import pandas as pd

    os.makedirs(save_dir, exist_ok=True)
    all_rms = []

    for port in list(data.keys()):
        df_port = data[port]
        if port == 'Port A':
            cuff = 'PEDOT_PSS'
            mean_label = 'Mean_port_A'
        elif port == 'Port B':
            cuff = 'PEDOT_DES'
            mean_label = 'Mean_port_B'
        else:
            continue  # Skip Port C (EMG)

        # Apply slicing if start and end are provided
        if start is not None and end is not None:
            df_port = df_port.iloc[int(start*fs):int(end*fs)]

        # Calculate RMS per electrode
        rms_per_electrode = []
        for electrode in df_port.columns:
            if electrode.lower() == 'time':
                continue
            signal = df_port[electrode].values
            rms = np.sqrt(np.mean(signal ** 2))
            rms_per_electrode.append(rms)

            all_rms.append({
                'Rat_ID': rat_id,
                'Condition': condition_name,
                'Cuff': cuff,
                'Electrode': electrode,
                'RMS': rms
            })
        
        # Mean RMS for the port
        mean_rms_port = np.mean(rms_per_electrode)
        all_rms.append({
            'Rat_ID': rat_id,
            'Condition': condition_name,
            'Cuff': cuff,
            'Electrode': mean_label,
            'RMS': mean_rms_port
        })

    df_rms = pd.DataFrame(all_rms)       
    out_path = os.path.join(save_dir, f"{rat_id}-{start}s_to_{end}s-RMS_per_electrode.csv")
    df_rms.to_csv(out_path, index=False)
    print(f"[✓] Saved RMS data to {out_path}")
    return df_rms


In [None]:
# Analyze and save spontaneous RMS per cuff
# To then compute the SNR is important to ensure that both the signal and baseline are measured over the same time period 
# If they differ, the computed SNR may not accurately reflect the true ratio of signal strength to noise.

start = 5*60 # 300 #60
end = start+60 #310 #70
rms_df = compute_rms_per_electrode(filtered_signal, fs, rat_id, save_dir=save_path, start=start, end=end) # starts and end in sec

In [None]:
rms_df

In [None]:
##### Plot violin per rat

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

def plot_rms_distribution(df_rms, rat_id):
    # Filter out rows that are not 'Mean_port'
    non_mean_df = rms_df[~rms_df['Electrode'].str.contains('Mean')]
    
    # Count A and B channels based on 'Port A_' and 'Port B_' in the 'Electrode' column
    num_A_channels = non_mean_df['Electrode'].str.contains('Port A_').sum()
    num_B_channels = non_mean_df['Electrode'].str.contains('Port B_').sum()
    
    plt.figure(figsize=(8, 5))
     # Main violin plot (no hue)
    sns.violinplot(data=df_rms, x='Cuff', y='RMS', inner='box', color='lightgray')

    # Overlay individual electrode RMS values (colored)
    sns.stripplot(data=df_rms, x='Cuff', y='RMS',
                  hue='Electrode', palette='tab10',
                  dodge=True, jitter=True, alpha=0.8, size=6)

    # Annotate number of channels on top of each violin
    plt.text(x=0, y=df_rms['RMS'].min(), s=f"n={num_A_channels}", 
             ha='center', va='bottom', fontsize=20)
    plt.text(x=1, y=df_rms['RMS'].min(), s=f"n={num_B_channels}", 
             ha='center', va='bottom', fontsize=20)
    
    plt.title(f"RMS per Electrode by Cuff ({rat_id})")
    plt.ylabel("RMS")
    plt.xlabel("Cuff Type")
    plt.legend(title='Electrode', bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.tight_layout()
    plt.show()
plot_rms_distribution(rms_df, rat_id)
plt.savefig(f'{save_path}/{rat_id}_RMS-{start}s_to_{end}s.svg', dpi=300)


### Load and analysed pre-saved stimulation pkl

Load pre-saved pkl forom saved_stim_pkls folder, and then computed RMS from baseline and evoked activity (from 0.015s to 0.2s after stim artifact based on preliminary analysis). Saves the plots and cvs reports in the same folder

#### Definitions for extracting waveforms

In [None]:
def extract_rms(signal):
    """Compute RMS of a 1D NumPy array."""
    return np.sqrt(np.mean(signal**2))

def extract_max(signal):
    """Compute Max of a 1D NumPy array."""
    return np.max(signal)
    
def extract_auc_positive(signal, sampling_rate=1.0):
    """
    Compute the area under the curve (AUC) for the positive portion of a 1D signal.
    
    Parameters:
        signal (np.ndarray): Input signal.
        sampling_rate (float): Sampling rate in Hz (default 1.0 for unitless AUC).

    Returns:
        float: AUC of the positive part of the signal.
    """
    positive_signal = np.where(signal > 0, signal, 0)
    dt = 1.0 / sampling_rate
    return np.trapz(positive_signal, dx=dt)
def find_pulse_artifacts(emg, time, fs, save_path, threshold_factor=30, plot=True, title="EMG Artifact Detection"):
    """
    Detect stimulation artifacts in EMG and optionally plot the signal and detection threshold.

    Args:
        emg (np.array): EMG signal (1D).
        fs (float): Sampling frequency in Hz.
        threshold_factor (float): Multiplier for baseline SD to set detection threshold.
        plot (bool): Whether to plot EMG signal with detection threshold and artifacts.
        title (str): Title for the plot.

    Returns:
        pulses (list): List of artifact indices (rising edge).
    """
    valid_emg = emg[(emg < 1000) & (emg > -1000)]
    max_emg = np.max(valid_emg)
    min_emg_interval1 = np.min(valid_emg[int(10*fs):int(15*fs)])
    min_emg_interval2 = np.min(valid_emg[int(25*fs):int(30*fs)])
    thresh =  0.7*np.min([np.abs(min_emg_interval1), np.abs(min_emg_interval2)]) 
    
    # Detect threshold crossings (rising edge) on inverted signal
    crossings = np.where(
        (-emg[:-1] < thresh) &               
        (-emg[1:] >= thresh) &               
        (-emg[1:] < 1500)                    # discard if overly negative
    )[0] + 1

    # Keep only one index per pulse (ensure separation)
    min_separation = int(0.8 * fs)
    pulses = []
    for idx in crossings:
        if not pulses or (idx - pulses[-1]) > min_separation:
            pulses.append(idx)    
    if plot:
        plt.figure(figsize=(10, 4))
        plt.plot(time, emg, label='EMG')
        plt.axhline(-thresh, color='red', linestyle='--', label=f'Threshold = {threshold_factor} × SD')
        plt.scatter(np.array(pulses)/fs, emg[pulses], color='orange', label='Detected Artifacts')
        plt.xlabel("Time (s)")
        plt.ylabel("Amplitude")
        plt.title(title)
        plt.legend()
        plt.tight_layout()
        plt.show()
        plt.savefig(save_path, dpi=300)

    return pulses

def plot_waveforms_after_artifacts(emg, fs, pulse_idxs, save_path, title="EMG Waveforms"):
    """
    Superplot waveforms of EMG from right after each detected artifact for 1 second.
    
    Args:
        emg (np.array): EMG signal (1D).
        fs (float): Sampling frequency in Hz.
        pulse_idxs (list): Indices of detected artifacts.
        title (str): Title for the plot.
    """
    # Define the time window for extracting the waveform: from 0.0001s after the artifact to 1 second after it
    start_time = 0  # Time in seconds to start after the artifact
    end_time = 0.2  # Time in seconds to end after the artifact
    start_sample = int(start_time * fs)  # Convert to sample index
    end_sample = int(end_time * fs)  # Convert to sample index

    snippets_1 = []
    snippets_2 = []

    # Extract snippets of EMG signal after each artifact (up to 10 pulses) - and then next 10 pulses
    for idx in pulse_idxs[:10]:  # Get waveforms for the first 10 pulses
        # Check if there's enough signal before and after the artifact
        if idx + start_sample >= 0 and idx + end_sample < len(emg):
            snippets_1.append(emg[idx + start_sample: idx + end_sample])  # Extract the snippet after the artifact

    for idx in pulse_idxs[11:20]:  # Get waveforms for the next 10 pulses
        # Check if there's enough signal before and after the artifact
        if idx + start_sample >= 0 and idx + end_sample < len(emg):
            snippets_2.append(emg[idx + start_sample: idx + end_sample])  # Extract the snippet after the artifact

    snippets_1 = np.array(snippets_1)
    snippets_2 = np.array(snippets_2)

    # Calculate mean waveform and standard deviation
    mean_wf_1 = snippets_1.mean(0)
    std_wf_1 = snippets_1.std(0)

    mean_wf_2 = snippets_2.mean(0)
    std_wf_2 = snippets_2.std(0)

    # Adjust time vector to match the length of mean_wf
    tvec = np.arange(0, len(mean_wf_1)) / fs  # Time vector adjusted to match the length of the waveform
    try:
        # Plot the mean waveform with ± 1 standard deviation
        plt.figure(figsize=(8, 6))
        plt.plot(tvec, mean_wf_1, label='Mean waveform', color='blue')
        plt.fill_between(tvec, mean_wf_1 - std_wf_1, mean_wf_1 + std_wf_1, alpha=0.3, color='blue', label='± 1 SD')
        plt.plot(tvec, mean_wf_2, label='Mean waveform', color='orange')
        plt.fill_between(tvec, mean_wf_2 - std_wf_2, mean_wf_2 + std_wf_2, alpha=0.3, color='orange', label='± 1 SD')
        plt.title(title)
        plt.xlabel("Time (s)")
        plt.ylabel("Amplitude")
        plt.ylim([-100,100])
        plt.grid(True)
        plt.legend()
        plt.tight_layout()
        plt.show()
        plt.savefig(save_path, dpi=300)
    except:
        print('Not enough pulses')

# Define the band-pass filter function
def bandpass_filter(data, lowcut, highcut, fs, order=4):
    nyquist = 0.5 * fs
    low = lowcut / nyquist
    high = highcut / nyquist
    b, a = butter(order, [low, high], btype='band')
    return filtfilt(b, a, data)
    
def analyze_evoked_emg_folder(pkl_folder, rat_id, output_folder=None, fs_override=None):
    """
    For every .pkl in pkl_folder, load Port C EMG and:
      • detect pulse artifacts
      • compute RMS of baseline (first 5 s) and each 10 s epoch
      • extract ± windowed snippets around each pulse
      • superplot waveforms
      • test activation (>3×SD baseline)
      • save per‐file CSV of all 11 RMS + mean/std + activation count
    """
    os.makedirs(output_folder or pkl_folder, exist_ok=True)
    print(pkl_folder)
    for fname in os.listdir(pkl_folder):
        if not fname.lower().endswith('.pkl'):
            continue
        path = os.path.join(pkl_folder, fname)
        with open(path, 'rb') as f:
            stim_data_dict, fs, port_info, full_df = pickle.load(f)
            print(fs)
        if fs_override:
            fs = fs_override
        print('=----------------------------------------')
        print(fname)
        
        # get Port C as a 1D numpy array
        if 'Port C' in stim_data_dict and stim_data_dict['Port C'] is not None:
            port_c_df = stim_data_dict['Port C']
            raw_emg = port_c_df.iloc[:, 0].values  # assume column 1 is time
            time = port_c_df.iloc[:, 1].values 
        else:
            # fallback to full_df
            raw_emg = full_df['Port C'].iloc[:,0].values
        
        # Apply the band-pass filter to the EMG signal (30–500 Hz)
        lowcut = 60  # 30 Hz low cut
        highcut = 500  # 500 Hz high cut
        emg = bandpass_filter(raw_emg, lowcut, highcut, fs)

        # detect pulses
        pulse_idxs = find_pulse_artifacts(emg, time, fs, os.path.join(output_folder or pkl_folder, fname.replace('.pkl','') + '_signal_artifacts.svg'))

        # define epochs: baseline = first 8s; then up to 20 pulses @1Hz (~20 s)
        baseline_epoch = (2, int(2.2 * fs))
        stim_epochs = []
        for i, idx in enumerate(pulse_idxs[:20]):
            start = idx + int(0.011*fs) # right after stim artifact  # checked this is when stim artifact finishes
            end = start + int(0.2 * fs)  # 0.2‐second window per pulse
            stim_epochs.append((start, end))

        # compute RMS values
        rms_values = []
        auc_values = []
        max_values = []
        
        # baseline
        bdata = emg[baseline_epoch[0]:baseline_epoch[1]]
        rms_values.append(extract_rms(bdata))
        auc_values.append(extract_auc_positive(bdata))
        max_values.append(extract_max(bdata))
        
        # stim pulses
        for (s,e) in stim_epochs:
            rms_values.append(extract_rms(emg[s:e]))
            auc_values.append(extract_auc_positive(emg[s:e]))
            max_values.append(extract_max(emg[s:e]))

        # average & std
        rms_arr = np.array(rms_values)
        auc_arr = np.array(auc_values)
        max_arr = np.array(max_values)
        mean_rms, std_rms = rms_arr.mean(), rms_arr.std()
        mean_auc, std_auc = auc_arr.mean(), auc_arr.std()
        mean_max, std_max = max_arr.mean(), max_arr.std()

        # activation counts (>3×SD baseline) - there is significant difference in after stim EMG compared to baseline (3SD)
        act_thresh = 3 * np.std(bdata)
        activations = [(emg[s:e].max() > act_thresh) for s,e in stim_epochs]
        n_activation = sum(activations)

        plot_waveforms_after_artifacts(emg, fs, pulse_idxs, os.path.join(output_folder or pkl_folder, fname.replace('.pkl','') + '_emg_waveforms.svg'))
        
        # save results
        res = {
            'epoch':['baseline'] + [f'pulse_{i+1}' for i in range(len(stim_epochs))],
            'RMS': list(rms_arr),
            'AUC': list(auc_arr),
            'MAX': list(max_arr),
            'activation':[False] + activations  # baseline no activation
        }
        df_res = pd.DataFrame(res)
        df_res.loc['summary_mean'] = {
            'epoch':'summary_mean',
            'RMS':mean_rms,
            'AUC': mean_auc,
            'MAX': mean_max,
            'activation_count':n_activation
        }
        df_res.loc['summary_std'] = {
            'epoch':'summary_std',
            'RMS':std_rms,
            'AUC':std_auc,
            'MAX': std_max,
            'activation_count':n_activation
        }
        out_csv = os.path.join(output_folder or pkl_folder,
                               fname.replace('.pkl','') + '_emg_analysis.csv')
        df_res.to_csv(out_csv, index=False)
        print(f"Saved analysis for {fname} → {out_csv}")



#### Extract waveforms

In [None]:
#------------------------
base_path = "%s/%s/rhs_recordings"%(experiments_path, 'rat1-RightNerve' ) #'rat4-LeftNerve')
rat_id = 'Rat_01-RN' 
stim_load_path = f'{base_path}/saved_stim_pkls/'
analyze_evoked_emg_folder(stim_load_path, rat_id, output_folder=f'{experiments_path}/stimulation/{rat_id}')


#### Definitions for creating overview

In [None]:
import os
import re
import pandas as pd

def clean_and_validate_df(df, fname):
    # Check for required columns
    required_cols = {'epoch', 'RMS', 'activation'}
    if not required_cols.issubset(df.columns):
        print(f" Missing columns in {fname}, skipping.")
        return None, None

    # Extract baseline row
    baseline_row = df[df['epoch'].str.lower() == 'baseline']
    baseline_rms = baseline_row['RMS'].values[0] if not baseline_row.empty else 0
    baseline_auc = baseline_row['AUC'].values[0] if not baseline_row.empty else 0
    baseline_max = baseline_row['MAX'].values[0] if not baseline_row.empty else 0

    # Keep only rows matching pulse_1 to pulse_20
    pulse_rows = df[df['epoch'].str.match(r'^pulse_\d+$', na=False)].copy()

    if len(pulse_rows) != 20:
        print(f" File has {len(pulse_rows)} valid pulses instead of 20: {fname}")
        return None, 0, 0, 0

    pulse_rows['pulse_num'] = pulse_rows['epoch'].str.extract(r'pulse_(\d+)').astype(int)
    return pulse_rows, baseline_rms, baseline_auc, baseline_max


def parse_filename(fname):
    parts = fname.replace('.csv', '').split('_')

    current = None
    port = None
    duration = None

    for part in parts:
        if re.match(r'^\d+\.?\d*uA$', part):
            current = part
        elif 'Port' in part or 'port' in part:
            port = part
        elif re.match(r'^\d+us$', part):
            duration = part
        else:
            duration = '100us'

    return current, port, duration


def load_and_melt_emg_reports(folder, n_ch, twitch_th, side_nerve):
    records = []

    for fname in os.listdir(folder):
        if fname.startswith('overview'):
            continue
        if not fname.endswith('.csv'):
            continue
        path = os.path.join(folder, fname)
        try:
            df = pd.read_csv(path)
        except Exception as e:
            print(f" Error loading {fname}: {e}")
            continue

        print(f"📄 Found file: {fname}")

        # Clean and extract pulses
        pulses, baseline_rms, baseline_auc, baseline_max = clean_and_validate_df(df, fname)
        if pulses is None:
            continue

        # Extract metadata from filename
        current, port, duration = parse_filename(fname)
        if not all([current, port, duration]):
            print(f" Could not parse all metadata from: {fname}")
            continue

        for _, row in pulses.iterrows():
            # Extract current value
            match = re.match(r"(\d+)", current)
            if match:
                value = float(match.group(1))
            else:
                value = None  # or continue, depending on how you want to handle this
        
            pulse_num = int(row['pulse_num'])
        
            # Assign nerve side and threshold depending on port and pulse number
            if port == 'PortA' or port == 'portA' :
                side = side_nerve[0]
                threshold = twitch_th[0] if pulse_num <= 10 else twitch_th[1]
            elif port == 'PortB' or port == 'portB':
                side = side_nerve[1]
                threshold = twitch_th[1] if pulse_num <= 10 else twitch_th[0]
            else:
                side = None
                threshold = None
        
            # Determine twitch
            twitch = 'Yes' if value is not None and value > threshold else 'No'  
            
            records.append({
                'file': fname,
                'side_nerve': side,
                'number_ch': n_ch,
                'current': current,
                'port': port,
                'duration': duration,
                'pulse': row['pulse_num'],
                'RMS': row['RMS'],
                'AUC': row['AUC'],
                'MAX': row['MAX'],
                'activation': row['activation'],
                'twitch': twitch,
                'baseline_RMS': baseline_rms,
                'baseline_AUC': baseline_auc,
                'baseline_MAX': baseline_max
            })

    df_all = pd.DataFrame.from_records(records)
    if df_all.empty:
        print(" No data extracted.")
        return None

    print("✅ Loaded Data Columns:", df_all.columns)
    print("✅ Unique ports:", df_all['port'].unique())
    return df_all

def plot_rms_traces(df_long):
    sns.set(style="whitegrid")
    g = sns.FacetGrid(df_long, col="port", height=4, aspect=1.3)
    g.map_dataframe(
        sns.lineplot,
        x="pulse", y="RMS",
        hue="current", marker="o"
    )
    g.set_titles("First Port: {col_name}")
    g.add_legend(title="Amplitude (µA)")
    g.set_axis_labels("Pulse Number", "RMS (µV)")
    plt.ylim([0, 100])
    plt.suptitle("EMG RMS over 20 Pulses\n(Grouped by FirstPort and Amplitude)", y=1.05, fontsize=16)
    plt.tight_layout()
    plt.show()


#### Create overview file and plot RMS all pulses

In [None]:
# Analysis
folder = f'{experiments_path}/stimulation/{rat_id}' 
print(folder)
#----------------------------------
# 1st amplitudes to start seing a motor response in Port A, second in Port B
twitch_th = [6, 9]   # R1_RN: [6, 8.5] // R1_LN: [8, 13] // R2_RN: [6, 35] // R2_LN: [32, 5] // R3_RN: [8, 9] // R3_LN: [3, 2] // R4_RN: [7, 9] // L4_LN: [8, 10] 
# Location in nerve: first DES, 2nd PSS
side_nerve = ['Right', 'Left'] # For Rats 1 and 2
#side_nerve = ['Left', 'Right'] # For Rats 3 and 4

# Number of active channels to stimulate both ports
n_ch = 7       #  R1_RN: 7  // R1_LN: 7 // R2_RN: 8 // R2_LN: 10 // R3_RN: 5 // R3_LN: 7 // R4_RN:  // R4_LN: 5

#-------------------------

df_long = load_and_melt_emg_reports(folder, n_ch, twitch_th, side_nerve)
df_long.to_csv(f'{folder}/overview_stim.csv', index=False)

if not df_long.empty:
    plot_rms_traces(df_long)
    plt.savefig(f'{folder}/RMS_vs_pulse.svg', dpi=300)


In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# ---- STEP 1: Assign 'material' and 'order' ---- #
def get_material(row):
    first_material = 'PEDOT:DES' if row['port'] == 'PortA' else 'PEDOT:pSS'  # for pilot: '1stPortA' and PSS
    second_material = 'PEDOT:PSS' if first_material == 'PEDOT:DES' else 'PEDOT:DES'
    return first_material if row['pulse'] <= 10 else second_material

df_long['port'] = df_long['port'].replace({'portA': 'PortA', 'portB': 'PortB'})

df_long['material'] = df_long.apply(get_material, axis=1)

df_long['order'] = df_long['port'].map({
    'PortA': 'DES first',
    'PortB': 'PSS first'
})

# Convert current to numeric (e.g. "15uA" → 15)
df_long['current_uA'] = df_long['current'].str.replace('uA', '').astype(float)

# ---- STEP 2: Create unique identifier for each file ---- #
# Create file_id based on file names (assuming 'file' column exists)
if 'file' in df_long.columns:
    df_long['file_id'] = df_long.groupby(['file']).ngroup()  # Assigning a unique ID for each file
else:
    print("Error: 'file' column is missing in df_long")

# ---- STEP 3: Aggregate by current, material, order, and file_id ---- #
agg_funcs = {
    'RMS': ['mean', 'sem'],
    'AUC': ['mean', 'sem'],
    'MAX': ['mean', 'sem'],
    'baseline_RMS': 'mean',
    'baseline_AUC': 'mean',
    'baseline_MAX': 'mean'
}

# Perform aggregation, but keep file_id in the grouping
plot_data = df_long.groupby(['current_uA', 'material', 'order', 'file_id']).agg(agg_funcs)

# Flatten the multi-level columns
plot_data.columns = ['_'.join(col).strip() for col in plot_data.columns.values]
plot_data = plot_data.reset_index()

# ---- STEP 4: Plotting ---- #
metrics = ['RMS', 'AUC', 'MAX']
colors = {
    'PEDOT:PSS': 'tab:blue',
    'PEDOT:DES': 'tab:orange'
}

# Initialize the figure and axes
fig, axs = plt.subplots(1, 3, figsize=(18, 5), sharex=True)

# Loop over each metric (RMS, AUC, MAX)
for i, metric in enumerate(metrics):
    ax = axs[i]

    # Loop over each stimulation order (e.g., PEDOT:PSS first or PEDOT:DES first)
    for order in plot_data['order'].unique():
        subset = plot_data[plot_data['order'] == order]

        # Loop over materials (e.g., PEDOT:PSS and PEDOT:DES)
        for material in ['PEDOT:PSS', 'PEDOT:DES']:
            mat_data = subset[subset['material'] == material]
            if mat_data.empty:
                continue

            linestyle = '-' if order == 'PSS first' else '--'  # Decide on linestyle for material
            label = f'{material} ({order})'

            # Plot the data, with different markers based on 'file_id'
            for file_id in mat_data['file_id'].unique():
                file_data = mat_data[mat_data['file_id'] == file_id]
                ax.errorbar(
                    file_data['current_uA'], file_data[f'{metric}_mean'],
                    yerr=file_data[f'{metric}_sem'],
                    label=None,  # We handle labels later
                    color=colors[material],  # Ensure different colors for each material
                    linestyle=linestyle,
                    marker = 'o' if order == 'PSS first' else 's',
                    markersize=8,  # Adjust marker size if needed
                    alpha=0.7  # Adjust transparency to distinguish overlapping points
                )

    # Add baseline and 3×baseline lines (only once per metric)
    overall_baseline = df_long[f'baseline_{metric}'].mean()
    baseline_std = df_long[f'baseline_{metric}'].std()
    ax.text(0.02, 0.95, f'Baseline μ±σ:\n{overall_baseline:.2f} ± {baseline_std:.2f}',
            transform=ax.transAxes, fontsize=10, verticalalignment='top',
            bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0.5))
    ax.axhline(overall_baseline, color='red', linestyle=':', linewidth=2, label='Baseline')
    #ax.axhline(1.5 * overall_baseline, color='red', linestyle='--', linewidth=2, label='1.5× Baseline')
    ax.axvline(twitch_th[0], color='orange', linestyle='--', linewidth=2, label='Twitch threshol Port A')
    ax.axvline(twitch_th[1], color='blue', linestyle='--', linewidth=2, label='Twitch threshol Port B')

    # Add baseline labels (only once per subplot)
    if i == 0:
        ax.annotate('Baseline', xy=(0.95, overall_baseline), xycoords=('axes fraction', 'data'),
                    xytext=(-10, -5), textcoords='offset points',
                    ha='right', va='top', color='red', fontsize=12)

    # Customize titles and labels for each subplot
    ax.set_title(metric)
    ax.set_xlabel('Current (uA)')
    ax.set_ylabel(metric)
    ax.grid(True, alpha=0.3)

# Add a unified legend and ensure it's visible
handles, labels = axs[0].get_legend_handles_labels()
by_label = dict(zip(labels, handles))  # This removes duplicates

# Place the legend at the top of the plot
import matplotlib.lines as mlines

# Custom legend handles for stimulation order markers
circle_marker = mlines.Line2D([], [], color='black', marker='o', linestyle='None',
                              markersize=8, label='PSS first (circle)')
square_marker = mlines.Line2D([], [], color='black', marker='s', linestyle='None',
                              markersize=8, label='DES first (square)')
# Combine existing legend labels with marker shape legend
custom_handles = list(by_label.values()) + [circle_marker, square_marker]
custom_labels = list(by_label.keys()) + ['PSS first (circle)', 'DES first (square)']

fig.legend(custom_handles, custom_labels, loc='upper center', ncol=4, bbox_to_anchor=(0.5, 0.1))
fig.legend(custom_handles, custom_labels, loc='upper center', ncol=4, bbox_to_anchor=(0.5, 0.1))
plt.tight_layout()
plt.subplots_adjust(top=0.88)  # Adjust the layout to give space for the legend
plt.show()
plt.savefig(f'{folder}/stim_vs_current.svg', dpi=300)


### Population analysis

#### Population SNR (average baseline vs activity for each port, 60 sec intervals)

This ratio normalizes each animal to its own baseline, accounting for absolute differences in signal strength or noise floor.
Since both numerator and denominator are from the same electrodes, hardware, and conditions, many inter-animal differences cancel out.

In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import re

# Define your folder path
condition = 'evoked' 
directory =  "../PEDOT-DES/%s"%condition  # change if needed

# Collect all relevant files
all_files = os.listdir(directory)

baseline_files = [f for f in all_files if "baseline" in f]
activity_files = [f for f in all_files if "activity" in f]

# Good channels
good_channels = {1, 2, 3, 13, 18, 19, 29, 31, 22,10,21,11,28,4,27,26,24,7}  # small channels only


# Helper to extract channel number from Electrode string
def extract_channel_num(electrode):
    match = re.search(r'[A-Z]-0*(\d+)$', electrode)
    return int(match.group(1)) if match else None
    
# Helper to get rat ID
def get_rat_id(filename):
    parts = filename.split("-")
    return "-".join(parts[:2])  # e.g., 'Rat_01-RN'

# Match baseline and activity by rat ID
rat_ids = list(set(get_rat_id(f) for f in baseline_files))
snr_data = []

print(rat_ids)

# Map rat_id -> filename
baseline_dict = {get_rat_id(f): f for f in baseline_files if get_rat_id(f)}
activity_dict = {get_rat_id(f): f for f in activity_files if get_rat_id(f)}

# Match only rats with both files
common_rats = baseline_dict.keys() & activity_dict.keys()

snr_data = []

snr_data = []

for rat_id in sorted(common_rats):
    baseline_df = pd.read_csv(os.path.join(directory, baseline_dict[rat_id]))
    activity_df = pd.read_csv(os.path.join(directory, activity_dict[rat_id]))
    
    # Apply channel number extraction to both datasets
    for df in [baseline_df, activity_df]:
        df["channel_num"] = df["Electrode"].apply(extract_channel_num)

    # Initialize a dictionary to store the filtered data for both ports
    filtered_data = {"A": {"baseline": None, "activity": None}, "B": {"baseline": None, "activity": None}}

    # Filter baseline and activity data for both ports A and B
    for port in ["A", "B"]:
        port_str = f"Port {port}"
        is_port = lambda df: df["Electrode"].str.startswith(port_str)
        is_good = lambda df: df["channel_num"].isin(good_channels)

        # Filter baseline and activity data for the good channels of this port
        baseline_filtered = baseline_df[is_port(baseline_df) & is_good(baseline_df)]
        activity_filtered = activity_df[is_port(activity_df) & is_good(activity_df)]

        # Store the filtered data for both baseline and activity
        filtered_data[port]["baseline"] = baseline_filtered
        filtered_data[port]["activity"] = activity_filtered

    # Now, we trim both ports to the same number of channels (minimum number of good channels)
    min_channels = min(len(filtered_data["A"]["baseline"]), len(filtered_data["B"]["baseline"]),
                       len(filtered_data["A"]["activity"]), len(filtered_data["B"]["activity"]))

    # Trim data for both ports based on the minimum number of good channels
    baseline_filtered_A = filtered_data["A"]["baseline"].head(min_channels)
    baseline_filtered_B = filtered_data["B"]["baseline"].head(min_channels)
    activity_filtered_A = filtered_data["A"]["activity"].head(min_channels)
    activity_filtered_B = filtered_data["B"]["activity"].head(min_channels)

    # Calculate SNR if there's enough data
    if not baseline_filtered_A.empty and not baseline_filtered_B.empty and \
       not activity_filtered_A.empty and not activity_filtered_B.empty:
        baseline_rms_A = baseline_filtered_A["RMS"].mean()
        baseline_rms_B = baseline_filtered_B["RMS"].mean()
        activity_rms_A = activity_filtered_A["RMS"].mean()
        activity_rms_B = activity_filtered_B["RMS"].mean()

        # Calculate the SNR for both ports
        snr_A = activity_rms_A / baseline_rms_A if baseline_rms_A != 0 else None
        snr_B = activity_rms_B / baseline_rms_B if baseline_rms_B != 0 else None

        # Append results for both ports
        snr_data.append({
            "Rat": rat_id,
            "Port": "Port A",
            "SNR": snr_A,
            "GoodChannelsUsed": min_channels
        })
        snr_data.append({
            "Rat": rat_id,
            "Port": "Port B",
            "SNR": snr_B,
            "GoodChannelsUsed": min_channels
        })

# Convert to DataFrame
snr_df = pd.DataFrame(snr_data)
print(snr_df)
output_path = os.path.join(directory, f"SNR_fair_summary_{condition}.csv")
snr_df.to_csv(output_path, index=False)


# Plot paired SNR comparison
plt.figure(figsize=(8, 6))
sns.pointplot(data=snr_df, x="Port", y="SNR", hue="Rat", markers="o", dodge=True, join=True)
plt.title("Paired SNR Comparison Between Port A and Port B")
plt.ylabel("SNR (Activity / Baseline RMS)")
plt.grid(True)
plt.legend(title="Rat")
plt.tight_layout()
plt.savefig(f'{directory}/paired_fair_SNR_{condition}.svg', dpi=300)
plt.show()


# Pivot to wide format: SNR_Port_A and SNR_Port_B per rat
df_wide = snr_df.pivot(index="Rat", columns="Port", values="SNR").reset_index()
df_wide.columns.name = None  # clean up

# Add ΔSNR column
df_wide["Delta_SNR"] = df_wide["Port A"] - df_wide["Port B"]

# --- Plotting ---
fig, axs = plt.subplots(2, 1, figsize=(10, 8), sharex=True, gridspec_kw={'height_ratios': [3, 1]})

# ---- Plot 1: Paired SNR with error bars ----
for i, row in df_wide.iterrows():
    axs[0].plot(["Port A", "Port B"], [row["Port A"], row["Port B"]],
                marker="o", label=f'{row["Rat"]})')

axs[0].set_ylabel("SNR")
axs[0].set_title("SNR per Port (paired per rat)")
axs[0].legend(loc="upper right", fontsize=8, ncol=2)

# Summary error bars (mean ± SEM)
means = df_wide[["Port A", "Port B"]].mean()
sems = df_wide[["Port A", "Port B"]].sem()
axs[0].errorbar(["Port A", "Port B"], means, yerr=sems, fmt='o', color='black', capsize=5, lw=2, label="Mean ± SEM")

# ---- Plot 2: ΔSNR subplot ----
sns.barplot(x="Rat", y="Delta_SNR", data=df_wide, ax=axs[1], palette="coolwarm", edgecolor="black")
axs[1].axhline(0, color="gray", linestyle="--")
axs[1].set_ylabel("ΔSNR (A - B)")
axs[1].set_title("SNR Difference per Rat")

plt.tight_layout()
plt.savefig(f'{directory}/SNR_fair_increment_{condition}.svg', dpi=300)

# Summary statistics
print("\nSummary Statistics:")
print(df_wide[["Port A", "Port B"]].describe())

from scipy.stats import ttest_rel, wilcoxon, shapiro

# --- Check normality of the differences ---
differences = df_wide["Port A"] - df_wide["Port B"]
shapiro_stat, shapiro_p = shapiro(differences)

print("\nNormality Check (Shapiro-Wilk Test on paired differences):")
print(f"Shapiro-Wilk statistic = {shapiro_stat:.3f}, p = {shapiro_p:.3g}")

# Decide which test to use based on p-value
if shapiro_p > 0.05:
    print("Differences are normally distributed (p > 0.05). Using paired t-test.")
    t_stat, p_val = ttest_rel(df_wide["Port A"], df_wide["Port B"])
    print(f"\nPaired t-test result:\nt = {t_stat:.3f}, p = {p_val:.3g}")
else:
    print("Differences are NOT normally distributed (p ≤ 0.05). Using Wilcoxon signed-rank test.")
    try:
        w_stat, p_val = wilcoxon(df_wide["Port A"], df_wide["Port B"])
        print(f"\nWilcoxon signed-rank test result:\nW = {w_stat:.3f}, p = {p_val:.3g}")
    except ValueError as e:
        print(f"Wilcoxon test failed: {e}")


#### Population impedances and equivalent RC 

In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import ttest_rel, wilcoxon, shapiro

# Configuration
condition = 'saved_pkls' #'saved_stim_pkls' #' saved_pkls
base_path = "../PEDOT-DES"


def compute_series_R_C(magnitude, phase_deg, frequency=1000):
    # Convert phase to radians
    phase_rad = np.radians(phase_deg)

    # Series resistance (real part)
    R = magnitude * np.cos(phase_rad)

    # Capacitive reactance (imag part)
    X_C = magnitude * np.sin(phase_rad)

    # Series capacitance
    C = 1 / (2 * np.pi * frequency * abs(X_C)) if X_C != 0 else np.nan

    return R, C


all_rat_data = []

# Step 1: Load all rhs
good_channels = {1, 2, 3, 13, 18, 19, 29, 31, 22,10,21,11,28,4,27,26,24,7}  # small channels only
rat_shared_channels_map = {}  # maps rat_id -> set of shared (port, channel)

for rat_folder in os.listdir(base_path):
    if not rat_folder.startswith("rat"):
        continue
    
    rat_path = os.path.join(base_path, rat_folder, "rhs_recordings/%s"%condition)
    if not os.path.exists(rat_path):
        continue

    csv_files = [f for f in os.listdir(rat_path) if f.endswith(".csv")]
    if not csv_files:
        continue

    rat_frames = []
    per_file_channels = []
    if csv_files:
        f = csv_files[0]
        print(f)
        print(rat_folder)
        try:
            df = pd.read_csv(os.path.join(rat_path, f))
            df = df.rename(columns={
                "port_name": "port",
                "native_channel_name": "channel",
                "electrode_impedance_magnitude": "impedance",
                "electrode_impedance_phase": "phase"
            })
            df["rat"] = rat_folder
            df = df[["rat", "port", "channel", "impedance", "phase"]]
            # Extract numeric channel ID
            df["channel_num"] = df["channel"].str.extract(r"(\d+)$").astype(int)
            # Flag good channels
            df["is_good_channel"] = df["channel_num"].isin(good_channels)
            per_file_channels.append(set(zip(df["port"], df["channel"])))

            # Compute R and C for each row and store the results
            df[['equivalent_R', 'equivalent_C']] = df.apply(
                lambda row: pd.Series(compute_series_R_C(row['impedance'], row['phase'])), axis=1)
                
            rat_frames.append(df)
        except Exception as e:
            print(f"Error in {f}: {e}")
    
    # Compute shared (port, channel) pairs
    if rat_frames and per_file_channels:
        shared = set.intersection(*per_file_channels)
        rat_shared_channels_map[rat_folder] = shared
        all_rat_data.append(pd.concat(rat_frames, ignore_index=True))

# Combine and clean
df_all = pd.concat(all_rat_data, ignore_index=True)
df_all = df_all[~df_all["port"].str.contains("Port C", na=False)]

# Report shared channel quality
print("\n--- SHARED CHANNEL QUALITY REPORT ---")
report = []
for rat, shared in rat_shared_channels_map.items():
    rat_df = df_all[df_all["rat"] == rat]
    for port in sorted(set(p for p, _ in shared if "C" not in p)):
        port_shared = [(p, ch) for p, ch in shared if p == port]
        good, bad = [], []
        for p, ch in port_shared:
            subset = rat_df[(rat_df["port"] == p) & (rat_df["channel"] == ch)]
            if not subset.empty and subset["is_good_channel"].any():
                good.append(ch)
            else:
                bad.append(ch)
        report.append({
            "rat": rat, "port": port,
            "n_shared": len(port_shared),
            "n_good": len(good), "good_channels": ", ".join(good),
            "n_bad": len(bad), "bad_channels": ", ".join(bad)
        })

# Filter good and valid channels
df_filtered = df_all[df_all["is_good_channel"] & (df_all["impedance"] <= 100000)]

# Compute per-rat averages
metrics = ["impedance", "phase", "equivalent_R", "equivalent_C"]
summary_df = df_filtered.groupby(["rat", "port"])[metrics].mean().unstack()
summary_df = summary_df.dropna(subset=[("impedance", "Port A"), ("impedance", "Port B")])

# Paired statistics
print("\n--- PAIRED STATISTICS ---")
for metric in metrics:
    a = summary_df[(metric, "Port A")]
    b = summary_df[(metric, "Port B")]
    diff = a - b
    normal = shapiro(diff)[1] > 0.05
    if normal:
        t, p = ttest_rel(a, b)
        print(f"{metric.title()}: t = {t:.3f}, p = {p:.3g}")
    else:
        w, p = wilcoxon(a, b)
        print(f"{metric.title()} (nonparametric): W = {w:.3f}, p = {p:.3g}")

# Plot per-rat paired means
fig, axs = plt.subplots(1, len(metrics), figsize=(6 * len(metrics), 6))
for i, metric in enumerate(metrics):
    sns.boxplot(x="port", y=metric, data=df_filtered, ax=axs[i], palette="Set2")
    sns.stripplot(x="port", y=metric, data=df_filtered, ax=axs[i], color='black', alpha=0.3, jitter=True)
    for rat_id, row in summary_df.iterrows():
        axs[i].plot(["Port A", "Port B"],
                    [row[(metric, "Port A")], row[(metric, "Port B")]],
                    marker="o", label=rat_id if i == 0 else "")
    axs[i].set_title(f"{metric.title()} per Rat")
    axs[i].set_ylabel(metric.title())
    axs[i].grid(True)
    if i == 0:
        axs[i].legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.savefig(f"{base_path}/impedance_RC__pop_means_{condition}.svg", dpi=300)

# Boxplots for all values
fig, axs = plt.subplots(1, len(metrics), figsize=(6 * len(metrics), 6))
for i, metric in enumerate(metrics):
    sns.boxplot(x="port", y=metric, data=df_filtered, ax=axs[i], palette="Set2")
    sns.stripplot(x="port", y=metric, data=df_filtered, ax=axs[i], color='black', alpha=0.3, jitter=True)
    axs[i].set_title(f"All Channels: {metric.title()}")
    axs[i].grid(True)
plt.tight_layout()
plt.savefig(f"{base_path}/impedance_RC_population_{condition}.svg", dpi=300)
plt.show()

In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import ttest_rel

# Folder path
condition = 'evoked'  # 'baseline'
experiments_path = "../PEDOT-DES"
base_dir = f'{experiments_path}/{condition}'

# Load all files
summary_files = [f for f in os.listdir(base_dir) if f.endswith('spike_metrics_summary.csv')]
records = []

for file in summary_files:
    rat_id = file.replace('_spike_metrics_summary.csv', '')
    filepath = os.path.join(base_dir, file)
    df = pd.read_csv(filepath, index_col=0)

    # Get average metrics
    avg_metrics = df[df.index.str.startswith('avg_')]

    for metric in avg_metrics.index:
        try:
            val_des = avg_metrics.loc[metric, 'Port A']
            val_pss = avg_metrics.loc[metric, 'Port B']
            records.append({
                'rat_id': rat_id,
                'metric': metric,
                'PEDOT:DES': val_des,
                'PEDOT:PSS': val_pss,
                'delta': val_des - val_pss
            })
        except Exception as e:
            print(f"Skipping {metric} in {rat_id} due to error: {e}")

df_all = pd.DataFrame(records)

# Find metrics that have a baseline counterpart
baseline_pairs = []
metrics = df_all['metric'].unique()
for m in metrics:
    if m.endswith('_signal'):
        baseline = m.replace('_signal', '_baseline')
        if baseline in metrics:
            baseline_pairs.append((m, baseline))

# Create plot with 3 subplots per metric if baseline exists
n_rows = len(metrics)
fig, axs = plt.subplots(n_rows, 3, figsize=(16, 4 * n_rows))
fig.tight_layout(pad=5)

results = []

for i, metric in enumerate(metrics):
    df_m = df_all[df_all['metric'] == metric]
    if df_m.shape[0] < 2:
        continue

    # Paired t-test between materials
    t_stat, p_val = ttest_rel(df_m['PEDOT:PSS'], df_m['PEDOT:DES'])

    # Line plot per rat
    ax1 = axs[i, 0] if n_rows > 1 else axs[0]
    for _, row in df_m.iterrows():
        ax1.plot(['PEDOT:DES', 'PEDOT:PSS'], [row['PEDOT:DES'], row['PEDOT:PSS']], marker='o', label=row['rat_id'])
    ax1.set_title(f"{metric}\nt={t_stat:.2f}, p={p_val:.4f}")
    ax1.set_ylabel('Value')
    ax1.grid(True)
    if i == 0:
        ax1.legend()

    # Histogram of delta
    ax2 = axs[i, 1] if n_rows > 1 else axs[1]
    valid_deltas = df_m['delta'].dropna()
    if not valid_deltas.empty:
        ax2.hist(valid_deltas, bins=3)
        ax2.axvline(0, linestyle='--', color='black')
        ax2.set_title(f"Δ (DES - PSS) for {metric}")
    else:
        ax2.set_visible(False)
    ax2.axvline(0, linestyle='--', color='black')
    ax2.set_title(f"Δ (DES - PSS) for {metric}")
    ax2.set_xlabel('Δ Value')
    ax2.set_ylabel('Count')
    ax2.grid(True)

    # If it's a "signal" metric with baseline, plot the relative increase vs baseline
    if (metric, metric.replace('_signal', '_baseline')) in baseline_pairs:
        signal_df = df_all[df_all['metric'] == metric].set_index('rat_id')
        baseline_df = df_all[df_all['metric'] == metric.replace('_signal', '_baseline')].set_index('rat_id')

        # Only include rats in both, and with no NaNs
        merged_df = pd.merge(
            signal_df[['PEDOT:DES', 'PEDOT:PSS']],
            baseline_df[['PEDOT:DES', 'PEDOT:PSS']],
            left_index=True, right_index=True,
            suffixes=('_signal', '_baseline')
        ).dropna()
        
        if not merged_df.empty:
            des_inc = (merged_df['PEDOT:DES_signal'] - merged_df['PEDOT:DES_baseline'])
            pss_inc = (merged_df['PEDOT:PSS_signal'] - merged_df['PEDOT:PSS_baseline'])
        
            t_stat_inc, p_val_inc = ttest_rel(pss_inc, des_inc)
        
            ax3 = axs[i, 2] if n_rows > 1 else axs[2]
            for rat in merged_df.index:
                ax3.plot(['PEDOT:DES', 'PEDOT:PSS'], [des_inc[rat], pss_inc[rat]], marker='o', label=rat)
            ax3.axhline(0, linestyle='--', color='gray')
            ax3.set_title(f"Δ over Baseline ({metric})\nt={t_stat_inc:.2f}, p={p_val_inc:.4f}")
            ax3.set_ylabel('Increase')
            ax3.grid(True)
            if i == 0:
                ax3.legend()
        
            results.append({
                'metric': f"{metric} Δ",
                'n': merged_df.shape[0],
                't_stat': t_stat_inc,
                'p_val': p_val_inc
            })
        # Paired t-test onincrease
        t_stat_inc, p_val_inc = ttest_rel(pss_inc, des_inc)

        ax3 = axs[i, 2] if n_rows > 1 else axs[2]
        for rat in merged_df.index:
            ax3.plot(['PEDOT:DES', 'PEDOT:PSS'], [des_inc[rat], pss_inc[rat]], marker='o', label=rat)
        ax3.axhline(0, linestyle='--', color='gray')
        ax3.set_title(f"Δ over Baseline ({metric})\nt={t_stat_inc:.2f}, p={p_val_inc:.4f}")
        ax3.set_ylabel('Increase')
        ax3.grid(True)
        if i == 0:
            ax3.legend()

        results.append({
            'metric': f"{metric} Δ",
            'n': len(merged_df.index),
            't_stat': t_stat_inc,
            'p_val': p_val_inc
        })

    results.append({
        'metric': metric,
        'n': df_m.shape[0],
        't_stat': t_stat,
        'p_val': p_val
    })

# Summary table
results_df = pd.DataFrame(results).sort_values(by='p_val')

plt.savefig(f'{base_dir}/overview_metrics_{condition}.svg', dpi=300)


#### Population analysis stimulation

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
import glob
from scipy.stats import ttest_rel

# ------------------------------
# STEP 1: Find all `overview_stim.csv` files recursively
# ------------------------------
experiments_path = "../PEDOT-DES"
root_dir = f'{experiments_path}/stimulation/'  # Top-level folder
pattern = os.path.join(root_dir, '**', 'overview_stim.csv')
overview_files = glob.glob(pattern, recursive=True)

print(f"Found {len(overview_files)} overview files.")

all_data = []
for root, dirs, files in os.walk(root_dir):
    for file in files:
        if file == 'overview_stim.csv':
            df = pd.read_csv(os.path.join(root, file))
            df['nerve_id'] = os.path.basename(root)  # Add identifier
            all_data.append(df)

df_all = pd.concat(all_data, ignore_index=True)

# ---- STEP 2: Clean and annotate ---- #
# Convert current to numeric
df_all['current_uA'] = df_all['current'].str.replace('uA', '', regex=False).astype(float)

# Assign material per pulse number and port
def get_material(row):
    first_material = 'PEDOT:DES' if row['port'] == 'PortA' else 'PEDOT:PSS'
    second_material = 'PEDOT:PSS' if first_material == 'PEDOT:DES' else 'PEDOT:DES'
    return first_material if row['pulse'] <= 10 else second_material

df_all['material'] = df_all.apply(get_material, axis=1)

# Assign stimulation order
df_all['order'] = df_all['port'].map({
    'PortA': 'DES first',
    'PortB': 'PSS first'
})

# Compute RMS delta from baseline, and MAX metrics
df_all['delta_RMS'] = df_all['RMS'] - df_all['baseline_RMS']
df_all['delta_MAX'] = df_all['MAX'] - df_all['baseline_MAX']

# ---- STEP 3: Aggregate per nerve ---- #
agg_data = (
    df_all
    .groupby(['current_uA', 'material', 'nerve_id'])
    .agg(
        delta_RMS_mean=('delta_RMS', 'mean'),
        delta_RMS_std=('delta_RMS', 'std'),
        pulse_count=('pulse', 'count'),
        number_ch=('number_ch', 'first')
    )
    .reset_index()
)

agg_data['delta_RMS_sem'] = agg_data['delta_RMS_std']  # Optional: normalize by pulse_count

agg_data_MAX = (
    df_all
    .groupby(['current_uA', 'material', 'nerve_id'])
    .agg(
        delta_MAX_mean=('delta_MAX', 'mean'),
        delta_MAX_std=('delta_MAX', 'std'),
        pulse_count=('pulse', 'count'),
        number_ch=('number_ch', 'first')
    )
    .reset_index()
)

agg_data_MAX['delta_MAX_sem'] = agg_data_MAX['delta_MAX_std']

agg_data_RMS = (
    df_all
    .groupby(['current_uA', 'material', 'nerve_id'])
    .agg(
        RMS_mean=('RMS', 'mean'),
        RMS_std=('RMS', 'std'),
        pulse_count=('pulse', 'count'),
        number_ch=('number_ch', 'first')
    )
    .reset_index()
)

agg_data_RMS['RMS_sem'] = agg_data_RMS['RMS_std']

# ---- STEP 4: Plotting ---- #
sns.set(style='whitegrid')
fig, axs = plt.subplots(2,1 , figsize=(16, 10), sharex=True)

# ----- ΔRMS plot ----- #
for material, color in zip(['PEDOT:PSS', 'PEDOT:DES'], ['tab:blue', 'tab:orange']):
    mat_data = agg_data[agg_data['material'] == material]

    grouped = mat_data.groupby('current_uA').apply(
        lambda x: pd.Series({
            'mean': np.average(x['delta_RMS_mean'], weights=x['pulse_count']),
            'sem': np.sqrt(np.sum((x['delta_RMS_sem'] ** 2))) / len(x)
        })
    ).reset_index()

    # Compute n per current
    nerve_counts = (
        mat_data.groupby('current_uA')['nerve_id']
        .nunique()
        .reset_index(name='n_nerves')
    )
    grouped = grouped.merge(nerve_counts, on='current_uA')

    # Filter to valid currents
    valid = grouped[grouped['n_nerves'] > 3]

    axs[0].errorbar(valid['current_uA'], valid['mean'], yerr=valid['sem'],
                    label=material, marker='o', color=color, capsize=4)

    if material == 'PEDOT:PSS':
        for _, row in valid.iterrows():
            axs[0].text(row['current_uA'], row['mean'] + 20,
                        f"n={int(row['n_nerves'])}", ha='center', color=color, fontsize=15)

axs[0].set_title("Population ΔRMS (RMS - Baseline)")
axs[0].set_xlabel("Current (µA)")
axs[0].set_ylabel("ΔRMS")
axs[0].set_xlim([0, 32])
axs[0].legend()

# ----- RMS plot ----- #
for material, color in zip(['PEDOT:PSS', 'PEDOT:DES'], ['tab:blue', 'tab:orange']):
    mat_data = agg_data_RMS[agg_data_RMS['material'] == material]

    grouped = mat_data.groupby('current_uA').apply(
        lambda x: pd.Series({
            'mean': np.average(x['RMS_mean'], weights=x['pulse_count']),
            'sem': np.sqrt(np.sum((x['RMS_sem'] ** 2))) / len(x)
        })
    ).reset_index()

    nerve_counts = (
        mat_data.groupby('current_uA')['nerve_id']
        .nunique()
        .reset_index(name='n_nerves')
    )
    grouped = grouped.merge(nerve_counts, on='current_uA')

    valid = grouped[grouped['n_nerves'] > 3]

    axs[1].errorbar(valid['current_uA'], valid['mean'], yerr=valid['sem'],
                    label=material, marker='o', color=color, capsize=4)

    if material == 'PEDOT:PSS':
        for _, row in valid.iterrows():
            axs[1].text(row['current_uA'], row['mean'] + 10,
                        f"n={int(row['n_nerves'])}", ha='center', color=color, fontsize=15)

axs[1].set_title("Population RMS")
axs[1].set_xlabel("Current (µA)")
axs[1].set_ylabel("RMS")
axs[1].set_xlim([0, 32])
axs[1].legend()


plt.tight_layout()
plt.show()
plt.savefig(f'{root_dir}/population_stimulation_with_counts_PortA_only_above3.svg', dpi=300)

# ------------------------------
# STEP 5: Plot population ΔRMS
# ------------------------------
plt.figure(figsize=(10, 6))
colors = {'PEDOT:PSS': 'tab:blue', 'PEDOT:DES': 'tab:orange'}

# Plot individual nerves
for material in ['PEDOT:PSS', 'PEDOT:DES']:
    mat_data = agg_data[agg_data['material'] == material]
    for nerve_id in mat_data['nerve_id'].unique():
        sub = mat_data[mat_data['nerve_id'] == nerve_id]
        plt.plot(sub['current_uA'], sub['delta_RMS_mean'],
                 color=colors[material], alpha=0.4, linestyle='-', marker='o', label=None)

# Plot population mean ± SEM: Mean across nerves 
pop = agg_data.groupby(['current_uA', 'material']).agg(
    delta_RMS_mean=('delta_RMS_mean', 'mean'),
    delta_RMS_sem=('delta_RMS_mean', 'sem'),
    nerve_count=('nerve_id', 'nunique')  # Count of unique nerve_ids
).reset_index()

for material in ['PEDOT:PSS', 'PEDOT:DES']:
    data = pop[pop['material'] == material]
    # Plot the data
    for current in data['current_uA'].unique():
        count = data[data['current_uA'] == current]['nerve_count'].values[0]
        color = 'tab:blue' if material == 'PEDOT:PSS' else 'tab:orange'

        plt.errorbar(data[data['current_uA'] == current]['current_uA'],
                     data[data['current_uA'] == current]['delta_RMS_mean'],
                     yerr=data[data['current_uA'] == current]['delta_RMS_sem'],
                     color=color, linewidth=2.5, marker='o')

        # Add text box with nerve count for each current
        plt.text(data[data['current_uA'] == current]['current_uA'].values[0],
                 data[data['current_uA'] == current]['delta_RMS_mean'].values[0],
                 f'N={count}', color=color, fontsize=10,
                 ha='center', va='bottom', bbox=dict(facecolor='white', edgecolor=color, boxstyle='round,pad=0.3'))

plt.xlabel("Current (μA)")
plt.ylabel("ΔRMS = RMS - Baseline RMS")
plt.title("Population ΔRMS: PEDOT:PSS vs PEDOT:DES")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()
plt.savefig(f'{root_dir}/population_stimulation_allCurrents_individual_with_counts.svg', dpi=300)

# ------------------------------
# STEP 6: Paired statistical test at 10 μA and 30 μA
# ------------------------------
for target_uA in [10, 30]:
    data = agg_data[agg_data['current_uA'] == target_uA]

    # Pivot to align materials by nerve_id
    pivot = data.pivot(index='nerve_id', columns='material', values='delta_RMS_mean').dropna()

    print(f"\n--- Statistical Comparison at {target_uA} μA ---")
    if pivot.shape[0] >= 2:
        differences = pivot['PEDOT:PSS'] - pivot['PEDOT:DES']
        shapiro_stat, shapiro_p = shapiro(differences)

        print(f"  N = {pivot.shape[0]} nerves")
        print(f"  Shapiro-Wilk normality test on differences:")
        print(f"    W = {shapiro_stat:.3f}, p = {shapiro_p:.4g}")

        if shapiro_p > 0.05:
            print("  → Differences are normally distributed (p > 0.05). Using paired t-test.")
            t_stat, p_val = ttest_rel(pivot['PEDOT:PSS'], pivot['PEDOT:DES'])
            print(f"  Paired t-test result:\n    t = {t_stat:.3f}, p = {p_val:.4g}")
        else:
            print("  → Differences are NOT normally distributed (p ≤ 0.05). Using Wilcoxon signed-rank test.")
            try:
                w_stat, p_val = wilcoxon(pivot['PEDOT:PSS'], pivot['PEDOT:DES'])
                print(f"  Wilcoxon signed-rank test result:\n    W = {w_stat:.3f}, p = {p_val:.4g}")
            except ValueError as e:
                print(f"  Wilcoxon test failed: {e}")
        
        print("  Interpretation:")
        print("    Positive difference → PEDOT:PSS > PEDOT:DES")
        print("    Negative difference → PEDOT:DES > PEDOT:PSS")
    else:
        print(f"  Not enough matched nerves with both materials at {target_uA} μA for statistical comparison.")

df_all.to_csv(f'{root_dir}/_all_nerves_overview_stim.csv', index=False)


In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.stats import ttest_rel

# Threshold data (R1 to R4)
data = {
    'R1_RN': [6, 8.5],
    'R1_LN': [8, 13],
    'R2_RN': [6, 35],
    'R2_LN': [32, 5],
    'R3_RN': [8, 9],
    'R3_LN': [3, 2],
    'R4_RN': [7, 9],
    'R4_LN': [8, 10]
}

# Create a DataFrame from the dictionary
df = pd.DataFrame(data)

# Plot the data
plt.figure(figsize=(10, 6))

# Plot thresholds for each group (RN and LN)
for group in df.columns:
    plt.plot(df.index, df[group], marker='o', label=group)

# Customize the plot
plt.xlabel('Sample Index')
plt.ylabel('Threshold (uA)')
plt.title('Threshold Comparison: RN vs LN for Each Rat Group')
plt.legend()
plt.grid(True)
plt.show()
plt.savefig(f'{root_dir}/stimulation_twitches_pairedComparison.svg', dpi=300)

# Calculate the difference (DES - PSS) for each experiment (each column)
increments = df.iloc[0] - df.iloc[1]  # DES - PSS for each experiment

# Plot the increments (DES - PSS) for each experiment
plt.figure(figsize=(10, 6))

# Plot the increment for each experiment
plt.bar(increments.index, increments.values, color='purple')

# Customize the plot
plt.xlabel('Experiments')
plt.ylabel('Threshold Increment (DES - PSS) (uA)')
plt.title('Threshold Increment (DES - PSS) for Each Experiment')
plt.xticks(rotation=45)
plt.grid(True)
plt.show()

# Perform paired t-tests between DES and PSS for each experiment
results = {}
for experiment in df.columns:
    # Perform paired t-test for DES (first value) vs PSS (second value)
    t_stat, p_val = ttest_rel(df[experiment][0], df[experiment][1])
    
    # Store results
    results[experiment] = (t_stat, p_val)

# Print the statistical results
for experiment, (t_stat, p_val) in results.items():
    print(f"\n{experiment}:")
    print(f"  t = {t_stat:.3f}, p = {p_val:.3g}")
    if p_val < 0.05:
        print("  The difference between DES and PSS is statistically significant.")
    else:
        print("  The difference between DES and PSS is not statistically significant.")

plt.savefig(f'{root_dir}/stimulation_twitches.svg', dpi=300)
