In [1]:
%load_ext autoreload
%autoreload 2

import os
import sys
import time
from pprint import pprint
from pathlib import Path
from random import randint

import numpy as np
import torch
import pandas as pd
import matplotlib.pyplot as plt
from ipywidgets import interact
from tqdm.notebook import tqdm
import nibabel as nib
import glmsingle
from glmsingle.glmsingle import GLM_single
import bids
from bids import BIDSLayout
from scipy.ndimage import zoom, binary_dilation
import h5py
import nibabel as nib
from einops import rearrange

dir2 = os.path.abspath('..')
dir1 = os.path.dirname(dir2)
if not dir1 in sys.path: 
    sys.path.append(dir1)
    
from noise_ceiling import group_repetitions, compute_nc, compute_ncsnr
from sklearn.model_selection import KFold
from fracridge import FracRidgeRegressorCV
from metrics import (
    cosine_distance, squared_euclidean_distance, r2_score, two_versus_two,
    two_versus_two_slow
)
import warnings
from kamitani import load_data, convert_ids

In [2]:
# Enter the path to the Kamitani dataset
derivatives_path = Path('X:\\Datasets\\Deep-Image-Reconstruction\\derivatives\\')
derivatives_path_ssd = Path('D:\\Datasets\\Deep-Image-Reconstruction\\derivatives\\')
dataset_path = derivatives_path / 'fmriprep-20.2.4'

In [None]:
bold, stimulus_ids, mask, affine = load_data(
    derivatives_path_ssd / 'kamitani-bold.hdf5', 
    'sub-02', 
    tr_offset=4,
    run_normalize='linear_trend',
    session_normalize=False
)
stimulus_images = h5py.File(derivatives_path / 'stimulus_images.hdf5', "r")
stimulus_ids = np.array(convert_ids(stimulus_ids, list(stimulus_images.keys())))

In [None]:
group_repetitions(stimulus_ids, 5).shape

In [None]:
X = bold.copy()
X_nan = np.isnan(X)
X[X_nan] = 0.
X_nan.sum()

In [None]:
select_group_ids[train_image_ids]

In [None]:


warnings.filterwarnings('ignore')
np.seterr(all="ignore")

search_space = [(3, 50), (6, 25), (10, 15), (15, 10)]

num_images = 50
num_repetitions = 24
group_ids = group_repetitions(stimulus_ids, num_repetitions)

def shuffle_along_axis(a, axis):
    idx = np.random.rand(*a.shape).argsort(axis=axis)
    return np.take_along_axis(a,idx,axis=axis)

for select_num_repetitions, select_num_images in search_space:
    
    select_image_ids = np.arange(num_images)
    np.random.shuffle(select_image_ids)
    select_image_ids = select_image_ids[:select_num_images]
    
    select_group_ids = shuffle_along_axis(group_ids, 1)
    select_group_ids = select_group_ids[select_image_ids, :select_num_repetitions]
    
    folds = KFold(n_splits=5, shuffle=True, random_state=0)
    
    Y_true = []
    Y_pred = []
    Y_stimulus_ids = []
    for train_image_ids, val_image_ids in folds.split(np.arange(select_num_images)):
        train_ids = select_group_ids[train_image_ids].flatten()
        val_ids = select_group_ids[val_image_ids].flatten()
        
        X_train, Y_train = X[train_ids], Y[train_ids]
        X_val, Y_val = X[val_ids], Y[val_ids]
        
        ncsnr = compute_ncsnr(X_train, np.arange(train_ids.shape[0]).reshape(train_image_ids.shape[0], -1))
        nc = compute_nc(ncsnr, num_averages=1)
        
        threshold = 30
        X_train = X_train[:, nc > threshold]
        X_val = X_val[:, nc > threshold]
        
        model = FracRidgeRegressorCV()
        model.fit(X_train, Y_train)
        Y_pred.append(model.predict(X_val))
        Y_true.append(Y_val)
        Y_stimulus_ids.append(stimulus_ids[val_ids])
    Y_pred = np.concatenate(Y_pred)
    Y_true = np.concatenate(Y_true)
    Y_stimulus_ids = np.concatenate(Y_stimulus_ids)
    
    distances = cosine_distance(torch.from_numpy(Y_true[None]).float(), torch.from_numpy(Y_pred[:, None]).float())
    accuracy = round(two_versus_two(distances, stimulus_ids=Y_stimulus_ids).item() * 100, 2) 

    print(f'{select_num_repetitions=}, {select_num_images=}, {accuracy=}')
    

In [9]:
warnings.filterwarnings('ignore')
np.seterr(all="ignore")

search_space = [
    #(2, 1200), (3, 800), (4, 600), (5, 480),
    #(2, 600), (3, 400), (4, 300), (5, 240),
    #(2, 300), (3, 200), (4, 150), (5, 120),
    #(2, 150), (3, 100), (4, 75), (5, 60),
    (2, 75), (3, 50), (4, 38), (5, 30),
    (2, 30), (3, 10), (4, 15), (5, 12),
]

num_train_images = 1200
num_train_repetitions = 5

tr_offset = 4
subjects = ['sub-02', 'sub-03']

top_k_voxels_space = [2500]

model_name = 'ViT-B=32'
embedding_name = 'embedding'
with h5py.File(derivatives_path / f'{model_name}-features.hdf5', 'r') as f:
    stimulus = f[embedding_name][:]
stimulus_images = h5py.File(derivatives_path / 'stimulus_images.hdf5', "r")


def shuffle_along_axis(a, axis):
    idx = np.random.rand(*a.shape).argsort(axis=axis)
    return np.take_along_axis(a,idx,axis=axis)

for subject in subjects:
    print(subject)
    train_bold, train_stimulus_ids, _, _ = load_data(
        derivatives_path_ssd / 'kamitani-bold.hdf5', 
        subject, 
        tr_offset=tr_offset,
        run_normalize='linear_trend',
        session_normalize=False
    )
    train_stimulus_ids = np.array(convert_ids(train_stimulus_ids, list(stimulus_images.keys())))
    train_bold_nan = np.isnan(train_bold)
    train_bold[train_bold_nan] = 0.
    train_group_ids = group_repetitions(train_stimulus_ids, num_train_repetitions)
    
    bold, stimulus_ids, mask, affine = load_data(
        derivatives_path_ssd / 'kamitani-test-bold.hdf5', 
        subject, 
        tr_offset=tr_offset,
        run_normalize='linear_trend',
        session_normalize=False
    )
    stimulus_ids = np.array(convert_ids(stimulus_ids, list(stimulus_images.keys())))
    bold_nan = np.isnan(bold)
    bold[bold_nan] = 0.
    
    stimulus_images = h5py.File(derivatives_path / 'stimulus_images.hdf5', "r")
    
    Y_train = stimulus[train_stimulus_ids]
    Y = stimulus[stimulus_ids]
    
    for select_num_repetitions, select_num_images in search_space:
        select_image_ids = np.arange(num_train_images)
        np.random.shuffle(select_image_ids)
        select_image_ids = select_image_ids[:select_num_images]

        select_group_ids = shuffle_along_axis(train_group_ids, 1)
        select_group_ids = select_group_ids[select_image_ids, :select_num_repetitions]
        select_group_ids = select_group_ids.flatten()
        
        X_nc = train_bold[select_group_ids]
        
        ncsnr = compute_ncsnr(X_nc, group_repetitions(train_stimulus_ids[select_group_ids], select_num_repetitions))
        nc = compute_nc(ncsnr, num_averages=1)
        nc[np.isnan(nc)] = 0.
        nc_volume = np.zeros_like(mask, dtype=float)
        nc_volume[mask] = nc

        subject_path = derivatives_path_ssd / 'noise_ceiling' / subject
        subject_path.mkdir(exist_ok=True, parents=True)
        out_file_name = f'{subject}__repetitions-{select_num_repetitions}__images-{select_num_images}__noise-ceiling.nii.gz'
        
        image = nib.Nifti1Image(nc_volume, affine)
        nib.save(image, subject_path / out_file_name)
        
        for top_k_voxels in top_k_voxels_space:
            selected_voxels = np.argsort(nc)[::-1][:top_k_voxels]
            X_train = train_bold[:, selected_voxels]
            X = bold[:, selected_voxels]

            model = FracRidgeRegressorCV()
            model.fit(X_train, Y_train)

            Y_pred = model.predict(X)

            distances = cosine_distance(torch.from_numpy(Y[None]).float(), torch.from_numpy(Y_pred[:, None]).float())
            accuracy = round(two_versus_two(distances, stimulus_ids=stimulus_ids).item() * 100, 2) 

            print(f'num_voxels={top_k_voxels}, num_repetitions={select_num_repetitions}, num_images={select_num_images}, {accuracy=}')
        

sub-02
num_voxels=2500, num_repetitions=2, num_images=75, accuracy=71.66
num_voxels=2500, num_repetitions=3, num_images=50, accuracy=84.12
num_voxels=2500, num_repetitions=4, num_images=38, accuracy=78.81
num_voxels=2500, num_repetitions=5, num_images=30, accuracy=82.02
num_voxels=2500, num_repetitions=2, num_images=30, accuracy=71.84
num_voxels=2500, num_repetitions=3, num_images=10, accuracy=70.96
num_voxels=2500, num_repetitions=4, num_images=15, accuracy=64.82
num_voxels=2500, num_repetitions=5, num_images=12, accuracy=72.76
sub-03
num_voxels=2500, num_repetitions=2, num_images=75, accuracy=82.02
num_voxels=2500, num_repetitions=3, num_images=50, accuracy=85.5
num_voxels=2500, num_repetitions=4, num_images=38, accuracy=87.87
num_voxels=2500, num_repetitions=5, num_images=30, accuracy=82.52
num_voxels=2500, num_repetitions=2, num_images=30, accuracy=79.46
num_voxels=2500, num_repetitions=3, num_images=10, accuracy=77.54
num_voxels=2500, num_repetitions=4, num_images=15, accuracy=75.

In [None]:
np.isnan(nc).sum()