In [None]:
from utils import *

# Run training and validation

Run the bash script: `micromamba run -n IV25 bash run_experiments.sh`. 

This will run all of the model training/evaluation for the paper in sequence on a single GPU and saves results to `~/wandb`.

The experiments use 5-fold cross-validation from 5 different random seeds, so will take a while to complete.

In [None]:
# set wandb results path
wandb_path = "~/wandb/"
# all experiments repeated with 5 random seeds
random_seed_list = [str(i) for i in range(1, 6)]

# Analyze results

### CNN vs Foundation Model

In [None]:
d = {}
for seed in random_seed_list:
    d[seed] = {}
    results = aggregate_results(f"lite{seed}-all-split", wandb_path)
    d[seed]['moment_all'] = get_LR_max_avg_scores(results)
    results = aggregate_results(f"lite{seed}-all-cnn-split", wandb_path)
    d[seed]['cnn_all'] = get_LR_max_avg_scores(results)

def aggregate_seed_runs(d):
    agg = {}
    for seed in d:
        for exp in d[seed]:
            if exp not in agg.keys():
                agg[exp] = {}
            for metric in d[seed][exp]:
                if metric not in agg[exp].keys():
                    agg[exp][metric] = {}
                for test in d[seed][exp][metric]:
                    if test not in agg[exp][metric].keys():
                        agg[exp][metric][test] = {"results": []}
                    agg[exp][metric][test]["results"].append(d[seed][exp][metric][test])
    for exp in agg:
        for metric in agg[exp]:
            for test in agg[exp][metric]:
                agg[exp][metric][test]["mean"] = np.mean(agg[exp][metric][test]["results"])
                agg[exp][metric][test]["std"] = np.std(agg[exp][metric][test]["results"])
    return agg

agg = aggregate_seed_runs(d)

for metric in ["F1", "BAcc"]:
    print(f"** Mean and std {metric}")
    moment_scores = np.array([agg['moment_all'][metric][ii]['results'] for ii in agg['moment_all'][metric].keys()])
    print(f"MOMENT: {moment_scores.flatten().mean():0.2f} mean, {moment_scores.flatten().std():0.2f} std")
    cnn_scores = np.array([agg['cnn_all'][metric][ii]['results'] for ii in agg['cnn_all'][metric].keys()])
    print(f"CNN: {cnn_scores.flatten().mean():0.2f} mean, {cnn_scores.flatten().std():0.2f} std")


### Figure 7(a+b) - F1 and BAcc matrix

In [None]:
d = {}
for seed in random_seed_list:
    d[seed] = {}
    for exp in [f"lite{seed}-gaze_tracking", f"lite{seed}-fixed_gaze", f"lite{seed}-silent_reading", f"lite{seed}-choice_reaction"]:
        results = aggregate_results(exp, wandb_path)
        d[seed][exp.split('-')[1]] = get_LR_max_avg_scores(results)
    results = aggregate_results(f"lite{seed}-all-split", wandb_path)
    d[seed]['all'] = get_LR_max_avg_scores(results)

agg = aggregate_seed_runs(d)
            
# Extract keys
data_keys = list(agg.keys())
metric_keys = ["F1", "BAcc"]
split_keys = ["gaze_tracking", "fixed_gaze", "silent_reading", "choice_reaction"]

set_plot_style(font_size=16, grid_line_width="0.0")

# Set up plots for each metric
for metric in metric_keys:
    # Prepare data for the heatmap
    mean_data = np.array([[agg[data][metric][split]["mean"] for split in split_keys] for data in data_keys])
    std_data = np.array([[agg[data][metric][split]["std"] for split in split_keys] for data in data_keys])
    
    # Plot matrix
    plt.figure(figsize=(6, 4))
    plt.imshow(mean_data, cmap='viridis', aspect='auto')
    
    # Set x and y ticks
    plt.xticks(ticks=np.arange(len(split_keys)), labels=[LABEL_TRANSLATE[s] for s in split_keys])
    plt.yticks(ticks=np.arange(len(data_keys)), labels=[LABEL_TRANSLATE[s] for s in data_keys])

    # Add labels and title
    plt.xlabel("Test")
    plt.ylabel("Training data")
    
    # Annotate each cell with the value
    for i in range(len(data_keys)):
        for j in range(len(split_keys)):
            plt.text(j, i, f"{mean_data[i, j]:.2f} ± {std_data[i,j]:.2f}", fontsize=12, ha='center', va='center', color="white" if mean_data[i, j] < 0.5 else "black")
        
    # Display the plot
    plt.savefig(f"../docs/figures/train_vs_test_{metric}_LRmax.pdf", bbox_inches='tight')
    print(f"{metric}")
    plt.show()

### Figure 7c per subject analysis
For all subjects, for one random seed, plot the F1 score on y axis scatter

In [None]:
def get_avg_scores_per_pid(results):
    # for each experiment
    scores = {}
    scores['F1'] = {}
    scores['Acc'] = {}
    scores['BAcc'] = {}
    for exp in results.keys():
        scores['F1'][exp] = {}
        scores['Acc'][exp] = {}
        scores['BAcc'][exp] = {}
        for pid in results[exp]:
            if pid.startswith("P"):
                labels = np.array(results[exp][pid]['labels'])
                predicted = np.array(results[exp][pid]['predicted'])

                assert(len(labels) % 2 == 0)
        
                # Reshape to (*, 2) and take max (across L & R eyes)
                labels_max = np.max(labels.reshape(-1, 2), axis=1)
                labels = labels_max.reshape(-1, 1)
        
                # Reshape to (*, 2) and take max (across L & R eyes)
                predicted_max = np.max(predicted.reshape(-1, 2), axis=1)
                predicted = predicted_max.reshape(-1, 1)
                        
                if sum(labels) == 0 and sum(predicted) == 0:
                    tn = len(labels)
                    tp, fn, fp = 0, 0, 0
                elif sum(labels) == len(labels) and sum(predicted) == len(predicted):
                    tp = len(labels)
                    tn, fn, fp = 0, 0, 0
                else:
                    try:
                        tn, fp, fn, tp = confusion_matrix(labels, predicted).ravel()
                    except:
                        print(labels)
                        print(predicted)
                # print(tn, fp, fn, tp)
                scores['F1'][exp][pid] = (2 * tp) / (2 * tp + fp + fn)
                scores['Acc'][exp][pid] = (tp + tn) / (tp + tn + fp + fn)
                if tp + fn == 0:
                    fn = 1
                if tn + fp == 0:
                    fp = 1            
                scores['BAcc'][exp][pid] = 1/2 * tp / (tp + fn) + 1/2 * tn / (tn + fp)            
    return scores    

In [None]:
d = {}
for exp in ["lite1-gaze_tracking", "lite1-fixed_gaze", "lite1-silent_reading", "lite1-choice_reaction"]:
    results = aggregate_results(exp, wandb_path)
    d[exp.split('-')[1]] = get_avg_scores_per_pid(results)
results = aggregate_results("lite1-all-split", wandb_path)
d['all'] = get_avg_scores_per_pid(results)

In [None]:
# Extract keys
data_keys = list(d.keys())
metric_keys = ["F1", "BAcc"]
split_keys = ["gaze_tracking", "fixed_gaze", "silent_reading", "choice_reaction"]
pids = [pid for pid in results[data_keys[0]] if pid.startswith("P")]
mean_metrics = {}

# Set up plots for each metric
for metric in metric_keys:
    results_by_pid = {}
    for pid in pids:
        results_by_pid[pid] = []
        for xx in data_keys:
            for yy in split_keys:
                results_by_pid[pid].append(d[xx][metric][yy][pid])

    mean_metrics[metric] = [np.mean(results_by_pid[pid]) for pid in pids]

In [None]:
plt.figure(figsize=(6, 4))
plt.bar(range(len(mean_metrics["F1"])), sorted(mean_metrics["F1"], reverse=True), zorder=2)
plt.xlabel("Sorted Participants")
plt.ylabel("F1 Score")
plt.ylim((0.5,0.9))
set_plot_style(font_size=14, grid_line_width="2.0")
plt.grid(True, zorder=1)  
plt.savefig(f"../docs/figures/F1_per_participant.pdf", bbox_inches='tight')
plt.show()

### Fig 8a - changing sampling rates

In [None]:

d = {}
for seed in random_seed_list:
    d[seed] = {}
    for rate, exp in zip(["60","30","20","10"],[f"lite{seed}-all-split", f"lite{seed}-all-resample30-split", f"lite{seed}-all-resample20-split", f"lite{seed}-all-resample10-split"]):
        results = aggregate_results(exp, wandb_path)
        d[seed][rate] = get_LR_max_avg_scores(results)
agg = aggregate_seed_runs(d)

F1 = {}
BAcc = {}
for s in ["60","30","20","10"]:
    for test in ["choice_reaction"]:
        if test not in F1:
            F1[test] = {"mean": [], "std": []}
        if test not in BAcc:
            BAcc[test] = {"mean": [], "std": []}
        F1[test]["mean"].append(agg[s]["F1"][test]["mean"])
        F1[test]["std"].append(agg[s]["F1"][test]["std"])
        BAcc[test]["mean"].append(agg[s]["BAcc"][test]["mean"])
        BAcc[test]["std"].append(agg[s]["BAcc"][test]["std"])   

plt.figure(figsize=(6, 4))

x = ["60", "30", "20", "10"]
x_int = list(range(len(x)))

def plot_with_bounds(x, values, stds, label, color=None):
    plt.plot(x, values, '.-', label=label, color=color)
    plt.fill_between(
        x,
        [v - s for v, s in zip(values, stds)],
        [v + s for v, s in zip(values, stds)],
        alpha=0.2,
        color=color,
        zorder=2,
    )

from itertools import cycle
colors = cycle(plt.rcParams['axes.prop_cycle'].by_key()['color'])

plot_with_bounds(x, F1["choice_reaction"]["mean"], F1["choice_reaction"]["std"], "CR Test F1", next(colors))
plot_with_bounds(x, BAcc["choice_reaction"]["mean"], BAcc["choice_reaction"]["std"], "CR Test BAcc", next(colors))

set_plot_style(font_size=14, grid_line_width="2.0")
plt.grid(True, zorder=1)  
plt.xticks(x_int, x)
plt.xlabel("Sampling rate (Hz)")
plt.ylabel("Test Performance")
plt.ylim((0.4, 0.7))
plt.yticks([0.4,0.5,0.6,0.7])
plt.legend(ncol=2, loc="lower center")
plt.savefig(f"../docs/figures/changing_S.pdf", bbox_inches='tight')
plt.show()


### Figure 8b - effect of shorter W

In [None]:

d = {}
windows = [8.53,6.4,4.27,2.13]
for seed in random_seed_list:
    d[seed] = {}
    for rate, exp in zip(windows,[f"lite{seed}-all-split", f"lite{seed}-all-window384-split", f"lite{seed}-all-window256-split", f"lite{seed}-all-window128-split"]):
        results = aggregate_results(exp, wandb_path)
        d[seed][rate] = get_LR_max_avg_scores(results)
agg = aggregate_seed_runs(d)

F1 = {}
BAcc = {}
for s in windows:
    for test in ["choice_reaction"]:
        if test not in F1:
            F1[test] = {"mean": [], "std": []}
        if test not in BAcc:
            BAcc[test] = {"mean": [], "std": []}
        F1[test]["mean"].append(agg[s]["F1"][test]["mean"])
        F1[test]["std"].append(agg[s]["F1"][test]["std"])
        BAcc[test]["mean"].append(agg[s]["BAcc"][test]["mean"])
        BAcc[test]["std"].append(agg[s]["BAcc"][test]["std"])   

plt.figure(figsize=(6, 4))

colors = cycle(plt.rcParams['axes.prop_cycle'].by_key()['color'])

plot_with_bounds(windows, F1["choice_reaction"]["mean"], F1["choice_reaction"]["std"], "CR Test F1", next(colors))
plot_with_bounds(windows, BAcc["choice_reaction"]["mean"], BAcc["choice_reaction"]["std"], "CR Test BAcc", next(colors))
plt.grid(True, zorder=1)  
plt.xlabel("Input Window (s)")
plt.ylabel("Test Performance")
plt.ylim((0.4, 0.7))
plt.yticks([0.4,0.5,0.6,0.7])
plt.legend(ncol=2, loc="lower center")
plt.savefig(f"../docs/figures/changing_W.pdf", bbox_inches='tight')
plt.show()

### Figure 8c - effect of changing D

In [None]:
d = {}
categories = ["event","event+gaze","event+gaze+pupil"]
for seed in random_seed_list:
    d[seed] = {}
    for rate, exp in zip(categories,[f"lite{seed}-all-event-only-split", f"lite{seed}-all-split", f"lite{seed}-all-pd-split"]):
        results = aggregate_results(exp, wandb_path)
        d[seed][rate] = get_LR_max_avg_scores(results)
agg = aggregate_seed_runs(d)

F1 = {}
BAcc = {}
for s in categories:
    for test in ["choice_reaction"]:
        if test not in F1:
            F1[test] = {"mean": [], "std": []}
        if test not in BAcc:
            BAcc[test] = {"mean": [], "std": []}
        F1[test]["mean"].append(agg[s]["F1"][test]["mean"])
        F1[test]["std"].append(agg[s]["F1"][test]["std"])
        BAcc[test]["mean"].append(agg[s]["BAcc"][test]["mean"])
        BAcc[test]["std"].append(agg[s]["BAcc"][test]["std"])   

plt.figure(figsize=(6, 4))

colors = cycle(plt.rcParams['axes.prop_cycle'].by_key()['color'])

plot_with_bounds(categories, F1["choice_reaction"]["mean"], F1["choice_reaction"]["std"], "CR Test F1", next(colors))
plot_with_bounds(categories, BAcc["choice_reaction"]["mean"], BAcc["choice_reaction"]["std"], "CR Test BAcc", next(colors))
plt.grid(True, zorder=1)  
plt.xlabel("Input dimensions")
plt.ylabel("Test Performance")
plt.ylim((0.4, 0.7))
plt.yticks([0.4,0.5,0.6,0.7])
plt.legend(ncol=2, loc="lower center")
plt.savefig(f"../docs/figures/changing_D.pdf", bbox_inches='tight')
plt.show()