# Eye-tracker data preprocessing <br>
<b>Input</b>: subject-by-subject files with eye position coordinates as provided by the PsychoPy ASRT code; settings file. <br>
    
<b>Output </b>: trial-by-trial data, with RTs&anticipatory eye movements; data quality check; interference files. Important: currently, output files are comma-delimited, with period decimals. If you use Hungarian csv reader (excel), you will need to change the periods to commas manually at the end.

<br>

perks: a) it can handle missing data; b) you don't have to run it all over again if it gives you an error at any point (you just read the data that is needed)

<b>Instruction: </b> you will need to provide inputs and follow some intstructions in the beginning! Please pay attention, I will let you know when you can move to running cells mindlessly

## functions

In [174]:
#functions

import os
import sys
import shutil
import time
import pandas as pd
import warnings
import numpy as np
import math
import glob
import shutil

#general functions

def strToFloat(data):
    """ Creates float of string values. If there are any commas as decimals, it exchanges them to period """
    return float(str(data).replace(",", "."))

def calcRMS(values):
    """ Calculates root mean squared"""
    square = 0.0
    for i in range(len(values)):
        square += pow(values[i], 2)

    mean = square / float(len(values))

    return math.sqrt(mean)

def convertToAngle(value_cm):
    # This eye distance is the optimal eye-distance for Tobii Pro X3–120 eye-tracking distance.
    # We use this optimal value as an approximation of the actual screen distance.
    eye_screen_distance_cm = 65.0
    return math.degrees(math.atan(value_cm / eye_screen_distance_cm))

def check_variables_dict(example_file_path, variables_dict):
    # Get column names from the DataFrame
    raw_data = pd.read_csv(example_file_path, sep='\t', encoding="latin-1", low_memory=False)
    
    column_names = raw_data.columns.tolist()
    
    #Are columns called really like that?
    column_type_values = vars_dict['stim'], vars_dict["PCode"], vars_dict["epoch"], vars_dict["trial"], vars_dict["ID"]
        # Check if all keys in variables_dict are present in column_names
    missing_variables = [var for var in column_type_values if var not in column_names]
       
    
    if missing_variables:
                
        print("The following variables in 'variables_dict' are not present in the DataFrame columns and will make your code fail miserably:")
        print(missing_variables)
    else:
        print("All variables in 'variables_dict' are present in the DataFrame columns.")

#data validator functions:

def data_structure_validator(vars_dict, example_file_path):
    """  It checks if there are actually the nr of epochs you specified in the vars_dict. 
    Notifies you if the nr of epochs in the testfile differ from the nr of epochs you specified."""
    path = os.getcwd() + "/indat"
    ##check if testfile is in there
    if not os.path.exists(example_file_path):
        print(f"The file {example_file_path} doesn't exist, not in that folder anyways. Use another file as testfile.")
        return
    else:
        print(f"Using the file {example_file_path}")
        
    #read test data
    data_table = pd.read_csv(example_file_path, sep='\t', low_memory=False)
    
    #check if nr of epochs match the expectations:
    if vars_dict["epochN"] != len(data_table["epoch"].unique()):
        print(f'You said there are {vars_dict["epochN"]} epochs, but there are {len(data_table["epoch"].unique())}, you liar!')
    else:
        print("All good, your nr of epochs match with what you specified in the vars_dict.")
        
    #are there really that many trials/block?
    if vars_dict["randoms"] + vars_dict["trialN"] != len(data_table[vars_dict["trial"]].unique()):
        print("Number of trials / epoch does not match with what you said it should be. Check if the vars_dict is wrong or your data.")
    else:
        print("All good, the specified nr of random and sequential trials match the data structure.")
    
    #is there a different PCode for IF and OS epochs?
    if vars_dict["IFepoch"] != None:
        OSdata = data_table.loc[data_table[vars_dict["epoch"]] == vars_dict["OSepoch"]]
        IFdata = data_table.loc[data_table[vars_dict["epoch"]] == vars_dict["IFepoch"]]
        if OSdata[vars_dict["PCode"]].unique() == IFdata[vars_dict["PCode"]].unique():
            print("I found the same PCode in the epochs you specified as original sequence and interference sequence epochs. U sure you typed it right?")
        else:
            print("Aaand yepp, the interference and original sequence epochs indeed differ. Good. You are all set.")

def file_validator(indat_path):
    """Checks if all files in your indat folder are txt files. Notifies you if not"""
    all_txt = True
    for root, dirs, files in os.walk(indat_path):
        for subject_file in files:
            if ".txt" not in subject_file:
                print(f"{subject_file} has invalid data format (not a .txt). Remove it from the 'indat' folder.")
                all_txt = False
    if all_txt:
        print("All good, all file formats are valid in the 'indat' folder.")

def subject_dropper(indat_path, subs_with_bad_data, vars_dict):
    """ It checks if all subjects in the indat folder have the right amount of epochs. 
    The files with insufficient nr of epochs will be moved to a newly created junk folder"""
    # Loop through files in the folder
    sub_with_bad_epochN = False
    for file in os.listdir(indat_path):
        file_path = os.path.join(indat_path, file)
    
        # Ensure we are working with a file (not a folder)
        if os.path.isfile(file_path):  
            try:
                # Read the CSV file
                df = pd.read_csv(file_path, sep='\t', usecols=['subject_number', 'epoch'], encoding="latin-1", low_memory=False)
                
                # Extract subject number from filename
                segments = file.split("_")
                if len(segments) > 1:
                    sub_number = int(segments[1])  # Convert to integer
                    # Check if number of epochs is incorrect
                    if len(df["epoch"].unique()) != vars_dict["epochN"]:
                        print(f'File {file} has {len(df["epoch"].unique())} epochs insted of {vars_dict["epochN"]}. Marking subject {sub_number} as bad data.')
                        subs_with_bad_data.add(sub_number)  # Add subject to bad list
                        sub_with_bad_epochN =True
            except Exception as e:
                print(f"Error processing {file}: {e}")
    if not sub_with_bad_epochN:
        print("All good, all subjects seem to have sufficient amount of epochs")
    
    # Move bad files
    print(f"bad sub list is: {subs_with_bad_data}")
    for file in os.listdir(indat_path):
        segments = file.split("_")
        if len(segments) > 1:
            try:
                sub_number = int(segments[1])
                if sub_number in subs_with_bad_data:
                    source = os.path.join(indat_path, file)
                    destination = os.path.join(junk_folder, file)
                    shutil.move(source, destination)
                    print(f"Moved {file} to junk folder.")
            except ValueError:
                print(f"Skipping {file}: Cannot extract subject number.")
        
def test_gaze_coordinate_coding(indat):

    """
    Processes the first .txt file in the specified folder to compute mean gaze coordinates for each stimulus.
    
    Steps:
    1. Reads the first .txt file as a dataframe.
    2. Converts comma decimal separators to periods and ensures numeric values.
    3. Filters to keep only the last occurrence of each trial.
    4. Computes mean X and Y gaze coordinates for each stimulus.
    5. Checks if the means align with expected coordinate ranges.
    
    Parameters:
        indat_path (str): Path to the folder containing .txt data files.
    
    Returns:
        dict: A dictionary containing mean (X, Y) coordinates for each stimulus.
    """

    # Get the first file in the folder
    files = [f for f in os.listdir(indat_path) if f.endswith(".txt")]
    if not files:
        print("No text files found in the folder.")
        return None

    file_path = os.path.join(indat_path, files[0])
    print(f"Processing file: {files[0]}")

    # Read the file as a dataframe
    df = pd.read_csv(file_path, sep='\t', encoding="latin-1", low_memory=False)


    # Ensure required columns exist
    required_cols = ["stimulus", "left_gaze_data_X_ADCS", "left_gaze_data_Y_ADCS"]
    if not all(col in df.columns for col in required_cols):
        print("Missing required columns in the data.")
        return None


    # Convert gaze columns to numeric, replacing commas with periods
    df = df.applymap(lambda x: str(x).replace(",", ".") if isinstance(x, str) else x)
    df["left_gaze_data_X_ADCS"] = pd.to_numeric(df["left_gaze_data_X_ADCS"], errors="coerce")
    df["left_gaze_data_Y_ADCS"] = pd.to_numeric(df["left_gaze_data_Y_ADCS"], errors="coerce")

    # Dictionary to store the mean coordinates for each stimulus
    stim_means = {}


    # Loop through unique stimulus values
    for stim in df['stimulus'].unique():
        # get data where that stimulus was presented
        stim_df = df[df["stimulus"] == stim]
        #find last lines in each trial (this is when the stimulus disappeared so the eye coordinates oughta align with the stimulus present, otherwise it wouldn't have disappeared)
        stim_df = stim_df.loc[stim_df.groupby("trial").tail(1).index]

        #calc mean for both eye coordinates (using the left gaze because I felt like using the left, no political implications intended)
        x_mean = stim_df["left_gaze_data_X_ADCS"].mean(skipna=True)
        y_mean = stim_df["left_gaze_data_Y_ADCS"].mean(skipna=True)

        #save it in the dictionary
        stim_means[stim] = (round(x_mean, 2), round(y_mean, 2))


    # Check expectations - Change it if your coding differed!!
    incorrect_coords = {}
    expected_coords = {
        1: (lambda x, y: x <= 0.5 and y <= 0.5),
        2: (lambda x, y: x >= 0.5 and y <= 0.5),
        3: (lambda x, y: x <= 0.5 and y >= 0.5),
        4: (lambda x, y: x >= 0.5 and y >= 0.5),
    }

    ## loop through all 4 stim locations, if the measured eye coordinates don't match the expected ones, save them as incorrect
    for stim, (x_mean, y_mean) in stim_means.items():
        if stim in expected_coords and not expected_coords[stim](x_mean, y_mean):
            incorrect_coords[stim] = (x_mean, y_mean)


    # Print results
    if incorrect_coords:
        print("Stimuli with incorrect average coordinates:")
        for stim, coords in incorrect_coords.items():
            print(f"Stimulus {stim}: X_mean={coords[0]:.3f}, Y_mean={coords[1]:.3f}")
        print("You gotta change your getAOI function, as it expects different coordinates. Change it so it matches the coordinates in your data:")
        print(stim_means)
    else:
        print("All stimuli met expectations, your getAOI function functions.")
    
    return stim_means   


#data quality calculator functions:

def computeMissingDataRatio(input_dir, output_file, vars_dict):
    """
    Computes the missing data ratio for each subject in a given directory and saves the results to a file.

    Parameters:
    -----------
    input_dir : str
        The directory containing input .txt files for each subject.
    output_file : str
        The path where the computed missing data ratios will be saved as a .csv file.
    vars_dict : dict
        A dictionary containing necessary variable definitions:
        - "randoms": Number of preparatory trials to be ignored.
        - "epochN": Total number of epochs to process.

    Returns:
    --------
    None
        The function writes the missing data ratio results to `output_file` in tab-separated format.

    Process:
    --------
    - Iterates over all files in `input_dir`.
    - Extracts subject ID from filenames.
    - Calls `computeMissingDataRatioImpl` to compute missing data percentages.
    - Stores results in a pandas DataFrame and exports it as a .csv file.

    Notes:
    ------
    - The function expects filenames to follow a format where the subject ID is the second element when split by `_`.
    - Non-`.txt` files are ignored with a warning message.
    - TODO: Improve subject/epoch formatting and change list-based results to a structured DataFrame or dictionary.

    """

    blockprepN = vars_dict["randoms"]
    epochN = vars_dict["epochN"]

    missing_data_ratios = []
    subject_epochs = []

    for root, dirs, files in os.walk(input_dir):
        for subject_file in files:
            if ".txt" not in subject_file:
                print(f'{subject_file} has invalid data format (not a .txt)')
            else:
                subject = subject_file.split('_')[1]

                print("Compute missing data ratio for subject: " + subject)
                input_file = os.path.join(root, subject_file)

                for i in range(1, epochN + 1):
                    # TODO: Improve readability by storing epoch data in a separate column
                    subject_epochs.append(f"subject_{subject}_{i}")

                result = computeMissingDataRatioImpl(input_file, blockprepN, epochN)
                # TODO: Store results in a dictionary or DataFrame for better structure
                missing_data_ratios += result
        break  # Process only the first directory found

    missing_data = pd.DataFrame({'epoch': subject_epochs, 'missing_data_percent': missing_data_ratios})
    missing_data.to_csv(output_file, index=False)

def computeMissingDataRatioImpl(input, blockprepN, epochN):
    """
    Computes the percentage of missing eye-tracking data per epoch.

    Parameters:
    - input (str): Path to the input CSV file containing eye-tracking data.
    - blockprepN (int): The number of preparatory trials to exclude from analysis.
    - epochN (int): The total number of epochs in the experiment.

    Returns:
    - epoch_summary (list of str): A list where each index corresponds to an epoch (0-based),
      containing the percentage of missing eye-tracking data as a string.

    Notes:
    - An epoch's missing data ratio is calculated as:
        (number of missing samples in the epoch / total samples in the epoch) * 100
    - Preparatory trials (`trial <= blockprepN`) are ignored.
    - Calibration validation blocks (`block == '0'`) are also ignored.
    """

    required_cols = ['block', 'trial', 'epoch', 'left_gaze_validity', 'right_gaze_validity']
    
    try:
        # Read file with only necessary columns
        data_table = pd.read_csv(input, sep='\t', encoding="latin1", usecols=required_cols)
    
    except ValueError as e:  # Catches missing column errors
        print(f"Error: Missing required columns for this participant. Skipping to next.")
        print(f"Details: {e}")
        return None  # Return None to indicate failure, allowing the loop to continue


    trial_column = data_table["trial"]
    block_column = data_table["block"]
    epoch_column = data_table["epoch"]
    left_gaze_validity = data_table["left_gaze_validity"]
    right_gaze_validity = data_table["right_gaze_validity"]

    epoch_all_data = {}
    epoch_missing_data = {}
    for i in range(len(trial_column)):
        # We ignore preparatory trials.
        if int(trial_column[i]) <= blockprepN:
            continue

        # We ignore calibration validation blocks.
        if str(block_column[i]) == '0':
            continue
        
        current_epoch = int(epoch_column[i])

        # We count all eye-tracking samples for the current epoch.
        # can we change floats to int, there are no fraction of samples

        epoch_all_data[current_epoch] = epoch_all_data.get(current_epoch, 0) + 1

        # We count all missing eye-tracking data for the current epoch.
        if not bool(left_gaze_validity[i]) and not bool(right_gaze_validity[i]):
            if current_epoch in epoch_missing_data.keys():
                epoch_missing_data[current_epoch] += 1
            else:
                epoch_missing_data[current_epoch] = 1
        
    # We compute missing data ratio for all epochs.
    epoch_summary = np.zeros(epochN).tolist()
    
    for epoch in epoch_all_data.keys():
        epoch_summary[epoch - 1] = (epoch_missing_data[epoch] / epoch_all_data[epoch]) * 100.0

    return epoch_summary

def computeMissingDataRatio(input_dir, output_file, vars_dict):
    """
    Computes the missing data ratio for each subject in a given directory and saves the results to a file.

    Parameters:
    -----------
    input_dir : str
        The directory containing input .txt files for each subject.
    output_file : str
        The path where the computed missing data ratios will be saved as a .csv file.
    vars_dict : dict
        A dictionary containing necessary variable definitions:
        - "randoms": Number of preparatory trials to be ignored.
        - "epochN": Total number of epochs to process.

    Returns:
    --------
    None
        The function writes the missing data ratio results to `output_file` in tab-separated format.

    Process:
    --------
    - Iterates over all files in `input_dir`.
    - Extracts subject ID from filenames.
    - Calls `computeMissingDataRatioImpl` to compute missing data percentages.
    - Stores results in a pandas DataFrame and exports it as a .csv file.

    Notes:
    ------
    - The function expects filenames to follow a format where the subject ID is the second element when split by `_`.
    - Non-`.txt` files are ignored with a warning message.
    - TODO: Improve subject/epoch formatting and change list-based results to a structured DataFrame or dictionary.

    """
    
    blockprepN = vars_dict["randoms"]
    epochN = vars_dict["epochN"]

    missing_data_ratios = []
    subject_epochs = []
    for root, dirs, files in os.walk(input_dir):
        for subject_file in files:
            if ".txt" not in subject_file:
                print(f'{subject_file} has invalid data format (not a .txt)')
            else:
                subject = subject_file.split('_')[1]

                print("Compute missing data ratio for subject: " + subject)
                input_file = os.path.join(root, subject_file)

                for i in range(1,epochN + 1):
                    #todo: make it more userfriendly (epoch in sep column from sub)
                    subject_epochs.append(f"subject_{subject}_{i}")

                result = computeMissingDataRatioImpl(input_file, blockprepN, epochN)
                # dafuck is this, todo change it to dataframe cols / dictionary or something
                missing_data_ratios += result
        break

    missing_data = pd.DataFrame({'epoch' : subject_epochs, 'missing_data_percent' : missing_data_ratios})
    missing_data.to_csv(output_file, sep='\t', index=False)

def computeDistanceImpl(input, blockprepN, epochN):


    required_cols = ['block', 'trial', 'epoch', 'left_gaze_validity', 'right_gaze_validity',
                     'left_eye_distance', 'right_eye_distance']
    
    try:
        # Read file with only necessary columns
        data_table = pd.read_csv(input, sep='\t', encoding="latin1", usecols=required_cols, low_memory = False)
    
    except ValueError as e:  # Catches missing column errors
        print(f"Error: Missing required columns for participant {participant_id}. Skipping to next.")
        print(f"Details: {e}")
        return None  # Return None to indicate failure, allowing the loop to continue


    trial_column = data_table["trial"]
    epoch_column = data_table["epoch"]
    left_gaze_validity = data_table["left_gaze_validity"]
    right_gaze_validity = data_table["right_gaze_validity"]
    left_eye_distance = data_table["left_eye_distance"]
    right_eye_distance = data_table["right_eye_distance"]
    block_column = data_table["block"]

    epoch_distances = {}
    for i in range(len(trial_column)):
        # We ignore preparatory trials.
        if int(trial_column[i]) <= blockprepN:
            continue

        # We ignore calibration validation blocks.
        if str(block_column[i]) == '0':
            continue

        # Use the distance from the valid eye data. If there is two, we use average.
        distance_mm = -1.0
        if bool(left_gaze_validity[i]) and bool(right_gaze_validity[i]):
            distance_mm = (strToFloat(left_eye_distance[i]) + strToFloat(right_eye_distance[i])) / 2.0
        elif bool(left_gaze_validity[i]):
            distance_mm = strToFloat(left_eye_distance[i])
        elif bool(right_gaze_validity[i]):
            distance_mm = strToFloat(right_eye_distance[i])

        # Collect all distances of all epochs,
        if distance_mm > 0.0:
            distance_cm = distance_mm / 10.0
            current_epoch = int(epoch_column[i])
            if current_epoch in epoch_distances.keys():
                epoch_distances[current_epoch].append(distance_cm)
            else:
                epoch_distances[current_epoch] = [distance_cm]

    # Compute median distance of subject eyes for all epochs.
    epoch_summary = np.zeros(epochN).tolist()
    for epoch in epoch_distances.keys():
        epoch_summary[epoch - 1] = str(np.median(epoch_distances[epoch]))


    return epoch_summary

def computeDistance(input_dir, output_file, vars_dict):
    blockprepN = vars_dict["randoms"]
    epochN = vars_dict["epochN"]

    median_distances = []
    subject_epochs = []
    for root, dirs, files in os.walk(input_dir):
        for subject_file in files:
            if ".txt" not in subject_file:
                print(f'{subject_file} has invalid data format (not a .txt)')
            else:
                subject = subject_file.split('_')[1]

                print("Compute eye-screen distance data for subject: " + subject)
                input_file = os.path.join(root, subject_file)

                for i in range(1,epochN+1):
                    subject_epochs.append("subject_" + subject + "_" + str(i))

                epoch_medians = computeDistanceImpl(input_file, blockprepN, epochN)
                median_distances += epoch_medians
        break

    distance_data = pd.DataFrame({'epoch' : subject_epochs, 'median_distance_cm' : median_distances})
    distance_data.to_csv(output_file, sep='\t', index=False)

def computeDistanceImpl(input, blockprepN, epochN):
    """ Computes screen-to-eye distance for each participant in the indat folder. """
    
    required_cols = ['block', 'trial', 'epoch', 'left_gaze_validity', 'right_gaze_validity',
                                          'left_eye_distance', 'right_eye_distance']
    
    try:
        # Read file with only necessary columns
        data_table = pd.read_csv(input, sep='\t', encoding="latin1", usecols=required_cols)
    
    except ValueError as e:  # Catches missing column errors
        print(f"Error: Missing required columns for this participant. Skipping to next.")
        print(f"Details: {e}")
        return None  # Return None to indicate failure, allowing the loop to continue

    trial_column = data_table["trial"]
    epoch_column = data_table["epoch"]
    left_gaze_validity = data_table["left_gaze_validity"]
    right_gaze_validity = data_table["right_gaze_validity"]
    left_eye_distance = data_table["left_eye_distance"]
    right_eye_distance = data_table["right_eye_distance"]
    block_column = data_table["block"]

    epoch_distances = {}
    for i in range(len(trial_column)):
        # We ignore preparatory trials.
        if int(trial_column[i]) <= blockprepN:
            continue

        # We ignore calibration validation blocks.
        if str(block_column[i]) == '0':
            continue

        # Use the distance from the valid eye data. If there is two, we use average.
        distance_mm = -1.0
        if bool(left_gaze_validity[i]) and bool(right_gaze_validity[i]):
            distance_mm = (strToFloat(left_eye_distance[i]) + strToFloat(right_eye_distance[i])) / 2.0
        elif bool(left_gaze_validity[i]):
            distance_mm = strToFloat(left_eye_distance[i])
        elif bool(right_gaze_validity[i]):
            distance_mm = strToFloat(right_eye_distance[i])

        # Collect all distances of all epochs,
        if distance_mm > 0.0:
            distance_cm = distance_mm / 10.0
            current_epoch = int(epoch_column[i])
            if current_epoch in epoch_distances.keys():
                epoch_distances[current_epoch].append(distance_cm)
            else:
                epoch_distances[current_epoch] = [distance_cm]

    # Compute median distance of subject eyes for all epochs.
    epoch_summary = np.zeros(epochN).tolist()
    for epoch in epoch_distances.keys():
        epoch_summary[epoch - 1] = np.median(epoch_distances[epoch])


    return epoch_summary

def computeDistance(input_dir, output_file, vars_dict):

    
    blockprepN = vars_dict["randoms"]
    epochN = vars_dict["epochN"]

    median_distances = []
    subject_epochs = []
    for root, dirs, files in os.walk(input_dir):
        for subject_file in files:
            subject = subject_file.split('_')[1]

            print("Compute eye-screen distance data for subject: " + subject)
            input_file = os.path.join(root, subject_file)

            for i in range(1,epochN+1):
                subject_epochs.append("subject_" + subject + "_" + str(i))

            epoch_medians = computeDistanceImpl(input_file, blockprepN, epochN)
            median_distances += epoch_medians
        break

    distance_data = pd.DataFrame({'epoch' : subject_epochs, 'median_distance_cm' : median_distances})
    distance_data.to_csv(output_file, index=False)

def calcDistancesForFixation(j, k, data_table):
    """
    Calculates the Euclidean distance between the left and right eyes for a given fixation period.

    Parameters:
    -----------
    j : int
        The starting index of the fixation period.
    k : int
        The ending index of the fixation period.
    data_table : pd.DataFrame
        A dataframe containing gaze validity and gaze position data.

    Returns:
    --------
    list of float
        A list of eye-to-eye distances converted to angles for valid gaze data points.

    Process:
    --------
    - Iterates over samples from index `j` to `k`.
    - If both eyes provide valid gaze data, calculates Euclidean distance.
    - Converts distances to angles using `convertToAngle()` before appending.
    - Ignores invalid data points where either eye lacks valid gaze data.

    """

    all_eye_to_eye_distances = []
    left_gaze_validity = data_table["left_gaze_validity"]
    right_gaze_validity = data_table["right_gaze_validity"]
    left_gaze_data_X_PCMCS = data_table["left_gaze_data_X_PCMCS"]
    left_gaze_data_Y_PCMCS = data_table["left_gaze_data_Y_PCMCS"]
    right_gaze_data_X_PCMCS = data_table["right_gaze_data_X_PCMCS"]
    right_gaze_data_Y_PCMCS = data_table["right_gaze_data_Y_PCMCS"]

    for i in range(j, k + 1):
        eye_to_eye_distance = -1.0
        if bool(left_gaze_validity[i]) and bool(right_gaze_validity[i]):
            left_X = strToFloat(left_gaze_data_X_PCMCS[i])
            left_Y = strToFloat(left_gaze_data_Y_PCMCS[i])
            right_X = strToFloat(right_gaze_data_X_PCMCS[i])
            right_Y = strToFloat(right_gaze_data_Y_PCMCS[i])
            X_distance = abs(left_X - right_X)
            Y_distance = abs(left_Y - right_Y)
            eye_to_eye_distance = math.sqrt(pow(X_distance, 2) + pow(Y_distance, 2))

        if eye_to_eye_distance > 0.0:
            all_eye_to_eye_distances.append(convertToAngle(eye_to_eye_distance))

    return all_eye_to_eye_distances

def computeRMSEyeToEyeImpl(input, blockprepN, fixation_duration_threshold, epochN):
    """
    Computes the Root Mean Square (RMS) of eye-to-eye distances for fixations in each epoch.

    Parameters:
    -----------
    input : str
        The path to the input file containing gaze data.
    blockprepN : int
        The number of preparatory trials to be ignored.
    fixation_duration_threshold : int
        The minimum number of samples required to consider a fixation valid.
    epochN : int
        The total number of epochs.

    Returns:
    --------
    list of str
        A list of median RMS values (as strings) for each epoch.

    Process:
    --------
    - Reads gaze data from a tab-separated file.
    - Iterates over trials, skipping preparatory trials and calibration blocks.
    - Extracts the last fixation in each trial and computes distances.
    - Calculates the RMS of the distances for each epoch.
    - Returns the median RMS values for each epoch as a list.

    """
    required_cols = ['block', 'trial', 'epoch', 'left_gaze_validity', 'right_gaze_validity',
                                          'left_gaze_data_X_PCMCS', 'left_gaze_data_Y_PCMCS', 'right_gaze_data_X_PCMCS', 
                                          'right_gaze_data_Y_PCMCS']
    
    try:
        # Read file with only necessary columns
        data_table = pd.read_csv(input, sep='\t', encoding="latin1", usecols=required_cols)
    
    except ValueError as e:  # Catches missing column errors
        print(f"Error: Missing required columns for this participant. Skipping to next.")
        print(f"Details: {e}")
        return None  # Return None to indicate failure, allowing the loop to continue


    trial_column = data_table["trial"]
    block_column = data_table["block"]
    epoch_column = data_table["epoch"]

    epoch_rmss= {}
    for i in range(len(trial_column) - 1):
        # We ignore preparatory trials.
        if int(trial_column[i]) <= blockprepN:
            continue

        # We ignore calibration validation blocks.
        if str(block_column[i]) == '0':
            continue

        # end of trial -> check samples of the last fixation (duration threshold shows the number of samples)
        if trial_column[i] != trial_column[i + 1] or i + 1 == len(trial_column) - 1:
            # Distance values for the fixation samples.
            if i + 1 == len(trial_column) - 1:
                all_distances = calcDistancesForFixation(i - fixation_duration_threshold + 2, i + 1, data_table)
            else:
                all_distances = calcDistancesForFixation(i - fixation_duration_threshold + 1, i, data_table)

            if len(all_distances) > 0:
                current_epoch = int(epoch_column[i])

                # Calc RMS of all collected distances.
                new_RMS = calcRMS(all_distances)
                if current_epoch in epoch_rmss.keys():
                    epoch_rmss[current_epoch].append(new_RMS)
                else:
                    epoch_rmss[current_epoch] = [new_RMS]

    # We compute median RMS(E2E) for all epochs.
    epoch_summary = np.zeros(epochN).tolist()
    for epoch in epoch_rmss.keys():
        epoch_summary[epoch - 1] = str(np.median(epoch_rmss[epoch]))

    return epoch_summary

def computeRMSEyeToEye(input_dir, output_file, vars_dict):
    """
    Computes the median Root Mean Square (RMS) of eye-to-eye distances for each subject in the dataset.

    Parameters:
    -----------
    input_dir : str
        The directory containing input files for each subject.
    output_file : str
        The path where the computed RMS values will be saved as a .csv file.
    vars_dict : dict
        A dictionary containing key experimental parameters:
        - "randoms": Number of preparatory trials to ignore.
        - "epochN": Total number of epochs.
        - "fixation_threshold": Minimum fixation duration threshold.

    Returns:
    --------
    None
        The function saves the results to `output_file` in tab-separated format.

    Process:
    --------
    - Iterates through `.txt` files in `input_dir`.
    - Extracts subject ID from filenames.
    - Calls `computeRMSEyeToEyeImpl()` to compute RMS values.
    - Stores results in a pandas DataFrame and exports it as a `.csv` file.

    Notes:
    ------
    - Assumes filenames follow a format where the subject ID is the second element when split by `_`.
    - TODO: Improve subject/epoch formatting and store results in a structured DataFrame instead of a list.

    """
    
    blockprepN = vars_dict["randoms"]
    epochN = vars_dict["epochN"]
    fix_threshold = vars_dict["fixation_threshold"]

    median_rms = []
    subject_epochs = []
    for root, dirs, files in os.walk(input_dir):
        for subject_file in files:
            if ".txt" not in subject_file:
                print(f'{subject_file} has invalid data format (not a .txt)')
            else:
                subject = subject_file.split('_')[1]

                print("Compute RMS(E2E) for subject:  " + subject)
                input_file = os.path.join(root, subject_file)

                for i in range(1,epochN+1):
                    subject_epochs.append("subject_" + subject + "_" + str(i))

                result = computeRMSEyeToEyeImpl(input_file, blockprepN, fix_threshold, epochN)
                median_rms += result
        break

    RMS_E2E_data = pd.DataFrame({'epoch' : subject_epochs, 'RMS(E2E)_median' : median_rms})
    RMS_E2E_data.to_csv(output_file, index=False)

def getEyePos(data_table, i):
    """
    Computes the gaze position by averaging valid eye data.

    Parameters:
    -----------
    data_table : pd.DataFrame
        The dataframe containing gaze validity and position data.
    i : int
        The index of the current sample.

    Returns:
    --------
    tuple or str
        A tuple (pos_X, pos_Y) representing the computed gaze position,
        or 'none' if no valid data is available.
    """
    
    left_gaze_validity = data_table["left_gaze_validity"][i]
    right_gaze_validity = data_table["right_gaze_validity"][i]
    left_gaze_data_X_PCMCS = data_table["left_gaze_data_X_PCMCS"][i]
    left_gaze_data_Y_PCMCS = data_table["left_gaze_data_Y_PCMCS"][i]
    right_gaze_data_X_PCMCS = data_table["right_gaze_data_X_PCMCS"][i]
    right_gaze_data_Y_PCMCS = data_table["right_gaze_data_Y_PCMCS"][i]

    # Calculate current eye pos based on the valid eye positions (hybrid computation).
    if bool(left_gaze_validity) and bool(right_gaze_validity):
        pos_X = (strToFloat(left_gaze_data_X_PCMCS) + strToFloat(right_gaze_data_X_PCMCS)) / 2.0
        pos_Y = (strToFloat(left_gaze_data_Y_PCMCS) + strToFloat(right_gaze_data_Y_PCMCS)) / 2.0
    elif bool(left_gaze_validity):
        pos_X = strToFloat(left_gaze_data_X_PCMCS)
        pos_Y = strToFloat(left_gaze_data_Y_PCMCS)
    elif bool(right_gaze_validity):
        pos_X = strToFloat(right_gaze_data_X_PCMCS)
        pos_Y = strToFloat(right_gaze_data_Y_PCMCS)
    else: # No valid data
        return 'none'

    return (pos_X, pos_Y)

def calcDistancesForFixation_s2s(j, k, data_table):
    """
    Computes Euclidean distances between consecutive gaze positions within a fixation.

    Parameters:
    -----------
    j : int
        The starting index of the fixation period.
    k : int
        The ending index of the fixation period.
    data_table : pd.DataFrame
        The dataframe containing gaze data.

    Returns:
    --------
    list of float
        A list of visual angle distances for valid gaze samples.
    """
    
    all_distances = []
    # Check all samples with an index within j and k.
    for i in range(j, k):
        prev_pos = getEyePos(data_table, i)
        next_pos = getEyePos(data_table, i + 1)
        if prev_pos != 'none' and next_pos != 'none':
            # Distance in cm, based on psychopy coordinate system.
            distance = math.sqrt(pow(prev_pos[0] - next_pos[0], 2) + pow(prev_pos[1] - next_pos[1], 2))
            # Convert the distance in cm to a visual angle. It's more common to use visual angle values.
            all_distances.append(convertToAngle(distance))
    
    return all_distances

def computeRMSSampleToSampleImpl(input, preparatory_trial_number, fixation_duration_threshold, vars_dict):
    """
    Computes the Root Mean Square (RMS) of sample-to-sample distances for fixations.

    Parameters:
    -----------
    input : str
        Path to the input file containing gaze data.
    preparatory_trial_number : int
        The number of preparatory trials to ignore.
    fixation_duration_threshold : int
        The minimum number of samples required for a valid fixation.
    vars_dict : dict
        Dictionary containing experiment parameters, including epoch count.

    Returns:
    --------
    list of str
        A list of median RMS values for each epoch.
    """
    required_cols = ['block', 'trial', 'epoch', 'left_gaze_validity', 'right_gaze_validity',
                                          'left_gaze_data_X_PCMCS', 'left_gaze_data_Y_PCMCS', 'right_gaze_data_X_PCMCS', 'right_gaze_data_Y_PCMCS']
    
    try:
        # Read file with only necessary columns
        data_table = pd.read_csv(input, sep='\t', encoding="latin1", usecols=required_cols)

    
    except ValueError as e:  # Catches missing column errors
        print(f"Error: Missing required columns for this participant. Skipping to next.")
        print(f"Details: {e}")
        return None  # Return None to indicate failure, allowing the loop to continue


    trial_column = data_table["trial"]
    block_column = data_table["block"]
    epoch_column = data_table["epoch"]

    epoch_rmss= {}
    for i in range(len(trial_column) - 1):
        # We ignore preparatory trials.
        if int(trial_column[i]) <= preparatory_trial_number:
            continue

        # We ignore calibration validation blocks.
        if str(block_column[i]) == '0':
            continue

        # end of trial -> check samples of the last fixation (duration threshold shows the number of samples)
        if trial_column[i] != trial_column[i + 1] or i + 1 == len(trial_column) - 1:
            # Distance values for the fixation samples.
            if i + 1 == len(trial_column) - 1:
                all_distances = calcDistancesForFixation_s2s(i - fixation_duration_threshold + 2, i + 1, data_table)
            else:
                all_distances = calcDistancesForFixation_s2s(i - fixation_duration_threshold + 1, i, data_table)

            if len(all_distances) > 0:
                current_epoch = int(epoch_column[i])

                # Calc RMS of all collected distances.
                new_RMS = calcRMS(all_distances)
                if current_epoch in epoch_rmss.keys():
                    epoch_rmss[current_epoch].append(new_RMS)
                else:
                    epoch_rmss[current_epoch] = [new_RMS]

    # We compute median RMS(S2S) for all epochs.
    epoch_summary = np.zeros(vars_dict["epochN"]).tolist()
    for epoch in epoch_rmss.keys():
        epoch_summary[epoch - 1] = np.median(epoch_rmss[epoch])

    if len(epoch_summary) != vars_dict["epochN"]:
        raise Exception(f'Error: The input data should contain exactly {vars_dict["epochN"]} epochs for this data analysis.')

    return epoch_summary

def computeRMSSampleToSample(input_dir, output_file, vars_dict):
    """
    Computes and saves the median Root Mean Square (RMS) of sample-to-sample distances for each subject.

    Parameters:
    -----------
    input_dir : str
        Directory containing input files for each subject.
    output_file : str
        Path to save the computed RMS values in a CSV file.
    vars_dict : dict
        Dictionary containing experiment parameters, including:
        - "randoms": Number of preparatory trials to ignore.
        - "epochN": Total number of epochs.
        - "fixation_threshold": Minimum fixation duration threshold.

    Returns:
    --------
    None
        Saves the results to `output_file` in tab-separated format.
    """
    
    parent_folder = os.getcwd()

    median_rmss = []
    subject_epochs = []
    for root, dirs, files in os.walk(input_dir):
        for subject_file in files:
            if ".txt" not in subject_file:
                print(f'{subject_file} has invalid data format (not a .txt)')
            else:
                subject = subject_file.split('_')[1]

                print("Compute RMS(S2S) for subject: " + subject)

                input_file = os.path.join(root, subject_file)

                for i in range(1,vars_dict["epochN"] + 1):
                    subject_epochs.append("subject_" + subject + "_" + str(i))

                RMS = computeRMSSampleToSampleImpl(input_file, vars_dict["randoms"], vars_dict["fixation_threshold"], vars_dict)
                median_rmss += RMS
        break

    RMS_S2S_data = pd.DataFrame({'epoch' : subject_epochs, 'RMS(S2S)_median' : median_rmss})
    RMS_S2S_data.to_csv(output_file, index=False)


#actual calculations of the output dataset 

#trialdata
def computeTrialLevelData(input_dir, output_dir):

    """
    Processes raw data files to compute trial-level reaction times (RT) and last area of interest (AOI).
    
    Parameters:
        input_dir (str): Path to the directory containing raw .txt data files.
        output_dir (str): Path to the directory where output CSV files will be saved.
    """
    
    
    for root, _, files in os.walk(input_dir):
        for subject_file in files:
            if not subject_file.endswith(".txt"):
                print(f'{subject_file} has invalid data format (not a .txt)')
                continue
            
            subject = subject_file.split('_')[1]
            print(f"Compute trial level data for subject: {subject}")
            
            raw_data_path = os.path.join(root, subject_file)
            RT_data_path = os.path.join(output_dir, f'subject_{subject}__trial_log.csv')
            
            RT_data = calcRTColumn(raw_data_path)
            last_AOI_data = calcLastAOIColumn(raw_data_path)
            
            if RT_data is not None and last_AOI_data is not None:
                generateOutput(raw_data_path, RT_data_path, RT_data, last_AOI_data)
            else:
                print(f"Missing data for participant #{subject}")
        break

def calcRTColumn(raw_file_name):
    """
    Computes the reaction time (RT) for each trial based on gaze timestamp data.
    
    Parameters:
        raw_file_name (str): Path to the raw data file.
    
    Returns:
        list: Reaction times for each trial.
    """
    
    required_cols = ['block', 'trial', 'trial_phase', 'gaze_data_time_stamp']
    
    try:
        # Read file with only necessary columns
        data_table = pd.read_csv(raw_file_name, sep='\t', encoding="latin1", 
                                 usecols=required_cols, 
                                 dtype={'gaze_data_time_stamp': 'float64'})
    
    except ValueError as e:  # Catches missing column errors
        print(f"Error: Missing required columns for this participant in calcRTcol. Skipping to next.")
        print(f"Details: {e}")
        return []  # Return None to indicate failure, allowing the loop to continue
       
    

    RT_data = []
    trial_column = data_table["trial"]
    block_column = data_table["block"]
    trial_phase_column = data_table["trial_phase"]
    time_stamp_column = data_table["gaze_data_time_stamp"]

    start_time = 0
    end_time = 0
    start_time_found = False
    end_time_found = False

    for i in range(len(trial_column)):

        # we reached the next trial's first data (we compute the previous trial's reaction time).
        if i == 0:
            reached_next_trial = False
        else:
            reached_next_trial = trial_column[i] != trial_column[i - 1]
        if reached_next_trial:
            # we don't compute RT for 0 indexed blocks, which are calibration validation blocks.
            if str(block_column[i - 1]) != "0":
                RT_data.append(calcRTTrial(start_time_found, start_time, end_time_found, end_time, int(time_stamp_column[i - 1])))

            start_time_found = False
            end_time_found = False

        # stimulus appears on the screen -> start time
        if trial_phase_column[i] == "stimulus_on_screen" and not start_time_found:
            start_time = int(time_stamp_column[i])
            start_time_found = True

        # stimulus disappears from the screen -> end time
        if trial_phase_column[i] == "after_reaction" and not end_time_found:
            end_time = int(time_stamp_column[i - 1])
            end_time_found = True
            #Do not panic, if the endtime is not found here, we'll just save the last time stamp of the given trial - see calcRTTrial funct

        # we need to handle the last trial differently at the end of the data file.
        reached_end_of_file = (i == len(trial_column) - 1)
        if reached_end_of_file:
            # we don't compute RT for 0 indexed blocks, which are calibration validation blocks.
            if str(block_column[i - 1]) != "0":
                RT_data.append(calcRTTrial(start_time_found, start_time, end_time_found, end_time, int(time_stamp_column[i])))
        
    return RT_data

def calcLastAOIColumn(raw_file_name):
    """
    Determines the last area of interest (AOI) before the stimulus appears.
    
    Parameters:
        raw_file_name (str): Path to the raw data file.
    
    Returns:
        list: Last AOI values for each trial.
    """
    
    required_cols = ['block', 'trial', 'trial_phase', 'left_gaze_validity', 'right_gaze_validity',
                                         'left_gaze_data_X_ADCS', 'left_gaze_data_Y_ADCS', 'right_gaze_data_X_ADCS', 'right_gaze_data_Y_ADCS']
    
    try:
        # Read file with only necessary columns
        data_table = pd.read_csv(raw_file_name, sep='\t', encoding="latin1", usecols=required_cols)
    
    except ValueError as e:  # Catches missing column errors
        print(f"Error: Missing required columns for this participant in lastAOI. Skipping to next.")
        print(f"Details: {e}")
        return []  # Return None to indicate failure, allowing the loop to continue

    anticipation_data = []
    trial_column = data_table["trial"]
    block_column = data_table["block"]
    trial_phase_column = data_table["trial_phase"]
    left_gaze_validity_column = data_table["left_gaze_validity"]
    right_gaze_validity_column = data_table["right_gaze_validity"]
    left_gaze_data_X_column = data_table["left_gaze_data_X_ADCS"]
    left_gaze_data_Y_column = data_table["left_gaze_data_Y_ADCS"]
    right_gaze_data_X_column = data_table["right_gaze_data_X_ADCS"]
    right_gaze_data_Y_column = data_table["right_gaze_data_Y_ADCS"]

    last_AOI = -1
    for i in range(len(trial_column)):
        # we reached the next trial's first data (we compute the previous trial's last visited AOI).
        if i == 0:
            reached_next_trial = False
        else:
            reached_next_trial = trial_column[i] != trial_column[i - 1]
        if reached_next_trial:
            # we don't compute last AOI data for 0 indexed blocks, which are calibration validation blocks.
            if str(block_column[i - 1]) != "0":
                if last_AOI == -1:
                    anticipation_data.append('none')
                else:
                    anticipation_data.append(last_AOI)
            last_AOI = -1

        # get AOI during RSI
        if trial_phase_column[i] == 'before_stimulus':
            current_AOI = getAOI(bool(left_gaze_validity_column[i]), bool(right_gaze_validity_column[i]),
                                 strToFloat(left_gaze_data_X_column[i]), strToFloat(left_gaze_data_Y_column[i]),
                                 strToFloat(right_gaze_data_X_column[i]), strToFloat(right_gaze_data_Y_column[i]))
            if current_AOI != -1:
                last_AOI = current_AOI

        # we need to handle the last trial differently at the end of the data file.
        reached_end_of_file = (i == len(trial_column) - 1)
        if reached_end_of_file:
            # we don't compute last AOI data for 0 indexed blocks, which are calibration validation blocks.
            if str(block_column[i - 1]) != "0":
                if last_AOI == -1:
                    anticipation_data.append('none')
                else:
                    anticipation_data.append(last_AOI)
                last_AOI = -1

    return anticipation_data

def calcRTTrial(start_time_found, start_time, end_time_found, end_time, last_time_stamp_of_trial):
    """It calculates the differences between end time and start time and converts it into msec"""
    # We calculate the elapsed time during the stimulus was on the screen.
    if start_time_found and end_time_found:
        RT_ms = (end_time - start_time) / 1000.0
        #deleted str()
        return RT_ms

    # If there is no end time, then it means the software was fast enough to step
    # on to the next trial instantly after the stimulus was hidden. In this case
    # we use the last row of the given trial.
    elif start_time_found:
        end_time = last_time_stamp_of_trial
        RT_ms = (end_time - start_time) / 1000.0
        return RT_ms

    # The stimulus disappeared instantly.
    else:
        return "0"
      
def getAOI(left_gaze_validity, right_gaze_validity, left_gaze_X, left_gaze_Y, right_gaze_X, right_gaze_Y):
    """ Checks in which AOI of the four was the gaze """
    
    if left_gaze_validity and right_gaze_validity:
        X = (left_gaze_X + right_gaze_X) / 2.0
        Y = (left_gaze_Y + right_gaze_Y) / 2.0
    elif left_gaze_validity:
        X = left_gaze_X
        Y = left_gaze_Y
    elif right_gaze_validity:
        X = right_gaze_X
        Y = right_gaze_Y
    else:
        return -1

    if X <= 0.5 and Y <= 0.5:
        return 1
    elif X >= 0.5 and Y <= 0.5:
        return 2
    elif X <= 0.5 and Y >= 0.5:
        return 3
    else:
        return 4

def generateOutput(raw_file_name, new_file_name, RT_data, last_AOI_data):
    """
    Generates and saves the trial-level dataset with reaction times and last AOI information.
    
    Parameters:
        raw_file_name (str): Path to the raw data file.
        new_file_name (str): Path to save the generated output CSV.
        RT_data (list): Computed reaction times.
        last_AOI_data (list): Computed last AOI values.
    """
    
    
    required_cols = ['computer_name', 'monitor_width_pixel', 'monitor_height_pixel', 'subject_group',
                                          'subject_number', 'subject_sex', 'subject_age', 'asrt_type', 'PCode', 'session',
                                          'epoch', 'block', 'trial', 'frame_rate', 'frame_time', 'frame_sd', 'stimulus_color',
                                          'trial_type_pr', 'triplet_type_hl', 'stimulus']
    
    try:
        # Read file with only necessary columns
        input_data = pd.read_csv(raw_file_name, sep='\t', encoding="latin1", usecols=required_cols, dtype={'PCode': 'str' })
    
    except ValueError as e:  # Catches missing column errors
        print(f"Error: Missing required columns for this participant in generateOUTPUT. Skipping to next.")
        print(f"Details: {e}")
        return None  # Return None to indicate failure, allowing the loop to continue

    
    input_data = input_data.replace(',', '.', regex=True)
    output_data = pd.DataFrame(columns=input_data.columns)

    output_data = input_data[input_data['block'] != 0].drop_duplicates(subset=['epoch', 'block', 'trial'])

    #output_index = 0
    #last_trial = "0"
    #for index, row in input_data.iterrows():
    #    # we ignore 0 indexed blocks, which are calibration validation blocks.
    #    if str(row['block']) == "0":
    #        continue
#
    #    # insert one row of each trial data
    #    if row['trial'] != last_trial:
    #        last_trial = row['trial']
    #        output_data.loc[output_index] = row
    #        output_index += 1

    # reaction time of the trial
    
    if len(output_data.index) == len(RT_data):
        output_data['RT'] = RT_data
    else:
        RT_data = [None] * len(output_data.index)
        output_data['RT'] = RT_data
        warnings.warn("The length of your data table and RT data do not match. A list of NaN added to the data. Check the calcRTColumn function.")

    # Last AOI during RSI (useful for anticipation data calculation)
    
    if len(output_data.index) == len(last_AOI_data):
        output_data['last_AOI_before_stimulus'] = last_AOI_data
    else:
        last_AOI_data = [None] * len(output_data.index)
        output_data['last_AOI_before_stimulus'] = last_AOI_data
        warnings.warn("The length of your data table and lastAOI data do not match. A list of NaN added to the data. Check the calcLastAOIColumn function.")


    output_data.to_csv(new_file_name, sep='\t', index=False)
    

#extended trial data

def computeHighLowBasedOnLearningSequence(data_table, vars_dict):
    high_low_column = []
    #TODO gotta switch those to not hardcoded var names
    stimulus_column = data_table["stimulus"]
    trial_column = data_table["trial"]
    

    learning_epoch = vars_dict["OSepoch"]
    epoch_data = data_table.loc[data_table["epoch"] == learning_epoch]

    # get the learning sequence
    learning_sequence = findSequence(epoch_data, vars_dict["PCode"], vars_dict["notSeq"])
    if learning_sequence == "":
        warnings.warn(
            "Warning: could not find a valid learning sequence in the data for this subject. They might haven't finished the whole session")
        return []
    else:
        learning_sequence += learning_sequence[0]

        # We calculate wether the current triplet is a high or low triplet based on the learning sequence.
        for i in range(len(stimulus_column)):
            # Can't calculate for the first two trials of the block, because there is no triplet we can use.
            if trial_column[i] <= vars_dict["randoms"]:
                high_low_column.append('none')
            elif (str(stimulus_column[i - 2]) + str(stimulus_column[i])) in learning_sequence:
                high_low_column.append('high')
            else:
                high_low_column.append('low')

        return high_low_column

def computeTrillColumn(data_table):
    trill_column = []
    stimulus_column = data_table["stimulus"]
    trial_column = data_table["trial"]

    # A trial is a trill if the first and third item of the current triplet has the same
    # stimulus, but the middle item of the triplet is different.
    for i in range(len(stimulus_column)):
        # Can't calculate for the first two trials of the block, because there is no triplet we can use.
        if trial_column[i] <= 2:
            trill_column.append('none')
        elif (stimulus_column[i] != stimulus_column[i - 1] and
              stimulus_column[i] == stimulus_column[i - 2]):
            trill_column.append(True)
        else:
            trill_column.append(False)

    return trill_column

def computeRepetitionColumn(data_table):
    repetition_column = []
    stimulus_column = data_table["stimulus"]
    trial_column = data_table["trial"]

    # A trial is repetition if the previous trial has the same stimulus
    for i in range(len(stimulus_column)):
        # Can't calculate for the first trial of the block, because there is no previous trial.
        if trial_column[i] <= 1:
            repetition_column.append('none')
        elif stimulus_column[i] == stimulus_column[i - 1]:
            repetition_column.append(True)
        else:
            repetition_column.append(False)

    return repetition_column

def computeAnticipationColumn(data_table):

    anticipation_column = []
    stimulus_column = data_table["stimulus"]
    last_AOI_column = data_table["last_AOI_before_stimulus"]
    trial_column = data_table["trial"]

    # We calculate wether the last AOI during the current RSI was different from the
    # AOI of the last trial. So the the eye was moved after the last trial.
    for i in range(len(stimulus_column)):
        # Can't calculate for the first trial of the block, because there is no previous trial.
        if trial_column[i] <= 1:
            anticipation_column.append('none')
        # There is no valid AOI data recorded during RSI.
        elif last_AOI_column[i] == 'none':
            anticipation_column.append(False)
        elif int(last_AOI_column[i]) != int(stimulus_column[i - 1]):
            anticipation_column.append(True)
        else:
            anticipation_column.append(False)

    return anticipation_column

def findSequence(epoch_data_table, PCode_var, noPattern):
    PCode_column = epoch_data_table[PCode_var]

    PCode_list = PCode_column.unique().astype(str)

    if np.isin(noPattern, PCode_list):
        PCode_list = np.delete(PCode_list, np.where(PCode_list == noPattern))  # safer than .remove()

    if PCode_list.size == 0:
        warnings.warn("There's no valid PCode for this subject.")
        return ""
    elif PCode_list.size > 1:
        warnings.warn("There's more than one PCode for this subject.")
        return ""
    else:
        return str(PCode_list[0])

def computeLearntAnticipationColumn(data_table, vars_dict, is_learning_seq):
    learnt_anticipation_data = []
    #TODO: switch these to non-hardcoded var names from vars_dict
    stimulus_column = data_table["stimulus"]
    last_AOI_column = data_table["last_AOI_before_stimulus"]
    trial_column = data_table["trial"]


    if is_learning_seq:
        epoch_for_finding_seq = vars_dict["OSepoch"]
    else:
        if vars_dict["IFepoch"] == None:
            warnings.warn("This design doesn't have an interference epoch. Computation of interference anticipation is terminated")
            return []
        else:
            epoch_for_finding_seq = vars_dict["IFepoch"]

    epoch_data_table = data_table.loc[data_table["epoch"] == epoch_for_finding_seq]

    sequence = findSequence(epoch_data_table, vars_dict["PCode"], vars_dict["notSeq"])

    if sequence == "":
        warnings.warn(
            "Warning: could not find a valid learning sequence in the data for this subject. They might haven't finished the whole session")
        return []
    else:
        sequence += sequence[0]

        for i in range(len(stimulus_column)):
            # Can't calculate for the first two trials of the block, because there is no triplet we can use.
            if trial_column[i] <= vars_dict["randoms"]:
                learnt_anticipation_data.append('none')
            # There is no valid AOI data recorded during RSI.
            elif last_AOI_column[i] == 'none':
                learnt_anticipation_data.append(False)
            # No anticipation eye movement. Eye is in the same AOI where the previous stimulus was.
            elif int(last_AOI_column[i]) == int(stimulus_column[i - 1]):
                learnt_anticipation_data.append(False)
            # The last registered AOI during RSI follows the learning sequence.
            elif str(stimulus_column[i - 2]) + str(last_AOI_column[i]) in sequence:
                learnt_anticipation_data.append(True)
            else:
                learnt_anticipation_data.append(False)

        return learnt_anticipation_data

def extendTrialLevelDataForOneSubject(input_file, output_file, vars_dict):
    data_table = pd.read_csv(input_file, sep='\t')

    # previous trial has the stimulus at the same position -> repetition.
    repetition_data = computeRepetitionColumn(data_table)
    if len(repetition_data) == len(data_table.index):
        data_table["repetition"] = repetition_data
    else:
        repetition_data = [None] * len(data_table.index)
        data_table["repetition"] = repetition_data
        warnings.warn("The length of your data table and repetition data do not match. A list of NaN addet to the repetition col. Check the computeRepetitionColumn function.")



    # trill: first item and third item of trial triplet is the same: e.g. 1x1, 2x2, etc.
    trill_data = computeTrillColumn(data_table)
    if len(trill_data) == len(data_table.index):
        data_table["trill"] = trill_data
    else:
        trill_data = [None] * len(data_table.index)
        data_table["trill"] = trill_data
        warnings.warn("The length of your data table and trill data do not match. A list of NaN addet to the trill col. Check the computeTrillColumn function.")

    # calculate frequency based on learning sequence
    high_low_data = computeHighLowBasedOnLearningSequence(data_table, vars_dict)
    if len(high_low_data) == len(data_table.index):
        data_table["high_low_learning"] = high_low_data
    else:
        high_low_data = [None] * len(data_table.index)
        data_table["high_low_learning"] = high_low_data
        warnings.warn("The length of your data table and HL learning data do not match. A list of NaN addet to the high_low_learning col. Check the computeHighLowBasedOnLearningSequence function.")

    
    
    # calculate whether anticipatory eye-movement has happened
    anticipation_data = computeAnticipationColumn(data_table)
    if len(anticipation_data) == len(data_table.index):
        data_table["has_anticipation"] = anticipation_data
    else:
        anticipation_data = [None] * len(data_table.index)
        data_table["has_anticipation"] = anticipation_data
        warnings.warn("The length of your data table and anticipation data do not match. A list of NaN addet to the has_anticipation col. Check the computeAnticipationColumn function.")

    # calculate whether learning dependent anticipatory eye-movement has happened

    learnt_anticipation_data_OS = computeLearntAnticipationColumn(data_table, vars_dict, True)
    if len(learnt_anticipation_data_OS) == len(data_table.index):
        data_table["has_learnt_anticipation_OS"] = learnt_anticipation_data_OS
    else:
        learnt_anticipation_data_OS = [None] * len(data_table.index)
        data_table["has_learnt_anticipation_OS"] = learnt_anticipation_data_OS
        warnings.warn("The length of your data table and has_learnt_anticipation_original_seq data do not match. A list of NaN addet to the has_learnt_anticipation_original_seq col. Check the computeLearntAnticipationColumn function.")

    if vars_dict["IFepoch"] != None:
        learnt_anticipation_interference = computeLearntAnticipationColumn(data_table, vars_dict, False)
        if len(learnt_anticipation_interference) == len(data_table.index):
            data_table["has_learnt_anticipation_IF"] = learnt_anticipation_interference
        else:
            learnt_anticipation_interference = [None] * len(data_table.index)
            data_table["has_learnt_anticipation_IF"] = learnt_anticipation_interference
            warnings.warn("The length of your data table and has_learnt_anticipation_interference data do not match. A list of NaN addet to the has_learnt_anticipation_interference col. Check the computeLearntAnticipationColumn function.")


    data_table.to_csv(output_file, sep='\t', index=False)

def extendTrialLevelData(input_dir, output_dir, vars_dict):
    for root, dirs, files in os.walk(input_dir):
        for subject_file in files:
            subject = subject_file.split('_')[1]

            print("Extend trial level data with additional fields for subject: " + subject)

            input_file = os.path.join(input_dir, subject_file)
            output_file = os.path.join(output_dir, 'subject_' + subject + '__trial_extended_log.csv')
            extendTrialLevelDataForOneSubject(input_file, output_file, vars_dict)

        break
        
        
#cleanup to create HL LMM-compatible data + ofc our beloved simple pivot

def clean_and_organize(list_cols, variables, sub_to_drop, ext_trialdata_path):
    ''' 
    function to arrange everything about the data: drop unneeded cols, convert RT to numeric, drop excl subjects, drop random trials from the beginning of blocks
    INPUT: raw ext trial data, list of cols needed, variables dict as defined above, list of subjects to exclcude
    OUTPUT: raw data with cleaned columns, RT as num, dropped subjects
    '''
    
    ## first, read all file names in ext_trialdat_path
    all_files = glob.glob(os.path.join(ext_trialdata_path, '*.csv'))
    
    # read & concat them
    raw_dat = pd.concat((pd.read_csv(f, sep = '\t') for f in all_files), ignore_index=True)
    
    blockprepN = variables["randoms"]
    trial = variables["trial"]
    unwanted_col = False
    ## step1: drop not needed cols
    #check list_cols
    for col in list_cols:
        if col not in raw_dat.columns:
            print(f'{col} not in data.')
            unwanted_col = True
    if not unwanted_col:
        raw_dat = raw_dat[list_cols]
    else:
        print("couldn't do col cleaning. Check the list_cols and try again.")


    ##step2
    #if RT data is not numeric, change it
    if raw_dat[variables["dep"]].dtypes != 'float64':
        raw_dat[variables["dep"]] = pd.to_numeric(raw_dat[variables["dep"]].str.replace(',', '.'))
    
    #step3: drop subjects
    if len(sub_to_drop) < 0:
        raw_dat.drop(raw_dat.loc[raw_dat[variables["ID"]].isin(sub_to_drop)].index, inplace=True)


    
    return raw_dat

def drop_blockstarting_randoms(trial_data, vars_dict):
    blockprepN = vars_dict["randoms"]
    trial = vars_dict["trial"]
    #drop random trials in the beginning of each epoch
    cleaned_data = trial_data.loc[trial_data[trial] > 5].copy()
    
    return cleaned_data

def get_trialdat_ready_for_HL_LMM(raw_dat_clean, vars_dict, drop0RT = True, is_interference = True):
    """This function 1) deletes the first blockprepN trials that are not sequential, 
    create integer of the has learnt anticipation column (if is_interference == True, for both OS and IF anticipation cols)
    if drop0RT == True it drops null reactiontimes
    drop trills and reps"""
    raw_dat_clean = raw_dat_clean.copy()
    RT = vars_dict["dep"]

    #1) drop first few trials
    blockprepN = vars_dict["randoms"]
    raw_dat_clean = raw_dat_clean.loc[raw_dat_clean["trial"]>blockprepN]

    #2) whether has anticipation and has learnt anticipation cols are string or bool, these lines will code it as 0 and 1
    raw_dat_clean["has_anticipation_int"] = raw_dat_clean["has_anticipation"].map({"True": 1, "False": 0, True: 1, False: 0, "TRUE": 1, "FALSE": 0})
    raw_dat_clean["has_learnt_anticipation_OS_int"] = raw_dat_clean["has_learnt_anticipation_OS"].map({"True": 1, "False": 0, True: 1, False: 0, "TRUE": 1, "FALSE": 0})
    if vars_dict["IFepoch"] != None:
        raw_dat_clean["has_learnt_anticipation_IF_int"] = raw_dat_clean["has_learnt_anticipation_IF"].map({"True": 1, "False": 0, True: 1, False: 0, "TRUE": 1, "FALSE": 0})

    #3) drop null reaction times if requested
    if drop0RT:
        raw_dat_clean = drop_blockstarting_randoms(raw_dat_clean, vars_dict)

    #4) drop trills and reps
    true_list =  [True, "TRUE", "True"]
    raw_dat_clean = raw_dat_clean.loc[~raw_dat_clean["trill"].isin(true_list)]
    raw_dat_clean = raw_dat_clean.loc[~raw_dat_clean["repetition"].isin(true_list)]

    return raw_dat_clean

def simple_pivot(data_in, outcome_var, indeces, cols, aggfunct):
    data_out = pd.pivot_table(data_in, values = outcome_var, index = indeces, columns = cols, aggfunc = aggfunct)
    data_out.columns = [' '.join(str(col)).strip() for col in data_out.columns.values]
    data_out.reset_index(inplace=True)
    data_out.columns = data_out.columns.str.replace("[()'', ]", "", regex = True)
    data_out.replace('', np.nan, inplace=True)
    data_out.dropna(axis = 0, inplace=True)
    if "group_string" in data_in.columns:
        groupfile = data_in[['Subject', "group", "group_string"]].drop_duplicates()
        data_out = pd.merge(data_out, groupfile, on='Subject', how='outer')
    return data_out

def calcHighLowDiff(data_wide, vars_dict, is_RT = False):
    high = vars_dict["H"]
    low = vars_dict["L"]
    data_for_HLdiff = data_wide[[col for col in data_wide.columns if high in col or low in col]]
    #if data_for_HLdiff.columns[0] != '2high' or data_for_HLdiff.columns[1] != '2low':
    #    raise Exception("Column name structure is not suitable. Please rename to [2high, 2low, 3high, 3low....]")
    for col in data_for_HLdiff:
        if "high" in col:
            previous_col = col
        elif "low" in col:
            if is_RT:
                data_wide["TRI_" + col[0]] = data_for_HLdiff[col] - data_for_HLdiff[previous_col]
            else:
                data_wide["TRI_" + col[0]] = data_for_HLdiff[previous_col] - data_for_HLdiff[col]
            previous_col = col
    return data_wide


#processing the IF epoch only:
def calculate_HLLH_RT(dat_IF_epochs, vars_dict):
    """
    Compute HLLH_RT based on 'high_low_learning' and 'triplet_type_hl' values.

    Parameters:
    - dat_IF_epochs: DataFrame with IF epochs
    - vars_dict: Dictionary containing 'H' and 'L' keys

    Returns:
    - Updated DataFrame with 'HLLH_RT' column
    """
    HLLH_list = []
    
    for _, row in dat_IF_epochs.iterrows():
        if row["high_low_learning"] == vars_dict["H"] and row["triplet_type_hl"] == vars_dict["H"]:
            HLLH_list.append("HH")
        elif row["high_low_learning"] == vars_dict["H"] and row["triplet_type_hl"] == vars_dict["L"]:
            HLLH_list.append("HL")
        elif row["high_low_learning"] == vars_dict["L"] and row["triplet_type_hl"] == vars_dict["H"]:
            HLLH_list.append("LH")
        elif row["high_low_learning"] == vars_dict["L"] and row["triplet_type_hl"] == vars_dict["L"]:
            HLLH_list.append("LL")
        else:
            HLLH_list.append("none")

    if len(HLLH_list) != dat_IF_epochs.shape[0]:
        raise IndexError("Mismatch between DataFrame length and computed HLLH list.")

    dat_IF_epochs.loc[:, "HLLH_RT"] = HLLH_list
    return dat_IF_epochs

def calculate_HLLH_LDAEM(dat_IF_epochs, vars_dict):
    """
    Compute HLLH_LDAEM based on 'has_learnt_anticipation_OS' and 'has_learnt_anticipation_IF' values.

    Parameters:
    - dat_IF_epochs: DataFrame with IF epochs
    - vars_dict: Dictionary containing 'H' and 'L' keys

    Returns:
    - Updated DataFrame with 'HLLH_LDAEM' column
    """
    LDNLD_list = []

    for _, row in dat_IF_epochs.iterrows():
        if row["has_learnt_anticipation_OS"] == "True" and row["has_learnt_anticipation_IF"] == "True":
            LDNLD_list.append("LD-LD")
        elif row["has_learnt_anticipation_OS"] == "True" and row["has_learnt_anticipation_IF"] == "False":
            LDNLD_list.append("LD-NLD")
        elif row["has_learnt_anticipation_OS"] == "False" and row["has_learnt_anticipation_IF"] == "True":
            LDNLD_list.append("NLD-LD")
        elif row["has_learnt_anticipation_OS"] == "False" and row["has_learnt_anticipation_IF"] == "False":
            LDNLD_list.append("NLD-NLD")
        else:
            LDNLD_list.append("none")

    if len(LDNLD_list) != dat_IF_epochs.shape[0]:
        raise IndexError("Mismatch between DataFrame length and computed LDNLD list.")

    dat_IF_epochs.loc[:, "HLLH_LDAEM"] = LDNLD_list
    return dat_IF_epochs

def process_IF_epochs(raw_dat_clean, vars_dict):
    """
    Filters IF epochs and computes HLLH_RT and HLLH_LDAEM.

    Parameters:
    - raw_dat_clean: Raw DataFrame
    - vars_dict: Dictionary containing epoch information and labels ('H', 'L')

    Returns:
    - Processed DataFrame with new columns 'HLLH_RT' and 'HLLH_LDAEM'
    """
    seqB_list = vars_dict["seqB_list"]
    if len(seqB_list) == 0:
        print("You didn't even have interference epochs here. Calculating HLLH scores failed miserably.")
        return None
    epoch = vars_dict["epoch"]
    dat_IF_epochs = raw_dat_clean.loc[raw_dat_clean[epoch].isin(seqB_list)].copy()

    dat_IF_epochs = calculate_HLLH_RT(dat_IF_epochs, vars_dict)
    dat_IF_epochs = calculate_HLLH_LDAEM(dat_IF_epochs, vars_dict)

    return dat_IF_epochs

#generating response type files

def filter_random_epochs(raw_data, variables_dict):
    out_dat = raw_data.loc[raw_data[variables_dict['PCode']] != variables_dict['notSeq']].copy()
    return out_dat

def is_response_LD(PCode, current_triplet, response):
    """
    checks if the given response is learning-dependent (corresponding to a H triplet)
    INPUT: sequence code e.g. 1234, the triplet of which the current stim is the last element of, current last AOI
    OUTPUT: bool whether it's learning-dependent or not OR None if something is wrong
    """
    #returns True if response is LD, False if NLD
    PCode_loop = str(PCode) + str(PCode)
    
    if len(str(current_triplet)) <3:
        raise IndexError('Uh-oh, it seems your triplet isnt long enough if you know what I mean.')
        return None
    
    resp_PCode = str(current_triplet)[0] + str(response)
    
    if resp_PCode in PCode_loop:
        return True
    else:
        return False
    
def is_response_correct(response, stimulus):
    """
    calculates if the given response corresponds to the stim appearing
    """
    
    if str(response) == str(-1):
        print("no interpretable lastAOI. Skipping this one.")
        return None
    
    if str(response) == str(stimulus):
        return True
    else:
        return False
    
def update_triplet(current_triplet, stim):
    if len(str(current_triplet))<3:
        return str(current_triplet) + str(stim)
    else:
        return str(current_triplet)[1:3] + str(stim)
    
def response_data_calculator(raw_data_in, variables_dict, is_interference=False):
    """
    This function calculates the four response types for each trial.
    INPUT: raw ext trial data with filtered cols and subjects (as in clean_and_organize); 
        variables dictionary as defined above. 
        If you check an interference epoch, set is_interference = True
    OUTPUT: raw ext trial data with response_type/original_response_type (if is_interference == True), and current_triplet cols
        
    """

    # drop nonseq epochs
    raw_data = filter_random_epochs(raw_data_in, variables_dict)
       
    # init vars needed
    if not is_interference:
        PCodevar = variables_dict['PCode']
    else:
        PCodevar = variables_dict['original_PCode']
   
    response_type = []
    triplet = []
    current_triplet = ''  # Initialize current_triplet outside the loop
    trial = variables_dict['trial']
    stim = variables_dict['stim']
    lastAOI = variables_dict['response']
    
    for index, row in raw_data.iterrows():
        
        # Checking if we are at a new block, if so, reset current_triplet
        if int(row[trial]) == 1:
            current_triplet = ''
        else:
            pass
       
        # Update triplet
        current_triplet = update_triplet(current_triplet, row[stim])
        triplet.append(current_triplet)
       
        if len(current_triplet) < 3:
            response_type.append('none')
            
        else:
            if is_response_correct(row[lastAOI], row[stim]):
                if is_response_LD(row[PCodevar], current_triplet, row[lastAOI]):
                    response_type.append('LD_CORRECT')
                elif not is_response_LD(row[PCodevar], current_triplet, row[lastAOI]):
                    response_type.append('NLD_CORRECT')
                else:
                    print(f'Failed to calculate resp type at row {index}, epoch {row[variables_dict["epoch"]]} trial {trial}')
                    response_type.append("None")
            
            elif not is_response_correct(row[lastAOI], row[stim]):
                if is_response_LD(row[PCodevar], current_triplet, row[lastAOI]):
                    response_type.append('LD_ERROR')
                elif not is_response_LD(row[PCodevar], current_triplet, row[lastAOI]):
                    response_type.append('NLD_ERROR')
                else:
                    print(f'Failed to calculate resp type at row {index}, epoch {row[variables_dict["epoch"]]} trial {trial}')
                    response_type.append("None")
            
            else: 
                print(f'Failed to calculate resp type at row {index}, epoch {row[variables_dict["epoch"]]} trial {trial}')
                response_type.append("None")
                

    if not is_interference:
        raw_data['response_type'] = response_type
        raw_data['current_triplet'] = triplet
    else:
        raw_data['original_response_type'] = response_type
       
    return raw_data


#likelihoods

def count_values_epochwise(data, variables_dict, counted_var):
    """
    """
    ### create variables of col names
    ID = variables_dict["ID"]
    epoch = variables_dict["epoch"]
    ### create a column full of ones only
    data["count_support"] = 1
    ### calculating how many are there of the resp types
    trial_total_counts = simple_pivot(data, "count_support", ID, epoch, "sum")
    counts = simple_pivot(data, "count_support", ID, [epoch, counted_var], "sum")
    
    return counts, trial_total_counts

def get_percentage_data(counts, trial_total_counts, variables_dict):
    """
    This function calculates what percentage of the trials in a given epoch were a given response type.
    INPUT: wide response count data as produced by count_epochwise, wide trial count data (epoch nr as column), variables dictionary to define col names
    OUTPUT: wide response percentage data
    """
    ID = variables_dict['ID']
    dat_percentage = pd.DataFrame(columns = counts.columns)
    dat_percentage[ID] = counts[ID]
    for col in counts.columns:
        for col2 in trial_total_counts.columns:
            if col2 == ID:
                pass
            elif col2 in col:
                dat_percentage[col] = counts[col]/trial_total_counts[col2]
    return dat_percentage

def divide_cols(dat_percentage, variables, probability_dict):
    """
    divides colums with corresponding probabilities
    INPUT: wide (response) percentage data, variable dictionary, probabilities in dictionary form (default: the resp type probabilities)
    OUTPUT: the new shiny likelihood ratio wide data
    """
    ID = variables['ID']
    random = variables['rando']
    
    dat_LLR = pd.DataFrame(columns=dat_percentage.columns)
    for colname, col in dat_percentage.items():
        if colname == ID:
            dat_LLR[colname] = dat_percentage[colname]
        elif random in colname:
            pass
        else:
            for key in probability_dict:
                if key in colname and len(colname) == len(key)+1:
                    #(this latter bit cuz LD_ERROR is in NLD_ERROR too :/ )
                    dat_LLR[colname] = dat_percentage[colname]/probability_dict[key]
                else:
                    pass
    return dat_LLR

def likelihood_generator(trialdata, variables_dict, count_variable, probability_dict = {"LD_CORRECT": 0.15625, "LD_ERROR": 0.09375,"NLD_ERROR": 0.65625,"NLD_CORRECT": 0.09375}):
    """
    This function creates the likelihood data, as in, P(response type|baseline probability of given resp type).
    INPUT: trial-based response type data, variables dictionary, and if you calculate for anything but response types then the probabilities of options in a dictionary
    OUTPUT: the absolutely beautiful likelihood ratio data (wide)
    """
    
    dat_resp_type, dat_epochtrial = count_values_epochwise(trialdata, variables_dict, count_variable)
    
    percdata = get_percentage_data(dat_resp_type, dat_epochtrial, variables_dict)
    
    LLR_data = divide_cols(percdata, variables_dict, probability_dict)
    
    return LLR_data

#updates

def was_there_an_update(tri_dict, row, lastAOI):
    current = str(row["current_triplet"])
    if current not in tri_dict:
        raise ValueError(f'Man, we dont have any {current} in the triplet dictionary, check the is_prior_updated_function or dunno')
        return pd.NA
    elif tri_dict[current][0] == row[lastAOI]:
        #there was no update
        return False
    elif row[lastAOI] == -1:
        return pd.NA
    else:
        return True  
    
def update_type(row, update_type, prev_resp):
    if update_type not in [True, False]:
        return ""
    elif update_type == False:
        if "NLD" in row["response_type"]:
            return "NLD_SAME"
        else:
            return "LD_SAME"
    elif update_type == True:
        if row["response_type"] in ["LD_CORRECT", "LD_ERROR"]:
            return "MOD_CORR"
            
        #cuz it cannot be LD AND update unless the previous response was NLD
        else:
            if prev_resp in ["LD_CORRECT", "LD_ERROR"]:
                return "MOD_DISR"
            else:
                return "EXPL"
    else:
        print("TypeWarning: most likely your update variable is not a bool. You may wanna check the calculate_update vars function.")      
        
def calculate_update_vars(dat_with_responses, variables_dict):
    """
    
    INPUT: unfiltered trial-by-trial response data
    """
    
    #init target vars
    update = []
    trials_since_triplet = []
    LD_update = []
    prev_response_type = []
    current_subject = ''
    triplets = dict()
    ID = variables_dict["ID"]
    lastAOI = variables_dict["response"]
    
    for index, row in dat_with_responses.iterrows():
        
        #vars to make code more readable:
        if str(current_subject) != str(row[ID]):
            triplets = dict()
            print(f'current_subject {current_subject} updated to {row[ID]}')
            current_subject = str(row[ID])
        current_triplet = str(row["current_triplet"])
        response = row[lastAOI]
        
        if len(current_triplet)<3:
            update.append(pd.NA)
            trials_since_triplet.append(np.nan)
            LD_update.append('none')
            prev_response_type.append('none')
        
        elif current_triplet not in triplets:
            triplets[current_triplet] = [response, row['response_type'], index]
            update.append(pd.NA)
            trials_since_triplet.append(np.nan)
            LD_update.append('first')
            prev_response_type.append('first')

        else:
            update_now = was_there_an_update(triplets, row, lastAOI)
            update.append(update_now)
            trials_since_triplet.append(index - triplets[current_triplet][2])
            LD_update.append(update_type(row, update_now, triplets[current_triplet][1]))
            prev_response_type.append(triplets[current_triplet][1])
            #print(f'Im adding {triplets[current_triplet]} to the dict')
            triplets[current_triplet] = [response, row['response_type'], index]
            
    #print(triplets)
    dat_with_responses['update'] = update
    dat_with_responses['trials_since_triplet'] = trials_since_triplet
    dat_with_responses['update_type'] = LD_update
    dat_with_responses['prev_resp'] = prev_response_type
    dat_with_updates = dat_with_responses.copy()
    
    return dat_with_updates




todo: readind settings file

## create folder tree

create required folders if they don't exist already, and save their paths

In [12]:

# List of folder names to be created
folders = [
    "indat", 
    "trialdata", 
    "ext_trialdata", 
    "DATAQUAL_missing_data_ratio", 
    "DATAQUAL_distance", 
    "DATAQUAL_E2E", 
    "DATAQUAL_S2S", 
    "anticipatory_eye_movements", 
    "statistical_learning", 
    "interference",
    "output"
]
# Create a dictionary to store the paths
folder_paths = {}

# Get the current working directory
current_directory = os.getcwd()

# Loop through the folders, create them, and store their paths
for folder in folders:
    folder_path = os.path.join(current_directory, folder)
    folder_paths[f"{folder}_path"] = folder_path
    if not os.path.exists(folder_path):
        os.makedirs(folder_path)
        print(f"Folder '{folder}' created at: {folder_path}")
    
    else:
        print(f"Folder '{folder}' already exists.")

# Unpack the dictionary to assign each path as a variable
globals().update(folder_paths)


Folder 'indat' already exists.
Folder 'trialdata' already exists.
Folder 'ext_trialdata' already exists.
Folder 'DATAQUAL_missing_data_ratio' already exists.
Folder 'DATAQUAL_distance' already exists.
Folder 'DATAQUAL_E2E' already exists.
Folder 'DATAQUAL_S2S' already exists.
Folder 'anticipatory_eye_movements' already exists.
Folder 'statistical_learning' already exists.
Folder 'interference' already exists.
Folder 'output' already exists.


## Put the input files in the "indat" folder we just created

(subject-by-subject, created by PsychoPy, name structure as in "subject_##__log") 

## Your design 
Tell me about your design. <br>
In the future, this should be read from the settings file but thing is I'm a lazy but time-efficient bastard so bear with me

In [184]:
## we will use this vars dictionary for the data preparation phase.

vars_dict = {
    #add the nr of epochs in the files
    "epochN": 8,
    #how many blocks are there in one epoch?
    "blocks": 5,
    #number of practice random trials in a block
    "randoms": 5,
    #sequential trials in a block
    "trialN": 80,
    #if u use the Tobii-optimized ASRT, it's always 12:
    "fixation_threshold": 12,
    #the location of the appearing stim:
    'stim': 'stimulus',
    #H L triplets:
    'triplet_type': 'triplet_type_hl',
    #participant's response (e.g., last AOI in eye-tracking data, or button the participant pressed in the motor version):
    'response': 'last_AOI_before_stimulus',
    #how are random stimuli in the beginning of blocks marked?
    'rando': 'none',
    #the string for high triplets
    'H': 'high',
    #and low?
    'L': 'low',
    #which var is a string of the sequence (e.g., '1234')?
    'PCode': 'PCode',
    #IF you have interference - if not change to None:
    'original_PCode': 'original_PCode',
    #time unit (epoch or block)
    'epoch': 'epoch',
    #string marking nonpattern blocks:
    'notSeq': 'noPattern',
    #trial nr
    'trial': 'trial',
    #response type (LDC LDE NLDC NLDE) variable
    'response_type': 'response_type',
    #dependent var, typically RT
    'dep': 'RT',
    #triplet type 
    'TT': 'triplet_type_hl',
    #subject ID
    'ID': 'subject_number',
    #first sequential epoch nr
    'OSepoch': 6,
    #IFepoch: which one is your first interference epoch? If there's no interference epoch, just write None
    'IFepoch': 1,
    #list of epochs with seqA
    "seqA_list": [6,8],
    #list of epochs with seqB (if none, add empty list)
    "seqB_list": [1,2,3,4,5,7],
    #HL based on OS
    "HL_OS_var": "high_low_learning"
}



# List of subjects with bad data
subs_with_bad_data = {47}


## copy here any valid file name in the indat folder. We will use it to test the settings
testfile = "subject_87__log.txt"

## Data validity

Let's check if your data looks ok and throw out some bad data so you don't get fcked later on. <br>

This section requires you to fill the subs_with_bad_data list with your subjects-to-drop !!

### data type validation

In [21]:
# Define paths
junk_folder = os.path.join(os.getcwd(), "junk")
os.makedirs(junk_folder, exist_ok=True)  # Ensure junk folder exists
 
file_validator(indat_path)

All file formats are valid in the 'indat' folder.


### is vars_dict correct

checking if you say the truth

In [25]:
check_variables_dict(indat_path + "/" + testfile, vars_dict)
data_structure_validator(vars_dict, indat_path + "/" + testfile)


All variables in 'variables_dict' are present in the DataFrame columns.
Using the file C:\Users\porsi\BML-MEMO LAB Dropbox\bml memo members\Orsi_Pesthy\everuseful\ET_py\NEW_IMPROVED_ET_analyser\indat/subject_87__log.txt
Your nr of epochs match with what you specified in the vars_dict.
The specified nr of random and sequential trials match the data structure.
Aaand yepp, the interference and original sequence epochs indeed differ. Good. You are all set.


### drop wrong datafiles

In [29]:
subject_dropper(indat_path, subs_with_bad_data, vars_dict)

All good, all subjects seem to have sufficient amount of epochs
bad sub list is: {47}


### stimulus col validator
 If it warns you about the getAOI function, you have to fix the coordinates belonging to each stimulus

In [32]:
test_gaze_coordinate_coding(indat_path)

Processing file: subject_100__log.txt


  df = df.applymap(lambda x: str(x).replace(",", ".") if isinstance(x, str) else x)


All stimuli met expectations, your getAOI function functions.


{3: (0.36, 0.74), 1: (0.36, 0.25), 2: (0.63, 0.25), 4: (0.64, 0.74)}

# (now, you can just mindlessly run all the cells, the following codes don't require your attention, at least they shouldn't)

# data quality measures

## missing data ratio

doublechecked, it works. 

15012025 update: triplechecked, F, C, O yeee


In [23]:
computeMissingDataRatio(indat_path, DATAQUAL_missing_data_ratio_path+"/missing_data.csv", vars_dict)

Compute missing data ratio for subject: 100
Compute missing data ratio for subject: 106
Compute missing data ratio for subject: 75
Compute missing data ratio for subject: 87


## distance 

In [24]:
computeDistance(indat_path, DATAQUAL_distance_path+"/distance.csv", vars_dict)

Compute eye-screen distance data for subject: 100
Compute eye-screen distance data for subject: 106
Compute eye-screen distance data for subject: 75
Compute eye-screen distance data for subject: 87


## compute RMS_E2E

doublechecked, works

In [25]:
computeRMSEyeToEye(indat_path, DATAQUAL_E2E_path+"/E2E.csv" , vars_dict)

Compute RMS(E2E) for subject:  100
Compute RMS(E2E) for subject:  106
Compute RMS(E2E) for subject:  75
Compute RMS(E2E) for subject:  87


## S2S

In [27]:
computeRMSSampleToSample(indat_path, DATAQUAL_S2S_path+ "/S2S.csv", vars_dict)

Compute RMS(S2S) for subject: 100
Compute RMS(S2S) for subject: 106
Compute RMS(S2S) for subject: 75
Compute RMS(S2S) for subject: 87


## compute trial-level data

you will find the trial level data in the trialdat folder created earlier. <br>
This will take a while, especially if you have a lot of data. Make a coffee in the meantime

In [26]:
computeTrialLevelData(indat_path, trialdata_path)

Compute trial level data for subject: 100
Compute trial level data for subject: 106
Compute trial level data for subject: 75
Compute trial level data for subject: 87


todo: trial-level data validation, fixing saving file - not to write all files only in the end

## extend trial level data

In [16]:
extendTrialLevelData(trialdata_path, ext_trialdata_path, vars_dict)

Extend trial level data with additional fields for subject: 100
Extend trial level data with additional fields for subject: 106
Extend trial level data with additional fields for subject: 75
Extend trial level data with additional fields for subject: 87


# Chapter 2: creating trial-level datasets to use

## cleanup

In [18]:
## which cols do you want to keep?

list_cols = ['subject_number', 'PCode', 'epoch', 'block', 'trial','triplet_type_hl', 'stimulus', 'RT',
       'last_AOI_before_stimulus', 'repetition', 'trill', 'high_low_learning',
       'has_anticipation', 'has_learnt_anticipation_OS', 'has_learnt_anticipation_IF']



In [20]:
raw_dat_clean = clean_and_organize(list_cols, vars_dict, subs_with_bad_data, ext_trialdata_path)


In [22]:
raw_dat_clean.to_csv(output_path + "/trialdat_raw.csv")

## create an analyzable file

In [72]:
trialdat_HL_LMM= get_trialdat_ready_for_HL_LMM(raw_dat_clean, vars_dict)

In [171]:
trialdat_HL_LMM.to_csv(output_path + "/trialdat_HL_LMM.csv")

## use this latter one to calculate wide HLdiff data

In [200]:
#step1: create wide data
wide_RT = simple_pivot(trialdat_HL_LMM, vars_dict["dep"], vars_dict["ID"], [vars_dict["epoch"], vars_dict["HL_OS_var"]], aggfunct = "median")


In [188]:
wide_RT.head()

Unnamed: 0,subject_number,1high,1low,2high,2low,3high,3low,4high,4low,5high,5low,6high,6low,7high,7low,8high,8low
0,75,275.0405,274.983,274.916,275.093,275.043,275.891,283.2725,274.997,266.704,274.93,274.895,266.774,275.065,274.987,274.954,283.287
1,87,325.068,308.3715,325.045,308.442,333.346,308.402,341.682,324.893,333.376,308.3985,308.477,308.391,308.499,316.733,324.988,316.884
2,100,299.958,291.654,299.979,308.3835,304.1035,291.677,283.3075,291.687,291.85,291.683,283.43,291.606,291.724,291.6875,283.356,283.3255
3,106,308.428,324.9765,316.631,316.887,308.35,324.9565,308.343,324.9065,312.3885,325.033,324.933,324.97,315.991,333.301,316.718,325.057


In [192]:
#step2: calc HLdiff
wide_RT_HLdiff = calcHighLowDiff(wide_RT, vars_dict, is_RT = True)

In [206]:
#LD anticipations

#step1_ calc LDAEM
wide_LDAEM = simple_pivot(trialdat_HL_LMM, "has_learnt_anticipation_OS_int", vars_dict["ID"], vars_dict["epoch"], aggfunct = "mean")
#step2: calc AEM
wide_AEM = simple_pivot(trialdat_HL_LMM, "has_anticipation_int", vars_dict["ID"], vars_dict["epoch"], aggfunct = "mean")
#step3: calc their ratio
wide_LDAR = wide_LDAEM/wide_AEM

In [214]:
wide_LDAR.to_csv(output_path + "/wide_LDAR.csv")
wide_RT_HLdiff.to_csv(output_path + "/wide_RT_HLdiff.csv")


## creating HLLH interference measure col

In [142]:
dat_IF_epochs = process_IF_epochs(raw_dat_clean, vars_dict)

## compute response types

In [81]:
dat_resp_all = response_data_calculator(raw_dat_clean, vars_dict)


In [82]:
dat_resp_all = drop_blockstarting_randoms(dat_resp_all, vars_dict)

In [152]:
dat_resp_all.to_csv(output_path + "/dat_resp_all.csv")

## likelihood calculations

In [86]:
resp_likelihood_wide = likelihood_generator(dat_resp_all, vars_dict, "response_type")

In [154]:
resp_likelihood_wide.to_csv(output_path + "/resp_likelihood_wide.csv")

## update calculation

### update column (trial-based)

In [92]:
dat_update_all = calculate_update_vars(dat_resp_all, vars_dict)

current_subject  updated to 100
current_subject 100 updated to 106
current_subject 106 updated to 75
current_subject 75 updated to 87


In [158]:
dat_update_all_firstdropped = dat_update_all.loc[dat_update_all["update_type"] != "first"].copy()

In [160]:
dat_update_all_firstdropped.to_csv("dat_update_resptype_LMM.csv")

### update type likelihood calculation (wide)

In [166]:
update_likelihood_wide_dropped_first = likelihood_generator(dat_update_all_firstdropped, vars_dict, "update_type", probability_dict = {"NLD_SAME": 0.1875, "LD_SAME": 0.0625, "EXPL": 0.375, "MOD_CORR": 0.1875, "MOD_DISR": 0.1875})

In [168]:
update_likelihood_wide_dropped_first.to_csv(output_path + "/update_likelihood_wide.csv")