In [1]:
import os
import numpy as np
import pandas as pd
from utils.utils import *
from utils.kabsch2D import *
from nsd_access import NSDAccess 
from scipy import stats
import matplotlib.pyplot as plt

In [2]:
subj_list = [f'subj0{i}' for i in range(1,9)]
sessions = [37, 37, 29, 27, 37, 29, 37, 27]

In [7]:
def create_rotation_df(subj, rois, sess):
    cols = ['source', 'base', 'target', 'U', 'error']
    rotations = np.zeros(((len(rois) * (len(rois)-1)) * 2, len(cols)), dtype=object)
    i = 0 
    for roi_source in rois.keys():
        mds_source = np.load(f'/home/stan/thesis-repo/data/MDS/{subj}/{subj}_{sess}_{roi_source}_mds_betas_train.npy', allow_pickle=False)
        for roi_target in rois.keys():
            if roi_source != roi_target:
                for j in range(2):
                    mds_target = np.load(f'/home/stan/thesis-repo/data/MDS/{subj}/{subj}_{sess}_{roi_target}_mds_betas_train.npy', allow_pickle=False)
                    # flip the target when j is at 1, does nothing when j = 0 
                    rotations[i, 1] = roi_target
                    if j == 1:
                        mds_target = np.dot(mds_target, np.array([[-1, 0], [0, 1]]))
                        roi_target = roi_target + "_flipped"
                    U, t= kabsch2D(mds_source, mds_target, translate=True )
                    rotations[i, 0] = roi_source
                    rotations[i, 2] = roi_target
                    rotations[i, 3] = U
                    rotated_source = rotate(mds_source, U)
                    rotations[i, 4] = error(rotated_source, mds_target, t)
                    i += 1
    df = pd.DataFrame(rotations, columns=cols)
    df = df.astype({'error': 'float32'}) # needed for later indexing
    return df

def get_means(df):
    """
    Returns the mean error for each ROI 
    """
    # get the index of the lowest error for each pair in each ROI
    indx_to_keep = df.groupby(['source', 'base'])['error'].idxmin()
    df_filtered = df.loc[indx_to_keep]
    means = df_filtered.groupby('source')['error'].mean()
    return means

In [8]:
df = create_rotation_df(subj_list[0], rois, sessions[0])
df

Unnamed: 0,source,base,target,U,error
0,V1,V2,V2,"[[0.9937364, 0.111750215], [-0.111750215, 0.99...",5843.815918
1,V1,V2,V2_flipped,"[[0.2857228217385375, 0.9583123025077827], [-0...",4923.302246
2,V1,V3,V3,"[[0.25467932, 0.9670255], [-0.9670255, 0.25467...",6156.311523
3,V1,V3,V3_flipped,"[[0.9206654035364477, 0.39035267993324924], [-...",5277.029297
4,V1,hV4,hV4,"[[-0.8642162, -0.50312066], [0.50312066, -0.86...",6783.592773
...,...,...,...,...,...
259,TO-2,LO-1,LO-1_flipped,"[[0.740268298892458, 0.6723115688837029], [-0....",7223.580078
260,TO-2,LO-2,LO-2,"[[-0.026073754, -0.99966], [0.99966, -0.026073...",6625.286621
261,TO-2,LO-2,LO-2_flipped,"[[0.6885233656526, 0.7252141579908771], [-0.72...",5522.831055
262,TO-2,TO-1,TO-1,"[[-0.6986541, -0.7154597], [0.7154597, -0.6986...",4708.388184


In [9]:
means = get_means(df)
means

source
LO-1     6075.488770
LO-2     5409.333984
PHC-1    6010.904785
PHC-2    6169.129395
TO-1     6525.177734
TO-2     6574.104492
V1       7091.245117
V2       6597.642578
V3       6543.895508
VO-1     5858.212891
VO-2     6028.265137
hV4      6777.357910
Name: error, dtype: float32

In [11]:
def create_rotation_df_shuffle(subj, rois, sess):
    cols = ['source', 'base', 'target', 'U', 'error']
    rotations = np.zeros(((len(rois) * (len(rois)-1)) * 2, len(cols)), dtype=object)
    i = 0 
    for roi_source in rois.keys():
        mds_source = np.load(f'/home/stan/thesis-repo/data/MDS/{subj}/{subj}_{sess}_{roi_source}_mds_betas_train.npy', allow_pickle=False)
        for roi_target in rois.keys():
            if roi_source != roi_target:
                for j in range(2):
                    mds_target = np.load(f'/home/stan/thesis-repo/data/MDS/{subj}/{subj}_{sess}_{roi_target}_mds_betas_train.npy', allow_pickle=False)
                    # flip the target when j is at 1, does nothing when j = 0 
                    rotations[i, 1] = roi_target
                    if j == 1:
                        mds_target = np.dot(mds_target, np.array([[-1, 0], [0, 1]]))
                        roi_target = roi_target + "_flipped"
                    np.random.shuffle(mds_source)
                    np.random.shuffle(mds_target)
                    U, t= kabsch2D(mds_source, mds_target, translate=True)
                    rotations[i, 0] = roi_source
                    rotations[i, 2] = roi_target
                    rotations[i, 3] = U
                    rotated_source = rotate(mds_source, U)
                    rotations[i, 4] = error(rotated_source, mds_target, t)
                    i += 1
    df = pd.DataFrame(rotations, columns=cols)
    df = df.astype({'error': 'float32'}) # needed for later indexing
    return df

In [13]:
df_s = create_rotation_df_shuffle(subj_list[0], rois, sessions[0])
df_s

Unnamed: 0,source,base,target,U,error
0,V1,V2,V2,"[[-0.9680869, -0.250615], [0.250615, -0.9680869]]",8511.281250
1,V1,V2,V2_flipped,"[[-0.7029601299816798, 0.711229256749284], [-0...",8520.349609
2,V1,V3,V3,"[[0.84960616, 0.5274176], [-0.5274176, 0.84960...",8497.887695
3,V1,V3,V3_flipped,"[[0.6550483380746871, 0.7555869736738391], [-0...",8484.265625
4,V1,hV4,hV4,"[[0.9240173, 0.38235068], [-0.38235068, 0.9240...",8453.744141
...,...,...,...,...,...
259,TO-2,LO-1,LO-1_flipped,"[[-0.9792959305412029, -0.2024338914940863], [...",8855.619141
260,TO-2,LO-2,LO-2,"[[0.38006416, -0.9249601], [0.9249601, 0.38006...",8954.554688
261,TO-2,LO-2,LO-2_flipped,"[[-0.6597559870035391, -0.7514798983425872], [...",9024.000977
262,TO-2,TO-1,TO-1,"[[0.80668473, 0.5909821], [-0.5909821, 0.80668...",9096.779297


In [41]:
iterations = 1000
zeros = np.zeros((iterations + 1, len(rois)))
df_param = pd.DataFrame(zeros, columns=list(rois.keys()))

df_test = create_rotation_df(subj_list[0], rois, sessions[0])
means_test = get_means(df_test)
df_param.loc[0,:] = means_test

for i in range(1, iterations+1):
    df_it = create_rotation_df_shuffle(subj_list[0], rois, sessions[0])
    means_it = get_means(df_it)
    df_param.loc[i, :] = means_it

df_param


Unnamed: 0,V1,V2,V3,hV4,VO-1,VO-2,PHC-1,PHC-2,LO-1,LO-2,TO-1,TO-2
0,0.751698,0.715576,0.710606,0.731694,0.669528,0.683722,0.679085,0.687655,0.688435,0.641302,0.712962,0.717176
1,0.848028,0.851150,0.853128,0.851833,0.855344,0.857017,0.855295,0.865658,0.855800,0.860848,0.866157,0.860445
2,0.847321,0.850614,0.852177,0.850532,0.854758,0.855725,0.855641,0.863260,0.855774,0.859855,0.866201,0.861002
3,0.848258,0.850197,0.851521,0.851263,0.855468,0.856951,0.855536,0.863473,0.855316,0.861427,0.867245,0.860218
4,0.848407,0.850432,0.851330,0.850800,0.854783,0.857206,0.856143,0.866992,0.856139,0.860622,0.866640,0.860433
...,...,...,...,...,...,...,...,...,...,...,...,...
996,0.848093,0.850807,0.852166,0.852468,0.854266,0.856783,0.855343,0.864421,0.854742,0.861672,0.866913,0.860844
997,0.847433,0.850067,0.851185,0.853117,0.853946,0.857618,0.855295,0.864632,0.855533,0.861159,0.866737,0.860925
998,0.847561,0.850614,0.851218,0.850832,0.855696,0.856227,0.856504,0.865839,0.856857,0.860384,0.866437,0.860629
999,0.848509,0.850386,0.851511,0.851030,0.854709,0.856663,0.855662,0.866204,0.855527,0.861110,0.867829,0.859614


In [42]:
df_param.to_csv('parametric_test_output.csv', index=False)

In [2]:
df_param = pd.read_csv('parametric_test_output.csv')
df_param

Unnamed: 0,V1,V2,V3,hV4,VO-1,VO-2,PHC-1,PHC-2,LO-1,LO-2,TO-1,TO-2
0,0.751698,0.715576,0.710606,0.731694,0.669528,0.683722,0.679085,0.687655,0.688435,0.641302,0.712962,0.717176
1,0.848028,0.851150,0.853128,0.851833,0.855344,0.857017,0.855295,0.865658,0.855800,0.860848,0.866157,0.860445
2,0.847321,0.850614,0.852177,0.850532,0.854758,0.855725,0.855641,0.863260,0.855774,0.859855,0.866201,0.861002
3,0.848258,0.850197,0.851521,0.851263,0.855468,0.856951,0.855536,0.863473,0.855316,0.861427,0.867245,0.860218
4,0.848407,0.850432,0.851330,0.850800,0.854783,0.857206,0.856143,0.866992,0.856139,0.860622,0.866640,0.860433
...,...,...,...,...,...,...,...,...,...,...,...,...
996,0.848093,0.850807,0.852166,0.852468,0.854266,0.856783,0.855343,0.864421,0.854742,0.861672,0.866913,0.860844
997,0.847433,0.850067,0.851185,0.853117,0.853946,0.857618,0.855295,0.864632,0.855533,0.861159,0.866737,0.860925
998,0.847561,0.850614,0.851218,0.850832,0.855696,0.856227,0.856504,0.865839,0.856857,0.860384,0.866437,0.860629
999,0.848509,0.850386,0.851511,0.851030,0.854709,0.856663,0.855662,0.866204,0.855527,0.861110,0.867829,0.859614
