In [None]:
import os

import pandas as pd
import numpy as np
import pickle
import matplotlib.pyplot as plt

import preprocessing_config as config
import random
from scipy.signal import butter, filtfilt

In [None]:
PARTICIPANT_ID = "ffff"
participant_id = PARTICIPANT_ID
run = 1

# Set to True to produce data for deconvolution rather than using a running baseline method:
DECONVOLUTION_OUTPUT = True

PRODUCE_BASELINES = not DECONVOLUTION_OUTPUT

In [None]:
# Load Data:
test_df = pd.read_csv(f"./data/processed_data/{participant_id}S{run}.csv", index_col=0, low_memory=False)

In [None]:
# Create a test_df that only contains the metric data - This is necessary to remove some of the overlapping timestamps due to the inclusion of message data:
test_metrics_df = test_df[test_df["Message Type"] == "Data"].copy()

# Create a test_df that only contains the messages:
test_messages_df = test_df[test_df["Message Type"] != "Data"].copy()

# Delete the original dfs as they are no longer useful
del test_df

In [None]:
test_metrics_df.size

# Ensure all timepoints are covered without breaks.
The eye-tracker, despite being set to 1000Hz, may have skipped an entry or two. We should check this, and interpolate the data if it did.

However, there are some occasions where missing data is expected. For instance, each run should have 3 series of missing values representing the time of non-recording between the 4 sessions of a run. As such, we need to investigate these manually.

### CHECK THE OUTPUTS OF THE CELL BELOW:
The output will tell you whether there is missing data that is not explained by the design of the experiment itself (There will always be some missing data, but some of it is expected, such as during breaks and adaptation periods).

In [None]:
min_ts = test_metrics_df['Timestamp'].min()
max_ts = test_metrics_df['Timestamp'].max()

expected_timestamps = pd.Series(range(min_ts, max_ts + 1), name="ExpectedTimestamp")
actual_timestamps = test_metrics_df['Timestamp']

missing_mask = ~expected_timestamps.isin(actual_timestamps)
missing_timestamps = expected_timestamps[missing_mask]

columns_of_interest = ["Timestamp", "Message Type", "Block", "Trial"] # Used for checking the session_end series are present (see above note)

if not missing_timestamps.empty:
    print(f"Total individual missing timestamps: {len(missing_timestamps)}")

    # --- Find consecutive series of missing timestamps ---

    # 1. Calculate the difference between consecutive missing timestamps
    #    A 'break' in a consecutive series of missing items occurs when diff > 1
    #    The first item in missing_timestamps always starts a new series (or is a series of 1)
    diffs = missing_timestamps.diff()

    # 2. Identify the start of each new series of missing timestamps
    #    A new series starts where the diff is not 1 (or for the very first missing timestamp)
    #    We use .ne(1) which means "not equal to 1". For the first NaN, NaN.ne(1) is True.
    series_starts_mask = diffs.ne(1) # True where a new series begins
    # Alternatively, and perhaps more explicitly for the first element:

    # 3. Create a series ID by taking the cumulative sum of these 'new series' markers
    series_id = series_starts_mask.cumsum()

    # 4. Group the missing timestamps by these series IDs
    #    And aggregate to get the start, end, and count of each series
    missing_series_summary = missing_timestamps.groupby(series_id).agg(
        start_ts='first',
        end_ts='last',
        count='size'
    )

    # 5. Filter for series that have 2 or more consecutive missing values
    consecutive_missing_series = missing_series_summary[missing_series_summary['count'] >= 2]

    series_val_manual_checks = [] # Holds any series vals that are not caused by session ends, therefore needing manual checks

    if not consecutive_missing_series.empty:
        print("\nReport of Consecutive Missing Timestamp Series (length >= 2):")
        for index, row in consecutive_missing_series.iterrows(): # series_id is the index here
            print(f"  - Missing from {row['start_ts']} to {row['end_ts']} (Count: {row['count']})")
        print(f"\nTotal number of such series: {len(consecutive_missing_series)}")

        print("\n--- Detailed Analysis of Missing Series ---")

        for index, row in consecutive_missing_series.iterrows():
            print(f"\nAnalyzing missing series from {row['start_ts']} to {row['end_ts']} (Count: {row['count']})")
            series_is_expected = True

            # --- Check the event IMMEDIATELY BEFORE the missing series starts ---
            # The missing series starts at row['start_ts'], so we look at row['start_ts'] - 1
            timestamp_at_start_of_gap = row['start_ts']
            start_event_df = test_messages_df[test_messages_df["Timestamp"] == timestamp_at_start_of_gap]

            if not start_event_df.empty:
                # Get the message type of the single event before the gap
                starting_message_type = start_event_df['Message Type'].iloc[0]
                print(f"  Event at the start of the gap (at {timestamp_at_start_of_gap}): '{starting_message_type}'")
                if starting_message_type == "Session End":
                    print("    Context: Starts at a 'Session End'. This is expected, as it follows a break and adaptation period of no recording.")
                else:
                    print(f"    MANUAL CHECK REQUIRED - The gap did not begin when a session ended.")
                    series_is_expected = False
            else:
                print(f"  MANUAL CHECK REQUIRED - The gap did not begin when a session ended: No corresponding event found in test_messages_df.")
                series_is_expected = False

            # --- Check the event IMMEDIATELY AFTER the missing series ends ---
            # The missing series ends at row['end_ts'], so we look at row['end_ts'] + 1,
            # which should be the first trial of a session if the gap was due to a session ending.
            timestamp_after_gap = row['end_ts'] + 1
            event_after_df = test_messages_df[test_messages_df["Timestamp"] == timestamp_after_gap][columns_of_interest].head(1) # Some timestamps have multiple messages

            if not event_after_df.empty:
                # Extract scalar values since we know it's a single row DataFrame now
                trial_val = event_after_df["Trial"].item() # .item() extracts the single value
                block_val = event_after_df["Block"].item()
                message_type_after = event_after_df["Message Type"].item()

                print(f"  Event after gap (at {timestamp_after_gap}): Type='{message_type_after}', Block={block_val}, Trial={trial_val}")

                if trial_val == 1 and block_val == 1: # Assuming Trial 1, Block 1 is the start of a session
                    print("    Context: Ending point of gap is before a 'Session Start' (Block 1, Trial 1). This is expected.")
                else:
                    print(f"    MANUAL CHECK REQUIRED - Ending point of gap is before an event that is NOT Block 1, Trial 1 "
                          f"(Block={block_val}, Trial={trial_val}).")
                    series_is_expected = False

            else:
                print(f"  MANUAL CHECK REQUIRED - The gap didn't end at the start of a new session, so it needs to be investigated manually.")
                series_is_expected = False

            if not series_is_expected:
                series_val_manual_checks.append((row['start_ts'], row['end_ts']))

            print("-" * 40) # Separator for each series

    else:
        print("\nNo consecutive missing timestamp series (length >= 2) found.")

    # Optional: Report single missing timestamps if you want to differentiate
    single_missing = missing_series_summary[missing_series_summary['count'] == 1]
    if not single_missing.empty and not consecutive_missing_series.empty:
        print(f"\nAdditionally, there are {len(single_missing)} single isolated missing timestamps.")
    elif not single_missing.empty:
         print(f"\nFound {len(single_missing)} single isolated missing timestamps (no consecutive series >= 2).")


else:
    print("No missing integer timestamps found in the expected 1ms interval.")

single_val_manual_checks = []
for index, row in single_missing.iterrows():
    timestamp = row["start_ts"]
    message_type = test_messages_df[test_messages_df["Timestamp"] == timestamp+1]["Message Type"].values[0]
    trial_num = test_messages_df[test_messages_df["Timestamp"] == timestamp+1]["Trial"].values[0]
    if not (message_type == "Trial Start" and trial_num == 1):
        print(f"MANUAL CHECK REQUIRED - Missing timestamp ({timestamp}) is not caused by one block ending and another beginning")
        single_val_manual_checks.append(timestamp)

# Report a summary of the single missing value checks:
print("\n"*2)
print("=" * 40)
print("SUMMARY:")
print("-" * 40)

if len(series_val_manual_checks) > 0:
    print(f"Series Missing Timestamps REQUIRE MANUAL CHECKS:")
    for timestamps in series_val_manual_checks:
        print(f"{timestamps[0]} - {timestamps[1]}")
else:
    print("All series missing timestamps are EXPECTED - they are caused by the recording being paused during breaks and interim periods between sessions.")

if len(single_val_manual_checks) > 0:
    print(f"Single Missing Timestamps REQUIRE MANUAL CHECKS:")
    for timestamp in single_val_manual_checks:
        print(timestamp)
else:
    print("All single missing timestamps are EXPECTED - they are caused by delays between a block's offset and the next block's onset.")

# Delete dfs that are no longer useful to preserve memory:
del [start_event_df, event_after_df]


# Flag trials where blinks cover the main stimulus for later dropping
Flag any trials via a new column in which the blink covers the main stimulus presentation period, while adding another bool column to mark any samples containing a blink for later interpolation (after pre-processing).

In [None]:
# Before dropping any trials, create a backup df for later comparison after all pre-processing steps:
backup_df = test_metrics_df.copy()

In [None]:
# Load in the adaptation period blinks and saccades for this run.
# This allows us to check if there was a blink or saccade started that didn't end, bleeding into one of the sessions below:
with open(f"data/processed_data/final_blinks_and_saccades/{PARTICIPANT_ID}/run_{run}.pkl", 'rb') as f: # 'rb' for read binary
        loaded_dict = pickle.load(f)

In [None]:
# Locate the starting and ending points of a blink - These are samples where there should be no recorded data:
blink_periods = []
for session in [1, 2, 3, 4]:
    session_df = test_messages_df[test_messages_df["Session"] == session]

    # Get the timestamps of the blinks, both when they begin and when they end:
    blink_starts = session_df[session_df["Message Type"] == "Blink Start"]["Timestamp"].values
    blink_ends = session_df[session_df["Message Type"] == "Blink End"]["Timestamp"].values

    # Add the blink padding, configured in preprocessing_config.py:
    blink_starts -= config.BLINK_PRE_PADDING
    blink_ends += config.BLINK_POST_PADDING

    # Check if there's a hidden blink start that happened during the adaptation period ran over into the session data:
    if loaded_dict[f"session_{session}"]["blinks"] is not None:
        blink_starts = np.insert(blink_starts, 0, loaded_dict[f"session_{session}"]["blinks"]) # Add the value to the front of the numpy array

    # Add a tuple containing the paired start and end times for each blink:
    for start_time, end_time in zip(blink_starts, blink_ends):
        blink_periods.append((int(start_time), int(end_time)))


In [None]:
# Create a 'Blink' column which tells us whether a blink was occurring at that timestamp:
test_metrics_df["Blink"] = False # Initialise the column with all False values.

for blink_start_ts, blink_end_ts in blink_periods:
    # Create a boolean mask for the current blink period:
    is_during_blink_mask = (test_metrics_df['Timestamp'] >= blink_start_ts) & (test_metrics_df['Timestamp'] <= blink_end_ts)

    # Use .loc to set 'Blink' to True for rows matching the mask
    test_metrics_df.loc[is_during_blink_mask, "Blink"] = True

In [None]:
# Filter the database to identify any rows containing blinks where the main stimulus was visible:
contaminated_df = test_metrics_df[(test_metrics_df["Blink"] == True) & (test_metrics_df["Main Stimulus Visibility"] == True)]

# Filter the resultant list of contaminated rows such that there's only 1 row per unique trial:
trial_identifier_cols = ["Run", "Session", "Block", "Trial"]
contaminated_trials_df = contaminated_df[trial_identifier_cols].drop_duplicates()
contaminated_trials_list = [tuple(row) for row in contaminated_trials_df.itertuples(index=False)]

# Convert the resultant dataframe of contaminated trials to a list of tuples in the format: (run, session, block, trial)
print(f"Number of blink-contaminated trials: {len(contaminated_trials_list)}")

blink_trials_removed = contaminated_trials_list
contaminated_trials_list

In [None]:
# Flag the blink-contaminated trials:

# Perform a left merge with an indicator column called '_merge'. This will be 'both' if a match was found (i.e., the trial is in contaminated_trials_df), 'left_only' otherwise:
merged_blink_df = pd.merge(
    test_metrics_df,      # Your main data DataFrame at this stage
    contaminated_trials_df, # DataFrame of trials contaminated by blinks during main stim
    on=trial_identifier_cols,
    how='left',           # Keep all rows from test_metrics_df
    indicator=True        # Adds a column named '_merge'
)

# Create the 'Blink_During_Main_Stim' flag column based on the '_merge' column.
# If '_merge' is 'both', it means the trial was found in contaminated_trials_df.
merged_blink_df["Blink On Main Stim"] = merged_blink_df['_merge'] == 'both'

# Keep all rows and drop the temporary '_merge' column. The flag "Blink On Main Stim" is now part of the DataFrame.
test_metrics_df = merged_blink_df.drop(columns=['_merge'])

# Output to verify
print(f"Flagged {test_metrics_df["Blink On Main Stim"].sum()} samples belonging to blink-contaminated trials (during main stimulus).")
print(f"Number of unique trials flagged for blink during main stim: {test_metrics_df[test_metrics_df["Blink On Main Stim"] == True][trial_identifier_cols].drop_duplicates().shape[0]}")

# Remove the unneeded dfs:
del [contaminated_df, contaminated_trials_df, merged_blink_df]

# Flag trials where saccades cover the main stimulus for later dropping
Flag any trials via a new column in which the saccade covers the main stimulus presentation period, while adding another bool column to mark any samples containing a saccade for later interpolation (after pre-processing).

In [None]:
# Locate the starting and ending points of a saccade.
# These typically have data associated with them, but plotting the data usually shows artefacts in the timeseries where the pupil size
# forms a valley. As such, they need identifying:
saccade_periods = []
for session in [1, 2, 3, 4]:
    session_df = test_messages_df[test_messages_df["Session"] == session]

    # Get the timestamps of the saccades, both when they begin and when they end:
    saccade_starts = session_df[session_df["Message Type"] == "Saccade Start"]["Timestamp"].values
    saccade_ends = session_df[session_df["Message Type"] == "Saccade End"]["Timestamp"].values

    # Add the saccade padding, configured in preprocessing_config.py:
    saccade_starts -= config.SACCADE_PRE_PADDING
    saccade_ends += config.SACCADE_POST_PADDING

    # Check if there's a hidden saccade start that happened during the adaptation period ran over into the session data:
    if loaded_dict[f"session_{session}"]["saccades"] is not None:
        saccade_starts = np.insert(saccade_starts, 0, loaded_dict[f"session_{session}"]["saccades"]) # Add the value to the front of the numpy array

    # Add a tuple containing the paired start and end times for each blink:
    for start_time, end_time in zip(saccade_starts, saccade_ends):
        saccade_periods.append((int(start_time), int(end_time)))

In [None]:
# Create a 'Saccade' column which tells us whether a saccade was occurring at that timestamp:
test_metrics_df["Saccade"] = False # Initialise the column with all False values.

for saccade_start_ts, saccade_end_ts in saccade_periods:
    # Create a boolean mask for the current saccade period:
    is_during_saccade_mask = (test_metrics_df['Timestamp'] >= saccade_start_ts) & (test_metrics_df['Timestamp'] <= saccade_end_ts)

    # Use .loc to set 'Saccade' to True for rows matching the mask
    test_metrics_df.loc[is_during_saccade_mask, "Saccade"] = True

In [None]:
# Filter the database to identify any rows containing saccades where the main stimulus was visible:
contaminated_df = test_metrics_df[(test_metrics_df["Saccade"] == True) & (test_metrics_df["Main Stimulus Visibility"] == True)]

# Filter the resultant list of contaminated rows such that there's only 1 row per unique trial:
trial_identifier_cols = ["Run", "Session", "Block", "Trial"]
contaminated_trials_df = contaminated_df[trial_identifier_cols].drop_duplicates()
contaminated_trials_list = [tuple(row) for row in contaminated_trials_df.itertuples(index=False)]

# Convert the resultant dataframe of contaminated trials to a list of tuples in the format: (run, session, block, trial)
print(f"Number of saccade-contaminated trials: {len(contaminated_trials_list)}")
saccade_trials_removed = contaminated_trials_list
contaminated_trials_list

In [None]:
# Flag the saccade-contaminated trials:

# Perform a left merge with an indicator column called '_merge'. This will be 'both' if a match was found (i.e., the trial is in contaminated_trials_df), 'left_only' otherwise:
merged_saccade_df = pd.merge(
    test_metrics_df,      # Your main data DataFrame at this stage
    contaminated_trials_df, # DataFrame of trials contaminated by blinks during main stim
    on=trial_identifier_cols,
    how='left',           # Keep all rows from test_metrics_df
    indicator=True        # Adds a column named '_merge'
)

# Create the 'Saccade_During_Main_Stim' flag column based on the '_merge' column.
# If '_merge' is 'both', it means the trial was found in contaminated_trials_df.
merged_saccade_df["Saccade On Main Stim"] = merged_saccade_df['_merge'] == 'both'

# Keep all rows and drop the temporary '_merge' column. The flag 'Blink_During_Main_Stim' is now part of the DataFrame.
test_metrics_df = merged_saccade_df.drop(columns=['_merge'])

# Output to verify
print(f"Flagged {test_metrics_df["Saccade On Main Stim"].sum()} samples belonging to saccade-contaminated trials (during main stimulus).")
print(f"Number of unique trials flagged for saccade during main stim: {test_metrics_df[test_metrics_df["Saccade On Main Stim"] == True][trial_identifier_cols].drop_duplicates().shape[0]}")

# Remove the unneeded dfs:
del [contaminated_df, contaminated_trials_df, merged_saccade_df]

In [None]:
# test_metrics_df[test_metrics_df["Saccade On Main Stim"] == True].drop_duplicates(subset=trial_identifier_cols, keep="first")

# Create a 'Since Target' column for behavioural task-relevance filtering
This allows me to screen trials later that may be confounded due to a target appearing within the period of interest, thereby affecting the pupil (potentially due to the LC's responsiveness to task-relevant stimuli).

I'll decide whether to overwrite these values based on settings in preprocessing_config.py

In [None]:
# Create a temp column containing the timestamps of rows where the target was showing
test_metrics_df["Target Timestamps Only"] = test_metrics_df['Timestamp'].where(test_metrics_df['Target Status'] == True)

# Each session will begin without a target having appeared, so we want these values to be NaN. For the rest, we want to forward fill the column:
test_metrics_df["Last Target Timestamp"] = test_metrics_df.groupby("Session")["Target Timestamps Only"].ffill()

# Calculate the "Time Since Target" Column:
test_metrics_df["Time Since Target"] = test_metrics_df["Timestamp"] - test_metrics_df["Last Target Timestamp"]

# Delete the temp columns:
del test_metrics_df["Target Timestamps Only"]
del test_metrics_df["Last Target Timestamp"]

# Pre-Processing According to Kret & Sjak-Shie (2019)
## Step 1: Prepare the Raw Data - Remove Blinks, Saccades, and Identified Artifacts/Confounds
This involves handling the blinks, saccades, behavioural responses, etc that were identified above.

We handle them by setting their respective samples' pupil data to NaN for now to later interpolate over:

In [None]:
# To avoid the copy issue with Pandas, we'll set the condition separately, then use it with loc to set these vals to NaN:
cleaned_df = test_metrics_df.copy()

if not DECONVOLUTION_OUTPUT:
    condition = (cleaned_df["Blink"] == True) | (cleaned_df["Saccade"] == True)
    cleaned_df.loc[condition, "Pupil"] = np.nan

### Visual Check for outliers
The plots below show the timeseries of all the trials in a session, with lines marking 3SD from the mean.

There should be a reduction in large 'spikes' that extend towards the 3SD lines when we compare the cleaned version against the original. Make sure there's no obvious spikes remaining. If there are, then you need to inspect the data to see why this is:

In [None]:
fig, axes = plt.subplots(nrows=4, ncols=2, figsize=(12, 12), sharey=True)
axes_flat = axes.flatten()

for row, df, name in zip((0, 1), (test_metrics_df, cleaned_df), ("Original", "Cleaned")):
    for col in range(4):
        run_data = df[df["Session"] == col+1][["Pupil", "Timestamp"]]
        axes[col][row].plot(run_data["Timestamp"], run_data["Pupil"])
        axes[col][row].set_title(f"{name} - Session {col+1}")
        if session % 2 == 0:
            axes[col][row].set_ylabel("Pupil")

        # Plot ref lines:
        mean = run_data["Pupil"].mean()
        std = run_data["Pupil"].std()

        axes[col][row].axhline(y=mean, xmin=0, xmax=1, color='black', linestyle='--')
        axes[col][row].axhline(y=mean+(std*3), xmin=0, xmax=1, color='red', linestyle='--')
        axes[col][row].axhline(y=mean-(std*3), xmin=0, xmax=1, color='red', linestyle='--')

plt.tight_layout()
plt.show()


## Step 2.1: Reject dilation speed outliers
By calculating how quickly the pupil changes its area between its prior and next samples, we can detect artifacts in the pupil data.

Such artifacts could be from partial eyelid closure, a slight head movement, etc.

This filter detects extremely rapid changes - i.e. noise that is characterised by high-velocity change:

In [None]:
# 1. Get the data for the PREVIOUS and NEXT samples:
prev_pupil = cleaned_df["Pupil"].shift(1)
prev_time = cleaned_df["Timestamp"].shift(1)
next_pupil = cleaned_df["Pupil"].shift(-1)
next_time = cleaned_df["Timestamp"].shift(-1)

# Calculate the differences with the PREVIOUS and NEXT samples:
delta_pupil_backward = cleaned_df["Pupil"] - prev_pupil
delta_time_backward = cleaned_df["Timestamp"] - prev_time
velocity_backward = (delta_pupil_backward / delta_time_backward).abs()

delta_pupil_forward = next_pupil - cleaned_df["Pupil"]
delta_time_forward = next_time - cleaned_df["Timestamp"]
velocity_forward = (delta_pupil_forward / delta_time_forward).abs()

# Store the resultant normalised dilation speeds for each sample:
cleaned_df["Dilation Speed"] = np.maximum(velocity_backward, velocity_forward)

In [None]:
# Calculate the median absolute deviation (MAD):
mad = (cleaned_df["Dilation Speed"] - cleaned_df["Dilation Speed"].median()).abs().median()

In [None]:
# Explore possible constants to set n to as a threshold for rejecting samples based on the speed of dilation:
possible_ns = []
for possible_n in range(5, 25):
    test_threshold = cleaned_df["Dilation Speed"].median() + (possible_n * mad)
    test_outliers = cleaned_df[cleaned_df["Dilation Speed"] > test_threshold]["Subject"].count()
    total_samples = cleaned_df["Dilation Speed"].count()

    pc_removed = ((test_outliers / total_samples) * 100).round(2)
    if pc_removed < 1 and test_outliers >= 1:
        possible_ns.append((possible_n, test_outliers, pc_removed))
for possible_n, test_outliers, pc_removed in possible_ns:
    print(f"n = {possible_n}:\t\tOutliers detected: {test_outliers}\t\tPercentage of total:{pc_removed}")

In [None]:
# Define a constant to multiply with MAD to determine the threshold for accepting/rejecting a sample. n depends on the data:
n = 14

d_speed_threshold = cleaned_df["Dilation Speed"].median() + (n * mad)
total_samples = cleaned_df["Dilation Speed"].count()

fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(12, 6), sharey=True)
axes_flat = axes.flatten()

for session, ax in enumerate(axes_flat):
    run_data = cleaned_df[cleaned_df["Session"] == session+1][["Dilation Speed", "Timestamp"]]
    ax.scatter(run_data["Timestamp"], run_data["Dilation Speed"], s=1, color="black", alpha=0.15)

    ax.axhline(y=cleaned_df["Dilation Speed"].median(), xmin=0, xmax=1, color='magenta', linestyle='-', label="Median")
    ax.axhline(y=d_speed_threshold, xmin=0, xmax=1, color='red', linestyle='-', label="Chosen Threshold")

    ax.legend(loc="upper right")

    ax.set_title(f"Session {session+1}")
plt.tight_layout()
plt.show()

fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(12, 6), sharey=True)
axes_flat = axes.flatten()

for session, ax in enumerate(axes_flat):
    run_data = cleaned_df[cleaned_df["Session"] == session+1][["Dilation Speed", "Timestamp"]]
    ax.hist(run_data["Dilation Speed"], color="black")

    ax.axvline(x=cleaned_df["Dilation Speed"].median(), color='magenta', linestyle='-', label="Median")
    ax.axvline(x=d_speed_threshold, color='red', linestyle='-', label="Chosen Threshold")
    ax.legend(loc="upper right")

    ax.set_title(f"Session {session+1}")
plt.tight_layout()
plt.show()

In [None]:
# Plot the raw pupil data matching some of these samples to see if they are obvious outliers:
session = cleaned_df[cleaned_df["Session"] == random.choice([1,2,3,4])]

num_of_plots = 10
window_period = 300

fig, axes = plt.subplots(nrows=2, ncols=num_of_plots//2, figsize=(12, 8))
axes_flat = axes.flatten()

for ax in axes_flat:
    # Choose a random sample where the dilation speed is flagged as an outlier - get its timestamp:
    selection = random.choice(session[session["Dilation Speed"] >= d_speed_threshold]["Timestamp"].values)

    # Get a series that's 25ms either side of the selection:
    window = session[(session["Timestamp"] >= selection - (window_period/2)) & (session["Timestamp"] <= selection + (window_period/2))]

    ax.plot(window["Timestamp"], window["Pupil"], color="black")

    selection_pupil_data = window[window["Timestamp"] == selection]["Pupil"].values
    ax.scatter(selection, selection_pupil_data, color="red", s=100)

    difference_in_minmax = window["Pupil"].max() - window["Pupil"].min()
    ax.set_title(f"TS:{selection}\nRange={difference_in_minmax}")

plt.tight_layout()
plt.show()


print(f"Outliers detected with a threshold of {d_speed_threshold} based on an n-value of {n}:\n{cleaned_df[cleaned_df["Dilation Speed"] > d_speed_threshold]["Subject"].count()}")


In [None]:
# Once happy with the threshold selected, move on to this cell.
# This takes any sample with a dilation speed above this threshold and sets its pupil data to NaN:
condition = cleaned_df["Dilation Speed"] > d_speed_threshold
cleaned_df.loc[condition, "Pupil"] = np.nan

print(f"Outliers detected with a threshold of {d_speed_threshold} after pruning:\n{cleaned_df[condition & cleaned_df['Pupil'].notna()]['Pupil'].count()}")

## Step 2.2: Perform a multi-pass deviation filter on the dilation speed data:
This is very similar to the first dilation speed calculation and rejection we did, but this time we're widening the window of time that we're comparing against to measure the velocity of change in pupil diameter.

This effectively means that we're looking for datapoints that vary considerably from an abstract 'smoothed' trend of pupil dilation. Effectively, we're screening for sustained noise - not rapidly corrected noise, potentially caused by things like the eyetracker temporarily losing signal, or a prolonged change in head position before correction.

Note: The paper calls this a "trend-line deviation" rather than a filter.

### Bug Fix:
Running this on the continuous dataframe ~(i.e. with all sessions at once) causes the trendline to deviate from the data in the timespan between sessions, thus the first second or two of data between sessions will typically get flagged and deleted. Thus, we instead need to feed it in one session at a time.

In [None]:
# Multiplier for MAD (similar to 'n' for dilation speed)
# This determines how many MADs away from the median deviation a point must be to be an outlier.
# Need to adjust based on the outcome of the plot at the end:
n_trend_dev = n # Set it higher to be marginally less aggressive while still trying target those prolonged spikes that remain

# Number of passes for the filter. The paper suggests this can help "reintroduce" valid samples if a previous pass's trend line
# was skewed by other outliers. A small number is good:
num_passes_trend = 3

# Maximum number of consecutive NaNs to interpolate (linearly) when generating the trend line.
# This helps the smoothing function operate on more continuous data.
# The interpolated trend line is supposed to have SMALL gaps filled, but not blinks or saccades necessarily - so don't worry
# that blinks and saccade padding likely exceed the number below; this isn't intended to handle these gaps.
interpolation_limit_for_trend_gen = 50

# Smoothing window for the trend line (rolling median).
# The window size should represent a duration, e.g. 151-251ms (Should be odd):
smoothing_window_samples_for_trend = 201 # Odd number of ms for an odd window size in samples


In [None]:
# Keep track of NaNs before this specific filter
nans_before_trend_filter = cleaned_df["Pupil"].isna().sum()
print(f"NaNs before trend-line deviation filter: {nans_before_trend_filter}")

# Take a snapshot of the pupil data before the multi-pass filtering for plotting the comparison later:
before_pass_ts, before_pass_pupil = cleaned_df["Timestamp"], cleaned_df["Pupil"].ffill()

# Take another snapshot to later compare the effects of the multipass filter:
pupil_data_before_session_filtering = cleaned_df["Pupil"].copy()

In [None]:
# Store the trend lines for later plotting:
all_session_trend_lines = {}

for session_id in cleaned_df["Session"].unique():
    print(f"\n--- Processing Session: {session_id} ---")

    # Get the slice of the main DataFrame for the current session
    session_indices = cleaned_df[cleaned_df["Session"] == session_id].index

    # Store the trend line from the last effective pass for this session
    last_effective_trend_line_for_session = None

    for pass_num in range(1, num_passes_trend + 1):
        print(f"\nPass {pass_num} (Session {session_id}):")

        # Create a temporary version of the pupil data from this session for trend line generation.
        # It takes the current state of pupil data for this session from cleaned_df.
        pupil_for_trend_generation_session = cleaned_df.loc[session_indices, "Pupil"].copy()

        # Interpolate small gaps. This does NOT change the NaNs in cleaned_df["Pupil"] yet.
        interpolated_pupil_for_trend_session = pupil_for_trend_generation_session.interpolate(
            method='linear',
            limit_direction='both',
            limit=interpolation_limit_for_trend_gen
        )
        # If gaps are too large for 'limit', they might still be NaN.
        # Rolling median can handle some NaNs, but filling edges is good.
        # Fill any remaining NaNs (e.g. at ends or very large gaps) just for the trend.
        interpolated_pupil_for_trend_session = interpolated_pupil_for_trend_session.bfill().ffill()

        # Smooth the interpolated data to get the trend line:
        trend_line_session = interpolated_pupil_for_trend_session.rolling(
            window=smoothing_window_samples_for_trend,
            center=True,
            min_periods=1
        ).median()

        # Rolling operations can create NaNs at the very start/end if min_periods isn't 1 or window is large, so we fill them:
        trend_line_session = trend_line_session.bfill().ffill()

        # Update the 'last effective trend line' in each pass
        last_effective_trend_line_for_session = trend_line_session.copy()

        # Calculate absolute deviations of the current session's actual pupil data from this session's trend line:
        deviations_session = (cleaned_df.loc[session_indices, "Pupil"] - trend_line_session).abs()

        # Identify outliers based on these deviations using MAD (Eqs. 2 and 3 from paper).
        # Only consider deviations where the session's pupil data is NOT already NaN.
        valid_deviations_session = deviations_session.dropna()

        median_dev_session = valid_deviations_session.median()
        mad_dev_session = (valid_deviations_session - median_dev_session).abs().median()

        deviation_threshold_session = median_dev_session + (n_trend_dev * mad_dev_session)

        print(f"Session {session_id}, Pass {pass_num} - Median Dev: {median_dev_session:.4f}, MAD Dev: {mad_dev_session:.4f}, Threshold: {deviation_threshold_session:.4f}")


        # Identify samples within this session in cleaned_df that exceed this threshold.
        # The condition must be on indices belonging to the current session.
        # deviations_session is already indexed by session_indices.
        outlier_condition_this_session = (deviations_session > deviation_threshold_session) & \
                                         (cleaned_df.loc[session_indices, "Pupil"].notna())

        # Get the actual global indices from outlier_condition_this_session (which is a boolean Series on session_indices)
        global_indices_to_nan_this_pass = outlier_condition_this_session[outlier_condition_this_session].index

        num_outliers_this_pass = len(global_indices_to_nan_this_pass)
        print(f"Session {session_id}, Pass {pass_num} - Outliers identified: {num_outliers_this_pass}")

        if num_outliers_this_pass == 0:
            print(f"Session {session_id}, Pass {pass_num} - No new trend deviation outliers found. Stopping passes for this session.")
            break # Optimization: if no outliers found, subsequent passes likely won't change much

        # Set these newly identified outliers in the main cleaned_df to NaN using the global indices:
        cleaned_df.loc[global_indices_to_nan_this_pass, "Pupil"] = np.nan

    # After all passes for a session (or early exit), store the last computed trend line for this session
    if last_effective_trend_line_for_session is not None:
        all_session_trend_lines[session_id] = last_effective_trend_line_for_session
    else:
        print(f"Warning: No effective trend line was computed for session {session_id}")


print("\n--- Session-wise Trend-Line Deviation Filtering Complete ---")
nans_after_session_filtering = cleaned_df["Pupil"].isna().sum()

print(f"NaNs before session-wise trend-line filter: {nans_before_trend_filter}")
print(f"NaNs after session-wise trend-line filter: {nans_after_session_filtering}")
print(f"Samples marked NaN by this session-wise filter: {nans_after_session_filtering - nans_before_trend_filter}")

# For plotting the final trend line (concatenated from all sessions):
if all_session_trend_lines:
    sorted_session_ids = sorted(all_session_trend_lines.keys())
    final_concatenated_trend_line = pd.concat(
        [all_session_trend_lines[sid] for sid in sorted_session_ids]
    )
else:
    print("Warning: No session trend lines were stored.")
    final_concatenated_trend_line = pd.Series(index=cleaned_df.index, dtype=float)

In [None]:
import matplotlib.pyplot as plt
import pandas as pd # Ensure pandas is imported if not already

# --- Plot 1: Pupil Data Before Filter vs. Final (Concatenated) Trend Line ---

plt.figure(figsize=(15, 5))

# 'pupil_data_before_session_filtering' is the state of pupil data *before* the session-wise filter.
# We need to ffill it for continuous plotting, similar to how 'before_pass_pupil' was created previously.
# The timestamps for this should align with cleaned_df["Timestamp"].
pupil_before_plot = pupil_data_before_session_filtering.ffill() # ffill for plotting continuity

plt.plot(cleaned_df["Timestamp"], pupil_before_plot,
         label='Pupil Data Before Session-wise Trend Filter (NaNs ffilled for plot)', color="cyan", alpha=0.7, zorder=1, linewidth=1.5)

# 'final_concatenated_trend_line' is the trend line from the session-wise processing
plt.plot(cleaned_df["Timestamp"], final_concatenated_trend_line,
         label='Final Concatenated Trend Line (Last Pass per Session)', color='red', alpha=0.7, zorder=2)

plt.xlabel("Timestamp")
plt.ylabel("Pupil Size")
plt.title("Pupil Data Before Session-wise Trend Filter vs. Final Concatenated Trend Line")
plt.legend(loc="upper right")
plt.grid(True, linestyle=':', alpha=0.7)
plt.show()


# --- Plot 2: Pupil Data Before vs. After Session-wise Trend Filter ---
#    (No ffill, showing actual NaNs as gaps)

plt.figure(figsize=(15, 5))

# Plot 1: Pupil data *before* this specific session-wise filter (shows original NaNs for this stage)
plt.plot(cleaned_df["Timestamp"], pupil_data_before_session_filtering, label="Before", color="green")

# Plot 2: Pupil data *after* this specific session-wise filter (shows NaNs added by this filter as well)
plt.plot(cleaned_df["Timestamp"], cleaned_df["Pupil"], label="After", color="red")

plt.xlabel("Timestamp")
plt.ylabel("Pupil Size")
plt.title("Pupil Data: Before vs. After Session-wise Trend Filter")
plt.legend(loc="upper right")
plt.grid(True, linestyle=':', alpha=0.7)
plt.show()

In [None]:
# --- Visualization of Randomly Sampled Changes from Session-Wise Multi-Pass Trend Filter ---

# Number of random segments to plot
num_plots_to_sample = 5 # Adjust as needed
# Time window around the identified NaN'd point (in milliseconds)
# I'll keep your 50ms here, but consider increasing it to e.g., 250ms or 500ms for better context.
plot_window_ms = 50

# Ensure pupil_data_before_session_filtering and final_concatenated_trend_line exist
if 'pupil_data_before_session_filtering' not in locals() or \
   'final_concatenated_trend_line' not in locals():
    print("CRITICAL ERROR: 'pupil_data_before_session_filtering' or 'final_concatenated_trend_line' not found.")
    print("Please ensure you have run the session-wise trend filtering cell completely.")
    # Consider raising an error or exiting if these are missing.
    # For now, this script might fail or produce incorrect plots if they are missing.
    # raise NameError("Required data for plotting is missing. Rerun previous cells.")


# Candidate indices are points that were valid *before* this session-wise filter
# but are NaN in cleaned_df *after* it.
candidate_indices = cleaned_df[
    pupil_data_before_session_filtering.notna() & cleaned_df["Pupil"].isna()
].index

print(f"\n--- Plotting {num_plots_to_sample} Segments Randomly Sampled from those NaN'd by Session-Wise Filter ---")
if len(candidate_indices) == 0:
    print("No points found that were specifically NaN'd by the session-wise multi-pass filter.")
else:
    if len(candidate_indices) < num_plots_to_sample:
        print(f"Warning: Fewer than {num_plots_to_sample} samples were NaN'd. Plotting all {len(candidate_indices)}.")
        sampled_indices = candidate_indices
    else:
        sampled_indices = random.sample(list(candidate_indices), num_plots_to_sample)

    fig_mp_session, axes_mp_session = plt.subplots(nrows=num_plots_to_sample, ncols=1,
                                                   figsize=(15, 5 * num_plots_to_sample), sharex=False)
    if num_plots_to_sample == 1: # Make axes_mp_session iterable if only one plot
        axes_mp_session = [axes_mp_session]

    for i, random_idx in enumerate(sampled_indices):
        ax = axes_mp_session[i]
        center_timestamp = cleaned_df.loc[random_idx, "Timestamp"]

        min_ts = center_timestamp - plot_window_ms
        max_ts = center_timestamp + plot_window_ms

        # Get window indices from the main cleaned_df based on timestamps
        window_df_indices = cleaned_df[
            (cleaned_df["Timestamp"] >= min_ts) & (cleaned_df["Timestamp"] <= max_ts)
        ].index

        # Ensure we have some data in the window
        if window_df_indices.empty:
            print(f"Warning: Window around index {random_idx} (Timestamp {center_timestamp}) is empty. Skipping plot.")
            ax.set_title(f"Empty window around Timestamp: {center_timestamp:.0f} ms (Index: {random_idx})")
            ax.text(0.5, 0.5, "No data in window", ha='center', va='center')
            continue

        timestamps_window = cleaned_df.loc[window_df_indices, "Timestamp"]

        # Data *before* this specific session-wise multi-pass filter
        pupil_before_filter_window = pupil_data_before_session_filtering.loc[window_df_indices]

        # Data *after* this filter (has the new NaNs)
        pupil_after_filter_window = cleaned_df.loc[window_df_indices, "Pupil"]

        # Trend line from the session-wise processing for this window
        # Use final_concatenated_trend_line which should align with cleaned_df's index
        trend_line_window = final_concatenated_trend_line.loc[window_df_indices]

        # Plot 1: Pupil data *before* this session-wise filter (showing actual NaNs at this stage)
        # Using a solid line for "before" and points for "after" can help differentiate
        ax.plot(timestamps_window, pupil_before_filter_window,
                label='Pupil BEFORE session-wise filter', color='green', alpha=0.7, linestyle='-', marker='o', markersize=3)

        # Plot 2: Pupil data *after* this filter (with new NaNs shown as gaps)
        ax.plot(timestamps_window, pupil_after_filter_window,
                label='Pupil AFTER session-wise filter', color='black', alpha=0.9, marker='.', markersize=20, linestyle='None')

        # Plot 3: The trend line used in the last pass (for the relevant session)
        ax.plot(timestamps_window, trend_line_window,
                label='Final Concatenated Trend Line', color='red', linestyle='--', alpha=0.8)

        # Highlight the specific point that was selected (which is now NaN)
        original_value_at_idx = pupil_data_before_session_filtering.loc[random_idx] # Value before this filter
        if pd.notna(original_value_at_idx):
             ax.scatter(center_timestamp, original_value_at_idx,
                       color='magenta', s=100, zorder=5,
                       label=f'Original value of this NaN point: {original_value_at_idx:.2f}')

        # Calculate deviation for the specific point from the relevant trend line
        deviation_at_point = abs(original_value_at_idx - final_concatenated_trend_line.loc[random_idx]) \
                             if pd.notna(original_value_at_idx) and pd.notna(final_concatenated_trend_line.loc[random_idx]) \
                             else float('nan')

        ax.set_title(f"Segment NaN'd by Session-Wise Filter (Index: {random_idx})\nTS: {center_timestamp:.0f} ms | Original Dev: {deviation_at_point:.2f}")
        ax.set_xlabel("Timestamp (ms)")
        ax.set_ylabel("Pupil Size")
        ax.legend(loc='best')
        ax.grid(True, linestyle=':', alpha=0.5)

    plt.tight_layout()
    plt.show()

In [None]:
# --- Visualization of Changes from Session-Wise Multi-Pass Trend Filter ---
# --- AND Comparison with Data that PASSED the filter ---

# Number of random segments to plot for EACH category (NaN'd and Passed)
num_plots_to_sample = 5 # Adjust as needed
# Time window around the identified point (in milliseconds)
plot_window_ms = 500  # Plot before and after the point by this amount

# Ensure necessary variables from the session-wise filtering are available
if 'pupil_data_before_session_filtering' not in locals() or \
   'final_concatenated_trend_line' not in locals() or \
   'cleaned_df' not in locals():
    print("CRITICAL ERROR: Required data ('pupil_data_before_session_filtering', "
          "'final_concatenated_trend_line', 'cleaned_df') not found.")
    print("Please ensure you have run the session-wise trend filtering cell completely.")
    # Consider raising an error. For now, the script might fail later.
    # raise NameError("Required data for plotting is missing. Rerun previous cells.")


# --- Plotting Segments NaN'd by the Session-Wise Multi-Pass Filter ---

# Candidate indices are points that were valid before the session-wise filter
# but are NaN in cleaned_df after it.
candidate_indices_nand = cleaned_df[
    pupil_data_before_session_filtering.notna() & cleaned_df["Pupil"].isna()
].index

print(f"\n--- Plotting {num_plots_to_sample} Segments NaN'd by Session-Wise Multi-Pass Filter ---")
if len(candidate_indices_nand) == 0:
    print("No points found that were specifically NaN'd by the session-wise multi-pass filter.")
else:
    if len(candidate_indices_nand) < num_plots_to_sample:
        print(f"Warning: Fewer than {num_plots_to_sample} samples were NaN'd. Plotting all {len(candidate_indices_nand)}.")
        sampled_indices_nand = candidate_indices_nand
    else:
        sampled_indices_nand = random.sample(list(candidate_indices_nand), num_plots_to_sample)

    fig_nand, axes_nand = plt.subplots(nrows=num_plots_to_sample, ncols=1,
                                       figsize=(15, 5 * num_plots_to_sample), sharex=False)
    if num_plots_to_sample == 1: # Make axes_nand iterable if only one plot
        axes_nand = [axes_nand]

    for i, random_idx in enumerate(sampled_indices_nand):
        ax = axes_nand[i]
        center_timestamp = cleaned_df.loc[random_idx, "Timestamp"]
        min_ts = center_timestamp - plot_window_ms
        max_ts = center_timestamp + plot_window_ms

        window_df_indices = cleaned_df[
            (cleaned_df["Timestamp"] >= min_ts) & (cleaned_df["Timestamp"] <= max_ts)
        ].index

        if window_df_indices.empty:
            ax.set_title(f"Empty window for NaN'd segment (Index: {random_idx}) TS: {center_timestamp:.0f} ms")
            ax.text(0.5, 0.5, "No data in window", ha='center', va='center')
            continue

        timestamps_window = cleaned_df.loc[window_df_indices, "Timestamp"]

        # Pupil data *before* this specific session-wise multi-pass filter step
        pupil_before_filter_window = pupil_data_before_session_filtering.loc[window_df_indices]

        # Pupil data *after* this filter (has the new NaNs)
        pupil_after_filter_window = cleaned_df.loc[window_df_indices, "Pupil"]

        # Trend line from the session-wise processing for this window
        trend_line_window = final_concatenated_trend_line.loc[window_df_indices]

        ax.plot(timestamps_window, pupil_before_filter_window,
                label='Pupil BEFORE session-wise filter', color='green', alpha=0.7, linestyle='-', markersize=3)
        ax.plot(timestamps_window, pupil_after_filter_window,
                label='Pupil AFTER session-wise filter (NaNs shown)', color='black', alpha=0.9, marker='.', markersize=5, linestyle='None')
        ax.plot(timestamps_window, trend_line_window,
                label='Final Concatenated Trend Line', color='red', linestyle='--', alpha=0.8)

        original_value_at_idx = pupil_data_before_session_filtering.loc[random_idx]
        if pd.notna(original_value_at_idx):
             ax.scatter(center_timestamp, original_value_at_idx,
                       color='magenta', s=100, zorder=5, label=f'Original: {original_value_at_idx:.2f}')

        deviation_at_point = abs(original_value_at_idx - final_concatenated_trend_line.loc[random_idx]) \
                             if pd.notna(original_value_at_idx) and pd.notna(final_concatenated_trend_line.loc[random_idx]) \
                             else float('nan')
        ax.set_title(f"NaN'd Segment (Index: {random_idx}) TS: {center_timestamp:.0f} ms\nOriginal Dev: {deviation_at_point:.2f}")
        ax.set_xlabel("Timestamp (ms)")
        ax.set_ylabel("Pupil Size")
        ax.legend(loc='best')
        ax.grid(True, linestyle=':', alpha=0.5)

    plt.tight_layout()
    plt.show()


# --- Plotting Segments that PASSED the Session-Wise Multi-Pass Filter ---

# Candidate indices are those where pupil data is NOT NaN in the *final* cleaned_df
# And where final_concatenated_trend_line was also computed (not NaN)
candidate_indices_passed = cleaned_df[
    cleaned_df["Pupil"].notna() & final_concatenated_trend_line.notna()
].index

print(f"\n--- Plotting {num_plots_to_sample} Segments that PASSED All Filters (Session-Wise) ---")
if len(candidate_indices_passed) == 0:
    print("No data points found that passed all filters.")
else:
    if len(candidate_indices_passed) < num_plots_to_sample:
        print(f"Warning: Fewer than {num_plots_to_sample} samples passed. Plotting all {len(candidate_indices_passed)}.")
        sampled_indices_passed = candidate_indices_passed
    else:
        # Simpler random sampling for passed data; can add edge buffer if issues arise
        sampled_indices_passed = random.sample(list(candidate_indices_passed), num_plots_to_sample)


    fig_passed, axes_passed = plt.subplots(nrows=num_plots_to_sample, ncols=1,
                                           figsize=(15, 5 * num_plots_to_sample), sharex=False)
    if num_plots_to_sample == 1: # Make axes_passed iterable if only one plot
        axes_passed = [axes_passed]

    for i, random_idx in enumerate(sampled_indices_passed):
        ax = axes_passed[i]
        center_timestamp = cleaned_df.loc[random_idx, "Timestamp"]
        min_ts = center_timestamp - plot_window_ms
        max_ts = center_timestamp + plot_window_ms

        window_df_indices = cleaned_df[
            (cleaned_df["Timestamp"] >= min_ts) & (cleaned_df["Timestamp"] <= max_ts)
        ].index

        if window_df_indices.empty:
            ax.set_title(f"Empty window for Passed segment (Index: {random_idx}) TS: {center_timestamp:.0f} ms")
            ax.text(0.5, 0.5, "No data in window", ha='center', va='center')
            continue

        timestamps_window = cleaned_df.loc[window_df_indices, "Timestamp"]
        pupil_data_window = cleaned_df.loc[window_df_indices, "Pupil"] # Final "good" pupil data
        trend_line_window = final_concatenated_trend_line.loc[window_df_indices]

        ax.plot(timestamps_window, pupil_data_window,
                label='Pupil Data (Passed Filter)', color='blue', alpha=0.9, marker='.', markersize=3, linestyle='-')
        ax.plot(timestamps_window, trend_line_window,
                label='Final Concatenated Trend Line', color='red', linestyle='--', alpha=0.8)

        value_at_idx = cleaned_df.loc[random_idx, "Pupil"]
        ax.scatter(center_timestamp, value_at_idx,
                   color='cyan', s=100, zorder=5, label=f'Selected Point: {value_at_idx:.2f}')

        deviation_at_point = abs(value_at_idx - final_concatenated_trend_line.loc[random_idx]) \
                             if pd.notna(value_at_idx) and pd.notna(final_concatenated_trend_line.loc[random_idx]) \
                             else float('nan')
        ax.set_title(f"Passed Segment (Index: {random_idx}) TS: {center_timestamp:.0f} ms\nActual Dev: {deviation_at_point:.2f}")
        ax.set_xlabel("Timestamp (ms)")
        ax.set_ylabel("Pupil Size")
        ax.legend(loc='best')
        ax.grid(True, linestyle=':', alpha=0.5)

    plt.tight_layout()
    plt.show()

## Step 2.3: Identifying Islands (Sparsity Filter):
Identifies 'small islands' of potentially valid but noisy-looking data surrounded by large gaps (such as between blinks and saccades) which can be indicative of noise, thereby throwing off interpolation going forward.

In [None]:
# --- Parameters for identifying small islands ---
# These likely need fine-tuning:
MAX_ISLAND_DURATION_SAMPLES = 50  # e.g. islands of 15ms or less
MIN_SURROUNDING_GAP_SAMPLES = 40 # e.g. surrounded by gaps of at least 75ms on each side

pupil_series_for_island_removal = cleaned_df["Pupil"].copy()

# --- Helper function to get block information ---
def get_block_info(series):
    """Identifies contiguous blocks of NaN and non-NaN data."""
    is_nan_series = series.isna()
    # Find changes from NaN to non-NaN or vice-versa
    change_points = is_nan_series.diff().ne(0) # For each value, diff subtracts the prior value in the series, .ne(0) then returns True if the result == 0.
    # Assign a unique ID to each block
    block_ids = change_points.cumsum() # Goes over the bool series, giving each an index that increases when a True value is seen.

    blocks = []
    for block_id, group in series.groupby(block_ids):
        blocks.append({
            "id": block_id,
            "is_nan_block": group.isna().all(), # True if it's a NaN block
            "start_index": group.index[0],
            "end_index": group.index[-1],
            "duration_samples": len(group)
        })
    return pd.DataFrame(blocks)


nans_before_island_removal = pupil_series_for_island_removal.isna().sum()
print(f"NaNs before island removal: {nans_before_island_removal}")

# Apply per session to handle session boundaries correctly:
indices_to_nan_from_islands = []
for session_id in cleaned_df["Session"].unique():
    session_mask = (cleaned_df["Session"] == session_id)
    session_pupil_data = pupil_series_for_island_removal[session_mask]

    # Get blocks for this session
    session_blocks_df = get_block_info(session_pupil_data) # So we've effectively got a dataframe describing subseries that are defined either
                                                           # as consisting either of single and isolated or continuous data points/NaN values.

    # Iterate through the blocks to find small data islands surrounded by large NaN gaps
    for i, current_block in session_blocks_df.iterrows():
        if not current_block["is_nan_block"]:  # If it's a data island
            if current_block["duration_samples"] <= MAX_ISLAND_DURATION_SAMPLES:
                # This is a "small" island. Now check its neighbors.

                # Check preceeding NaN block
                preceding_nan_gap_sufficient = False
                if i > 0: # If there is a preceding block, i.e we're not at the first index
                    prev_block = session_blocks_df.iloc[i-1]
                    if prev_block["is_nan_block"] and prev_block["duration_samples"] >= MIN_SURROUNDING_GAP_SAMPLES:
                        preceding_nan_gap_sufficient = True
                else:
                    # If it's the first block of the session and it's a small data island,
                    # consider the "gap" before it as infinitely large.
                    preceding_nan_gap_sufficient = True

                # Check the succeeding NaN block
                succeeding_nan_gap_sufficient = False
                if i < len(session_blocks_df) - 1: # If there is a succeeding block (-1 because of 0-indexing)
                    next_block = session_blocks_df.iloc[i+1]
                    if next_block["is_nan_block"] and next_block["duration_samples"] >= MIN_SURROUNDING_GAP_SAMPLES:
                        succeeding_nan_gap_sufficient = True
                else:
                    # If it's the last block of the session and it's a small data island,
                    # consider the "gap" after it as infinitely large.
                    succeeding_nan_gap_sufficient = True

                if preceding_nan_gap_sufficient and succeeding_nan_gap_sufficient:
                    # This island meets the criteria. Mark its samples for NaNing.
                    # Get the actual indices from the original cleaned_df
                    island_original_indices = cleaned_df[session_mask][
                        (cleaned_df[session_mask].index >= current_block["start_index"]) &
                        (cleaned_df[session_mask].index <= current_block["end_index"])
                    ].index
                    indices_to_nan_from_islands.extend(island_original_indices)
                    print(f"Session {session_id}: Marking island from {current_block['start_index']} to {current_block['end_index']} (Duration: {current_block['duration_samples']}) as NaN.")


# Apply the NaNing to the pupil series
if indices_to_nan_from_islands:
    # Use .loc with the list of original DataFrame indices
    pupil_series_for_island_removal.loc[indices_to_nan_from_islands] = np.nan
    # To update the main DataFrame:
    cleaned_df.loc[indices_to_nan_from_islands, "Pupil"] = np.nan


nans_after_island_removal = pupil_series_for_island_removal.isna().sum() # or cleaned_df["Pupil"].isna().sum()
print(f"NaNs after island removal: {nans_after_island_removal}")
print(f"Samples NaN'd due to island removal: {nans_after_island_removal - nans_before_island_removal}")

## Step 3.1: Interpolate over sufficiently small enough gaps:

In [None]:
# --- Step: Interpolate Missing Data (NaNs) ---
print("\n--- Interpolating Missing Pupil Data ---")

# Parameters for interpolation
MAX_INTERPOLATION_GAP_MS = 250 # 250 based on paper, set to None to interpolate all gaps regardless of length.

limit_samples_for_interpolation = MAX_INTERPOLATION_GAP_MS if MAX_INTERPOLATION_GAP_MS is not None else None

nans_before_interpolation = cleaned_df["Pupil"].isna().sum()
print(f"NaNs in 'Pupil' column before interpolation: {nans_before_interpolation}")

# Create a new column for interpolated data.
# This keeps your 'Pupil' column (with NaNs from artifact rejection) intact for reference if needed.
# We apply interpolation per session to avoid filling data across intended large gaps between sessions.
cleaned_df["Pupil Int"] = cleaned_df.groupby("Session")["Pupil"].transform( # Transform merges the subseries from groupby back to the original series shape
    lambda x: x.interpolate(
        method='linear',
        limit_direction='both',  # Fills from both ends of a NaN sequence up to the limit
        limit=limit_samples_for_interpolation  # Max number of consecutive NaNs to fill
    )
)

nans_after_interpolation = cleaned_df["Pupil Int"].isna().sum()
print(f"NaNs in 'Pupil Int' column after interpolation: {nans_after_interpolation}")

In [None]:
# --- Visual Check (Optional but Recommended) ---
# This is a basic plot. You can adapt your existing, more detailed plotting routines.
# We try to find a segment that HAD NaNs and now doesn't, to illustrate.

# Find indices where 'Pupil' was NaN but 'Pupil Int' is not
interpolated_indices = cleaned_df[cleaned_df["Pupil"].isna() & cleaned_df["Pupil Int"].notna()].index

if not interpolated_indices.empty:
    num_plot_samples = min(3, len(interpolated_indices)) # Plot up to 3 examples
    fig, axes = plt.subplots(num_plot_samples, 1, figsize=(15, 4 * num_plot_samples), sharex=False)
    if num_plot_samples == 1: axes = [axes] # Make axes iterable

    for i in range(num_plot_samples):
        # Pick a random index from those that were successfully interpolated
        center_idx = random.choice(interpolated_indices)

        # Define a window around the center_idx for plotting
        # Ensure window indices are valid and within the DataFrame bounds
        plot_start_idx = cleaned_df.index.get_loc(center_idx) - 100
        plot_end_idx = cleaned_df.index.get_loc(center_idx) + 100

        # Clamp to valid DataFrame iloc range
        plot_start_idx = max(0, plot_start_idx)
        plot_end_idx = min(len(cleaned_df) -1, plot_end_idx)

        # Get the segment using iloc for integer-based slicing
        segment_df = cleaned_df.iloc[plot_start_idx:plot_end_idx + 1] # +1 because iloc end is exclusive

        if not segment_df.empty:
            ax = axes[i]
            # Plot original data (with NaNs) for this segment
            ax.plot(segment_df["Timestamp"], segment_df["Pupil"],
                    marker='o', linestyle='-', color='skyblue', alpha=1, markersize=5, linewidth=5, label="Original Pupil (with NaNs)")
            # Plot interpolated data for this segment
            ax.plot(segment_df["Timestamp"], segment_df["Pupil Int"],
                    marker='.', linestyle='-', color='red', alpha=0.9, markersize=3, label="Pupil Int")
            ax.axvline(x=cleaned_df.loc[center_idx, 'Timestamp'], label=f"Center Point, TS: {cleaned_df.loc[center_idx, 'Timestamp']}", color="green")

            ax.set_title(f"Interpolation Example (Index {center_idx}, Timestamp {cleaned_df.loc[center_idx, 'Timestamp']})")
            ax.set_xlabel("Timestamp")
            ax.set_ylabel("Pupil Size")
            ax.legend(loc="upper center")
            ax.grid(True)
        else:
            axes[i].text(0.5, 0.5, "Could not get segment for plotting", ha='center')

    plt.tight_layout()
    plt.show()
else:
    print("No segments found that were NaN before and interpolated after (e.g., all data was initially valid or all gaps were too long).")

## Step 3.2: Calculate Percentage of Trial That Contains Interpolated Data
This allows us to later reject trials if they have data that is mostly 'inferred'.

In [None]:
investigate_percentage_threshold = 20 # Change this to affect the printout (pc trials containing interpolated data)

trial_identifier_cols = ["Run", "Session", "Block", "Trial"]

# Create a new column, initialized to NaN
cleaned_df["Interpolation PC"] = np.nan

# Get unique trials
# Sort them to process in a logical order, though not strictly necessary for this calculation
unique_trials_df = cleaned_df[trial_identifier_cols].drop_duplicates().sort_values(by=trial_identifier_cols).reset_index(drop=True)

print(f"Found {len(unique_trials_df)} unique trials to process.")

# counter for num of trials with interpolated data:
intp_trials = 0
intp_trials_investigate = 0

for index, trial_info in unique_trials_df.iterrows():
    current_run = trial_info["Run"]
    current_session = trial_info["Session"]
    current_block = trial_info["Block"]
    current_trial_num = trial_info["Trial"]

    # Create a mask to select all rows for the current trial
    trial_mask = (cleaned_df["Run"] == current_run) & \
                 (cleaned_df["Session"] == current_session) & \
                 (cleaned_df["Block"] == current_block) & \
                 (cleaned_df["Trial"] == current_trial_num)

    # Get the segment of the DataFrame for the current trial
    current_trial_data = cleaned_df[trial_mask]

    if current_trial_data.empty:
        print(f"Warning: No data found for trial: Run {current_run}, Sess {current_session}, Block {current_block}, Trial {current_trial_num}. Skipping.")
        continue

    # Total number of samples in this trial
    total_samples_in_trial = len(current_trial_data)

    # Number of samples that were originally NaN and have interpolated values for this trial
    successfully_interpolated_samples = current_trial_data[
        current_trial_data["Pupil"].isna() & current_trial_data["Pupil Int"].notna()
    ].shape[0]

    # Calculate percentage
    if total_samples_in_trial > 0:
        percent_interpolated = (successfully_interpolated_samples / total_samples_in_trial) * 100
    else:
        percent_interpolated = 0.0 # Or np.nan if you prefer for empty trials

    # Add to trial count if trial contains interpolated data:
    if percent_interpolated > 0:
        intp_trials += 1
    if percent_interpolated >= investigate_percentage_threshold:
        intp_trials_investigate += 1

    # Assign this percentage to all rows of the current trial in the main DataFrame
    cleaned_df.loc[trial_mask, "Interpolation PC"] = percent_interpolated

print("\n--- Finished Calculating Interpolation Percentages ---")

# --- Verification (Optional) ---
# Check if any trials still have NaN in "Percent_Interpolated" (shouldn't happen if all trials were processed)
if cleaned_df["Interpolation PC"].isna().any():
    print("\nWarning: Some trials have NaN for 'Interpolation PC'. This might indicate an issue.")

num_of_unique_trials = len(unique_trials_df)

print(f"Percentage of trials containing interpolated data: {intp_trials/num_of_unique_trials * 100:.2f}% ({intp_trials}/{num_of_unique_trials})")
print(f"Percentage of trials containing more than {investigate_percentage_threshold}% interpolated data: {intp_trials_investigate/num_of_unique_trials * 100:.2f}% ({intp_trials_investigate}/{num_of_unique_trials})")
cleaned_df["Interpolation PC"].describe()

## Step 3.3: Remove Trials That Still Contain NaNs
If a trial still contains NaN data, then it was likely involved in an exceptionally long blink or artifact due to recording error. In that case, we'll flag those trials for exclusion, as they'll be both unreliable and will mess with the smoothing in Step 4:

In [None]:
cleaned_df[(cleaned_df["Timestamp"].between(2533052,2533251))]

In [None]:
# --- Step: Identify and Handle Trials with Remaining NaNs After Interpolation ---
print("\n--- Identifying and Handling Trials with Remaining NaNs Before Filtering ---")

trial_identifier_cols = ["Run", "Session", "Block", "Trial"]

# Add an 'Exclude Trial' column if it doesn't exist
if "Exclude Trial" not in cleaned_df.columns:
    cleaned_df["Exclude Trial"] = False
    print("Added 'Exclude Trial' column, initialized to False.")

# Get unique trials
unique_trials_for_nan_check_df = cleaned_df[trial_identifier_cols].drop_duplicates()

trials_with_nans_count = 0
total_samples_in_nan_trials = 0

for index, trial_info in unique_trials_for_nan_check_df.iterrows():
    current_run = trial_info["Run"]
    current_session = trial_info["Session"]
    current_block = trial_info["Block"]
    current_trial_num = trial_info["Trial"]

    trial_mask = (cleaned_df["Run"] == current_run) & \
                 (cleaned_df["Session"] == current_session) & \
                 (cleaned_df["Block"] == current_block) & \
                 (cleaned_df["Trial"] == current_trial_num)

    # Check if any NaNs exist in 'Pupil Int' for this trial
    if cleaned_df.loc[trial_mask, "Pupil Int"].isna().any():
        trials_with_nans_count += 1
        total_samples_in_nan_trials += trial_mask.sum()

        # Mark the trial for exclusion
        cleaned_df.loc[trial_mask, "Exclude Trial"] = True


if trials_with_nans_count > 0:
    print(f"Identified {trials_with_nans_count} trials with remaining NaNs in 'Pupil Int'.")
    print(f"These {total_samples_in_nan_trials} samples in these trials will be marked for exclusion (or handled as per your choice).")
    print("Example trials marked for exclusion:")
    print(cleaned_df[cleaned_df["Exclude Trial"] == True][trial_identifier_cols].drop_duplicates().head())
else:
    print("No trials found with remaining NaNs in 'Pupil Int'. All trials are ready for filtering.")

## Step 4: Smoothing the data
The paper recommends running the data through a zero-phase low-pass filter with 4Hz cutoff:

In [None]:
cleaned_df[(cleaned_df["Timestamp"].between(2533052,2533251))]

In [None]:
# --- Step: Low-Pass Filtering (e.g., Butterworth) --- #
print("\n--- Applying Zero-Phase Low-Pass Butterworth Filter ---")

# Filter Parameters
CUTOFF_FREQUENCY_HZ = 4.0
FILTER_ORDER = 4

# Design the Butterworth filter
nyquist_freq = 0.5 * 1000 # * sampling rate
normal_cutoff = CUTOFF_FREQUENCY_HZ / nyquist_freq
b, a = butter(FILTER_ORDER, normal_cutoff, btype='low', analog=False)

# The data to filter is 'Pupil Int'.
# Trials with remaining NaNs in 'Pupil Int' are marked by 'Exclude Trial = True'
# and we'll create 'Pupil Filtered' for all rows, setting it to NaN for these excluded trials

# Create the 'Pupil Filtered' column, initialized with NaN data and later filled only for non-excluded trials.
cleaned_df["Pupil Filtered"] = np.nan

# Create a temporary series for filtering that has NaNs for excluded trials to ensure filtfilt is only applied to valid trial data.
temp_pupil_for_filtering = cleaned_df["Pupil Int"].copy()
if cleaned_df["Exclude Trial"].any():
    temp_pupil_for_filtering.loc[cleaned_df["Exclude Trial"] == True] = np.nan
    print(f"Temporarily NaN'd Pupil Int for {cleaned_df['Exclude Trial'].sum()} samples from excluded trials before filtering.")

if temp_pupil_for_filtering.isna().any():
    print(f"Warning: {temp_pupil_for_filtering.isna().sum()} NaNs present in the data to be filtered (likely from excluded trials).")
    print("  The simplified filtfilt expects no NaNs within a processing segment.")
    print("  The apply_filtfilt_simplified_safe function will handle NaN segments by returning NaNs.")

else: # If no NaN vals, then all Pupil Int data is fair game:
    temp_pupil_for_filtering = cleaned_df["Pupil Int"]


# Simplified safe application of filtfilt. It ensures that if a segment passed to it contains NaNs, it returns NaNs.
MIN_SAMPLES_FOR_FILTER = FILTER_ORDER * 3 +1 # Padlen must be less than data length

def apply_filtfilt_simplified_safe(series, b_coeffs, a_coeffs):
    if series.isna().all(): # If the entire series segment (e.g., a session where all trials were excluded) is NaN
        return series # Return the all-NaN series
    if series.isna().any(): # If there are any NaNs mixed in (shouldn't happen if pre-processing is correct)
        print("Error: apply_filtfilt_simplified_safe received a series with mixed NaNs and non-NaNs. This indicates an issue in pre-filtering NaN handling.")
        # Fallback: return NaNs for the whole series to avoid errors, but this is a sign of a problem.
        return pd.Series(np.nan, index=series.index, dtype=series.dtype)
    if len(series.dropna()) < MIN_SAMPLES_FOR_FILTER: # Check length of actual data
        print(f"Warning: Segment too short ({len(series.dropna())} non-NaN samples) for filtfilt. Returning original data for this segment.")
        return series
    return filtfilt(b_coeffs, a_coeffs, series.dropna()) # Apply to non-NaN data; re-indexing might be needed if using on subset

# Applying the filter using groupby.transform
filtered_series_list = []
for session_id, group_df in cleaned_df.groupby("Session"):
    session_data_to_filter = temp_pupil_for_filtering.loc[group_df.index]

    # If all data for this session is NaN (e.g., all trials excluded)
    if session_data_to_filter.isna().all():
        filtered_segment = session_data_to_filter # Keep it all NaN
    elif session_data_to_filter.isna().any():
        # This case means some trials in the session are valid, some are NaN (excluded).
        # We need to filter only the contiguous non-NaN parts.
        filtered_segment = pd.Series(np.nan, index=session_data_to_filter.index, dtype=float)
        not_nan_mask = session_data_to_filter.notna()

        # Find change points to identify contiguous blocks of non-NaN data
        change_points = not_nan_mask.diff().ne(0)
        block_ids = change_points.cumsum()

        for block_id, data_block in session_data_to_filter[not_nan_mask].groupby(block_ids[not_nan_mask]):
            if len(data_block) >= MIN_SAMPLES_FOR_FILTER:
                filtered_sub_segment = filtfilt(b, a, data_block)
                filtered_segment.loc[data_block.index] = filtered_sub_segment
            else:
                print(f"Info: Non-NaN sub-segment in session {session_id} too short ({len(data_block)} samples). Not filtering this part.")
                filtered_segment.loc[data_block.index] = data_block # Keep original for very short valid segments
    else: # No NaNs in this session's data to filter
        if len(session_data_to_filter) < MIN_SAMPLES_FOR_FILTER:
            print(f"Info: Entire session {session_id} data too short ({len(session_data_to_filter)} samples). Not filtering.")
            filtered_segment = session_data_to_filter
        else:
            filtered_segment = filtfilt(b, a, session_data_to_filter)

    filtered_series_list.append(pd.Series(filtered_segment, index=group_df.index))

if filtered_series_list:
    cleaned_df["Pupil Filtered"] = pd.concat(filtered_series_list).sort_index()
else:
    print("No data processed for filtering.")


nans_after_filtering = cleaned_df["Pupil Filtered"].isna().sum()
if "Exclude Trial" in cleaned_df.columns:
    expected_nans_from_excluded = cleaned_df["Exclude Trial"].sum()
    if nans_after_filtering > 0 :
        print(f"NaNs in 'Pupil Filtered' after filtering: {nans_after_filtering} (expected around {expected_nans_from_excluded} from excluded trials).")
else:
     if nans_after_filtering > 0:
        print(f"NaNs in 'Pupil Filtered' after filtering: {nans_after_filtering}.")


print("--- Butterworth Filtering Complete ---")

In [None]:
# --- Visual Check ---
# (Same visual check code as before, it will use Pupil Int and Pupil Filtered)
if not cleaned_df.empty:
    num_plot_samples_viz = min(3, len(cleaned_df["Session"].unique()))
    plot_indices_viz = []

    # Try to pick segments from non-excluded trials if possible
    valid_for_plot_df = cleaned_df
    if "Exclude Trial" in cleaned_df.columns:
        valid_for_plot_df = cleaned_df[cleaned_df["Exclude Trial"] == False]

    if valid_for_plot_df.empty and "Exclude Trial" in cleaned_df.columns:
        print("No non-excluded trials to pick for visualization. Plotting from original df if available.")
        valid_for_plot_df = cleaned_df # Fallback to all data if no non-excluded

    if not valid_for_plot_df.empty:
        for session_id in valid_for_plot_df["Session"].unique()[:num_plot_samples_viz]:
            session_data = valid_for_plot_df[valid_for_plot_df["Session"] == session_id]
            if len(session_data) > 4000: # Ensure session is long enough for a good window
                mid_point_iloc_viz = session_data.index.get_loc(session_data.index[len(session_data) // 2])
                plot_indices_viz.append(cleaned_df.index[mid_point_iloc_viz]) # Get original df index
            elif not session_data.empty:
                plot_indices_viz.append(session_data.index[0])

        if not plot_indices_viz:
             if len(cleaned_df) > 4000:
                plot_indices_viz.append(cleaned_df.index[len(cleaned_df)//2])
             elif not cleaned_df.empty:
                plot_indices_viz.append(cleaned_df.index[0])

        fig, axes = plt.subplots(len(plot_indices_viz), 1, figsize=(15, 5 * len(plot_indices_viz)), sharex=False)
        if len(plot_indices_viz) == 1: axes = [axes]

        for i, center_idx in enumerate(plot_indices_viz):
            ax = axes[i]
            plot_start_iloc = cleaned_df.index.get_loc(center_idx) - 2000
            plot_end_iloc = cleaned_df.index.get_loc(center_idx) + 2000
            plot_start_iloc = max(0, plot_start_iloc)
            plot_end_iloc = min(len(cleaned_df) - 1, plot_end_iloc)
            segment_df_viz = cleaned_df.iloc[plot_start_iloc:plot_end_iloc + 1]

            if not segment_df_viz.empty:
                ax.plot(segment_df_viz["Timestamp"], segment_df_viz["Pupil Int"], label="Pupil Interpolated", color="skyblue", alpha=0.7)
                ax.plot(segment_df_viz["Timestamp"], segment_df_viz["Pupil Filtered"], label=f"Pupil Filtered ({CUTOFF_FREQUENCY_HZ}Hz)", color="red")
                current_session_id_viz = segment_df_viz["Session"].iloc[0]
                ax.set_title(f"Filtering Example (Session {current_session_id_viz}, around Timestamp {cleaned_df.loc[center_idx, 'Timestamp']})")
                ax.set_xlabel("Timestamp")
                ax.set_ylabel("Pupil Size")
                ax.legend(loc="lower center")
                ax.grid(True)
            else:
                ax.text(0.5, 0.5, "Could not get segment for plotting", ha='center')
        plt.tight_layout()
        plt.show()
    else:
        print("No suitable data for filter visualization.")

In [None]:
print("\n--- Plotting Original (Interpolated) vs. Filtered Pupil Data for the Entire Run ---")

if cleaned_df.empty:
    print("DataFrame is empty. Cannot generate plot.")
else:
    plt.figure(figsize=(20, 8)) # Adjust figure size as needed

    # Plot 1: Pupil Int (data after interpolation, before this Butterworth filter)
    # We plot this with some transparency to see the filtered line clearly.
    # If there are still NaNs in Pupil Int (e.g., from very long gaps not interpolated),
    # they will appear as breaks in this line.
    plt.plot(cleaned_df["Timestamp"], cleaned_df["Pupil Int"],
             label="Pupil Data (After Interpolation)", color="skyblue", alpha=0.7, linewidth=1.5)

    # Plot 2: Pupil Filtered (data after Butterworth filter)
    # This line should be smoother. NaNs here would correspond to:
    #   - Trials excluded before filtering (their input was NaN).
    #   - Very short valid segments that were not filtered.
    plt.plot(cleaned_df["Timestamp"], cleaned_df["Pupil Filtered"],
             label=f"Pupil Data (Filtered - {CUTOFF_FREQUENCY_HZ}Hz Butterworth)", color="red", linewidth=1)

    # Add vertical lines for session boundaries if you have multiple sessions in `cleaned_df`
    # and want to visualize them.
    if "Session" in cleaned_df.columns and cleaned_df["Session"].nunique() > 1:
        session_changes = cleaned_df["Session"].diff().fillna(0) != 0
        session_start_timestamps = cleaned_df.loc[session_changes, "Timestamp"]
        for ts in session_start_timestamps:
            if ts != cleaned_df["Timestamp"].iloc[0]: # Don't draw for the very first sample
                plt.axvline(x=ts, color='gray', linestyle='--', linewidth=1, label='Session Start' if ts == session_start_timestamps.iloc[1] else None) # Label only once


    plt.title(f"Pupil Data: Interpolated vs. Butterworth Filtered ({CUTOFF_FREQUENCY_HZ}Hz) - Entire Run")
    plt.xlabel("Timestamp (ms)")
    plt.ylabel("Pupil Size")
    plt.legend(loc="upper right")
    plt.grid(True, linestyle=':', alpha=0.6)

    # You might want to adjust x-axis limits if the run is extremely long
    # to focus on specific parts, or let matplotlib auto-scale.
    # Example: plt.xlim(cleaned_df["Timestamp"].min(), cleaned_df["Timestamp"].min() + 60000) # First minute

    plt.tight_layout()
    plt.show()

# --- Optional: Plotting each session separately for better detail ---
if "Session" in cleaned_df.columns and cleaned_df["Session"].nunique() > 1:
    print("\n--- Plotting each session separately ---")
    num_sessions = cleaned_df["Session"].nunique()

    # Determine number of rows and columns for subplots
    # Aim for a somewhat square layout, or max 2 columns for readability
    if num_sessions <= 2:
        n_cols = num_sessions
        n_rows = 1
    elif num_sessions <= 4:
        n_cols = 2
        n_rows = int(np.ceil(num_sessions / 2.0))
    else: # More than 4 sessions, might get crowded
        n_cols = 2 # Or 3 if you prefer
        n_rows = int(np.ceil(num_sessions / float(n_cols)))


    fig_sessions, axes_sessions = plt.subplots(n_rows, n_cols, figsize=(10 * n_cols, 6 * n_rows), sharey=True, squeeze=False)
    axes_sessions_flat = axes_sessions.flatten()

    for i, (session_id, session_df) in enumerate(cleaned_df.groupby("Session")):
        if i >= len(axes_sessions_flat): # Should not happen with correct subplot calculation
            break
        ax = axes_sessions_flat[i]

        ax.plot(session_df["Timestamp"], session_df["Pupil Int"],
                label="Interpolated", color="skyblue", alpha=0.7, linewidth=1.5)
        ax.plot(session_df["Timestamp"], session_df["Pupil Filtered"],
                label=f"Filtered ({CUTOFF_FREQUENCY_HZ}Hz)", color="red", linewidth=1)

        ax.set_title(f"Session {session_id}")
        ax.set_xlabel("Timestamp (ms)")
        ax.set_ylabel("Pupil Size")
        ax.legend(loc="upper right")
        ax.grid(True, linestyle=':', alpha=0.6)

    # Hide any unused subplots
    for j in range(i + 1, len(axes_sessions_flat)):
        fig_sessions.delaxes(axes_sessions_flat[j])

    plt.tight_layout()
    plt.show()
elif "Session" not in cleaned_df.columns:
    print("No 'Session' column found, cannot plot per session.")

# Create Baselines:
Baselines will be needed to then standardise the pupil's responses over trials, thereby accounting and controlling for factors such as changes in light and environment which may skew the data when comparing over timecourses.

I will take the final 200ms of the prior trial (-200ms relative to stimulus onset) for the trial's baseline. However, this does mean that the first trial of each session needs to be handled differently. I shall use the final 500ms of the adaptation period to form this baseline. I chose a longer period from which to average from in the hopes that it would be more consistent with the interpolated, smoothed pupil data, but honestly it's just a shot in the dark

In [None]:
if PRODUCE_BASELINES:
    # Parameters
    BASELINE_DURATION_MS = 200 # Your desired duration (e.g., -200ms to -1ms from stim onset)

    # Initialize new columns
    cleaned_df["Baseline"] = np.nan
    cleaned_df["Pupil Normed"] = np.nan      # For subtractive normalization
    cleaned_df["Pupil Normed PC"] = np.nan   # For percent change normalization

    # 1. Load adaptation baselines (for the first trial of each session), which were saved by the EDF Reader notebook
    with open(f"data/processed_data/baselines/{PARTICIPANT_ID}/run_{run}_baselines.pkl", 'rb') as f:
        adaptation_baselines = pickle.load(f) # Expected format: {'Session 1': value, 'Session 2': value, ...}

    trial_identifier_cols = ["Run", "Session", "Block", "Trial"]

    # Get unique trials, sorted to process them chronologically. We also need 'Timestamp' to find trial onsets
    unique_trials_df = cleaned_df[trial_identifier_cols + ["Timestamp"]].groupby(
        trial_identifier_cols, as_index=False
    ).agg(
        Trial_Onset_Timestamp=('Timestamp', 'min') # Get the min timestamp for each trial as its onset
    ).sort_values(by=trial_identifier_cols).reset_index(drop=True)

    print(f"Found {len(unique_trials_df)} unique trials for baseline calculation.")

    for index, trial_info in unique_trials_df.iterrows():
        current_run_num = trial_info["Run"] # Using a different var name to avoid conflict with 'run' if it's global
        current_session = trial_info["Session"]
        current_block = trial_info["Block"]
        current_trial_num = trial_info["Trial"]
        trial_onset_timestamp = trial_info["Trial_Onset_Timestamp"]

        # Create a mask to select all rows for the current trial in the main DataFrame
        trial_mask_main_df = (cleaned_df["Run"] == current_run_num) & \
                             (cleaned_df["Session"] == current_session) & \
                             (cleaned_df["Block"] == current_block) & \
                             (cleaned_df["Trial"] == current_trial_num)

        baseline_mean = np.nan

        # Check if this trial was marked for exclusion (if you have that column)
        is_excluded_trial = False
        if "Exclude_Trial" in cleaned_df.columns:
            # Check the first row of the trial for the exclusion flag
            if not cleaned_df[trial_mask_main_df].empty:
                is_excluded_trial = cleaned_df[trial_mask_main_df]["Exclude_Trial"].iloc[0]

        if is_excluded_trial:
            print(f"Info: Trial Run {current_run_num}, Sess {current_session}, Block {current_block}, Trial {current_trial_num} is marked for exclusion. Baseline will be NaN or calculated based on available data if not fully NaN.")
            # If an excluded trial has all NaNs in Pupil Filtered, the baseline will be NaN.
            # If it has some data, the baseline will be calculated, but the trial is still excluded.
            pass


        if current_block == 1 and current_trial_num == 1: # First trial of a session
            baseline_key = f"Session {current_session}"
            if baseline_key in adaptation_baselines:
                baseline_mean = adaptation_baselines[baseline_key]
            else:
                print(f"  Warning: No adaptation baseline found for '{baseline_key}'. Baseline for this trial will be NaN.")
        else:
            # Baseline from the specified period before the current trial's onset timestamp
            baseline_start_ts = trial_onset_timestamp - BASELINE_DURATION_MS
            baseline_end_ts = trial_onset_timestamp - 1 # Up to, but NOT including, the trial onset

            # Extract 'Pupil Filtered' data from this baseline period
            baseline_period_data = cleaned_df[
                (cleaned_df["Timestamp"] >= baseline_start_ts) &
                (cleaned_df["Timestamp"] <= baseline_end_ts)
            ]["Pupil Filtered"]

            if not baseline_period_data.empty and not baseline_period_data.isna().all():
                baseline_mean = baseline_period_data.mean()
            elif baseline_period_data.empty:
                print(f"  Warning: Baseline period empty for S{current_session} B{current_block} T{current_trial_num} (Timestamps {baseline_start_ts}-{baseline_end_ts}). Baseline will be NaN.")
            else: # Not empty, but all NaNs
                 print(f"  Warning: Baseline period data is all NaNs for S{current_session} B{current_block} T{current_trial_num} (Timestamps {baseline_start_ts}-{baseline_end_ts}). Baseline will be NaN.")


        # If a baseline_mean was successfully determined (not NaN)
        if pd.notna(baseline_mean):
            cleaned_df.loc[trial_mask_main_df, "Baseline"] = baseline_mean

            # Get the 'Pupil Filtered' data for the current trial
            current_trial_Pupil_Filtered = cleaned_df.loc[trial_mask_main_df, "Pupil Filtered"]

            # Subtractive Normalization
            cleaned_df.loc[trial_mask_main_df, "Pupil Normed"] = current_trial_Pupil_Filtered - baseline_mean

            # Percent Change Normalization
            # Avoid division by zero or by very small baseline values
            if baseline_mean != 0:
                cleaned_df.loc[trial_mask_main_df, "Pupil Normed PC"] = ((current_trial_Pupil_Filtered - baseline_mean) / baseline_mean) * 100
            else:
                # Handle zero baseline (e.g., set PC to NaN or a large number, or 0 if pupil also 0)
                cleaned_df.loc[trial_mask_main_df, "Pupil Normed PC"] = np.nan
                if not is_excluded_trial: # Don't warn for trials that are already problematic
                     print(f"  Warning: Baseline is zero for S{current_session} B{current_block} T{current_trial_num}. 'Pupil Normed PC' set to NaN.")
        else:
            # If baseline_mean is NaN, the normed columns will also remain NaN (as initialized)
            if not is_excluded_trial: # Don't warn for trials that are already problematic
                print(f"  Failed to determine baseline for S{current_session} B{current_block} T{current_trial_num}. Normed values will be NaN.")


    print("\n--- Baseline Calculation and Normalization Complete ---")

    # --- Verification (Optional) ---
    print("\nExample trials with baseline and normed values:")
    cols_to_show = trial_identifier_cols + ["Baseline", "Pupil Filtered", "Pupil Normed", "Pupil Normed PC"]
    # Show a few rows from a couple of trials to see the effect
    example_display_df = pd.concat([
        cleaned_df[cleaned_df["Trial"] == 1].head(3), # First trial (likely uses adaptation baseline)
        cleaned_df[cleaned_df["Trial"] == 2].head(3)  # Second trial (uses calculated baseline)
    ]).drop_duplicates(subset=trial_identifier_cols, keep='first') # Show one representative from each

    if not example_display_df.empty:
        print(example_display_df[cols_to_show])
    else:
        # If the above specific trials don't exist, just show some head() data
        print(cleaned_df[cols_to_show].head(10).drop_duplicates(subset=trial_identifier_cols, keep='first'))


    # Check for any trials where baseline might still be NaN
    nan_baseline_trials = cleaned_df[cleaned_df["Baseline"].isna()][trial_identifier_cols].drop_duplicates()
    if not nan_baseline_trials.empty:
        print(f"\nWarning: {len(nan_baseline_trials)} unique trials still have NaN for 'Baseline'. Review warnings above.")
        # print(nan_baseline_trials.head())

# Reject Trials:

In [None]:
def produce_pre_processed_csv(analysis_df, rejected_trials, filepath):
    # rejected_trials is a set of tuples: (Run, Session, Block, Trial)
    # analysis_df is your main cleaned_df at that point
    print(f"\nTotal unique trials identified for potential rejection/flagging: {len(rejected_trials)}")

    # Start with a full copy. This will be modified.
    output_df = analysis_df.copy()

    if rejected_trials:
        # Convert the set of rejected trial tuples into a DataFrame for merging
        df_trials_to_flag_or_reject = pd.DataFrame(list(rejected_trials), columns=trial_identifier_cols)

        # Merge analysis_df with the trials to flag/reject.
        # This adds an indicator column '_merge'.
        #   '_merge' == 'both': The trial was in df_trials_to_flag_or_reject.
        #   '_merge' == 'left_only': The trial was only in analysis_df (i.e., not flagged for rejection).
        merged_df = pd.merge(
            output_df,
            df_trials_to_flag_or_reject,
            on=trial_identifier_cols,
            how='left',  # Keep all rows from output_df
            indicator=True
        )

        if not DECONVOLUTION_OUTPUT:
            # Original behavior: Physically drop the rejected trials
            # Keep only rows where '_merge' is 'left_only' (i.e., trials not in rejected_trials)
            output_df = merged_df[merged_df['_merge'] == 'left_only'].drop(columns=['_merge'])
            action_taken = "rejected and dropped"
        else:
            # DECONVOLUTION_OUTPUT is True: Keep all trials, add a flag column
            # The 'merged_df' already contains all original rows.
            # We just need to create our flag based on the '_merge' column.
            # Let's name the flag column clearly, e.g., 'Trial_Rejected_Least_Strict'
            output_df = merged_df.copy() # Keep all rows from the merge
            output_df['Trial_Rejected_Least_Strict'] = (output_df['_merge'] == 'both')
            output_df = output_df.drop(columns=['_merge']) # Clean up the temporary indicator
            action_taken = "flagged in 'Trial_Rejected_Least_Strict' column"

        print(f"  Trials matching criteria have been {action_taken}.")
    else:
        print("  No trials matched rejection/flagging criteria.")
        if DECONVOLUTION_OUTPUT: # Ensure the flag column exists even if no trials are flagged
             output_df['Trial_Rejected_Least_Strict'] = False


    # --- Verify and Report Final Counts (of trials in the output_df) ---
    # This part needs to be careful based on whether rows were dropped or not.
    # For DECONVOLUTION_OUTPUT=True, total_trials_final should be same as total_trials_initial.

    total_trials_initial = len(analysis_df[trial_identifier_cols].drop_duplicates())

    if not DECONVOLUTION_OUTPUT: # If trials were dropped
        total_trials_final = len(output_df[trial_identifier_cols].drop_duplicates())
        print(f"\nInitial unique trials in input analysis_df: {total_trials_initial}")
        print(f"Trials remaining in output_df after rejection: {total_trials_final}")
        if total_trials_initial > 0:
            rejected_count = total_trials_initial - total_trials_final
            print(f"Number of unique trials rejected: {rejected_count}")
            print(f"Percentage of trials rejected: {(rejected_count / total_trials_initial) * 100:.2f}%")
        else:
            print("No trials to begin with in input analysis_df.")
    else: # If DECONVOLUTION_OUTPUT is True, trials are flagged, not dropped
        total_trials_final = len(output_df[trial_identifier_cols].drop_duplicates())
        flagged_trials_count = output_df[output_df['Trial_Rejected_Least_Strict'] == True][trial_identifier_cols].drop_duplicates().shape[0]
        print(f"\nInitial unique trials in input analysis_df: {total_trials_initial}")
        print(f"Total unique trials in output_df (all kept): {total_trials_final}")
        print(f"Number of unique trials flagged as 'Trial_Rejected_Least_Strict': {flagged_trials_count}")
        if total_trials_initial > 0:
             print(f"Percentage of trials flagged: {(flagged_trials_count / total_trials_initial) * 100:.2f}%")
        else:
            print("No trials to begin with in input analysis_df.")

    output_df.to_csv(filepath, index=(not isinstance(output_df.index, pd.RangeIndex))) # Save index if it's not default
    print(f"Output saved to: {filepath}")

In [None]:
# First, we need to add information about our epoch onsets and offsets to the dataframe for each trail:

trial_identifier_cols = ["Run", "Session", "Block", "Trial"]

# --- Get Main Stimulus Onset Timestamps for Each Trial ---
# Make sure to include trial identifiers for merging
main_stim_onsets_df = test_messages_df[test_messages_df["Message Type"] == "Main Stimulus Onset"][trial_identifier_cols + ["Timestamp"]].copy()

# Ensure one onset per trial (in case of duplicate messages)
main_stim_onsets_df = main_stim_onsets_df.drop_duplicates(subset=trial_identifier_cols, keep="first")
main_stim_onsets_df = main_stim_onsets_df.rename(columns={"Timestamp": "Trial Main Stim Onset"})

# --- Calculate Epoch Start and End Timestamps for Each Trial ---
main_stim_onsets_df["Trial Epoch Start"] = main_stim_onsets_df["Trial Main Stim Onset"] + config.EPOCH_START
main_stim_onsets_df["Trial Epoch End"] = main_stim_onsets_df["Trial Main Stim Onset"] + config.EPOCH_END

# --- Merge Epoch Information into cleaned_df ---
# This adds 'Trial Main Stim Onset', 'Trial Epoch Start', 'Trial Epoch End' to each sample based on its trial identifiers.
cleaned_df = pd.merge(cleaned_df,
                      main_stim_onsets_df[trial_identifier_cols + ["Trial Main Stim Onset", "Trial Epoch Start", "Trial Epoch End"]],
                      on=trial_identifier_cols,
                      how='left') # Use 'left' to keep all rows from cleaned_df


# --- Create a Boolean Column 'Is In Epoch' ---
# This flags each sample if its Timestamp falls within its trial's defined epoch
cleaned_df["Is In Epoch"] = (
    (cleaned_df["Timestamp"] >= cleaned_df["Trial Epoch Start"]) &
    (cleaned_df["Timestamp"] <= cleaned_df["Trial Epoch End"])
)

# --- Create a Boolean Column 'Is Before Epoch':
cleaned_df["Is Before Epoch"] = cleaned_df["Timestamp"] < cleaned_df["Trial Epoch Start"]

# Handle cases where Trial Epoch Start might be NaN (if a trial somehow missed its onset message)
cleaned_df["Is In Epoch"] = cleaned_df["Is In Epoch"].fillna(False)

print(f"Number of samples within an epoch: {cleaned_df["Is In Epoch"].sum()}")


In [None]:
# Next, we'll handle the rejection criteria that is essential, i.e. what I consider absolutely necessary:
trials_to_reject_set = set()
os.makedirs(f"data/fully_preprocessed_data/{participant_id}", exist_ok=True)

# Reject trials which contained data that could not be interpolated:
excluded_by_nan_trials = cleaned_df[cleaned_df["Exclude Trial"] == True][trial_identifier_cols].drop_duplicates()
for _, row in excluded_by_nan_trials.iterrows():
    trials_to_reject_set.add(tuple(row))
print(f"Trials to reject due to un-interpolatable NaNs: {len(excluded_by_nan_trials)}")

# Reject trials if there was a blink during the main stimulus presentation period:
blink_on_stim_trials = cleaned_df[cleaned_df["Blink On Main Stim"] == True][trial_identifier_cols].drop_duplicates()
for _, row in blink_on_stim_trials.iterrows():
    trials_to_reject_set.add(tuple(row))
print(f"Trials to reject due to blink on main stimulus: {len(blink_on_stim_trials)}")

# Reject trials above the interpolation allowed threshold (if a trial contains more than x% interpolated data, reject it):
high_interp_trials = cleaned_df[cleaned_df["Interpolation PC"] >= config.INTERPOLATION_PC_THRESHOLD][trial_identifier_cols].drop_duplicates()
for _, row in high_interp_trials.iterrows():
    trials_to_reject_set.add(tuple(row))
print(f"Trials to reject due to high interpolation ({config.INTERPOLATION_PC_THRESHOLD}%+): {len(high_interp_trials)}")

if not DECONVOLUTION_OUTPUT:
    produce_pre_processed_csv(
        analysis_df=cleaned_df,
        rejected_trials=trials_to_reject_set, filepath=f"data/fully_preprocessed_data/{participant_id}/{run}_least_strict.csv"
    )
else:
    produce_pre_processed_csv(
        analysis_df=cleaned_df,
        rejected_trials=trials_to_reject_set, filepath=f"data/fully_preprocessed_data/{participant_id}/{run}_deconvolution_input.csv"
    )
    test_messages_df.to_csv(f"data/fully_preprocessed_data/{participant_id}/{run}_messages.csv")

In [None]:
if not DECONVOLUTION_OUTPUT:
    # Next, we'll handle the rejection criteria that is probably a good idea to add, but not strictly necessary:

    # Reject trials if there was a saccade during the main stimulus presentation period:
    saccade_on_stim_trials = cleaned_df[cleaned_df["Saccade On Main Stim"] == True][trial_identifier_cols].drop_duplicates()
    for _, row in saccade_on_stim_trials.iterrows():
        trials_to_reject_set.add(tuple(row))
    print(f"Trials to reject due to saccade on main stimulus: {len(saccade_on_stim_trials)}")

    # Reject trials where a target appeared before or during the epoch of interest:
    samples_in_epoch_df = cleaned_df[(cleaned_df["Is In Epoch"] == True) | (cleaned_df["Is Before Epoch"] == True)]
    target_in_epoch_trials_df = samples_in_epoch_df[samples_in_epoch_df["Target Status"] == True][trial_identifier_cols].drop_duplicates()
    for _, row in target_in_epoch_trials_df.iterrows():
        trials_to_reject_set.add(tuple(row))
    print(f"Trials to reject due to a target appearing before or within the epoch ({config.EPOCH_START}-{config.EPOCH_END}ms from onset): {len(target_in_epoch_trials_df)}")

    produce_pre_processed_csv(analysis_df=cleaned_df, rejected_trials=trials_to_reject_set, filepath=f"data/fully_preprocessed_data/{participant_id}/{run}_mid_strict.csv")

In [None]:
if not DECONVOLUTION_OUTPUT:
    # Finally, we'll reject trials based on our most strictest requirements - things that can be argued are not necessary:

    # Reject trials where a blink has occurred before or during the epoch of interest:
    blink_in_epoch_trials_df = samples_in_epoch_df[samples_in_epoch_df["Blink"] == True][trial_identifier_cols].drop_duplicates()
    for _, row in blink_in_epoch_trials_df.iterrows():
        trials_to_reject_set.add(tuple(row))
    print(f"Trials to reject due to any blink occurring before or within the epoch ({config.EPOCH_START}-{config.EPOCH_END}ms from onset): {len(blink_in_epoch_trials_df)}")

    # Reject trials where a saccade has occurred before or during the epoch of interest:
    saccade_in_epoch_trials_df = samples_in_epoch_df[samples_in_epoch_df["Saccade"] == True][trial_identifier_cols].drop_duplicates()
    for _, row in saccade_in_epoch_trials_df.iterrows():
        trials_to_reject_set.add(tuple(row))
    print(f"Trials to reject due to any saccade occurring before or within the epoch ({config.EPOCH_START}-{config.EPOCH_END}ms from onset): {len(saccade_in_epoch_trials_df)}")

    produce_pre_processed_csv(analysis_df=cleaned_df, rejected_trials=trials_to_reject_set, filepath=f"data/fully_preprocessed_data/{participant_id}/{run}_most_strict.csv")

In [None]:
cleaned_df.columns