In [None]:
import pandas as pd
import blinklab_python_sdk.functions as sdk
import numpy as np
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')

In [None]:
def convert_string_to_array(string):
    return np.fromstring(string.strip("[]"), sep=',')


def build_x_trace(length: int):
    MS_PER_FRAME = 1 / 60 * 1000  # milliseconds per frame
    return np.arange(0, length * MS_PER_FRAME, MS_PER_FRAME)


def plot_median(data, y_column, split_by, y_min=-0.1, y_max=1):
    plt.figure(figsize=(20, 4))
    plt.ylim(y_min, y_max)
    plt.title(y_column)
    plot_df = data.copy()
    plot_df = plot_df.dropna(subset=[y_column])
    plot_df[y_column] = plot_df[y_column].apply(convert_string_to_array)

    max_length = max(plot_df[y_column].apply(len))
    x_trace = build_x_trace(max_length)

    for unique_value, unique_df in plot_df.groupby(split_by):
        if not unique_df.empty:
            padded_arrays = []
            for array in unique_df[y_column]:
                padding = max_length - len(array)
                padded_array = np.pad(array, (0, padding), 'constant', constant_values=np.NaN)
                padded_arrays.append(padded_array)

            label = unique_df['label'].iloc[0]
            stacked = np.stack(padded_arrays)
            median_values = np.nanmedian(stacked, axis=0)
            mean_values = np.nanmean(stacked, axis=0)

            for eye_trace in padded_arrays:
                plt.plot(x_trace, eye_trace, color='lightgray', alpha=0.6)

            plt.plot(x_trace, median_values, label=f'Median {label}', linewidth=2)
            # plt.plot(x_trace, mean_values, label=f'Mean {label}', linewidth=2)

    plt.legend()
    plt.show()

In [None]:
target = "has_adhd"

metadata = pd.read_csv("data/group/metadata.csv")[["experiment_result_id", "diagnosis_label", "experiment_id"]]

traces = pd.read_csv("data/group/overview_traces.csv")
traces['label'] = traces['proto_trial_content'].apply(sdk.make_label)

# drop entries where diagnosis_label string does not contain 'ADHD' or 'Neurotypical'
metadata = metadata[metadata["diagnosis_label"].str.contains("ADHD|Neurotypical")]

print("Unique result ids ", metadata["experiment_result_id"].nunique())
print("Unique diagnosis labels ", metadata["diagnosis_label"].nunique())

metadata = metadata.drop_duplicates(subset="experiment_result_id")
metadata["has_adhd"] = metadata["diagnosis_label"].apply(
    lambda x: 1 if "ADHD" in x else (0 if 'Neurotypical' in x else None))

metadata = metadata.drop(columns=["diagnosis_label"])
data = metadata.merge(traces, on="experiment_result_id")

print(data[target].value_counts())

In [None]:
data.head()

In [None]:
plot_median(data, 'baselineMedian', 'proto_trial_hash', y_min=-0.1, y_max=0.5)