In [None]:
## import os.path as op
import pyplr
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import math
from scipy.signal import savgol_filter
import scipy.io as sio
import statistics as stats
from scipy.interpolate import interp1d

sns.set_context('notebook', font_scale=1.2)
from pyplr import graphing, utils, preproc
from pyplr.plr import PLR

np.set_printoptions(threshold=np.inf)

# Define trial lists - using only Visual trials
Neutral = [1,2,3,4,5,6,7,8,9,10,21,22,23,24,25,26,27,28,29,30,41,42,43,44,45,46,47,48,49,50]
Visual = [11,12,13,14,15,16,17,18,19,20,31,32,33,34,35,36,37,38,39,40,51,52,53,54,55,56,57,58,59,60]

# Lists to store trial durations and classifications
trial_durations = []
short_duration_trials = []
long_duration_trials = []

def get_condition_label(duration, median_duration):
    """Get the label for trial condition based on duration"""
    if duration <= median_duration:
        return "Short"
    else:
        return "Long"

baseline_avg_array = []
self_reward_data_series = []
other_reward_data_series = []
self_reward_timestamps = []
other_reward_timestamps = []
self_reward_confidence = []
other_reward_confidence = []

# Lists to store trial data by condition (as dictionaries to handle variable length)
short_condition_trials = []
long_condition_trials = []
short_condition_baselines = []
long_condition_baselines = []

# Load the mat file data
mat_data = sio.loadmat('001_sirisha_v.mat')

# First pass: Calculate all trial durations for median split (ONLY Visual TRIALS)
print("Calculating trial durations for median split (Visual trials only)...")
for trial_num in Visual:  # Only process Visual trials
    trial_key = f'Trial{trial_num}'
    if trial_key in mat_data:
        trial_data = mat_data[trial_key]
        
        # Get behavioral codes
        behavioral_codes = trial_data['BehavioralCodes'][0, 0]
        code_times = behavioral_codes['CodeTimes'][0, 0].flatten()
        code_numbers = behavioral_codes['CodeNumbers'][0, 0].flatten()
        time_2 = code_times[code_numbers == 2]
         
        if len(time_2) == 0:
            continue
            
        # Get raw data
        analog_data = trial_data['AnalogData'][0, 0]
        lsl = analog_data['LSL'][0, 0]
        lsl_data = lsl['LSL1'][0, 0]
        
        # Extract timestamps
        timestamps = lsl_data[:, 0]
        
        # Calculate trial duration (from start to time_2)
        trial_start = timestamps[0]
        trial_end = time_2[0]
        duration = trial_end - trial_start
        
        trial_durations.append({
            'trial_num': trial_num,
            'duration': duration
        })

# Calculate median duration
durations_only = [trial['duration'] for trial in trial_durations]
median_duration = np.median(durations_only)
print(f"Median trial duration (Visual trials): {median_duration:.2f} ms")
print(f"Trial duration range (Visual trials): {min(durations_only):.2f} - {max(durations_only):.2f} ms")

# Create duration-based trial lists
short_duration_trial_nums = []
long_duration_trial_nums = []

for trial in trial_durations:
    if trial['duration'] <= median_duration:
        short_duration_trial_nums.append(trial['trial_num'])
    else:
        long_duration_trial_nums.append(trial['trial_num'])

print(f"Short Duration Visual trials (≤{median_duration:.2f}ms): {len(short_duration_trial_nums)} trials")
print(f"Long Duration Visual trials (>{median_duration:.2f}ms): {len(long_duration_trial_nums)} trials")
print(f"Short Duration Visual trial numbers: {short_duration_trial_nums}")
print(f"Long Duration Visual trial numbers: {long_duration_trial_nums}")

# Second pass: Process trials with duration-based classification (ONLY Visual TRIALS)
for trial_num in Visual:  # Only process Visual trials
    trial_key = f'Trial{trial_num}'
    if trial_key in mat_data:
        trial_data = mat_data[trial_key]
        
        # Get trial duration and classify
        trial_duration_info = next((t for t in trial_durations if t['trial_num'] == trial_num), None)
        if trial_duration_info is None:
            continue
            
        condition_label = get_condition_label(trial_duration_info['duration'], median_duration)
        
        # Get behavioral codes
        behavioral_codes = trial_data['BehavioralCodes'][0, 0]
        code_times = behavioral_codes['CodeTimes'][0, 0].flatten()
        code_numbers = behavioral_codes['CodeNumbers'][0, 0].flatten()
        time_2 = code_times[code_numbers == 2]
         
        if len(time_2) == 0:
            continue
        # Get raw data
        analog_data = trial_data['AnalogData'][0, 0]
        lsl = analog_data['LSL'][0, 0]
        lsl_data = lsl['LSL1'][0, 0]
        
        # Extract timestamps and diameter data
        timestamps = lsl_data[:, 0]
        diameter0 = lsl_data[:, 19].astype(float)
        diameter1 = lsl_data[:, 20].astype(float)
        average_diameter = (diameter0 + diameter1) / 2

        # Use global confidence
        confidence = lsl_data[:, 1]

        # MODIFIED: Extract data from time_2 (trial onset) to 3000ms after time_2
        time_start = time_2[0]  # Start at the event (trial onset)
        time_end = time_2[0] + 3000  # End 3000ms after the event
        mask = (timestamps >= time_start) & (timestamps <= time_end)
        
        # Get masked timestamps and data
        masked_timestamps = timestamps[mask]
        masked_confidence = confidence[mask]
        masked_diameter0 = diameter0[mask]
        masked_diameter1 = diameter1[mask]
        masked_average_diameter = average_diameter[mask]

        # Normalize timestamps to start from 0 (trial onset is now 0)
        normalized_timestamps = masked_timestamps - time_2[0]  # This will make the first timestamp 0
       
        nan_mask = ~np.isnan(masked_average_diameter)
        working_pupil_data = masked_average_diameter[nan_mask]
        working_timestamps = normalized_timestamps[nan_mask]

        # Plot raw pupil series after NaNs removal
        f = plt.figure(1)
        f.set_figheight(10)
        f.set_figwidth(20)
        plt.plot(working_timestamps, working_pupil_data, alpha=1)
        plt.title("Raw pupil Data after NaNs removal (Visual Trials)")
        plt.xlabel("Time in milliseconds (after trial onset)", fontsize="20")
        plt.ylabel("Pupil Diameter in arbitrary units", fontsize="20")

        # Calculating velocity
        velocity_series = np.diff(working_pupil_data)
        velocity_timestamps = working_timestamps[:-1]

        # Plot velocity
        f = plt.figure(3)
        f.set_figheight(10)
        f.set_figwidth(20)
        plt.title("Velocity Calculated (Visual Trials)")
        plt.plot(velocity_timestamps, velocity_series)
        plt.xlabel("Time in ms (after trial onset)", fontsize="20")
        plt.ylabel("Velocity au/sec")

        bfring=50
        blink_indices = np.where(masked_confidence < 0.4)[0]
        blink_indices=(np.array(blink_indices)).flatten()
        blinklist = [] # Empty list created to store blinks
        for i in blink_indices:
            b = list(range((i-bfring),(i+bfring))) # Adding fringes on either side of points exceeding blink threshold...
            blinklist += b #...and adding indices of all these points to the list of blinks
        for i in range(len(blinklist)): # Removing portions of fringes that might have dropped below zero
            if blinklist[i]<0:
                blinklist[i]=0
        final_blinks = np.unique(np.array(blinklist)) # Storing final blink indices in an array, keeping only unique values
    
        permanent_finalblinks=[]
        for i in final_blinks:
            if i < len(working_pupil_data):
                permanent_finalblinks.append(i)
        permanent_finalblinks = np.array(permanent_finalblinks)

        # Plot blink detection
        f = plt.figure(4)
        f.set_figheight(10)
        f.set_figwidth(20)
        plt.plot(working_timestamps, working_pupil_data)
        plt.title("Blink Detection (Visual Trials)")

        blink_timestamps = [working_timestamps[i] for i in permanent_finalblinks if i < len(working_timestamps)]
        # Only plot vlines if we have valid data for min/max calculation
        if len(working_pupil_data) > 0 and not np.all(np.isnan(working_pupil_data)):
            valid_data = working_pupil_data[~np.isnan(working_pupil_data)]
            if len(valid_data) > 0:
                plt.vlines(blink_timestamps, ymin=min(valid_data), ymax=max(valid_data), colors='gray')

        # Insert NaNs at blink positions
        for i in permanent_finalblinks:
            if i < len(working_pupil_data):
                working_pupil_data[i] = np.nan

        # Plot NaNs inserted
        f = plt.figure(5)
        f.set_figheight(10)
        f.set_figwidth(20)
        plt.plot(working_timestamps, working_pupil_data)
        plt.xlabel("Time in milliseconds (after trial onset)", fontsize="20")
        plt.ylabel("Pupil Diameter in arbitrary units", fontsize="20")
        plt.title("NaNs Inserted in place of blinks (Visual Trials)")

        # Before interpolation, check if we have any valid data points
        if np.any(~np.isnan(working_pupil_data)):
            # We have some valid data points, proceed with interpolation
            ok = ~np.isnan(working_pupil_data)
            xp = ok.ravel().nonzero()[0]
            fp = working_pupil_data[~np.isnan(working_pupil_data)]
            x = np.isnan(working_pupil_data).ravel().nonzero()[0]
    
            # Only interpolate if we have both valid points and points to fill
            if len(xp) > 0 and len(x) > 0:
                working_pupil_data[np.isnan(working_pupil_data)] = np.interp(x, xp, fp)
            elif len(x) > 0:
                # If all data is NaN, fill with zeros (or another placeholder value)
                working_pupil_data[np.isnan(working_pupil_data)] = 0
        else:
            # All data is NaN, handle this case (e.g., skip the trial or fill with zeros)
            print(f"Warning: All data is NaN in trial {trial_num}. Filling with zeros.")
            working_pupil_data.fill(0)

        # Plot interpolated data
        f = plt.figure(6)
        f.set_figheight(10)
        f.set_figwidth(20)
        plt.plot(working_timestamps, working_pupil_data)
        plt.xlabel("Time in milliseconds (after trial onset)", fontsize="20")
        plt.ylabel("Pupil Diameter in arbitrary units", fontsize="20")
        plt.title("Interpolated over NaNs (Visual Trials)")

        # Smoothing data using savgol filter
        # Adjust window size based on data length to avoid ValueError
        data_length = len(working_pupil_data)
        if data_length >= 11:
            window_length = 11
        elif data_length >= 5:
            window_length = data_length if data_length % 2 == 1 else data_length - 1  # Must be odd
        else:
            window_length = min(3, data_length if data_length % 2 == 1 else data_length - 1) if data_length > 0 else 1
        
        # Ensure window_length is at least 1 and odd, and polyorder is less than window_length
        window_length = max(1, window_length)
        if window_length % 2 == 0:
            window_length -= 1
        polyorder = min(3, window_length - 1)
        
        if data_length > 0 and window_length > 0 and polyorder >= 0:
            filtered_pupil_data = savgol_filter(working_pupil_data, window_length, polyorder)
        else:
            # If data is too short for filtering, just use the original data
            filtered_pupil_data = working_pupil_data.copy()
        
        # Plot smoothened data
        f = plt.figure(2)
        f.set_figheight(10)
        f.set_figwidth(20)
        plt.plot(working_timestamps, filtered_pupil_data)
        plt.title("Filtered Pupil Data using Savgol (Visual Trials)")
        plt.xlabel("Time in milliseconds (after trial onset)", fontsize="20")
        plt.ylabel("Pupil Diameter in arbitrary units", fontsize="20")

        # MODIFIED: Calculate baseline from first 50 data points (closer to trial onset at time 0)
        # We'll use the data points closest to time 0 (trial onset) as our baseline
        if len(filtered_pupil_data) == 0:
            print(f"Warning: No data points in trial {trial_num}. Skipping this trial.")
            continue
            
        baseline_indices = range(0, min(50, len(filtered_pupil_data)))
        baseline = filtered_pupil_data[baseline_indices]
        
        # Check if baseline has valid data
        if len(baseline) == 0 or np.all(np.isnan(baseline)):
            print(f"Warning: No valid baseline data in trial {trial_num}. Skipping this trial.")
            continue
            
        # Remove NaN values from baseline for mean calculation
        valid_baseline = baseline[~np.isnan(baseline)]
        if len(valid_baseline) == 0:
            print(f"Warning: All baseline data is NaN in trial {trial_num}. Skipping this trial.")
            continue
            
        avg = stats.mean(valid_baseline)
        baseline_corrected = filtered_pupil_data - avg
        
        baseline_short = filtered_pupil_data[baseline_indices]
        valid_baseline_short = baseline_short[~np.isnan(baseline_short)]
        if len(valid_baseline_short) == 0:
            avg_short = avg  # Fallback to general average
        else:
            avg_short = stats.mean(valid_baseline_short)
        baseline_corrected_short = filtered_pupil_data - avg_short
        
        baseline_long = filtered_pupil_data[baseline_indices]
        valid_baseline_long = baseline_long[~np.isnan(baseline_long)]
        if len(valid_baseline_long) == 0:
            avg_long = avg  # Fallback to general average
        else:
            avg_long = stats.mean(valid_baseline_long)
        baseline_corrected_long = filtered_pupil_data - avg_long
        
        f = plt.figure(7)
        f.set_figheight(10)
        f.set_figwidth(20)
        plt.plot(working_timestamps, baseline_corrected)
        plt.xlabel("Time in milliseconds (after trial onset)", fontsize="20")
        plt.ylabel("Pupil Diameter in arbitrary units", fontsize="20")
        plt.title("Subtractive Baseline Correction (Visual Trials)")

        baseline_avg_array.append(avg)
        
        # Store trial data and baseline by condition
        if condition_label == "Short":
            short_condition_trials.append({
                'trial_num': trial_num,
                'data': baseline_corrected_short,
                'timestamps': working_timestamps.copy()
            })
            short_condition_baselines.append(avg)
        elif condition_label == "Long":
            long_condition_trials.append({
                'trial_num': trial_num,
                'data': baseline_corrected_long,
                'timestamps': working_timestamps.copy()
            })
            long_condition_baselines.append(avg)

# Convert baseline lists to numpy arrays (these are consistent in shape)
short_condition_baselines = np.array(short_condition_baselines)
long_condition_baselines = np.array(long_condition_baselines)

# Print summary of the data
print(f"\nNumber of Short Duration Visual trials: {len(short_condition_trials)}")
print(f"Number of Long Duration Visual trials: {len(long_condition_trials)}")

if short_condition_trials:
    print("\nFirst few Short Duration Visual trials:")
    for i in range(min(3, len(short_condition_trials))):
        print(f"Trial {short_condition_trials[i]['trial_num']}: {len(short_condition_trials[i]['data'])} data points")
        print(f"First 5 values: {short_condition_trials[i]['data'][:5]}")

if long_condition_trials:
    print("\nFirst few Long Duration Visual trials:")
    for i in range(min(3, len(long_condition_trials))):
        print(f"Trial {long_condition_trials[i]['trial_num']}: {len(long_condition_trials[i]['data'])} data points")
        print(f"First 5 values: {long_condition_trials[i]['data'][:5]}")

# Plot histogram of baseline averages
baseline_avg_array = np.array(baseline_avg_array)
fig, ax = plt.subplots(figsize=(20, 10))
ax.hist(baseline_avg_array, bins=20, alpha=0.7, label="All Visual trials")
ax.hist(short_condition_baselines, bins=20, alpha=0.5, label="Short Duration Visual")
ax.hist(long_condition_baselines, bins=20, alpha=0.5, label="Long Duration Visual")
ax.legend()
ax.set_title("Baseline Pupil Diameter Distribution by Duration (Visual Trials)")
ax.set_xlabel("Baseline Pupil Diameter (arbitrary units)")
ax.set_ylabel("Count")
plt.show()

# Resample trials to a common time axis for averaging
def resample_trial_data(trials, target_length=500):
    """Resample all trials to a consistent length for averaging"""
    resampled_data = []
    
    for trial in trials:
        data = trial['data']
        if len(data) < 2:  # Skip trials with insufficient data
            continue
            
        # Create interpolation function
        x_original = np.linspace(0, 1, len(data))
        x_new = np.linspace(0, 1, target_length)
        interpolator = interp1d(x_original, data, kind='linear', bounds_error=False, fill_value='extrapolate')
        
        # Resample data
        resampled = interpolator(x_new)
        resampled_data.append(resampled)
    
    return np.array(resampled_data)

# Resample and average trials
target_length = 500  # Choose an appropriate length for resampling

if short_condition_trials:
    short_resampled = resample_trial_data([t for t in short_condition_trials if len(t['data']) >= 2], target_length)
    if len(short_resampled) > 0:
        short_mean = np.mean(short_resampled, axis=0)
        short_std = np.std(short_resampled, axis=0)
        short_sem = short_std / np.sqrt(short_resampled.shape[0])  # Calculate SEM
        print(f"\nResampled Short Duration Visual data shape: {short_resampled.shape}")
    else:
        print("\nNo valid Short Duration Visual trials for resampling")

if long_condition_trials:
    long_resampled = resample_trial_data([t for t in long_condition_trials if len(t['data']) >= 2], target_length)
    if len(long_resampled) > 0:
        long_mean = np.mean(long_resampled, axis=0)
        long_std = np.std(long_resampled, axis=0)
        long_sem = long_std / np.sqrt(long_resampled.shape[0])  # Calculate SEM
        print(f"Resampled Long Duration Visual data shape: {long_resampled.shape}")
    else:
        print("No valid Long Duration Visual trials for resampling")

# MODIFIED: Create time axis from 0 to 3000
time_axis = np.linspace(0, 3000, target_length)

f=plt.figure(8)
f.set_figheight(10) 
f.set_figwidth(20)  
plt.plot(time_axis, short_mean, 'bo', label="Short Duration Visual", linewidth=2)
plt.fill_between(time_axis, short_mean - short_sem, short_mean + short_sem, color='blue', alpha=0.3)
plt.plot(time_axis, long_mean, 'ro', label="Long Duration Visual", linewidth=2)
plt.fill_between(time_axis, long_mean - long_sem, long_mean + long_sem, color='red', alpha=0.3)
plt.xlabel("Time in milliseconds (after trial onset)", fontsize="20")
plt.ylabel("Pupil Diameter in arbitrary units", fontsize="20")
plt.legend(loc="upper left", fontsize="16") 
plt.title("Pupil Size Time Series Comparison - Short vs Long Duration Visual (with SEM)")


f=plt.figure(11)
f.set_figheight(10)
f.set_figwidth(20)
plt.plot(time_axis, short_mean, 'bo', label="Short Duration Visual", linewidth=2)
plt.fill_between(time_axis, short_mean - short_sem, short_mean + short_sem, color='blue', alpha=0.3)
plt.plot(time_axis, long_mean, 'ro', label="Long Duration Visual", linewidth=2)
plt.fill_between(time_axis, long_mean - long_sem, long_mean + long_sem, color='red', alpha=0.3)
plt.xlabel("Time in milliseconds (after trial onset)", fontsize="20")
plt.ylabel("Change in Pupil size over time wrt Baseline size", fontsize="20")
plt.legend(loc="upper left", fontsize="16")
plt.title("Subtractive Baseline Corrected Pupil Size Series -  ")


# Proportion change
short_percent_change = (short_mean / np.mean(short_condition_baselines)) * 100
long_percent_change = (long_mean / np.mean(long_condition_baselines)) * 100

# Calculate SEM for percentage change
short_percent_sem = (short_sem / np.mean(short_condition_baselines)) * 100
long_percent_sem = (long_sem / np.mean(long_condition_baselines)) * 100

f=plt.figure(12)
f.set_figheight(10)
f.set_figwidth(20)
plt.plot(time_axis, long_percent_change, 'ro', label="Long Duration Visual", linewidth=2)
plt.fill_between(time_axis, long_percent_change - long_percent_sem, long_percent_change + long_percent_sem, color='red', alpha=0.3)
plt.plot(time_axis, short_percent_change, 'bo', label="Short Duration Visual", linewidth=2)
plt.fill_between(time_axis, short_percent_change - short_percent_sem, short_percent_change + short_percent_sem, color='blue', alpha=0.3)
plt.xlabel("Time in milliseconds (after trial onset)", fontsize="20")
plt.ylabel("Percentage Change in Pupil Size from the Baseline Size")
plt.legend(loc="upper left", fontsize="16")
plt.title("Percentage Change in Pupil Size -  ")



# Combine short and long means into one array
final_reward = list(short_mean)
final_reward.extend(list(long_mean))
mean_final_reward = stats.mean(final_reward)
std_final = stats.stdev(final_reward)

# Calculate z-scores for all data points
zscored = []
for i in range(len(final_reward)):
    z = (final_reward[i] - mean_final_reward) / std_final
    zscored.append(z)

# Split the z-scored data back into short and long conditions
short_z = zscored[:target_length]  # First half is short condition
long_z = zscored[target_length:]  # Second half is long condition

# Convert to numpy arrays
short_z = np.array(short_z)
long_z = np.array(long_z)

# MODIFIED: Calculate baseline z-scores (first 50 points instead of last 50)
baseline_short_z = short_z[:50]
baseline_long_z = long_z[:50]

# Calculate mean baseline z-scores
mean_base_short = stats.mean(baseline_short_z)
mean_base_long = stats.mean(baseline_long_z)

# Calculate baseline-corrected z-scores
final_z_b_short = short_z - mean_base_short
final_z_b_long = long_z - mean_base_long

# Calculate percentage change
percentage_short = (final_z_b_short / mean_base_short) * 100 if mean_base_short != 0 else np.zeros_like(final_z_b_short)
percentage_long = (final_z_b_long / mean_base_long) * 100 if mean_base_long != 0 else np.zeros_like(final_z_b_long)

# Calculate SEM for z-score percentage change
# For z-scores, we need to propagate error through the calculations
# Since z-scores are already normalized, we'll use a simplified SEM calculation
z_short_sem = np.ones_like(percentage_short) * (np.std(percentage_short) / np.sqrt(len(short_condition_trials))) if len(short_condition_trials) > 1 else np.zeros_like(percentage_short)
z_long_sem = np.ones_like(percentage_long) * (np.std(percentage_long) / np.sqrt(len(long_condition_trials))) if len(long_condition_trials) > 1 else np.zeros_like(percentage_long)

# Plot the percentage changes
f = plt.figure(14)
f.set_figheight(10)
f.set_figwidth(20)

# Plot both lines with fill_between for SEM
plt.plot(time_axis, percentage_long, 'ro', label="Long Duration Visual", linewidth=2)
plt.fill_between(time_axis, percentage_long - z_long_sem, percentage_long + z_long_sem, color='red', alpha=0.3)
plt.plot(time_axis, percentage_short, 'bo', label="Short Duration Visual", linewidth=2)
plt.fill_between(time_axis, percentage_short - z_short_sem, percentage_short + z_short_sem, color='blue', alpha=0.3)

plt.xlabel("Time (ms)", fontsize="20")
plt.ylabel("Percentage change in pupil size")
plt.legend(loc="upper left", fontsize="16")
plt.title("Percentage change in pupil size -  ")
plt.show()

# Plot average pupil responses if we have valid data
if 'short_mean' in locals() and 'long_mean' in locals():
    fig, ax = plt.subplots(figsize=(20, 10))
    
    # Plot with SEM using fill_between
    ax.plot(time_axis, short_mean, label="Short Duration Visual", color="blue", linewidth=2)
    ax.fill_between(time_axis, short_mean - short_sem, short_mean + short_sem, color="blue", alpha=0.3, label="Short Duration Visual ± SEM")
    
    ax.plot(time_axis, long_mean, label="Long Duration Visual", color="red", linewidth=2)
    ax.fill_between(time_axis, long_mean - long_sem, long_mean + long_sem, color="red", alpha=0.3, label="Long Duration Visual ± SEM")
    
    ax.legend()
    ax.set_title("Average Pupil Response by Trial Duration ")
    ax.set_xlabel("Time (ms)")
    ax.set_ylabel("Baseline-corrected Pupil Diameter")
    plt.show()

# Print final summary
print(f"\nFinal Summary:")
print(f"Median duration threshold: {median_duration:.2f} ms")
print(f"Short Duration Visual trials: {len(short_condition_trials)} (≤ {median_duration:.2f} ms)")
print(f"Long Duration Visual trials: {len(long_condition_trials)} (> {median_duration:.2f} ms)")
print(f"Total processed Visual trials: {len(short_condition_trials) + len(long_condition_trials)}")