In [1]:
import h5py
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
import torchvision
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim
import copy
import matplotlib.pyplot as plt
import os
import pandas as pd
import nibabel as nib
from nilearn import input_data, plotting,datasets
import pickle
import sys
from models import get_models
from sklearn.linear_model import Ridge
from sklearn.metrics import r2_score
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import LeaveOneGroupOut,KFold
from sklearn.utils import shuffle




In [2]:
subject=6

In [3]:
## HYPERPARAMETER CELL 
model_subject = subject

load_epoch = 3000
model_num = 1
    
LOAD_PATH = '../models/'+'sub-'+str(model_subject)+'-model-'+str(model_num)+'/'+str(load_epoch)+'.torch'
load_model = True

frame_skip = 5
sequence_len = int(90/frame_skip)
batch_size = 1

h5_filename = '../data/bw-'+'sub-'+str(subject)+'-len-'+str(sequence_len)+'.h5'#data file to train
df_filepath_stage1 = '../temp_files/'+'sub-'+str(subject)+'-stage-1-df.pkl'
runshape_file_path = '../temp_files/'+'sub-'+str(subject)+'-stage-2-runshapes-'+str(sequence_len)+'.pkl'
fmri_file_path = '../temp_files/'+'sub-'+str(subject)+'-stage-3-parcel-confounds9-nohigh.pkl'

In [4]:
if subject==1:
    exclude_runs = []
if subject==2:
    exclude_runs = []   
if subject==4:
    exclude_runs = []
if subject==6:
    exclude_runs = [0]

In [5]:
class Concatdataset(torch.utils.data.Dataset):
    def __init__(self, *datasets):
        self.datasets = datasets

    def __getitem__(self, i):
        return tuple(d[i] for d in self.datasets)

    def __len__(self):
        return min(len(d) for d in self.datasets)
    
dataset = datasets.fetch_atlas_basc_multiscale_2015()
atlas_filename = dataset.scale444
masker = input_data.NiftiLabelsMasker(labels_img=atlas_filename,high_pass=0.01 ,standardize=True,t_r=1.49,smoothing_fwhm=8,memory='nilearn_cache')
masker.fit()

NiftiLabelsMasker(high_pass=0.01,
                  labels_img='/home/ani686/nilearn_data/basc_multiscale_2015/template_cambridge_basc_multiscale_nii_sym/template_cambridge_basc_multiscale_sym_scale444.nii.gz',
                  memory='nilearn_cache', smoothing_fwhm=8, standardize=True,
                  t_r=1.49)

In [6]:
f = h5py.File(h5_filename,'r')
data = f['state']
label = f['action']
sess = f['session']
run = f['run']
data_all = np.array(data)
label_all = np.array(label)
sess_all = np.array(sess)
run_all = np.array(run)
main_dataset = Concatdataset(data_all,label_all,sess_all,run_all)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print('frame data loaded')

model = get_models(model_num,sequence_len)
model = model.to(device)
if load_model:
    saved_state = torch.load(LOAD_PATH ,map_location=device)
    model.load_state_dict(saved_state)
    print('Loaded model',flush=True)

frame data loaded
Loaded model


In [7]:
def stage_5(model,main_dataset):
    if True:
    
        test_loader = DataLoader(dataset=main_dataset, shuffle=False, batch_size=1)

 
        allinput =[]
        sess_all = []
        run_all = []

        if True:
            test_accuracy=0.0
            model.eval()
            
            with torch.no_grad():
                for batchid,traindata in enumerate(test_loader):
                    data,labels,sess,run = traindata
                    
                    labels = labels.reshape(-1).to(device=device)
                   
                    data = data.float().to(device=device)
                    
                    inputdata = data.detach().cpu().numpy()
                                     
                    
                    inputdata = inputdata[0].reshape(sequence_len,-1)
                    allinput.append(inputdata)
                    sess_all.append(sess)
                    run_all.append(run)
             
            print("Accuracy: %.3f"%(test_accuracy))
            all_dict = {'inputdata':allinput,'session':sess_all,'run':run_all}
            return all_dict

In [8]:
def stage_6(dl_activations):
    if True:

        inputdata = np.array(dl_activations['inputdata']) 
        session = np.array(dl_activations['session'])
        run = np.array(dl_activations['run'])

        with open(runshape_file_path, 'rb') as f:
            test_shapes = pickle.load(f)
        dl_all = np.arange(inputdata.shape[0])
        
        #split into runs
        dl_final = []
        start_idx = 0
        stop_idx = 0
        for i in range(len(test_shapes)):
            start_idx = stop_idx
            stop_idx = start_idx+test_shapes[i]
            dl_final.append(dl_all[start_idx:stop_idx])
        
        tsum = 0
        for i in range(len(dl_final)):
            tsum = tsum + dl_final[i].shape[0]
        #shape match check
        if not(sum(test_shapes)==dl_all.shape[0]==tsum):
            print('not pass 1')
        
        with open(fmri_file_path, 'rb') as f:
            fmri_all = pickle.load(f)
        df = pd.read_pickle(df_filepath_stage1)
        
        fmri_all2 = []
        for i in range(len(df)):
            onset_index = int(df['onset'][i]/1.49)
            fmri_all2.append(fmri_all[i][onset_index:])
        for i in range(len(fmri_all2)):
            if(fmri_all2[i].shape[0]*1.49<df['duration'][i]):
                print('not pass 2',i)
                
        fmri_all3 = []
        for i in range(len(df)):
            stop_index = int(df['duration'][i]/1.49)
            fmri_all3.append(fmri_all2[i][:stop_index]) 
        for i in range(len(fmri_all3)):
            diff = fmri_all3[i].shape[0]*1.49 - df['duration'][i]
            if not(0>diff and diff>-1.49):
                print('Not pass 3',i,diff)
          
        for i in range(len(df)):
            if(fmri_all3[i].shape[0]-dl_final[i].shape[0]<0):
                print('not pass 4',i)
          
                
        #the dl is always small above
        for i in range(len(df)):
            fmri_all3[i]= fmri_all3[i][:dl_final[i].shape[0]]
        for i in range(len(df)):
            if not(fmri_all3[i].shape[0]==dl_final[i].shape[0]):
                print('not pass 5',i)

        lag_range = [4] # each lag = 1.5 secs
        # find out common len for all lags
        # largest lag will have shortest array size
        common_length = np.zeros(len(df))
        for i in range(len(df)):
            common_length[i] = fmri_all3[i].shape[0]-max(lag_range)

        for lag in lag_range:
            #below note dl_final deals with indices
            dl_final_all = []
            fmri_final_all = []
            for i in range(len(df)):
                
                dl_final_all.append(dl_final[i][:(fmri_all3[i].shape[0]-lag)])
                fmri_final_all.append(fmri_all3[i][lag:])    

            #for same size vectors for all considered lags
            for i in range(len(df)):
                dl_final_all[i] = dl_final_all[i][int(-common_length[i]):]
                fmri_final_all[i] = fmri_final_all[i][int(-common_length[i]):]
                
            for i in range(len(df)):
                if not(dl_final_all[i].shape[0]==fmri_final_all[i].shape[0]):
                    print('not pass 6',i)
                    
            for i in range(len(df)):
                scaler = StandardScaler()
                scaler.fit(fmri_final_all[i])
                fmri_final_all[i] = scaler.transform(fmri_final_all[i])
            
            
            
            new_dl_final = []
            new_fmri_final = []
            for i in range(len(df)):
                if i not in exclude_runs:
                    new_dl_final.append(dl_final_all[i])
                    new_fmri_final.append(fmri_final_all[i])
                
            
            
            dl = np.concatenate(new_dl_final[:],axis=0)
            fmri = np.concatenate(new_fmri_final[:],axis=0)
            
            tinput = inputdata[dl]
            tsession = session[dl]
            trun = run[dl]

        all_dict = {'inputdata':tinput,'fmri':fmri,'session':tsession,'run':trun}
        return all_dict
        

In [9]:
dl_activations = stage_5(model,main_dataset)

Accuracy: 0.000


In [10]:
dl_fmri = stage_6(dl_activations)

not pass 4 0
not pass 5 0


In [11]:
inputdata = dl_fmri['inputdata']
fmri = dl_fmri['fmri']


def transform_layer_activations2(layer,pca1):
    layer = layer.reshape((-1,layer.shape[-1]))
    pca = PCA(n_components=pca1)
    pca.fit(layer)
    print(pca.explained_variance_ratio_.sum())
    layer = pca.transform(layer)
    layer = layer.reshape((-1,sequence_len,layer.shape[-1]))
    layer = layer.reshape(layer.shape[0],-1)
    pca = PCA(n_components=800)
    pca.fit(layer)
    print(pca.explained_variance_ratio_.sum())
    layer = pca.transform(layer)
    return layer


inputdata = transform_layer_activations2(inputdata,800)


0.9664241945563473
0.90584487


In [12]:
dl_fin_temp = inputdata
dl_fin = dl_fin_temp
alpha = 0.2
logo = LeaveOneGroupOut()


all_test=[]
for train_index, test_index in logo.split(dl_fin, fmri, dl_fmri['session']):

    clf = Ridge(alpha=alpha,normalize=True)

    clf.fit(dl_fin[train_index], fmri[train_index])
    train_predict = clf.predict(dl_fin[train_index])
    train_true = fmri[train_index]
    r2_train = r2_score(train_true,train_predict,multioutput='raw_values').clip(min=0)

    logo2 =  LeaveOneGroupOut()
    print("sess num of runs:",logo2.get_n_splits(dl_fin[test_index],fmri[test_index], dl_fmri['run'][test_index]))
    
    for run_train_index, run_test_index in logo2.split(dl_fin[test_index],fmri[test_index], dl_fmri['run'][test_index]):
        test_predict = clf.predict(dl_fin[test_index[run_test_index]])
        test_true = fmri[test_index[run_test_index]]
        r2_test = r2_score(test_true,test_predict,multioutput='raw_values').clip(min=0)
        all_test.append(r2_test)
        print(np.histogram(r2_test, bins=[0.1, 0.2, 0.3,0.4,0.5])[0])
            
all_test = np.array(all_test[:])
mean_r2  = all_test.mean(axis=0)
            
print('-----------')
print(np.histogram(mean_r2, bins=[0.1, 0.2, 0.3,0.4,0.5])[0])

sess num of runs: 4
[3 0 0 0]
[15  0  0  0]
[47  1  0  0]
[55  8  0  0]
sess num of runs: 4
[7 0 0 0]
[0 0 0 0]
[24  1  0  0]
[39  5  0  0]
sess num of runs: 4
[5 0 0 0]
[36  1  0  0]
[28  1  0  0]
[17  0  0  0]
sess num of runs: 4
[10  0  0  0]
[83  4  0  0]
[45  2  0  0]
[0 0 0 0]
sess num of runs: 4
[3 0 0 0]
[42  0  0  0]
[57  4  0  0]
[9 0 0 0]
sess num of runs: 5
[45  5  0  0]
[62 12  0  0]
[34  0  0  0]
[29  1  0  0]
[5 0 0 0]
sess num of runs: 4
[19  2  0  0]
[18  0  0  0]
[8 1 0 0]
[15  1  0  0]
sess num of runs: 5
[1 0 0 0]
[44  1  0  0]
[5 0 0 0]
[11  0  0  0]
[22  0  0  0]
sess num of runs: 4
[33  0  0  0]
[32  3  0  0]
[29  0  0  0]
[23  0  0  0]
sess num of runs: 5
[27  1  0  0]
[22  1  0  0]
[32  6  0  0]
[0 0 0 0]
[29  1  0  0]
-----------
[8 0 0 0]


In [13]:
np.save('results/main-'+str(subject)+'-3.npy',all_test)