Load IBL encoding results and plots / visualization

In [26]:
import numpy as np
import os
import matplotlib.pyplot as plt
import glob

RESULTS_DIR = "../logs/encoding"

ENCODING_TEST_EIDS_PATH = "../data/encoding_test_eids.txt"

with open(ENCODING_TEST_EIDS_PATH, "r") as f:
    EIDS = f.read().splitlines()

# get all results_dict.npy files in RESULTS_DIR
RESULTS_FILES = glob.glob(os.path.join(RESULTS_DIR, "**", "results_dict.npy"), recursive=True)
print(f"Found {len(RESULTS_FILES)} results files.")

def load_results(file_path):
    return np.load(file_path, allow_pickle=True).item()

Found 10 results files.


In [27]:
model='model-ijepa/'
avail_views = 'views-left/'
dataset = 'ds-ibl-mouse-separate'
if 'sv' in model:
    resume_model = 'sv'
elif 'mv' in model:
    resume_model = 'mv'
else:
    resume_model = 'none'
    dataset = 'None'
# filter results files for the given model
results_files = [f for f in RESULTS_FILES if model in f]
print(f"Found {len(results_files)} results files for model {model}.")
rrr_bps_list = []
tcn_bps_list = []
for eid in EIDS:
    try:
        print(f"Processing eid: {eid}")
        result_file = [f for f in results_files if eid in f]
        result_file = [f for f in result_file if avail_views in f]
        result_file = [f for f in result_file if dataset in f]
        assert len(result_file) == 1, f"Expected one result file for eid {eid}, found {len(result_file)}"
        result_file = result_file[0]
        results = load_results(result_file)
        rrr_bps = results['rrr']['bps']
        tcn_bps = results['tcn']['bps']
        rrr_bps_list.append(rrr_bps)
        tcn_bps_list.append(tcn_bps)
    except AssertionError as e:
        print(e)
        continue

rrr_bps_list, tcn_bps_list = np.array(rrr_bps_list), np.array(tcn_bps_list)
print(f"TCN BPS: {tcn_bps_list.mean():.4f} ± {tcn_bps_list.std():.4f}")
print(f"RRR BPS: {rrr_bps_list.mean():.4f} ± {rrr_bps_list.std():.4f}")

Found 10 results files for model model-ijepa/.
Processing eid: d23a44ef-1402-4ed7-97f5-47e9a7a504d9
Processing eid: db4df448-e449-4a6f-a0e7-288711e7a75a
Processing eid: 3638d102-e8b6-4230-8742-e548cd87a949
Processing eid: 4b7fbad4-f6de-43b4-9b15-c7c7ef44db4b
Processing eid: 03d9a098-07bf-4765-88b7-85f8d8f620cc
Processing eid: 0841d188-8ef2-4f20-9828-76a94d5343a4
Processing eid: 9b528ad0-4599-4a55-9148-96cc1d93fb24
Processing eid: f140a2ec-fd49-4814-994a-fe3476f14e66
Processing eid: 687017d4-c9fc-458f-a7d5-0979fe1a7470
Processing eid: d04feec7-d0b7-4f35-af89-0232dd975bf0
TCN BPS: 0.2246 ± 0.1011
RRR BPS: 0.0983 ± 0.0618


In [28]:
print("TCN BPS per session:")
for tcn_bps in tcn_bps_list:
    print(f"{tcn_bps:.4f}")
print("-"*40)
print("RRR BPS per session:")
for rrr_bps in rrr_bps_list:
    print(f"{rrr_bps:.4f}")

TCN BPS per session:
0.4860
0.2164
0.2572
0.2423
0.2364
0.1573
0.1159
0.2572
0.1331
0.1440
----------------------------------------
RRR BPS per session:
0.2119
0.1012
0.1438
0.0830
0.0571
0.0722
0.0490
0.1969
0.0248
0.0428
