In [4]:
%load_ext autoreload
%autoreload 2

import os
import sys
import time
from pprint import pprint
from pathlib import Path
from random import randint

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from ipywidgets import interact
from tqdm.notebook import tqdm
import nibabel as nib
import glmsingle
from glmsingle.glmsingle import GLM_single
import bids
from bids import BIDSLayout
from scipy.ndimage import zoom, binary_dilation
from scipy.io import loadmat
import h5py
import nibabel as nib
import re

dir2 = os.path.abspath('..')
dir1 = os.path.dirname(dir2)
if not dir1 in sys.path: 
    sys.path.append(dir1)
    
from noise_ceiling import (
    compute_ncsnr,
    compute_nc,
    group_repetitions
)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [5]:
dataset_root = Path('D:\Datasets')

In [6]:
tc2see_version = 3 # [1, 2]
dataset_path = dataset_root / f"TC2See_v{tc2see_version}"
derivatives_path = dataset_path / "derivatives_TC2See_prdgm"
num_runs = 6

task = "bird"
subject = '03' # ['03', '04']
glm_run = 'level1_native_singletrial'
singletrial = True

# Initialize BIDSLayouts for querying files.
dataset_layout = BIDSLayout(dataset_path / 'TC2See_prdgm')
#betas_path = derivatives_path / 'spm' / f"sub-{subject}" / glm_run
betas_path = derivatives_path / 'spm' / glm_run

In [67]:
write_file = True

spm = loadmat(betas_path / 'SPM.mat')
descrip = spm['SPM'][0][0]['Vbeta']['descrip'][0]
stimulus_ids = []
for i in range(len(descrip)):
    if singletrial:
        name = re.search ('\d+_([a-z_]+)_\d+', descrip[i][0])
    else:
        name = re.search('\) (\w+)\*', descrip[i][0])

    if name != None:
        stimulus_ids.append(name.groups(1)[0])

# write file
if write_file:
    with open(betas_path / 'image_labels.txt', 'w') as f:
        for name in stimulus_ids:
            f.write(name + '\n')
            
# load mask
mask = nib.load(betas_path / 'mask.nii').get_fdata().astype(bool)
print(mask.shape)

# load betas images
betas = []
num_betas = 432 if singletrial else 72

for fnum in range(1, num_betas+1):
    img = nib.load(betas_path / f"beta_{fnum:04}.nii").get_fdata()
    img = img[mask]
    betas.append(img)

betas = np.array(betas)
print(betas.shape)

betas = (betas - betas.mean(axis=0, keepdims=True)) /  betas.std(axis=0, keepdims=True)
stimulus_ids = np.array(stimulus_ids)

model_name = 'ViT-B=32'
embedding_name = 'embedding'

with h5py.File(derivatives_path / f'{model_name}-features.hdf5', 'r') as f:
    stimulus = f[embedding_name][:]
    
stimulus_images = h5py.File(derivatives_path / 'stimulus-images.hdf5', 'r')
stimulus_keys = {k[:-4]: i for i, k in enumerate(stimulus_images.keys())}
stimulus_ids = np.array([stimulus_keys[i] for i in stimulus_ids])
Y = stimulus[stimulus_ids]

run_mask = np.array([1,1,0,1,1,0], dtype=bool).repeat(72)

X = betas[run_mask]
Y = Y[run_mask]
stimulus_ids = stimulus_ids[run_mask]

(102, 102, 64)
(432, 125506)


In [68]:
Y.shape

(288, 512)

In [69]:
from sklearn.model_selection import KFold
from fracridge import FracRidgeRegressorCV
from metrics import (
    cosine_distance, squared_euclidean_distance, r2_score, two_versus_two,
    two_versus_two_slow
)
import warnings
import torch

warnings.filterwarnings('ignore')
np.seterr(all="ignore")

k = 500

folds = KFold(n_splits=5, shuffle=True, random_state=0)

Y_pred = np.zeros_like(Y)
for train_ids, val_ids in folds.split(X):
    X_train, Y_train = X[train_ids], Y[train_ids]
    X_val, Y_val = X[val_ids], Y[val_ids]
    
    grouped_repetitions = []
    for i in range(2, 13):
        x = group_repetitions(stimulus_ids[train_ids], num_repetitions=i)
        if x is not None:
            grouped_repetitions.append(x)
    
    ncsnr = compute_ncsnr(X_train, grouped_repetitions)
    nc = compute_nc(ncsnr, num_averages=1)
    top_ids = np.argsort(nc)[::-1][:k]

    model = FracRidgeRegressorCV()
    model.fit(X_train[:, top_ids], Y_train)
    Y_pred[val_ids] = model.predict(X_val[:, top_ids])

distances = cosine_distance(torch.from_numpy(Y[None]).float(), torch.from_numpy(Y_pred[:, None]).float())
accuracy = round(two_versus_two(distances, stimulus_ids=stimulus_ids).item() * 100, 2) 
accuracy2 = round(two_versus_two_slow(distances, stimulus_ids=stimulus_ids) * 100, 2)

print(f'{accuracy=}')


accuracy=46.54
