### Functions and script to construct a data frame trials information from raw Bonsai outputs

**Inputs:**
- .csv files containing raw harp data
    - poke_events.csv, containing timestamps of nose pokes to port 0 and port 1
    - sound_events.csv, containing timestamps of sound events
    - photodiode_data,
- experimental-data.csv file in 'Experimental-data' subdirectory containing trial-level behavioural data output directory from Bonsai workflow.

**Key outputs**
- trials_df data frame containing a summary of behavioural events in each trial including harp timestamps for dot onset and offset, nose pokes, and audio onset and offset times within each trial. Note that this data set contains redundancy to double check consistency of trial information between this script and the Bonsai output. 

**Overview** 
1. Read r
2. Align key behavioural events to trials in a pandas data frame in which each row represents one trial. 
3. Append harp data frame to behavioural summary data frame containing trial-level information.


In [None]:
# Import main libraries and define data folder
import harp
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

#==============================================================================

# Choose example session to analyze
animal_ID = 'FNT099'
session_ID = '2024-05-13T11-03-59'

# path raw data on Ceph
raw_data_dir = "W:\\projects\\FlexiVexi\\raw_data"
#==============================================================================

# Create reader for behavior.
bin_b_path = os.path.join(raw_data_dir, animal_ID, session_ID, "Behavior.harp")
behavior_reader = harp.create_reader(bin_b_path)

# Specify mapping from sound index to reward port (This shouldn't change unless 
# you reprogramme the soundcard!)
soundIdx0 = 14
soundIdx1 = 10
soundOffIdx = 18

# Output folder to save intermediate variables (Use session folder in raw data directory)
harp_data_dir = os.path.join(raw_data_dir, animal_ID, session_ID, "harp_data")

## Part 1: Read in raw data

**Get all poke events**

In [None]:
# Read in poke events .csv as a pandas dataframe
poke_events_filename = animal_ID + '_' + session_ID + '_' + 'poke_events.csv'
poke_events_filepath = os.path.join(harp_data_dir, poke_events_filename)
poke_events = pd.read_csv(poke_events_filepath)
poke_events

**Parse photodiode data**

In [None]:
# Read in photodiode data .csv as a Series
photodiode_filename = animal_ID + '_' + session_ID + '_' + 'photodiode_data.csv'
photodiode_filepath = os.path.join(harp_data_dir, photodiode_filename)
photodiode_data = pd.read_csv(photodiode_filepath)

# Set 'Time' as the index and extract 'AnalogInput0' as a Series
photodiode_data = photodiode_data.set_index('Time')['AnalogInput0']

photodiode_data

**Get all audio events**

In [None]:
# Read in poke events .csv as a pandas dataframe
sound_events_filename = animal_ID + '_' + session_ID + '_' + 'sound_events.csv'
sound_events_filepath = os.path.join(harp_data_dir, sound_events_filename)
sound_events = pd.read_csv(sound_events_filepath)
sound_events

**Get Data frame with Bonsai-triggered event timestamps and trial logic**

Note that this step is necessary for constructing trial information data frame downstream.

We will also change transform the data frame into something that is more useful for analysis as follows:

1. Append animal and session ID to trials_df

2. Reparameterise TrialCompletionCode into a more useful format:
    - CorrectTrial
    - AbortTrial
    - ChoicePort
    - CorrectPort

3. Rename 'DotOnsetTime' and 'DotOffsetTime' to 'DotOnsetTrigger' and 'DotOffsetTrigger' to distinguish the timestamp of the Bonsai trigger to project/offset the dot from the true onset/offset of the dot, which are extracted from the photodiode data downstream. This is necessary only as a common-sense check since our method of extracting dot onset/offset time from photodiode is not completely robust.

In [None]:
def get_experimental_data(root_dir):
    """
    Recursively searches for the 'experimental-data.csv' file within the given root directory.

    Args:
        root_dir (str): The root directory to start the search from.

    Returns:
        str: The full path to the 'experimental-data.csv' file if found, otherwise None.
    """
    for root, dirs, files in os.walk(root_dir):
        for file in files:
            if file.endswith("experimental-data.csv"):
                return os.path.join(root, file)

import numpy as np

def reparameterise_TrialCompletionCode(trial_data):
    """
    Parses the TrialCompletionCode into a more useful parameterization.

    Parameters:
    trial_data (pd.DataFrame): DataFrame containing trial data with the following columns:
        - 'TrialCompletionCode' (str): Code indicating the outcome of the trial.
        - 'AudioCueIdentity' (int): Identifier for the audio cue used in the trial.
        - 'TrialNumber' (int): Identifier for the trial number.
        - 'Animal_ID' (str): Identifier for the animal.
        - 'Session_ID' (str): Identifier for the session.

    Returns:
    pd.DataFrame: The modified DataFrame with additional columns:
        - 'CorrectTrial' (bool): Indicates if the trial was correct.
        - 'AbortTrial' (int): Indicates the type of aborted trial (1 for nosepoke, -1 for dot offset, 0 otherwise).
        - 'ChoicePort' (int): The chosen port ID for non-aborted trials.
        - 'CorrectPort' (int): The correct port ID based on the audio cue identity.
    
    Example usage:
    trial_data = reparameterise_TrialCompletionCode(trial_data)
    """
    
    # Extract logical value for whether trial was correct
    trial_data['CorrectTrial'] = trial_data['TrialCompletionCode'].apply(lambda x: x[:-1] == 'RewardedNosepoke')

    # MAKE DISTINCTION BETWEEN ABORTTRIALTYPES
    # annotate aborted nosepokes with '1', aborted dot offsets with -1
    trial_data['AbortTrial'] = 0
    abort_nosepoke = trial_data['TrialCompletionCode'].apply(lambda x: x[:-2] == 'AbortedTrial')
    abort_dot_offset = trial_data['TrialCompletionCode'].apply(lambda x: x[:-2] == 'DotTimeLimitReached')
    trial_data.loc[abort_nosepoke, 'AbortTrial'] = 1
    trial_data.loc[abort_dot_offset, 'AbortTrial'] = -1

    # Extract chosen port ID for all non-aborted trials
    trial_data['ChoicePort'] = trial_data['TrialCompletionCode'].apply(lambda x: int(x[-1]))
    trial_data.loc[trial_data['ChoicePort'] == 2, 'AbortTrial'] = 1  # Flag trials for which port 2 as aborted trials
    trial_data.loc[trial_data['AbortTrial'] != 0, 'ChoicePort'] = np.nan  # Omit choice port for all aborted trials

    # Extract correct port ID
    trial_data['CorrectPort'] = np.nan
    trial_data.loc[trial_data['AudioCueIdentity'] == 10, 'CorrectPort'] = 0
    trial_data.loc[trial_data['AudioCueIdentity'] == 14, 'CorrectPort'] = 1

    # Convert columns to int, filling NaN with a specific value (e.g., -1)
    trial_data['AbortTrial'] = trial_data['AbortTrial'].astype(int)
    trial_data['ChoicePort'] = trial_data['ChoicePort'].fillna(-1).astype(int)
    trial_data['CorrectPort'] = trial_data['CorrectPort'].fillna(-1).astype(int)

    # Reorder variables
    cols = trial_data.columns.tolist()
    cols.insert(cols.index('TrialNumber'), cols.pop(cols.index('Animal_ID')))
    cols.insert(cols.index('TrialNumber'), cols.pop(cols.index('Session_ID')))
    cols.insert(cols.index('TrialCompletionCode') + 1, cols.pop(cols.index('ChoicePort')))
    cols.insert(cols.index('TrialCompletionCode') + 1, cols.pop(cols.index('AbortTrial')))
    cols.insert(cols.index('TrialCompletionCode') + 1, cols.pop(cols.index('CorrectTrial')))
    cols.insert(cols.index('ChoicePort') + 1, cols.pop(cols.index('CorrectPort')))
    trial_data = trial_data[cols]

    return trial_data

# Example usage
# trial_data = reparameterise_TrialCompletionCode(trial_data)

# Import behavioral data as data frame
experimental_data_filepath = get_experimental_data(
    os.path.join(
        raw_data_dir, 
        animal_ID, 
        session_ID
    )
)
trials_df = pd.read_csv(experimental_data_filepath)

# Rename 'DotOnsetTime' and 'DotOffsetTime' to 'DotOnsetTrigger' and 'DotOffsetTrigger'
trials_df = trials_df.rename(
    columns={
        'DotOnsetTime': 'DotOnsetTrigger', 
        'DotOffsetTime': 'DotOffsetTrigger'
    }
)

# Add animal_ID and session_ID to trials_df
trials_df['Animal_ID'] = animal_ID
trials_df['Session_ID'] = session_ID

# Reparameterise TrialCompletionCode into more useful variables
trials_df = reparameterise_TrialCompletionCode(trials_df)
trials_df[['Animal_ID', 'Session_ID', 'TrialNumber', 'CorrectTrial', 'AbortTrial', 'ChoicePort', 'CorrectPort']]

## Part 2: Align each data stream into trials

**Align poke events with trials**

Plot poke events relative to trial start times:

In [None]:
import matplotlib.pyplot as plt

def plot_nose_pokes_and_trial_starts(poke_events, trials_df, start_trial_idx, end_trial_idx):
    """
    Plots nose pokes and trial start times between specified trial indices.

    Parameters:
    poke_events (pd.DataFrame): DataFrame containing poke events with the following columns:
        - 'DIPort0' (bool): Indicates if a nose poke occurred at port 0.
        - 'DIPort1' (bool): Indicates if a nose poke occurred at port 1.
        - 'Time' (float): Timestamp of the nose poke event.
    trials_df (pd.DataFrame): DataFrame containing trial start times with the following column:
        - 'TrialStart' (float): Timestamp of the trial start.
    start_trial_idx (int): Starting index of the trial to plot.
    end_trial_idx (int): Ending index of the trial to plot.

    Returns:
    None: This function does not return any value. It displays a plot.
    
    Example usage:
    plot_nose_pokes_and_trial_starts(poke_events, trials_df, 20, 30)
    """
    # Filter poke events for DIPort0 or DIPort1 being True
    poke_events_filtered = poke_events[(poke_events['DIPort0'] == True) | (poke_events['DIPort1'] == True)]

    # Create a Series with the timestamp as the index and the port ID as the value
    port_id_series = poke_events_filtered.apply(lambda row: 0 if row['DIPort0'] else 1, axis=1)
    port_id_series.index = poke_events_filtered['Time']

    # Plot the port ID Series
    fig, ax = plt.subplots(figsize=(12, 2))

    # Set y-ticks to show labels
    ax.set_yticks([0, 1])
    ax.set_yticklabels(['Port 0', 'Port 1'])

    # Scatter plot for port choices
    ax.scatter(port_id_series.index, port_id_series.values, label='Port Choices')

    # Mark trial start times
    ax.vlines(trials_df['TrialStart'], 0, 1, colors='g', label='Trial Starts')

    # Set plot labels
    ax.set_xlabel('timestamp (s)')
    ax.set_ylabel('port ID')
    ax.set_title('Port ID vs. timestamp')

    # Set x-axis limits to the start and end of the selected trials
    ax.set_xlim(trials_df['TrialStart'].iloc[start_trial_idx], trials_df['TrialStart'].iloc[end_trial_idx])

    # Add legend
    ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))

    # Show plot
    plt.show()

# Example usage
plot_nose_pokes_and_trial_starts(poke_events, trials_df, 20, 30)

Get data frame with port choice ID and timestamp for each trial, where the port choice is taken as the first nose poke within the response window (between dot offset and trial end). If the trial is aborted, the port ID and timestamp are both taken as NaN.

In [None]:
def parse_trial_pokes(trial_start_times, poke_events):
    """
    Parses nose poke events within each trial and returns a DataFrame with the results 
    where each row gives the timestamps and port ID of all nosepokes which occured within 
    that trial.

    Args:
        trial_start_times (pd.Series): Series of trial start times.
        poke_events (pd.DataFrame): columns:
            - Time: Timestamp of the event.
            - DIPort0: Boolean in which a value changing from false to true indicates a 
            nose poke into port 0, and vice versa indicates a nosepoke out of port 0.
            - DIPort1: Boolean in which a value changing from false to true indicates a 
            nose poke into port 1, and vice versa indicates a nosepoke out of port 1.

    Returns:
        pd.DataFrame: DataFrame containing nose poke events for each trial.
    """
    num_trials = len(trial_start_times)
    NosePokeIn = [[] for _ in range(num_trials)]
    NosePokeOut = [[] for _ in range(num_trials)]
    PortID = [[] for _ in range(num_trials)]
    NumPokes = [0] * num_trials

    # Iterate through trial start times and extract data from harp stream
    for i, start_time in enumerate(trial_start_times):
        if i < num_trials - 1:
            end_time = trial_start_times[i + 1]
        else:
            end_time = start_time + 100  # 100 seconds after the last trial start time

        # Extract events that occur within the time range of this trial
        trial_events = poke_events[(poke_events['Time'] >= start_time) & (poke_events['Time'] <= end_time)]

        # Create lists for nose pokes within trial
        NosePokeIn_trial, NosePokeOut_trial, PortID_trial = [], [], []

        for _, nosePokeEvent in trial_events.iterrows():
            # Get the timestamp of the event (either a nose poke in or out of a port)
            event_time = nosePokeEvent.name

            # Nose poke into port 0
            if nosePokeEvent.DIPort0:
                NosePokeIn_trial.append(event_time)
                PortID_trial.append(0)

            # Nose poke into port 1
            elif nosePokeEvent.DIPort1:
                NosePokeIn_trial.append(event_time)
                PortID_trial.append(1)

            # Nose poke out of port 0 or port 1
            elif not nosePokeEvent.DIPort0 and not nosePokeEvent.DIPort1:
                NosePokeOut_trial.append(event_time)

        NosePokeIn[i] = NosePokeIn_trial
        NosePokeOut[i] = NosePokeOut_trial
        PortID[i] = PortID_trial
        NumPokes[i] = len(NosePokeIn_trial)

    trial_pokes_df = pd.DataFrame({
        'NosePokeIn': NosePokeIn,
        'NosePokeOut': NosePokeOut,
        'PortID': PortID,
        'NumPokes': NumPokes
    })

    return trial_pokes_df

trial_pokes_df = parse_trial_pokes(trials_df['TrialStart'], poke_events)

# Plot histogram of NumPokes
fig, ax = plt.subplots(figsize=(6, 4))
trial_pokes_df['NumPokes'].hist(ax=ax)
ax.set_xlabel('Number of pokes')
ax.set_ylabel('Frequency')
ax.set_title('Histogram of number of pokes per trial')
plt.show()

trial_pokes_df.head()

**Align sound events to trials**

In [None]:
def parse_trial_sounds(trial_start_times, sound_events, OFF_index=18):

    # Create lists to store the poke IDs and timestamps for all trials
    ON_S, OFF_S, ID_S = [], [], []

    # Iterate through trial start times and extract data from harp stream
    for i, start_time in enumerate(trial_start_times):
        if i < len(trial_start_times) - 1:
            end_time = trial_start_times[i + 1]
        else:
            end_time = start_time + 100  # 100 seconds after the last trial start time

        # Extract events that occur within the time range of this trial
        trial_events = sound_events[(sound_events.Time >= start_time) & (sound_events.Time <= end_time)]

        # Create trial lists for sounds this trial
        ON, OFF, ID = [], [], []
        for _, sound in trial_events.iterrows():
            event_time = sound.Time
            sound = sound[['PlaySoundOrFrequency']]
            sound = int(sound.iloc[0])

            # Find audio IDs from the value. Only find ID for OFFSET
            if sound != OFF_index:
                ON.append(event_time)
                ID.append(sound)
            else:
                OFF.append(event_time)

        ON_S.append(ON)
        OFF_S.append(OFF)
        ID_S.append(ID)
        
    trial_sounds_df = pd.DataFrame({'AudioCueStartTimes': ON_S, 'AudioCueEndTimes': OFF_S, 'AudioCueIdentities': ID_S})  # Create dataframe from all nosepoke events

    return trial_sounds_df

# Get data frame with sound ID and timestamp for each trial
trial_sounds_df = parse_trial_sounds(trials_df['TrialStart'], sound_events)

# Append sound ID to trials_df
trials_df = pd.concat([trials_df, trial_sounds_df],axis=1)

# Show sound data frame
trial_sounds_df.head()

**Save trials_df dataframe for further analysis**

In [None]:
# save trials_df as a .pkl file to be used for further analysis
# session_output_dir = os.path.join(output_dir, animal_ID, session_ID)
# trials_df.to_pickle(os.path.join(session_output_dir, animal_ID + '_' + session_ID + '_trial_data_harp.pkl'))

**Parse photodiode data into dot onset, dot onset and fail state events within each trial**

Parse photodiode data into a time series with 3 distinct states:
- 0 indicates a resting state
- 1 indicates a state in which the dot is projected
- 2 indicates a fail state (in which the arena lights are on)

1. Apply the dot onset threshold (2000AU) to the raw signal to distinguish state 0 from states 1 and 2
2. Take average of signal for each instances of states 1 and 2 and distinguish them with a second fail state threshold (3922AU)

**Note!**: the outputs from this method will have to be checked for instances in stage 5 where we go straight from state 1 to state 2. 

In [None]:
def get_square_wave(df): 

    # Create a new DataFrame with repeated elements
    square_wave = {'timestamp': df['timestamp'].repeat(2).tolist()[1:],
        'state': df['state'].repeat(2).tolist()[:-1]
        }
    square_wave = pd.DataFrame(square_wave)
    return square_wave

def map_photodiode_state(dot_on_threshold, fail_state_threshold, photodiode_data):

    # Map values to 0 or 1 based on the threshold using a lambda function
    photodiode_state = photodiode_data.apply(lambda x: 0 if x < dot_on_threshold else 1)

    # Get indices at which photodiode state changes
    photodiode_state_diff = photodiode_state.diff()
    photodiode_state_change_indices = pd.Index([photodiode_state.index[0]]).union(
        photodiode_state_diff[photodiode_state_diff != 0].index
    )

    # Iterate through the indices of state changes
    rows = []
    for i in range(len(photodiode_state_change_indices) - 1):
        start_idx = photodiode_state_change_indices[i]
        end_idx = photodiode_state_change_indices[i + 1]
        state = photodiode_state[start_idx]
        avg_signal = photodiode_data[start_idx:end_idx].mean()
        rows.append([start_idx, state, avg_signal])

    # Handle the last state change to the end of the series
    start_idx = photodiode_state_change_indices[-1]
    state = photodiode_state[start_idx]
    avg_signal = photodiode_data[start_idx:].mean()
    rows.append([start_idx, state, avg_signal])

    photodiode_state_df = pd.DataFrame(rows, columns=['timestamp', 'state', 'AvgPhotodiodeSignal'])

    # Mark state 2 (fail state) based on a second threshold
    photodiode_state_df['state'] = photodiode_state_df.apply(
        lambda row: 2 if row['AvgPhotodiodeSignal'] > fail_state_threshold else row['state'], axis=1
    )

    return photodiode_state_df

dot_on_threshold = 3000
fail_state_threshold = 3922

photodiode_state_df = map_photodiode_state(dot_on_threshold, fail_state_threshold, photodiode_data)

# Get a Series where the index is Timestamp from df_state_changes and the values are State
photodiode_state = pd.Series(photodiode_state_df['state'].values, index=photodiode_state_df['timestamp'])

# Plot the diode trace, draw points at the TTL onsets. Restrict plot to chosen trial (2 seconds either side of TTL)
plt.figure(figsize=(20, 4))
plt.xlim(t_start-2, t_end+2)
state_changes_trace = get_square_wave(photodiode_state_df)
plt.plot(state_changes_trace['timestamp'], state_changes_trace['state'], label='Photodiode State')
plt.legend()
plt.show()

photodiode_state_df