# Phase 1: Setup and Data Loading

In [None]:
# --- Step 1.1: Imports ---
import pandas as pd
import numpy as np
import statsmodels.api as sm # For GLM
import matplotlib.pyplot as plt
import seaborn as sns
import os # For file/directory operations if needed later

In [None]:
# --- Step 1.2: Configuration ---
PARTICIPANT_ID = "dddd"  # Example: "aaaa"
RUN = 1                 # Example: 1
PLOT_REGRESSORS = True
regressor_session = 1 # The session for which regressors will be inspected if PLOT_REGRESSORS is set to True

# Path to the deconvolution input CSV
# This file should contain all original samples within continuous recording blocks (sessions),
# with flag
# s for trial rejection, but no trials physically dropped if it creates discontinuities.
# It should also NOT have undergone trial-specific baselining.
DECONV_INPUT_FILENAME = f"{RUN}_deconvolution_input.csv"
input_data_path = f"./data/fully_preprocessed_data/{PARTICIPANT_ID}/{DECONV_INPUT_FILENAME}"
message_data_path = f"./data/fully_preprocessed_data/{PARTICIPANT_ID}/{RUN}_messages.csv"

image_dir = f"./data/deconv_figs/{PARTICIPANT_ID}/"
os.makedirs(image_dir, exist_ok=True)

pupil_signal_column_name = 'Pupil Filtered'  # This should be the filtered, interpolated data

# Columns required from the input CSV
# TODO: Make sure all events are flagged - go to the pre-processing step and see what details have been left out of the deconvolution.csv output
required_cols = [
    'Subject', 'Run', 'Session', 'Block', 'Trial', 'Timestamp',
    pupil_signal_column_name,
    'Session Condition',      # e.g., 'attend', 'divert'
    'Trial Condition',        # e.g., 'Standard', 'Oddball'
    'Main Stimulus Type',     # e.g., 'coarse_gabor', 'fine_gabor', 'noise_disk'
    'Session Common Oddball', # Name of the common oddball type for that session
    'Trial Main Stim Onset',  # Timestamp of main visual stimulus onset for the current trial
    'Trial_Rejected_Least_Strict', # The flag indicating if a trial met rejection criteria
    'Blink On Main Stim',
    'Saccade On Main Stim',
    'Exclude Trial',          # Flag for trials with un-interpolatable NaNs in Pupil Int
    'Interpolation PC',
    # Columns needed for Nuisance Event Regressors
    'Target Status',          # Boolean: True if an attention task target is on screen
    'Attention Letter',       # The letter shown (for divert) or NaN (for attend)
    'Blink',                  # Boolean: True if current sample is within a padded blink period
    'Saccade'                 # Boolean: True if current sample is within a padded saccade period
]

print(f"Attempting to load: {input_data_path}")

# --- Step 1.3: Load Data ---
# messages df:
try:
    messages_df = pd.read_csv(message_data_path)
except FileNotFoundError:
    print(f"ERROR: File not found at {message_data_path}")
    print("Please ensure the 'Data Pre-processing.ipynb' notebook has been run with DECONVOLUTION_OUTPUT=True,")
    print(f"and the file '{RUN}_messages.csv' exists in the correct directory.")
    raise  # Stop execution
except ValueError as e:
    print(f"ERROR during data loading or column check: {e}")
    print("This might be due to missing columns in the CSV or an issue with the file path.")
    raise  # Stop execution
except Exception as e:
    print(f"An unexpected error occurred during data loading: {e}")
    raise

# data df:
try:
    raw_df = pd.read_csv(input_data_path, usecols=lambda col_name: col_name in required_cols)   # Passing cols as list would raise ValueError if col is missing.
                                                                                                # Passing via lambda simply doesn't load missing cols
    print(f"\nSuccessfully loaded data for Participant {PARTICIPANT_ID}, Run {RUN}.")
    print(f"Shape of loaded DataFrame (raw_df): {raw_df.shape}")
    original_sample_count = raw_df.shape[0] # Used in the downsampling code to check that the downsampled number of samples is the expected value by comparing the orig number / the downsampling factor

    # Basic check for pupil signal column
    if pupil_signal_column_name not in raw_df.columns:
        raise ValueError(f"Pupil signal column '{pupil_signal_column_name}' not found in loaded data.")
    if 'Trial Main Stim Onset' not in raw_df.columns:
        raise ValueError(f"'Trial Main Stim Onset' column not found. It's critical for event alignment.")
    if 'Trial_Rejected_Least_Strict' not in raw_df.columns:
        print(f"WARNING: 'Trial_Rejected_Least_Strict' column not found.\n\tWill assume all trials are good for event selection. Ensure this is intended.")
        raw_df['Trial_Rejected_Least_Strict'] = False # Add it as False if missing, for safety

    # Convert relevant columns to appropriate types if necessary (e.g., for merging or boolean checks)
    raw_df['Trial_Rejected_Least_Strict'] = raw_df['Trial_Rejected_Least_Strict'].astype(bool)

except FileNotFoundError:
    print(f"ERROR: File not found at {input_data_path}")
    print("Please ensure the 'Data Pre-processing.ipynb' notebook has been run with DECONVOLUTION_OUTPUT=True,")
    print(f"and the file '{DECONV_INPUT_FILENAME}' exists in the correct directory.")
    raise  # Stop execution
except ValueError as e:
    print(f"ERROR during data loading or column check: {e}")
    print("This might be due to missing columns in the CSV or an issue with the file path.")
    raise  # Stop execution
except Exception as e:
    print(f"An unexpected error occurred during data loading: {e}")
    raise

# Verify that 'Trial_Rejected_Least_Strict' has both True and False values if expected
if 'Trial_Rejected_Least_Strict' in raw_df.columns:
    print(f"\nValue counts for 'Trial_Rejected_Least_Strict':")
    print(raw_df['Trial_Rejected_Least_Strict'].value_counts(dropna=False))

# Display a snippet of the loaded data
print("\nSnippet of loaded data:")
raw_df

# Setup for the Finite Impulse Response Model and the Downsampling:

In [None]:
# --- Step 1.4: Define FIR Model Parameters ---
# These parameters define how we'll model the shape of the pupil response.
fir_window_start_ms = 0    # When to start estimating responses from an event onset
fir_window_end_ms = 3000   # Estimate response up to 3000ms after the event onset
fir_bin_width_ms = 200     # Each FIR beta will represent average activity over a fir_bin_width_ms millisecond bin

# Calculate the number of FIR bins/regressors per event type
if fir_bin_width_ms <= 0:
    raise ValueError("fir_bin_width_ms must be positive.")
num_fir_bins = int((fir_window_end_ms - fir_window_start_ms) / fir_bin_width_ms)

if num_fir_bins <= 0:
    raise ValueError("Number of FIR bins is not positive. Check fir_window_start_ms, fir_window_end_ms, and fir_bin_width_ms.")

print(f"--- FIR Model Parameters ---")
print(f"  Response window: {fir_window_start_ms}ms to {fir_window_end_ms}ms post-event")
print(f"  Bin width: {fir_bin_width_ms}ms")
print(f"  Number of FIR bins per event type: {num_fir_bins}")

# This will be the sampling period of the data *as it goes into the GLM*:
initial_sampling_period_ms = 1 # Placeholder, assuming 1000Hz initially.
# The actual sampling_period_ms used by run_fir_glm will be set after potential downsampling.



# --- Step 1.5: Define Downsampling Parameters ---
DO_DOWNSAMPLING = True  # Set to False to run on original 1000Hz data (slower)
target_sr_hz = 100      # Target sampling rate in Hz (e.g., 100Hz for 10ms resolution)

if DO_DOWNSAMPLING:
    if target_sr_hz <= 0:
        raise ValueError("target_sr_hz must be positive.")
    # Calculate the new sampling period after downsampling
    new_sampling_period_ms = int(1000 / target_sr_hz)

    # We need the original sampling rate to calculate the downsample factor:
    original_sr_hz = 1000 # Hz

    if original_sr_hz % target_sr_hz != 0:
        print(f"Warning: Original sampling rate ({original_sr_hz}Hz) is not an integer multiple of the target rate ({target_sr_hz}Hz). "
              "Downsampling might result in slightly uneven binning or require careful handling.")

    downsample_factor = int(original_sr_hz / target_sr_hz)

    if downsample_factor <= 1 and original_sr_hz > target_sr_hz : # Should only happen if target_sr is >= original_sr incorrectly
        print(f"Warning: Downsample factor is {downsample_factor}. Check original_sr_hz and target_sr_hz. Will effectively not downsample if factor is 1.")
        # If factor is 1 then effectively no downsampling, but we'll still use new_sampling_period_ms which should match initial_sampling_period_ms:
        actually_downsampling = False
    elif downsample_factor <=1 and original_sr_hz <= target_sr_hz:
        print(f"Target sampling rate ({target_sr_hz}Hz) is >= original ({original_sr_hz}Hz). No downsampling will be performed.")
        DO_DOWNSAMPLING = False # Override if no actual downsampling needed
        actually_downsampling = False
        new_sampling_period_ms = initial_sampling_period_ms # Revert to initial
        downsample_factor = 1
    else:
        actually_downsampling = True

    print(f"\n--- Downsampling Parameters ---")
    if actually_downsampling:
        print(f"  Downsampling ENABLED.")
        print(f"  Original sampling rate: {original_sr_hz}Hz (assumed)")
        print(f"  Target sampling rate: {target_sr_hz}Hz")
        print(f"  Downsample factor: {downsample_factor}")
        print(f"  New sampling period for GLM: {new_sampling_period_ms}ms")
    else:
        if DO_DOWNSAMPLING: # DO_DOWNSAMPLING was true but factor ended up <=1
             print(f"  Downsampling was enabled, but factor is {downsample_factor}. Effective sampling rate will be {original_sr_hz}Hz.")
             new_sampling_period_ms = initial_sampling_period_ms
        else: # DO_DOWNSAMPLING was false from the start
            print(f"  Downsampling DISABLED.")
            new_sampling_period_ms = initial_sampling_period_ms # Use initial if no downsampling
            downsample_factor = 1
            print(f"  Sampling period for GLM: {new_sampling_period_ms}ms")
else:
    print(f"\n--- Downsampling Parameters ---")
    print(f"  Downsampling DISABLED.")
    new_sampling_period_ms = initial_sampling_period_ms # Use initial if no downsampling
    downsample_factor = 1
    print(f"  Sampling period for GLM: {new_sampling_period_ms}ms")

# Store the effective sampling period to be used by the GLM and FIR construction
effective_sampling_period_ms_for_glm = new_sampling_period_ms

# Update FIR parameters dictionary that will be passed around
# Note: fir_bin_width_ms and num_fir_bins are defined based on desired *model resolution*,
# not directly by the sampling rate of the data, but the sampling rate determines
# how many actual data points fall into each conceptual FIR bin's influence or represent a lag.
fir_parameters_for_glm = {
    'fir_window_start_ms': fir_window_start_ms,
    'fir_window_end_ms': fir_window_end_ms,
    'fir_bin_width_ms': fir_bin_width_ms, # This is the width of the conceptual bin in ms
    'num_fir_bins': num_fir_bins,
    'sampling_period_ms': effective_sampling_period_ms_for_glm # This is critical for run_fir_glm
}

print(f"\nEffective sampling period for GLM construction: {fir_parameters_for_glm['sampling_period_ms']}ms")

# Define the GLM function:
This builds the design matrix and fits the OLS model for a single continuous segment of data.

In [None]:
# Pupil_Deconvolution.ipynb

# ... (Previous cells: Imports, Config, Data Loading, FIR/Downsample Params - Steps 1.1 to 1.5) ...

# --- Step 2.1: Define Function to Create Design Matrix and Run GLM ---
def run_fir_glm(pupil_data_segment, event_onsets_dict, fir_params, pupil_signal_col, segment_name="segment"):
    """
    Creates FIR design matrix and runs GLM for a continuous pupil data segment.

    Args:
        pupil_data_segment (pd.DataFrame): DataFrame with 'Timestamp' and `pupil_signal_col`.
                                           MUST be a continuous recording segment.
                                           Its index should be a simple RangeIndex if reset.
        event_onsets_dict (dict): Dict where keys are event_names (e.g., 'Std_Attend')
                                  and values are numpy arrays of onset timestamps (in original ms)
                                  for that event.
        fir_params (dict): Containing 'fir_window_start_ms', 'fir_window_end_ms',
                           'fir_bin_width_ms', 'num_fir_bins', 'sampling_period_ms'.
                           'sampling_period_ms' here is the period of pupil_data_segment.
        pupil_signal_col (str): Name of the column in pupil_data_segment containing the pupil signal.
        segment_name (str): Identifier for logging/printing.

    Returns:
        pd.Series: Beta coefficients from the GLM, or None if error.
    """
    _pupil_timestamps = pupil_data_segment['Timestamp'].values # Timestamps of the (potentially downsampled) data
    _y_pupil_signal = pupil_data_segment[pupil_signal_col].values

    if len(_pupil_timestamps) == 0:
        print(f"  GLM_FUNC: No pupil data in {segment_name}. Skipping GLM.")
        return None

    # Initialize design matrix using the index of pupil_data_segment
    design_matrix_fir = pd.DataFrame(index=pupil_data_segment.index)

    for event_name, onsets_for_event_type_ms in event_onsets_dict.items():
        if len(onsets_for_event_type_ms) == 0:
            # print(f"  GLM_FUNC: No onsets for {event_name} in {segment_name}.")
            continue # Skip if no onsets for this event type

        for i_bin in range(fir_params['num_fir_bins']):
            # current_lag_ms is the start of the conceptual bin in ms from event onset
            current_lag_ms = fir_params['fir_window_start_ms'] + (i_bin * fir_params['fir_bin_width_ms'])
            regressor_name = f'{event_name}_fir_lag{i_bin:02d}_{current_lag_ms}ms'
            regressor_column = np.zeros(len(_pupil_timestamps))

            # For each onset (in original ms), find its corresponding time point in _pupil_timestamps
            # and then add the lag.
            for onset_ts_ms in onsets_for_event_type_ms:
                # Target time for this impulse: original onset + current lag
                target_impulse_time_ms = onset_ts_ms + current_lag_ms

                # Find the index in _pupil_timestamps that is closest to target_impulse_time_ms
                # This correctly aligns events to the (potentially downsampled) timeline
                time_diff = np.abs(_pupil_timestamps - target_impulse_time_ms)
                idx_of_impulse_in_segment = np.argmin(time_diff)

                # Check if the closest timestamp is reasonably close to the target impulse time
                # (e.g., within half of the segment's sampling period)
                if time_diff[idx_of_impulse_in_segment] <= (fir_params['sampling_period_ms'] / 2.0):
                    if 0 <= idx_of_impulse_in_segment < len(regressor_column): # Boundary check
                        regressor_column[idx_of_impulse_in_segment] += 1 # Use +=1 if multiple events could map to same sample

            design_matrix_fir[regressor_name] = regressor_column

    if 'Blink' in pupil_data_segment.columns:
        design_matrix_fir['Blink_Boxcar'] = pupil_data_segment['Blink'].astype(int)
    else:
        print(f"  GLM_FUNC: Warning - 'Blink' column not found in pupil_data_segment for {segment_name}.")

    if 'Saccade' in pupil_data_segment.columns:
        design_matrix_fir['Saccade_Boxcar'] = pupil_data_segment['Saccade'].astype(int)
    else:
        print(f"  GLM_FUNC: Warning - 'Saccade' column not found in pupil_data_segment for {segment_name}.")

         # demeaned linear drift (center around zero)
    drift_regressor = np.linspace(-0.5, 0.5, len(pupil_data_segment))
    design_matrix_fir['Linear_Drift'] = drift_regressor

    if design_matrix_fir.empty and 'Linear_Drift' not in design_matrix_fir.columns : # Check if only drift was added to an empty df
        # This case means no event FIR regressors were created and no blink/saccade columns
        # If we only have drift and const, it's not very useful for event-related analysis
        print(f"  GLM_FUNC: Design matrix effectively empty (only drift/intercept possible) for {segment_name}. Skipping GLM.")
        return None

    # Add intercept (baseline)
    # Ensure 'const' is not already a column if this function is called multiple times on subsets
    if 'const' not in design_matrix_fir.columns:
        design_matrix_fir = sm.add_constant(design_matrix_fir, prepend=True, has_constant='skip')

    # Prepare Y and X for OLS:
    # Remove rows where Y is NaN (e.g. un-interpolated blinks)
    # OLS will do this by default, but it's good to be explicit and align X
    valid_y_mask = ~np.isnan(_y_pupil_signal)
    y_glm = _y_pupil_signal[valid_y_mask]

    if len(y_glm) == 0:
        print(f"  GLM_FUNC: No valid (non-NaN) pupil data points in y_glm for {segment_name}. Skipping GLM.")
        return None

    X_glm = design_matrix_fir.loc[valid_y_mask, :] # Use .loc to align with valid_y_mask from original index

    # Drop columns from X_glm that are all zero (can happen if no events fall into certain lags for this segment)
    # Also drop columns that might have become all NaN if something went very wrong (shouldn't happen)
    X_glm = X_glm.loc[:, (X_glm != 0).any(axis=0) & ~X_glm.isna().all(axis=0)]

    # Ensure 'const' is still there if other columns were dropped, and X_glm is not empty
    if 'const' not in X_glm.columns and not X_glm.empty:
        X_glm = sm.add_constant(X_glm, prepend=True, has_constant='skip')

    # More robust check for empty/trivial X_glm
    is_trivial_X = X_glm.empty
    if not is_trivial_X and 'const' in X_glm.columns: # If const is there
        if X_glm.shape[1] == 1: # Only const column
            is_trivial_X = True
        elif X_glm.drop(columns=['const'], errors='ignore').empty: # Only const, and other columns were all zero/NaN
             is_trivial_X = True

    if is_trivial_X or len(y_glm) < X_glm.shape[1]:
        # ... (rest of insufficient data check as before) ...
        print(f"  GLM_FUNC: Not enough data or valid regressors to fit GLM for {segment_name}. "\
              f"Y shape: {y_glm.shape}, X shape after cleanup: {X_glm.shape}. Skipping.")
        if not X_glm.empty: print(f"  GLM_FUNC: X_glm columns: {X_glm.columns.tolist()}")
        return None

    try:
        model = sm.OLS(y_glm, X_glm.astype(float)) # Ensure X_glm is float
        results = model.fit()

        r_squared = results.rsquared
        adj_r_squared = results.rsquared_adj
        r_values = (r_squared, adj_r_squared)
        print(f"  GLM_FUNC: Fit for {segment_name} - R-squared: {r_squared:.3f}, Adjusted R-squared: {adj_r_squared:.3f}")

        if PLOT_REGRESSORS:
            if segment_name == f"P{PARTICIPANT_ID}_R{RUN}_S{regressor_session}": # Or any session of interest
                print(f"\n--- Residuals for {segment_name} ---")
                plt.figure(figsize=(15, 5))
                plt.plot(results.resid) # results.resid are the residuals
                plt.title(f"Residuals for {segment_name}")
                plt.xlabel("Sample Index (within segment)")
                plt.ylabel("Residual Pupil Value")
                resid_output_path = os.path.join(image_dir, f"residuals_R{RUN}_S{regressor_session}.png")
                plt.savefig(resid_output_path, dpi=300, bbox_inches='tight')
                plt.show()

        return results.params, r_values # Return the beta coefficients

    except np.linalg.LinAlgError as e:
        print(f"  GLM_FUNC: LinAlgError during GLM fitting for {segment_name}: {e}. Check for perfect multicollinearity.")
        # print(f"  X_glm dtypes: {X_glm.dtypes}")
        # print(f"  Unique values in X_glm columns:\n {X_glm.nunique()}")
        # from statsmodels.stats.outliers_influence import variance_inflation_factor
        # if 'const' in X_glm.columns and X_glm.shape[1] > 2:
        #     vif_X = X_glm.drop(columns=['const'])
        #     vif_data = pd.DataFrame()
        #     vif_data["feature"] = vif_X.columns
        #     vif_data["VIF"] = [variance_inflation_factor(vif_X.values, i) for i in range(vif_X.shape[1])]
        #     print(vif_data[vif_data['VIF'] > 10]) # Print features with high VIF
        return None
    except Exception as e:
        print(f"  GLM_FUNC: Unexpected error during GLM fitting for {segment_name}: {e}")
        return None

print("--- run_fir_glm function defined ---")

## Step 2.2 - Mark different events in the timeseries for FIR model construction
This crucial step prepares the specific event markers that will define the Finite Impulse Response (FIR) regressors in our General Linear Model (GLM). The GLM attempts to explain the _entire continuous pupil signal_ for a session. However, we strategically choose which types of events and which instances of those events contribute to building the specific FIR predictors we are interested in (e.g., the response to a "Standard" stimulus).
### Key Actions:
- *Filter for "Good" Trials:* We first identify all trials within the current session that have not been flagged for rejection by our earlier preprocessing criteria (e.g., Trial_Rejected_Least_Strict == False). This ensures that only data from high-quality trials will inform the estimation of our primary event-related responses.
- *Extract Onset Timestamps:* For each type of event we want to model (e.g., "Standard stimulus in attend condition," "Oddball stimulus in divert condition," "Attention Task Target," "Keypress"), we extract the precise onset Timestamp (in original milliseconds). These onsets are sourced only from the "good" trials identified above.
### Impact on the GLM:
- *Selective FIR Regressor Construction:* The FIR regressors for our main experimental events (like "Standard_attend") are built only using the onsets from these "good" instances.
*Model Explains All Data, But Event Betas Reflect "Good" Instances:* The GLM still uses the entire continuous pupil time series of the session as the data to be explained (the y_pupil_signal).
    - The beta coefficients estimated for the FIR regressors (e.g., for Standard_attend_fir_lag_200ms) will therefore reflect the average pupil response at that lag, specifically following the "good" instances of "Standard_attend" events.
    - The pupil data during "bad" or unmodeled periods is accounted for by other components of the GLM, such as the intercept (baseline), any nuisance event regressors (like blinks, saccades, or explicitly modeled "bad trial noise" if you chose to add such a regressor), and the model's error term.
- *Effective "Rejection" for Specific Event Responses:* By not including event markers from "bad" trials when building the FIR regressors for our main conditions of interest, we are effectively ensuring that those "bad" trial instances do not directly influence the estimated average response shape (the betas) for those specific conditions. The model doesn't try to fit a "Standard_attend" response shape using data from a trial where, for instance, a blink obscured the standard stimulus.

In essence, this step precisely curates the set of event occurrences that will define and shape the estimated pupillary responses for each experimental condition we care about, while the GLM still considers the full observed pupil dynamics of the session.


In [None]:
all_betas_list = [] # To store DataFrames of betas from each successful GLM run

# Group by 'Session' because each session is a continuous recording block.
# The GLM should be run on continuous data.
# raw_df should contain all original samples for the run, with flags, not with rows dropped.
# If raw_df had rows dropped earlier, this groupby might yield discontinuous segments if
# a session was broken by dropped trials. This assumes raw_df is continuous per session.

print(f"\n--- Starting GLM Processing for Participant {PARTICIPANT_ID}, Run {RUN} ---")
if 'Session' not in raw_df.columns:
    raise ValueError("'Session' column not found in raw_df. Cannot group by session.")

n_sample_count = 0 # Counts the num of samples produced in the downsampling - checked against original downsampling to ensure the factor used is correct

# Define trial identifiers, useful for merging message onsets with good trial flags
trial_id_cols = ['Subject', 'Run', 'Session', 'Block', 'Trial']

model_r_values = []
session_orders = []
for session_id, session_data_continuous in raw_df.groupby('Session'):
    segment_name = f"P{PARTICIPANT_ID}_R{RUN}_S{session_id}"
    print(f"\nProcessing {segment_name}...")

    # --- Step 2.2.1: Current Session Data ---
    current_session_condition_str = session_data_continuous['Session Condition'].iloc[0]
    session_orders.append(current_session_condition_str)

    # --- Step 2.2.2: Downsample (Optional) ---
    current_fir_params_for_glm = fir_parameters_for_glm.copy() # Use the globally defined effective params

    if DO_DOWNSAMPLING and downsample_factor > 1:
        print(f"  Downsampling {segment_name} from {original_sr_hz}Hz to {target_sr_hz}Hz (factor {downsample_factor})...")
        # Ensure sorted by Timestamp (groupby should preserve order from sorted raw_df, but belt-and-suspenders)
        session_data_sorted = session_data_continuous.sort_values('Timestamp').reset_index(drop=True)
        group_idx_ds = session_data_sorted.index // downsample_factor
        agg_dict_ds = {pupil_signal_column_name: 'mean', 'Timestamp': 'first'}
        static_cols_for_ds = [
            'Subject', 'Run', 'Session', 'Block', 'Trial',
            'Session Condition', 'Trial Condition', 'Main Stimulus Type',
            'Session Common Oddball', 'Trial Main Stim Onset',
            'Trial_Rejected_Least_Strict', 'Blink', 'Saccade', # Keep Blink/Saccade for boxcar regressors
            'Target Status', 'Attention Letter' # Needed for nuisance onset logic if done on downsampled
        ]
        for col in static_cols_for_ds:
            if col in session_data_sorted.columns:
                agg_dict_ds[col] = 'first'
        session_data_for_glm = session_data_sorted.groupby(group_idx_ds).agg(agg_dict_ds).reset_index(drop=True)
        print(f"  Downsampled data shape: {session_data_for_glm.shape}")
        n_sample_count += session_data_for_glm.shape[0]
    else:
        session_data_for_glm = session_data_continuous.reset_index(drop=True)
        print(f"  No downsampling for {segment_name}. Using original resolution.")

    if session_data_for_glm.empty:
        print(f"  {segment_name} is empty after potential downsampling. Skipping.")
        continue

    # --- Step 2.2.3: Identify "Good" Event Onsets for this Session ---
    event_onsets_this_session = {}

    # Get "good" trial identifiers from the continuous session data (before downsampling for flags)
    good_trial_flags_df = session_data_continuous[
        session_data_continuous['Trial_Rejected_Least_Strict'] == False
    ][trial_id_cols].drop_duplicates()

    if good_trial_flags_df.empty:
        print(f"  {segment_name}: No 'good' trials found based on Trial_Rejected_Least_Strict flag. "
              "No event-related FIR regressors will be created for main/task events.")
    # else:
    #    print(f"  {segment_name}: Found {len(good_trial_flags_df)} good trials for event onset extraction.")


    # --- Main Stimulus Onsets (from "good" trials only) ---
    # Filter session_data_continuous for good trials first, then extract onsets
    data_for_main_event_onsets = pd.merge(session_data_continuous, good_trial_flags_df, on=trial_id_cols, how='inner')

    if not data_for_main_event_onsets.empty:
        std_onsets = data_for_main_event_onsets[
            data_for_main_event_onsets['Trial Condition'] == 'Standard'
        ]['Trial Main Stim Onset'].drop_duplicates().values
        event_onsets_this_session[f'Standard_{current_session_condition_str}'] = std_onsets

        oddball_trials_for_onsets = data_for_main_event_onsets[
            data_for_main_event_onsets['Trial Condition'] == 'Oddball'
        ].drop_duplicates(subset=['Block', 'Trial']) # Unique oddball trials from good trials

        common_odd_onsets = []
        rare_odd_onsets = []
        if not oddball_trials_for_onsets.empty:
            session_common_oddball_type = oddball_trials_for_onsets['Session Common Oddball'].iloc[0]
            for _, trial_row in oddball_trials_for_onsets.iterrows():
                if trial_row['Main Stimulus Type'] == session_common_oddball_type:
                    common_odd_onsets.append(trial_row['Trial Main Stim Onset'])
                else:
                    rare_odd_onsets.append(trial_row['Trial Main Stim Onset'])
        event_onsets_this_session[f'CommOdd_{current_session_condition_str}'] = np.array(common_odd_onsets)
        event_onsets_this_session[f'RareOdd_{current_session_condition_str}'] = np.array(rare_odd_onsets)
    else: # No good trials, so initialize main event types with empty arrays
        event_onsets_this_session[f'Standard_{current_session_condition_str}'] = np.array([])
        event_onsets_this_session[f'CommOdd_{current_session_condition_str}'] = np.array([])
        event_onsets_this_session[f'RareOdd_{current_session_condition_str}'] = np.array([])


    # --- Nuisance Event Onsets (from "good" trials using messages_df) ---
    # Filter messages_df for the current session first
    current_session_messages_df = messages_df[
        (messages_df['Subject'] == PARTICIPANT_ID) &
        (messages_df['Run'] == RUN) &
        (messages_df['Session'] == session_id)
    ].copy()

    # Merge with good_trial_flags_df to get only messages from good trials
    messages_from_good_trials = pd.merge(current_session_messages_df, good_trial_flags_df, on=trial_id_cols, how='inner')

    if not messages_from_good_trials.empty:
        # Keypress Onsets
        keypress_mask = (messages_from_good_trials['Message Type'] == 'Key Response')
        keypress_onsets = messages_from_good_trials.loc[keypress_mask, 'Timestamp'].values
        event_onsets_this_session[f'Keypress_{current_session_condition_str}'] = keypress_onsets

        # Attention Task Related Onsets from 'Fixation Stimulus Onset' messages
        fix_stim_onset_msgs = messages_from_good_trials[
            messages_from_good_trials['Message Type'] == 'Fixation Stimulus Onset'
        ]

        if not fix_stim_onset_msgs.empty:
            # Attention Targets (color change or 'X')
            # 'Target Status' in messages_df for 'Fixation Stimulus Onset' message should indicate if it's a target
            att_target_mask = (fix_stim_onset_msgs['Target Status'] == True)
            att_target_onsets = fix_stim_onset_msgs.loc[att_target_mask, 'Timestamp'].values
            event_onsets_this_session[f'AttTarget_{current_session_condition_str}'] = att_target_onsets

            # Divert Condition: Distractor Letter Onsets (non-'X' letters)
            if current_session_condition_str == 'divert':
                # Distractors are Fixation Stim Onsets that are NOT targets
                # AND Attention Letter is not NaN (it shouldn't be for divert) and not 'X'
                distractor_mask = (
                    (fix_stim_onset_msgs['Target Status'] == False) &
                    (fix_stim_onset_msgs['Attention Letter'].notna()) &
                    (fix_stim_onset_msgs['Attention Letter'] != 'X')
                )
                distractor_onsets = fix_stim_onset_msgs.loc[distractor_mask, 'Timestamp'].values
                event_onsets_this_session['Distractor_divert_Onset'] = distractor_onsets
    else: # No messages from good trials, initialize nuisance with empty arrays
        event_onsets_this_session[f'Keypress_{current_session_condition_str}'] = np.array([])
        event_onsets_this_session[f'AttTarget_{current_session_condition_str}'] = np.array([])
        if current_session_condition_str == 'divert':
            event_onsets_this_session['Distractor_divert_Onset'] = np.array([])


    # --- Sanity check event counts ---
    print(f"  Event counts for {segment_name} (from good trials):")
    for ev, ons in event_onsets_this_session.items():
        if isinstance(ons, np.ndarray): print(f"    {ev}: {len(ons)}")

    # --- Step 2.2.4: Call `run_fir_glm` ---
    # Pass the (potentially downsampled) session_data_for_glm.
    # Blink and Saccade columns for boxcar regressors should exist in session_data_for_glm
    # if they were in static_cols_for_ds during downsampling.

    session_betas, r_vals = run_fir_glm(
        pupil_data_segment=session_data_for_glm,
        event_onsets_dict=event_onsets_this_session,
        fir_params=current_fir_params_for_glm,
        pupil_signal_col=pupil_signal_column_name,
        segment_name=segment_name
    )
    model_r_values.append(r_vals)

    # --- Step 2.2.5: Store Betas ---
    if session_betas is not None and not session_betas.empty:
        betas_formatted_df = session_betas.reset_index()
        betas_formatted_df.columns = ['Regressor', 'Beta']
        betas_formatted_df['Subject'] = PARTICIPANT_ID
        betas_formatted_df['Run'] = RUN
        betas_formatted_df['Session'] = session_id
        betas_formatted_df['Session_Processed_Condition'] = current_session_condition_str
        all_betas_list.append(betas_formatted_df)
        print(f"  Successfully processed and stored betas for {segment_name}.")
    else:
        print(f"  No betas returned or betas were empty for {segment_name}.")

print(f"\nTotal downsampled samples across sessions: {n_sample_count}. Original sample number: {original_sample_count}. "
      f"\nClosest int to factor used: {int(round(original_sample_count/n_sample_count, 0))} - Expected factor: {downsample_factor}.")

print(f"\n--- GLM Processing Complete for Participant {PARTICIPANT_ID}, Run {RUN} ---")

In [None]:
# Check the betas have been successfully produced for each of the 4 sessions:
print(f"Length of all_betas_list: {len(all_betas_list)}")

for idx, df in enumerate(all_betas_list):
    print(f"\nSession {idx+1} df in all_betas_list:")
    print(f"\tShape of df: {df.shape}")
    print(f"\tCols in df: {df.columns}")
    print(f"\tSegment of df: {df.head(5)}")
    print("-"*100)


# Concatenate Betas into a single DF and Parse Regressor Names
## Step 3.1: Concatenate Betas (pd.concat)
- It first checks if all_betas_list (which should contain one DataFrame of betas per session) is empty. If it is, it prints a message and creates an empty final_betas_df to prevent errors in subsequent cells.
- If all_betas_list is not empty, pd.concat(all_betas_list, ignore_index=True) stacks all the individual session-beta DataFrames into one long DataFrame called final_betas_df.
- ignore_index=True ensures that the resulting DataFrame has a clean, new RangeIndex.
- It prints the shape and head/tail of this concatenated DataFrame.

## Step 3.2: Parse Regressor Names (parse_regressor_name function)
- This is a crucial step to make the final_betas_df more usable for plotting and analysis.
- The parse_regressor_name function is designed to take the Regressor column (which contains strings like Standard_attend_fir_lag00_0ms or const).
- *For FIR regressors:*
    - It splits the string to extract the base Event_Type (e.g., Standard_attend).
    - It extracts the Lag_Index (e.g., 00 -> 0). This is useful for ordered plotting if names aren't perfectly sortable by ms alone.
    - It extracts the Lag_ms (e.g., 0ms -> 0). This will be our x-axis for plotting PRFs.
    - It sets Is_FIR to True.
- For *non-FIR regressors (like const):*
    - It sets Event_Type to the regressor name itself.
    - Lag_Index and Lag_ms are set to np.nan.
    - Is_FIR is set to False.
- The function now takes the whole Series of regressor names and returns a DataFrame with the new parsed columns, which is then joined back to final_betas_df. This is generally more efficient than using .apply(pd.Series) for row-wise operations that create multiple new columns.
- Finally, it prints snippets and unique event types to help verify the parsing.


In [None]:
# --- Step 3.1: Concatenate Betas from all GLM runs ---

if not all_betas_list: # Check if the list is empty
    print("The list 'all_betas_list' is empty. No betas were collected from GLM runs.")
    print("Please check the output of the previous processing loop (Step 2.2) for errors or warnings.")
    raise RuntimeError("No betas collected, cannot proceed with Phase 3.")
else:
    try:
        final_betas_df = pd.concat(all_betas_list, ignore_index=True)
        print("\n--- Step 3.1: All Betas Concatenated ---")
        print(f"Shape of final_betas_df: {final_betas_df.shape}")
        print("Snippet of final_betas_df (head):")
        print(final_betas_df.head())
        print("\nSnippet of final_betas_df (tail):")
        print(final_betas_df.tail())
    except Exception as e:
        print(f"Error during concatenation of betas: {e}")
        final_betas_df = pd.DataFrame() # Ensure it exists but is empty on error


# --- Step 3.2: Parse Regressor Names to Extract Event Type and Lag Information ---
if not final_betas_df.empty:
    def parse_regressor_name(reg_name_series):
        """
        Parses a Series of regressor names to extract Event_Type, Lag_Index, Lag_ms, and Is_FIR.
        Handles 'const' and other non-FIR regressors gracefully.
        """
        # Initialize lists to store parsed components
        event_types = []
        lag_indices = []
        lag_ms_values = []
        is_fir_flags = []

        for reg_name in reg_name_series:
            if isinstance(reg_name, str) and '_fir_lag' in reg_name:
                try:
                    parts = reg_name.split('_fir_lag')
                    event_type = parts[0]
                    lag_info = parts[1].split('_') # e.g., "00_0ms" -> ["00", "0ms"]
                    lag_idx = int(lag_info[0])
                    lag_ms = int(lag_info[1].replace('ms', ''))

                    event_types.append(event_type)
                    lag_indices.append(lag_idx)
                    lag_ms_values.append(lag_ms)
                    is_fir_flags.append(True)
                except Exception as e: # Catch any parsing errors for specific strings
                    # print(f"Warning: Could not parse FIR regressor name '{reg_name}': {e}")
                    event_types.append(reg_name) # Keep original name as event type
                    lag_indices.append(np.nan)
                    lag_ms_values.append(np.nan)
                    is_fir_flags.append(False) # Mark as not a successfully parsed FIR
            else:
                # Not an FIR regressor (e.g., 'const' or other nuisance)
                event_types.append(reg_name) # Use the regressor name itself as the event type
                lag_indices.append(np.nan)
                lag_ms_values.append(np.nan)
                is_fir_flags.append(False)

        return pd.DataFrame({
            'Event_Type': event_types,
            'Lag_Index': lag_indices,
            'Lag_ms': lag_ms_values,
            'Is_FIR': is_fir_flags
        })

    # Apply the parsing function. It returns a DataFrame, so we join it.
    parsed_info_df = parse_regressor_name(final_betas_df['Regressor'])
    final_betas_df = final_betas_df.join(parsed_info_df)

    print("\n--- Step 3.2: Final Betas DF with Parsed Regressor Info ---")
    print(f"Shape after parsing: {final_betas_df.shape}")
    print("Snippet of final_betas_df with new columns (head):")
    print(final_betas_df.head())

    # Verify: Show unique event types and some FIR betas
    print("\nUnique parsed Event_Types:")
    print(final_betas_df['Event_Type'].unique())

    print("\nExample of FIR betas (first few rows where Is_FIR is True):")
    print(final_betas_df[final_betas_df['Is_FIR'] == True].head())

    print("\nExample of non-FIR betas (e.g., const):")
    print(final_betas_df[final_betas_df['Is_FIR'] == False].head())
else:
    if 'final_betas_df' not in locals(): # Check if it was even defined
         print("Variable 'final_betas_df' was not created. Skipping parsing.")
    else: # It was defined but is empty
         print("'final_betas_df' is empty. Skipping parsing.")

# Visualize Estimated Pupil Response Functions (PRFs)

In [None]:
# Filter out non-FIR regressors (like 'const') for these plots
fir_betas_to_plot_df = final_betas_df[final_betas_df['Is_FIR'] == True].copy()
# Ensure Lag_ms is numeric for plotting
fir_betas_to_plot_df['Lag_ms'] = pd.to_numeric(fir_betas_to_plot_df['Lag_ms'], errors='coerce')
fir_betas_to_plot_df.dropna(subset=['Lag_ms'], inplace=True) # Drop if any Lag_ms became NaN

# Group the event types for better interpretability:
unique_event_types = sorted(fir_betas_to_plot_df['Event_Type'].unique())
# You might want a custom order:
desired_event_order = [et for et in ['Standard_attend', 'CommOdd_attend', 'RareOdd_attend',
                                     'Standard_divert', 'CommOdd_divert', 'RareOdd_divert'] if et in unique_event_types]

# --- Plot 1: Averaged PRF for each Event_Type across all sessions for this Subject/Run ---
plt.figure(figsize=(20, 8))
sns.lineplot(
    data=fir_betas_to_plot_df,
    x='Lag_ms',
    y='Beta',
    hue='Event_Type',
    hue_order=desired_event_order,
    errorbar='se', # Standard error across sessions (since we have multiple sessions per Event_Type)
    linewidth=1.5
)

# Format the R2 / AdjR2 values into a string
r2_text = "R2 / AdjR2\n"
used_cons = []
for session_con, (r2, adj_r2) in (zip(session_orders, model_r_values)):
    if session_con not in used_cons:
        i=1
    else:
        i=2
    r2_text += f"{session_con.title()} {i}: {r2:.3f} / {adj_r2:.3f}\n"
    used_cons.append(session_con)

# Add R2 / AdjR2 text box
plt.text(
    x=1.07, y=0.5, s=r2_text,
    transform=plt.gca().transAxes,  # position relative to axes
    fontsize=12,
    verticalalignment='center',
    bbox=dict(boxstyle='round,pad=0.5', facecolor='lightgrey', alpha=0.8)
)

plt.axhline(0, color='black', linestyle='--', linewidth=1.5, label='Baseline (0 Beta)')
plt.axvline(0, color='green', linestyle=':', linewidth=1, label='Event Onset (Lag 0ms)')

plt.title(f'Estimated PRFs (Averaged Across Sessions)\nSubject {PARTICIPANT_ID}, Run {RUN}', fontsize=16)
plt.xlabel('Time Lag from Event Onset (ms)', fontsize=14)
plt.ylabel('Beta Coefficient (Change in Pupil Signal)', fontsize=14)
plt.legend(title='Event Type', bbox_to_anchor=(1.05, 1), loc='upper left', title_fontsize='13', fontsize='11')
plt.grid(True, linestyle=':', alpha=0.6)
plt.tight_layout(rect=[0, 0, 0.82, 1]) # Adjust layout to make space for legend (may need tuning)

combined_betas_output_path = os.path.join(image_dir, f"combo_R{RUN}.png")
plt.savefig(combined_betas_output_path, dpi=300, bbox_inches='tight')
plt.show()

# --- Plot 2: PRFs Faceted by original Session_Processed_Condition ---

def get_base_event_and_attention(event_type_str):
    """
    Parses an event type string to extract its base event name and attention condition.

    Assumes the event type string is formatted like "BaseEvent_attentioncondition"
    (e.g., "Standard_attend", "CommOdd_divert"). If the input string does not
    contain "_attend" or "_divert", or is not a string, it attempts to handle
    it gracefully by returning the original string as the base event and NaN or
    'unknown' for the attention condition.

    This function is typically used with pandas .apply() on a Series of event type strings.

    Parameters:
    -----------
    event_type_str : str or any
        The string representing the event type, potentially including an attention
        condition suffix like "_attend" or "_divert". If not a string, it's
        returned as the 'Base_Event'.

    Returns:
    --------
    pandas.Series
        A pandas Series with two elements:
        - 'Base_Event': The extracted base name of the event (e.g., "Standard", "CommOdd").
        - 'Attention_From_Event': The extracted attention condition ("attend", "divert",
                                   or "unknown"/np.nan if not parsable).
    """
    if not isinstance(event_type_str, str):
        return pd.Series({'Base_Event': event_type_str, 'Attention_From_Event': np.nan})

    if '_attend' in event_type_str:
        return pd.Series({'Base_Event': event_type_str.replace('_attend', ''), 'Attention_From_Event': 'attend'})

    elif '_divert' in event_type_str:
        return pd.Series({'Base_Event': event_type_str.replace('_divert', ''), 'Attention_From_Event': 'divert'})

    else: # Should not happen with current naming
        return pd.Series({'Base_Event': event_type_str, 'Attention_From_Event': 'unknown'})



parsed_event_components = fir_betas_to_plot_df['Event_Type'].apply(get_base_event_and_attention)
fir_betas_to_plot_df = fir_betas_to_plot_df.join(parsed_event_components)

attention_conditions = sorted(fir_betas_to_plot_df['Attention_From_Event'].unique())
if set(attention_conditions) == {'attend', 'divert'}:
    color_palette = {"attend": "green", "divert": "purple"}
elif "attend" in attention_conditions and "unknown" in attention_conditions and len(attention_conditions) == 2: # Handle if only one real condition exists
    color_palette = {"attend": "green", "unknown": "grey"}
elif "divert" in attention_conditions and "unknown" in attention_conditions and len(attention_conditions) == 2:
    color_palette = {"divert": "purple", "unknown": "grey"}
else: # Default if more/other conditions
    color_palette = sns.color_palette(n_colors=len(attention_conditions))

if 'Base_Event' in fir_betas_to_plot_df.columns and 'Attention_From_Event' in fir_betas_to_plot_df.columns:
    plt.close('all')
    plt.figure(figsize=(12, 10)) # Adjust if too many facets or for single column
    g = sns.FacetGrid(
        data=fir_betas_to_plot_df,
        row='Base_Event', # One row per base event type
        #col='Attention_From_Event', # One col per attention type
        height=4, aspect=2.5,
        sharey=True, # Keep y-axis consistent for comparison
        margin_titles=True
    )
    # Map the lineplot, but tell it NOT to create a legend on each facet.
    # The `hue` and `palette` are passed here directly to lineplot.
    g.map_dataframe(
        sns.lineplot,
        x='Lag_ms',
        y='Beta',
        hue='Attention_From_Event', # Lineplot handles the hue
        hue_order=attention_conditions,
        palette=color_palette,
        errorbar='se',
        linewidth=1.5,
        legend=False # IMPORTANT: Prevent lineplot from drawing its own legend per facet
    )

    # Map the reference lines (these should not generate legend entries by default if no label is given)
    g.map(plt.axhline, y=0, color='black', linestyle='--', linewidth=1)
    g.map(plt.axvline, x=0, color='grey', linestyle=':', linewidth=1)

    g.set_axis_labels('Time Lag from Event Onset (ms)', 'Beta Coefficient (Change in Pupil Signal)')
    g.set_titles(row_template="{row_name}", col_template="{col_name}")

    # --- Create a single, consolidated legend for the figure ---
    # We need to get handles and labels. Since lineplot created them based on hue,
    # we can usually grab them from the first axes that has plotted data.
    # However, a more robust way with FacetGrid is to make dummy plots for legend handles.

    handles = []
    labels = []
    for i, condition_name in enumerate(attention_conditions):
        handles.append(plt.Line2D([0], [0], color=color_palette[condition_name] if isinstance(color_palette, dict) else color_palette[i], lw=1.5))
        labels.append(condition_name)

    if handles: # Only add legend if there are handles
        g.fig.legend(
            handles,
            labels,
            title="Attention Condition",
            loc='upper right', # Experiment with location
            title_fontsize='13',
            fontsize='11'
        )

    plt.suptitle(f'Estimated PRFs by Base Event and Attention\nSubject {PARTICIPANT_ID}, Run {RUN}', y=1.04, fontsize=16)

    # Adjust layout to accommodate suptitle and potentially the external legend
    # If legend is outside, you may need to reduce the 'right' parameter in rect
    condition_betas_output_path = os.path.join(image_dir, f"conditions_R{RUN}.png")
    plt.savefig(condition_betas_output_path, dpi=300, bbox_inches='tight')
    plt.tight_layout() # Example: rect=[0,0,0.85,0.97] if legend is far right
    plt.show()

# Summarise the betas over chosen epoch for analyses

In [None]:
# --- Step 3.4: Summarize Betas for 2nd Level Analysis ---

if 'final_betas_df' not in locals() or final_betas_df.empty:
    raise ValueError("DataFrame 'final_betas_df' is not available or is empty. Cannot summarize betas.")
else:
    # Filter for FIR betas only, as 'const' and boxcars don't have lags to average over in this way
    fir_betas_for_summary_df = final_betas_df[final_betas_df['Is_FIR'] == True].copy()

    print("\n--- Summarizing FIR Betas into Analysis Windows ---")

    # --- 1. Cognitive Window ---
    cog_window_start_ms = 1250
    cog_window_end_ms = 3000 # Inclusive of the start of the bin at 3000ms if fir_bin_width allows
                              # Our last FIR bin starts at 2800ms for a 3000ms window and 200ms bin width.
                              # So we should filter Lag_ms <= (cog_window_end_ms - fir_bin_width_ms)
                              # or more simply, Lag_ms < cog_window_end_ms if lags are start of bin
                              # Let's ensure we capture the bin starting at 2800ms.
                              # If fir_bin_width_ms is 200, the bin starting at 2800ms covers 2800-2999ms.
                              # We need to be careful how `cog_window_end_ms` is used with `Lag_ms`.
                              # `Lag_ms` is the START of the FIR bin.
                              # So, we want lags where:
                              # Lag_ms >= cog_window_start_ms
                              # AND Lag_ms < cog_window_end_ms (if window end is exclusive of the start of the next bin)
                              # OR Lag_ms <= (cog_window_end_ms - fir_bin_width_ms) if window end is start of last bin

    print(f"Cognitive Window: {cog_window_start_ms}ms to {cog_window_end_ms}ms (lags included if lag_start >= start AND lag_start < end)")

    betas_in_cog_window = fir_betas_for_summary_df[
        (fir_betas_for_summary_df['Lag_ms'] >= cog_window_start_ms) &
        (fir_betas_for_summary_df['Lag_ms'] < cog_window_end_ms) # Captures bins starting up to, but not including, 3000ms
    ]

    if betas_in_cog_window.empty:
        print("Warning: No FIR betas found within the specified Cognitive Window.")
        summarized_betas_cognitive = pd.DataFrame()
    else:
        # Group by original identifiers and the parsed Event_Type, then mean of Beta
        summarized_betas_cognitive = betas_in_cog_window.groupby(
            ['Subject', 'Run', 'Session', 'Session_Processed_Condition', 'Event_Type']
        )['Beta'].mean().reset_index()
        summarized_betas_cognitive.rename(columns={'Beta': f'Mean_Beta_Cognitive_{cog_window_start_ms}-{cog_window_end_ms}ms'}, inplace=True)

    # --- 2. PLR Window (for secondary analysis) ---
    plr_window_start_ms = 200
    plr_window_end_ms = 1200 # Captures bins starting up to, but not including, 1200ms

    print(f"\nPLR Window: {plr_window_start_ms}ms to {plr_window_end_ms}ms (lags included if lag_start >= start AND lag_start < end)")

    betas_in_plr_window = fir_betas_for_summary_df[
        (fir_betas_for_summary_df['Lag_ms'] >= plr_window_start_ms) &
        (fir_betas_for_summary_df['Lag_ms'] < plr_window_end_ms)
    ]

    if betas_in_plr_window.empty:
        print("Warning: No FIR betas found within the specified PLR Window.")
        summarized_betas_plr = pd.DataFrame()
    else:
        summarized_betas_plr = betas_in_plr_window.groupby(
            ['Subject', 'Run', 'Session', 'Session_Processed_Condition', 'Event_Type']
        )['Beta'].mean().reset_index()
        summarized_betas_plr.rename(columns={'Beta': f'Mean_Beta_PLR_{plr_window_start_ms}-{plr_window_end_ms}ms'}, inplace=True)



# Display unique event types in the cognitive summary to check
if 'summarized_betas_cognitive' in locals() and not summarized_betas_cognitive.empty:
    print("\nUnique Event_Types in Cognitive Summary:")
    print(summarized_betas_cognitive['Event_Type'].unique())

In [None]:
print("\nSummarized Betas (Cognitive Window):")
summarized_betas_cognitive

In [None]:
print("\nSummarized Betas (PLR Window):")
div = summarized_betas_plr[summarized_betas_plr["Session_Processed_Condition"] == "divert"]
div[div.Event_Type == "CommOdd_divert"]

In [None]:
output_dir = f"./data/decon_betas_summaries/{PARTICIPANT_ID}"
os.makedirs(output_dir, exist_ok=True)

# --- Save Cognitive Window Betas ---
if 'summarized_betas_cognitive' in locals() and not summarized_betas_cognitive.empty:
    cog_filename = f"{PARTICIPANT_ID}_R{RUN}_betas_cognitive_window.csv"
    cog_filepath = os.path.join(output_dir, cog_filename)
    try:
        summarized_betas_cognitive.to_csv(cog_filepath, index=False)
        print(f"\nSuccessfully saved summarized COGNITIVE window betas to: {cog_filepath}")
    except Exception as e:
        print(f"Error saving cognitive window betas to {cog_filepath}: {e}")
else:
    print("\n'summarized_betas_cognitive' DataFrame not found or is empty. Nothing to save for cognitive window.")

# --- Save PLR Window Betas ---
if 'summarized_betas_plr' in locals() and not summarized_betas_plr.empty:
    plr_filename = f"{PARTICIPANT_ID}_R{RUN}_betas_plr_window.csv"
    plr_filepath = os.path.join(output_dir, plr_filename)
    try:
        summarized_betas_plr.to_csv(plr_filepath, index=False)
        print(f"Successfully saved summarized PLR window betas to: {plr_filepath}")
    except Exception as e:
        print(f"Error saving PLR window betas to {plr_filepath}: {e}")
else:
    print("\n'summarized_betas_plr' DataFrame not found or is empty. Nothing to save for PLR window.")

print("\n--- End of processing for this participant/run ---")