In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os

# get to project's folder
dirPath = "/content/drive/MyDrive/Colab Notebooks/Final project/Experiments"
os.chdir(dirPath)

In [None]:
!pip install tensorflow-addons

In [None]:
from Data_extraction_transformer import Get_Data
from Results import Get_Results

import collections
import logging
#import os already imported in code cell 2
import pathlib
import re
import string
import sys
import time
import math
import pickle

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.metrics import accuracy_score, confusion_matrix, ConfusionMatrixDisplay

import tensorflow_datasets as tfds
import tensorflow as tf
from tensorflow import keras
from keras import backend as K
import matplotlib.pyplot as plt
import seaborn as sns
from IPython.display import display, HTML
from scipy.io import savemat

In [None]:
K_SEED = 330

class args:

  def __init__(self,input_data,roi,net,roi_name,zscore,train_size):
    self.input_data = input_data
    self.roi = roi
    self.net = net
    self.roi_name = roi_name
    self.K_RUNS = K_RUNS
    # preprocessing
    self.zscore = zscore
    # training parameters
    self.train_size = train_size
    

# data parameters
args.input_data = 'data/roi_ts'
args.roi = 300
args.net = 7
args.roi_name = 'roi'
args.K_RUNS = 4
# preprocessing
args.zscore = 1
# training parameters
args.train_size = 100

In [None]:
#utils functions

def _get_clip_labels():
    '''
    assign all clips within runs a label
    use 0 for testretest
    '''
    # where are the clips within the run?
    timing_file = pd.read_csv('data/videoclip_tr_lookup.csv')

    clips = []
    for run in range(args.K_RUNS):
        run_name = 'MOVIE%d' %(run+1) #MOVIEx_7T_yz
        timing_df = timing_file[timing_file['run'].str.contains(run_name)]  
        timing_df = timing_df.reset_index(drop=True)

        for jj, row in timing_df.iterrows():
            clips.append(row['clip_name'])
            
    clip_y = {}
    jj = 1
    for clip in clips:
        if 'testretest' in clip:
            clip_y[clip] = 0
        else:
            clip_y[clip] = jj
            jj += 1

    return clip_y

def GetAcc(model, X, y):
  '''
  get the confusion matrix and accuracy from the model.predict(X)
  inputs: model, X-Eager tensor of data, y-labels
  outputs: acc-accuracy
  '''

  y_hat_hold = model.predict(X)
  y_hat = np.argmax(y_hat_hold, axis=2)
  true_y = y

  y_overtime = []
  y_hat_overtime = []
  for rows_y,rows_y_hat in zip(true_y,y_hat):
    values, counts = np.unique(rows_y, return_counts=True)
    ind = np.argmax(counts)
    if values[ind] == 15:
      counts[ind] = 0
      ind = np.argmax(counts)
    y_overtime.append(values[ind])

    values, counts = np.unique(rows_y_hat, return_counts=True)
    ind = np.argmax(counts)
    if values[ind] == 15:
      counts[ind] = 0
      ind = np.argmax(counts)
    y_hat_overtime.append(values[ind])

  acc = accuracy_score(y_overtime,y_hat_overtime)
  return acc

def GetMaxInfo(accMat):
  '''
  findes which layers and heads gave the maximum accuracy
  inputs: accMat - matrix of models accuracy where rows are number of heads and columns are number of layers
  outputs: maxV - maximum accuracy, headsidx - rows index, layersidx - columns index
  '''
  maxV = accMat.max().max()
  headsidx = np.argmax(np.max(accMat, axis=0), axis=0)
  layersidx = np.argmax(np.max(accMat, axis=1), axis=0)
  return maxV, headsidx, layersidx

def GetAttnDict(X, Model, num_of_clips=len(clip_time) + 3, att_block=5):
  '''
  builds a dict of summed attention matrices across clips and subjects
  inputs: X - Eager tensor, Model- traind model, num_of_clips - number of labels, att_block - which block of attention to take
  outputs: att_dict - attention dict of summed attention heads across films and subjects
  '''
  # att_block can be set to 2-5
  mask = Model.layers[0].compute_mask(X)
  _, att = Model.layers[att_block]((X))

  num_of_subjects = int(X.shape[0]/num_of_clips) # in our case 38 subjects
  att_dict = {}
  film0_att_ind = []
  for j in range(0,num_of_clips):
    film_att_ind = [j+i*num_of_clips for i in range(0,num_of_subjects)]
    if j<=3:
      film0_att_ind += film_att_ind
      if j==3:
        film0_att_ind = tf.sort(film0_att_ind)
        clip_len = tf.reduce_sum(tf.cast(mask[j], tf.int32))
        att_dict[f'film{j-3}_att'] = tf.gather(att[:,:,0:clip_len,0:clip_len] , film0_att_ind)
    else:
      film_att_ind = [j+i*num_of_clips for i in range(0,num_of_subjects)]
      clip_len = tf.reduce_sum(tf.cast(mask[j], tf.int32))
      att_dict[f'film{j-3}_att'] = tf.gather(att[:,:,0:clip_len,0:clip_len] , film_att_ind)

  for j in range(0,num_of_subjects):
    subject_att_ind = [j*num_of_clips+i for i in range(0,num_of_clips)]
    att_dict[f'subject{j+1}_att'] = tf.gather(att , subject_att_ind)

  return att_dict

def GetGraphs(att_dict, num_of_clips):
  '''
  prints attention matrix
  inputs: att_dict-dict of attention matrices, num_of_clips - number of labels
  outputs: the plots
  '''
  for j in range(3,num_of_clips):
      corr_prepare = att_dict[f'film{j-3}_att']
      corr_prepare = tf.reduce_mean(corr_prepare,axis=1)
      corr_prepare = np.reshape(corr_prepare, newshape=(corr_prepare.shape[0],-1))
      corr_prepare_corrcoef = np.corrcoef(corr_prepare,rowvar=True) # find corr coef for observationXatt

      fig = plt.figure()
      #fig.set_size_inches(corr_prepare.shape[0]/4, corr_prepare.shape[0]/4)
      fig.set_size_inches(152/4, 152/4)
      if j==3:
        sns.heatmap(corr_prepare_corrcoef)
      else:
        sns.heatmap(corr_prepare_corrcoef, annot=True)
      plt.title(f'film: {clip_names[j-3]} cross subjects corr')
      plt.grid()
      plt.savefig(RES_DIR + f'/film: {clip_names[j-3]} cross subjects corr')
      plt.close('all')

def printLoop(net):
  print('-----------------------------------------------------------------------------------------------------')
  print(f'---------------------------------------------NET: {net}------------------------------------------------')
  print('-----------------------------------------------------------------------------------------------------')

In [None]:
# get clips names
clip_y = _get_clip_labels()
k_class = len(np.unique(list(clip_y.values())))
print('number of classes = %d' %k_class)

clip_names = np.zeros(k_class).astype(str)
clip_names[0] = 'testretest'
for key, item in clip_y.items():
    if item!=0:
        clip_names[item] = key

In [None]:
# get the orginized data from Get_Data function in Data_extraction_transformer.py
X_train, train_len, y_train, X_val, val_len, y_val, X_test, test_len, y_test, train_list, test_list, clip_time = Get_Data(args)

In [None]:
def GetNetwork(X_train, X_val, X_test,startLH,endLH,startRH,endRH):
  '''
  sets brain network from right and left hemisphere to X
  inputs: X_train,X_val,X_test: Eager tensor with all brain networks
          startLH,endLH,startRH,endRH: indices of relevant brain network
  outputs: X_train_end,X_val_end,X_test_end: Eager tensor with relevant brain network
  '''
  X_train_LH = X_train[:,:,startLH:endLH]
  X_train_RH = X_train[:,:,startRH:endRH]
  X_train_end = tf.concat([X_train_LH, X_train_RH], axis=2)

  X_val_LH = X_val[:,:,startLH:endLH]
  X_val_RH = X_val[:,:,startRH:endRH]
  X_val_end = tf.concat([X_val_LH, X_val_RH], axis=2)

  X_test_LH = X_test[:,:,startLH:endLH]
  X_test_RH = X_test[:,:,startRH:endRH]
  X_test_end = tf.concat([X_test_LH, X_test_RH], axis=2)

  return X_train_end, X_val_end, X_test_end

networksDict = {'vis':{'startLH':0,'endLH':24,'startRH':151,'endRH':174,'train':[], 'val':[], 'test':[]},
                'SomMot':{'startLH':24,'endLH':53,'startRH':174,'endRH':201,'train':[], 'val':[], 'test':[]},
                'Attn':{'startLH':53,'endLH':85,'startRH':201,'endRH':237,'train':[], 'val':[], 'test':[]},
                'limbic':{'startLH':85,'endLH':95,'startRH':237,'endRH':247,'train':[], 'val':[], 'test':[]},
                'Cont':{'startLH':95,'endLH':112,'startRH':247,'endRH':270,'train':[], 'val':[], 'test':[]},
                'DMN':{'startLH':112,'endLH':150,'startRH':270,'endRH':300,'train':[], 'val':[], 'test':[]},
                'full':{'train':X_train, 'val':X_val, 'test':X_test}
                }

for net in networksDict:
  if net is not 'full':
    networksDict[net]['train'], networksDict[net]['val'], networksDict[net]['test'] = GetNetwork(X_train, X_val, X_test, networksDict[net]['startLH'], networksDict[net]['endLH'], networksDict[net]['startRH'], networksDict[net]['endRH'])
  print(net + ':')
  print('train: '+ str(networksDict[net]['train'].shape))
  print('val: '+ str(networksDict[net]['val'].shape))
  print('test: '+ str(networksDict[net]['test'].shape))

In [None]:
# Sets results dict for the brain networks

modelsDict = {'4 Layers':{'1 Head':None, '4 Heads':None, '6 Heads':None, '8 Heads':None},
             '6 Layers':{'1 Head':None, '4 Heads':None, '6 Heads':None, '8 Heads':None},
             '8 Layers':{'1 Head':None, '4 Heads':None, '6 Heads':None, '8 Heads':None}}

attDict = {'4 Layers':{'1 Head':None, '4 Heads':None, '6 Heads':None, '8 Heads':None},
          '6 Layers':{'1 Head':None, '4 Heads':None, '6 Heads':None, '8 Heads':None},
          '8 Layers':{'1 Head':None, '4 Heads':None, '6 Heads':None, '8 Heads':None}}

networksResults = {
    'vis':{'results':{}, 'results_prob':{}, 'acc':np.zeros((3,4)), 'Models':modelsDict, 'att_dict':attDict},
    'SomMot':{'results':{}, 'results_prob':{}, 'acc':np.zeros((3,4)), 'Models':modelsDict, 'att_dict':attDict},
    'Attn':{'results':{}, 'results_prob':{}, 'acc':np.zeros((3,4)), 'Models':modelsDict, 'att_dict':attDict},
    'limbic':{'results':{}, 'results_prob':{}, 'acc':np.zeros((3,4)), 'Models':modelsDict, 'att_dict':attDict},
    'Cont':{'results':{}, 'results_prob':{}, 'acc':np.zeros((3,4)), 'Models':modelsDict, 'att_dict':attDict},
    'DMN':{'results':{}, 'results_prob':{}, 'acc':np.zeros((3,4)), 'Models':modelsDict, 'att_dict':attDict},
    'full':{'results':{}, 'results_prob':{}, 'acc':np.zeros((3,4)), 'Models':modelsDict, 'att_dict':attDict}
}

In [None]:
# set hyperparameters
rows_names = ['4 Layers', '6 Layers', '8 Layers']
num_layers = [4, 6, 8]
rows = len(num_layers)
columns_names = ['1 Head', '4 Heads', '6 Heads', '8 Heads']
num_heads = [1, 4, 6, 8]
columns = len(num_heads)
BATCH_SIZE = 64
EPOCHS = 45

In [None]:
# read results

## files paths:
# RES_DIR + f'/net: {net} numHeads: {heads} layers: {layers} accuracy', dpi=fig.dpi     # summarize history for accuracy
# RES_DIR + f'/net: {net} numHeads: {heads} layers: {layers} loss', dpi=fig.dpi         # summarize history for loss
# RES_DIR + f'/net: {net} numHeads: {heads} layers: {layers} Cmat val', dpi=fig.dpi     # val Cmat
# RES_DIR + f'/net: {net} numHeads: {heads} layers: {layers} Cmat test', dpi=fig.dpi    # test Cmat
# f'/Model_{net}_numHeads_{heads}_num_layers_{layers}'                                  #Models


for net in networksResults:
  i = j = 0
  printLoop(net)
  for heads , column in zip(num_heads, columns_names):
    for layers, row in zip(num_layers, rows_names):
      print(f'---------------------------NET: {net}  number of heads: {heads} number of layers: {layers}--------------------------')

      # results directory
      RES_DIR = f'{dirPath}/results/encoder/{net}/to sort'
      RES_DIR_MODEL = f'{dirPath}/models/encoder_models/{net}'

      res_path = (RES_DIR + 
                  '/%s_%d_net_%s' %(args.roi_name, args.roi, net) +
                  '_k_layers_%d' %(layers) +
                  '_heads_%d_batch_size_%d' %(heads, BATCH_SIZE) +
                  '_num_epochs_%d.pkl' %(EPOCHS))

      # load results
      with open(res_path ,"rb") as  f:
          networksResults[net]['results'], networksResults[net]['results_prob'] = pickle.load(f)   

      # load model
      networksResults[net]['Models'][row][column] = keras.models.load_model(f'{RES_DIR_MODEL}/Model_{net}_numHeads_{heads}_num_layers_{layers}')

      # acc matrix
      networksResults[net]['acc'][i,j] = GetAcc(networksResults[net]['Models'][row][column], networksDict[net]['test'], y_test)
      i += 1
    i = 0
    j += 1
  
  networksResults[net]['df'] = pd.DataFrame(data=networksResults[net]['acc'], index=rows_names, columns=columns_names)


In [None]:
# get acc for networks

for net in networksResults:
  printLoop(net)
  
  # results directory
  RES_DIR = f'{dirPath}/results/encoder/{net}'
  if not os.path.exists(RES_DIR):
      os.makedirs(RES_DIR)

  maxV, headsidx, layersidx = GetMaxInfo(networksResults[net]['acc'])
  display(networksResults[net]['df'].style.highlight_max(color = 'orange', axis=None))
  networksResults[net]['df'].plot(kind='bar', title=f'{net}: max acc: {maxV:.3} for: {columns_names[headsidx]}, {rows_names[layersidx]}', figsize=(8,6), ylim=(0,1.2))
  plt.savefig(RES_DIR + f'/net: {net} accuracy bar')
  networksResults[net]['df'].to_csv(RES_DIR + f'/net: {net} accuracy data.csv')

In [None]:
# get pearson corr
num_of_clips=len(clip_time) + 3

for net in networksResults:
  printLoop(net)
  RES_DIR_MODEL = f'{dirPath}/models/encoder_models/{net}'
  for layers, row in zip(num_layers, rows_names):
    for heads, column in zip(num_heads, columns_names):
      # # load model
      # networksResults[net]['Models'][row][column] = keras.models.load_model(f'{RES_DIR_MODEL}/Model_{net}_numHeads_{heads}_num_layers_{layers}')
      Model = keras.models.load_model(f'{RES_DIR_MODEL}/Model_{net}_numHeads_{heads}_num_layers_{layers}')
      for att_block in range(2,layers+2):
        print(f'-------------------NET: {net}  number of heads: {heads} number of layers: {layers} att: {att_block-1}--------------------')
        # # attention dict
        # networksResults[net]['att_dict'][row][column] = GetAttnDict(networksDict[net]['val'], networksResults[net]['Models'][row][column], att_block=att_block)
        att_dict = GetAttnDict(networksDict[net]['val'], Model, att_block=att_block)


        # results directory
        RES_DIR = f'{dirPath}/results/encoder/{net}/{row}/{column}/att_block_num_{att_block-1}'
        if not os.path.exists(RES_DIR):
            os.makedirs(RES_DIR)

        GetGraphs(att_dict, num_of_clips)


In [None]:
# get pearson corr histogram
num_of_clips=len(clip_time) + 3

for net in networksResults:
  printLoop(net)
  RES_DIR_MODEL = f'{dirPath}/models/encoder_models/{net}'
  for layers, row in zip(num_layers, rows_names):
    for heads , column in zip(num_heads, columns_names):
      # load model
      networksResults[net]['Models'][row][column] = keras.models.load_model(f'{RES_DIR_MODEL}/Model_{net}_numHeads_{heads}_num_layers_{layers}')
      for att_block in range(2,layers+2):
        print(f'-------------------NET: {net}  number of heads: {heads} number of layers: {layers} att: {att_block-1}--------------------')
        # attention dict
        networksResults[net]['att_dict'][row][column] = GetAttnDict(networksDict[net]['val'], networksResults[net]['Models'][row][column], att_block=att_block)
      
        # results directory
        RES_DIR = f'{dirPath}/results/encoder/{net}/{row}/{column}/att_block_num_{att_block-1}'
        if not os.path.exists(RES_DIR):
            os.makedirs(RES_DIR)

        for j in range(3,num_of_clips):
          corr_prepare = networksResults[net]['att_dict'][row][column][f'film{j-3}_att']
          corr_prepare = tf.reduce_mean(corr_prepare,axis=1)
          corr_prepare = np.reshape(corr_prepare, newshape=(corr_prepare.shape[0],-1))
          corr_prepare_corrcoef = np.corrcoef(corr_prepare,rowvar=True) # find corr coef for observationXatt
          corr_prepare_corrcoef_hist = np.reshape(corr_prepare_corrcoef, newshape=-1)
          corr_prepare_corrcoef_hist = corr_prepare_corrcoef_hist[corr_prepare_corrcoef_hist < 0.98]

          fig = plt.figure()
          #fig.set_size_inches(corr_prepare.shape[0]/4, corr_prepare.shape[0]/4)
          #fig.set_size_inches(152/4, 152/4)
          if j==3:
            sns.histplot(corr_prepare_corrcoef_hist)
          else:
            sns.histplot(corr_prepare_corrcoef_hist)
          plt.title(f'film: {clip_names[j-3]} hist\nmax: {corr_prepare_corrcoef_hist.max():.3}\nmin: {corr_prepare_corrcoef_hist.min():.3}')
          plt.savefig(RES_DIR + f'/film: {clip_names[j-3]} cross subjects corr hist')
          # plt.show()
          plt.close()

In [None]:
# get att heads
num_of_clips=len(clip_time) + 3

for net in networksResults:
  printLoop(net)
  RES_DIR_MODEL = f'{dirPath}/models/encoder_models/{net}'
  for layers, row in zip(num_layers, rows_names):
    for heads , column in zip(num_heads, columns_names):
      # load model
      networksResults[net]['Models'][row][column] = keras.models.load_model(f'{RES_DIR_MODEL}/Model_{net}_numHeads_{heads}_num_layers_{layers}')
      for att_block in range(2,layers+2):
        print(f'-------------------NET: {net}  number of heads: {heads} number of layers: {layers} att: {att_block-1}--------------------')
        # attention dict
        networksResults[net]['att_dict'][row][column] = GetAttnDict(networksDict[net]['val'], networksResults[net]['Models'][row][column], att_block=att_block)
      
        # results directory
        RES_DIR = f'{dirPath}/results/encoder/{net}/{row}/{column}/att_block_num_{att_block-1}'
        if not os.path.exists(RES_DIR):
            os.makedirs(RES_DIR)

        for j in range(3,num_of_clips):
          att_prepare = networksResults[net]['att_dict'][row][column][f'film{j-3}_att']
          att = tf.reduce_mean(att_prepare,axis=1)
          att = tf.reduce_mean(att,axis=0)
          clip_len = clip_time[j-3]
          
          # savemat(RES_DIR + f'/film: {clip_names[j-3]} cross subjects mean attention.mat', {'att':np.asarray(att[0:clip_len,0:clip_len])})

          fig = plt.figure()
          fig.set_size_inches(clip_len/4, clip_len/4)
          ax = plt.gca()
          ax.matshow(att)
          ax.set_xticks(range(clip_len))
          ax.set_yticks(range(clip_len))
          labels = range(1,clip_len+1)
          ax.set_xticklabels(labels, rotation=90)
          ax.set_yticklabels(labels)
          plt.title(f'film: {clip_names[j-3]} mean attention ')
          plt.grid()
          plt.savefig(RES_DIR + f'/film: {clip_names[j-3]} cross subjects mean attention')
          # plt.show()
          plt.close()
