In [None]:
# %matplotlib widget

In [None]:
import os

import pandas as pd
import numpy as np


import matplotlib.pyplot as plt
import nibabel as nib
from nilearn import plotting, datasets, surface, image
import seaborn as sns

from matplotlib import gridspec
from matplotlib.cm import get_cmap
from matplotlib.colors import Normalize, LinearSegmentedColormap
import itertools
from tqdm import tqdm

# from sklearn.cluster import KMeans
# from src.regress import ridge
# from scipy import stats

## Set up

In [None]:
def mkNifti(arr, mask, im, nii=True):
    out_im = np.zeros(mask.size, dtype=arr.dtype)
    inds = np.where(mask)[0]
    out_im[inds] = arr
    if nii:
        out_im = out_im.reshape(im.shape)
        out_im = nib.Nifti1Image(out_im, affine=im.affine)
    return out_im

In [None]:
process = 'components'
top_dir = '/Users/emcmaho7/Dropbox/projects/SI_fmri/SIfMRI_analysis'
data_dir = f'{top_dir}/data/raw'
out_dir = f'{top_dir}/data/interim'
figure_dir = f'{top_dir}/reports/figures/{process}'
if not os.path.exists(figure_dir):
    os.mkdir(figure_dir)

In [None]:
#Load mask
mask = np.load(f'{out_dir}/Reliability/sub-all_reliability-mask.npy').astype('bool')
im = nib.load(f'{out_dir}/Reliability/sub-all_stat-rho_statmap.nii.gz')

rs = np.load(f'{out_dir}/VoxelPermutation/sub-all/sub-all_feature-all_rs-filtered.npy').astype('bool')
mask = mkNifti(rs, mask, im, nii=False)

In [None]:
X = []

n_voxels = sum(mask)
for sid_ in range(4):
    sid = str(sid_+1).zfill(2)
    betas = np.load(f'{out_dir}/grouped_runs/sub-{sid}/sub-{sid}_train-data.npy')
    
    #Filter the beta values to the reliable voxels
    betas = betas[mask, :]
    
    #Mean center the activation within subject
    # offset_subject = betas.mean()
    # betas -= offset_subject

    if type(X) is list:
        X = betas.T
    else:
        X += betas.T
X /= 4 

In [None]:
df = pd.read_csv(f'{data_dir}/annotations/annotations.csv')
train = pd.read_csv(f'{data_dir}/annotations/train.csv')
df = df.merge(train)
df.sort_values(by=['video_name'], inplace=True)
df.drop(columns=['video_name'], inplace=True)
features = np.array(df.columns)

y = df.to_numpy()

## Stuff

### Regression

In [None]:
coef, _ = ridge(X, y)

### KMeans

In [None]:
n_clusters = 4
kmeans = KMeans(n_clusters=n_clusters, random_state=0)
cluster_labels = kmeans.fit_predict(coef.T)

### TSNE

In [None]:
from sklearn.manifold import TSNE
tsne = TSNE(n_components=2)
embed = tsne.fit_transform(coef.T)

sns.scatterplot(embed[:,0], embed[:,1],
                hue=cluster_labels+1, palette=sns.color_palette('husl', 4))

### Plotting on brain

In [None]:
sum(cluster_labels ==0)

In [None]:
sum(cluster_labels ==3)

In [None]:
np.unique(brain_map.dataobj)

In [None]:
base = 10
cmap = sns.color_palette('Paired', 4, as_cmap=True)
# cmap.set_bad(color = 'k', alpha = 0)
brain_map = mkNifti(cluster_labels+base, mask, im)
fsaverage = datasets.fetch_surf_fsaverage(mesh='fsaverage')
texture = {'left': [], 'right': []}
for hemi in ['left', 'right']: 
    arr = surface.vol_to_surf(brain_map, fsaverage[f'pial_{hemi}'],
                                       interpolation='nearest')
    texture[hemi] = arr.astype('int')
    
for hemi in ['left', 'right']:
    view = plotting.view_surf(fsaverage[f'infl_{hemi}'],
                              title=hemi,
                              surf_map=texture[hemi],
                              symmetric_cmap=False,
                              vmin=base, vmax=n_clusters+base,
                              threshold=base,
                              cmap=cmap,bg_map=fsaverage[f'sulc_{hemi}'])
    view.open_in_browser()

In [None]:
view = plotting.view_img(brain_map, cmap=cmap,
                         vmin=base, vmax=n_clusters+base-1,
                          threshold=base, colorbar=True,
                        symmetric_cmap=False)
view.open_in_browser()

In [None]:
for i in range(n_clusters):
    vals = kmeans.cluster_centers_[i,:]
    vals = (vals - vals.mean()) / vals.std()
    sort_indices = vals.argsort()
    sort_indices = np.flip(sort_indices)
    print(f'N voxels in cluster: {sum(cluster_labels == i)}')
    print(pd.DataFrame({'cluster centers': vals[sort_indices],
                       'features': features[sort_indices]}))
    print()