<a href="https://colab.research.google.com/github/UN-GCPDS/python-gcpds.EEG_Tensorflow_models/blob/main/Experimental/DW_LCAM/%5B3%5D_Main_attention_maps_comparison.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Load drive

In [None]:
#-------------------------------------------------------------------------------
from google.colab import drive
drive.mount('/content/drive')
#-------------------------------------------------------------------------------

# Install Keras-vis toolbox




In [None]:
#-------------------------------------------------------------------------------
!pip install tf-keras-vis tensorflow
#-------------------------------------------------------------------------------

# Supporting modules

In [None]:
#-------------------------------------------------------------------------------
import numpy as np
import scipy.io as sio
import matplotlib.pyplot as plt
import tensorflow as tf
import pickle
from tf_keras_vis.utils.scores import CategoricalScore
from tf_keras_vis.utils.model_modifiers import ReplaceToLinear
from matplotlib import cm
from tf_keras_vis.gradcam import Gradcam
from sklearn.model_selection import ShuffleSplit
from tensorflow import keras
from tensorflow.keras.constraints import max_norm
from tensorflow.keras import backend as K
from tf_keras_vis.saliency import Saliency
from tf_keras_vis.scorecam import Scorecam
from tf_keras_vis.gradcam_plus_plus import GradcamPlusPlus
from sklearn.preprocessing import MinMaxScaler
from sklearn.preprocessing import StandardScaler
from sklearn.manifold import TSNE
from sklearn.decomposition import KernelPCA
from sklearn.metrics import pairwise_distances
%matplotlib inline
#-------------------------------------------------------------------------------

# Define load data, normalization and CNN model function

In [None]:
#-------------------------------------------------------------------------------
def TW_data(sbj,time_inf,time_sup):
    # Load data/images----------------------------------------------------------
    path_cwt = '/content/drive/MyDrive/Colab Notebooks/GradCam_Paper/GigaData/data/CWT_CSP_data_mubeta_8_30_Tw_'+str(time_inf)+'s_'+str(time_sup)+'s_subject'+str(sbj)+'_cwt_resized_10.pickle'  
    with open(path_cwt, 'rb') as f:
         X_train_re_cwt, X_test_re_cwt, y_train, y_test = pickle.load(f)
    path_csp = '/content/drive/MyDrive/Colab Notebooks/GradCam_Paper/GigaData/data/CWT_CSP_data_mubeta_8_30_Tw_'+str(time_inf)+'s_'+str(time_sup)+'s_subject'+str(sbj)+'_csp_resized_10.pickle'  
    with open(path_csp, 'rb') as f:
         X_train_re_csp, X_test_re_csp, y_train, y_test = pickle.load(f)
    #---------------------------------------------------------------------------
    return X_train_re_cwt, X_train_re_csp, X_test_re_cwt, X_test_re_csp, y_train, y_test
#-------------------------------------------------------------------------------
def norm_data(XF_train_cwt, XF_train_csp, XF_test_cwt, XF_test_csp, n_fb, Ntw, y_train, y_test, fld):
    # orden de las inputs:------------------------------------------------------
    # [CWT_fb1_TW1, CWT_fb2_TW1 --- CWT_fb1_TW2, CWT_fb2_TW2 --- CWT_fb1_TWN, CWT_fb2_TWN] ... [CSP]
    #---------------------------------------------------------------------------
    XT_train_csp = []
    XT_valid_csp = []
    XT_test_csp  = []
    XT_train_cwt = []
    XT_valid_cwt = []
    XT_test_cwt  = []
    for tw in range(Ntw):
        for fb in range(n_fb):
            X_train_cwt, X_test_cwt = XF_train_cwt[tw][:,fb,:,:].astype(np.uint8), XF_test_cwt[tw][:,fb,:,:].astype(np.uint8)
            X_train_csp, X_test_csp = XF_train_csp[tw][:,fb,:,:].astype(np.uint8), XF_test_csp[tw][:,fb,:,:].astype(np.uint8)
            #-------------------------------------------------------------------
            # train/validation data split
            rs = ShuffleSplit(n_splits=1, test_size=.1, random_state=fld)
            for train_index, valid_index in rs.split(X_train_cwt):
              X_train_cwtf = X_train_cwt[train_index,:,:] # cwt
              X_valid_cwtf = X_train_cwt[valid_index,:,:]
              X_train_cspf = X_train_csp[train_index,:,:] # csp
              X_valid_cspf = X_train_csp[valid_index,:,:]
            #-------------------------------------------------------------------          
            # Normalize data----------------------------------------------------
            X_mean_cwt  = X_train_cwtf.mean(axis=0, keepdims=True)
            X_std_cwt   = X_train_cwtf.std(axis=0, keepdims=True) + 1e-7
            X_train_cwt = (X_train_cwtf - X_mean_cwt) / X_std_cwt
            X_valid_cwt = (X_valid_cwtf - X_mean_cwt) / X_std_cwt
            X_test_cwt  = (X_test_cwt  - X_mean_cwt) / X_std_cwt

            X_mean_csp  = X_train_cspf.mean(axis=0, keepdims=True)
            X_std_csp   = X_train_cspf.std(axis=0, keepdims=True) + 1e-7
            X_train_csp = (X_train_cspf - X_mean_csp) / X_std_csp
            X_valid_csp = (X_valid_cspf - X_mean_csp) / X_std_csp
            X_test_csp  = (X_test_csp  - X_mean_csp) / X_std_csp
            #-------------------------------------------------------------------
            # set new axis------------------------------------------------------
            X_train_cwt = X_train_cwt[..., np.newaxis]
            X_valid_cwt = X_valid_cwt[..., np.newaxis]
            X_test_cwt  = X_test_cwt[..., np.newaxis]   
            XT_train_cwt.append(X_train_cwt)
            XT_valid_cwt.append(X_valid_cwt)
            XT_test_cwt.append(X_test_cwt)
                                
            X_train_csp = X_train_csp[..., np.newaxis]
            X_valid_csp = X_valid_csp[..., np.newaxis]
            X_test_csp  = X_test_csp[..., np.newaxis]   
            XT_train_csp.append(X_train_csp)
            XT_valid_csp.append(X_valid_csp)
            XT_test_csp.append(X_test_csp)
            #-------------------------------------------------------------------
    y_trainf = y_train[train_index]
    y_validf = y_train[valid_index]
    y_trainF, y_validF, y_testF = y_trainf.reshape((-1,))-1, y_validf.reshape((-1,))-1, y_test.reshape((-1,))-1
    #---------------------------------------------------------------------------
    # Convert class vectors to binary class matrices----------------------------
    y_train = keras.utils.to_categorical(y_trainF,num_classes)
    y_valid = keras.utils.to_categorical(y_validF,num_classes)
    y_test  = keras.utils.to_categorical(y_testF,num_classes)
    #---------------------------------------------------------------------------
    XT_train = XT_train_cwt + XT_train_csp
    XT_valid = XT_valid_cwt + XT_valid_csp
    XT_test  = XT_test_cwt  + XT_test_csp
    #---------------------------------------------------------------------------
    return XT_train, XT_valid, XT_test, y_train, y_valid, y_test, train_index, valid_index  
#-------------------------------------------------------------------------------
def vis_heatmap(HmapT,Ntw,names_x,norm):
  #-----------------------------------------------------------------------------
  # normalizing heatmap
  if norm == 1:
    hmap_max = np.max(np.array(HmapT))
    for i in range(20):
      HmapT[i] = tf.math.divide_no_nan(HmapT[i],hmap_max)
    new_max = np.max(np.array(HmapT))
    new_min = np.min(np.array(HmapT))
  else:
    for i in range(20):
      print(np.max(np.array(HmapT[i])),np.min(np.array(HmapT[i])))
      HmapT[i] = tf.math.divide_no_nan(HmapT[i],np.max(np.array(HmapT[i])))
    new_max = np.max(np.array(HmapT))
    new_min = np.min(np.array(HmapT))
  #-----------------------------------------------------------------------------
  # figure plot setting
  fig, axs = plt.subplots(4,5,figsize=(12,7.3))
  fig.subplots_adjust(hspace = 0.1, wspace=.0001)
  #-----------------------------------------------------------------------------
  # creating figure
  for tw in range(Ntw):
      if tw == 0:
        ids_tw = [tw, tw+1, tw+10, tw+10+1] 
      else:
        ids_tw = [tw*2, tw*2+1, tw*2+10, tw*2+10+1]
      axs[0,tw].matshow(HmapT[ids_tw[0]],vmin=new_min, vmax=new_max)
      axs[1,tw].matshow(HmapT[ids_tw[1]],vmin=new_min, vmax=new_max)
      axs[2,tw].matshow(HmapT[ids_tw[2]],vmin=new_min, vmax=new_max)
      axs[3,tw].matshow(HmapT[ids_tw[3]],vmin=new_min, vmax=new_max)
      axs[3,tw].set(xlabel=names_x[tw])
      axs[3,tw].xaxis.get_label().set_fontsize(15)
      if tw == 0:
        axs[0,tw].set(ylabel=r'$CWT \mu$')
        axs[0,tw].yaxis.get_label().set_fontsize(15)
        axs[1,tw].set(ylabel=r'$CWT \beta$')
        axs[1,tw].yaxis.get_label().set_fontsize(15)
        axs[2,tw].set(ylabel=r'$CSP \mu$')
        axs[2,tw].yaxis.get_label().set_fontsize(15)
        axs[3,tw].set(ylabel=r'$CSP \beta$')
        axs[3,tw].yaxis.get_label().set_fontsize(15)
  #-----------------------------------------------------------------------------
  for ax in axs.flat:
      ax.label_outer()
  for ax in axs.flat:
      ax.set_xticks([])
      ax.set_yticks([])
#-------------------------------------------------------------------------------
def vis_render(HmapT,new_input,Ntw):
  f, ax = plt.subplots(nrows=4, ncols=5, figsize=(12,7.3))
  f.subplots_adjust(hspace = 0.1, wspace=.0001)
  for tw in range(Ntw):
      if tw == 0:
        ids_tw = [tw, tw+1, tw+10, tw+10+1] 
      else:
        ids_tw = [tw*2, tw*2+1, tw*2+10, tw*2+10+1]
      heatmap_0 = np.uint8(cm.jet(HmapT[ids_tw[0]])[..., :3] * 255)
      heatmap_1 = np.uint8(cm.jet(HmapT[ids_tw[1]])[..., :3] * 255)
      heatmap_2 = np.uint8(cm.jet(HmapT[ids_tw[2]])[..., :3] * 255)
      heatmap_3 = np.uint8(cm.jet(HmapT[ids_tw[3]])[..., :3] * 255)
      
      ax[0,tw].imshow(np.squeeze(new_input[ids_tw[0]]), cmap='gray',vmin=0,vmax=1)
      ax[1,tw].imshow(np.squeeze(new_input[ids_tw[1]]), cmap='gray',vmin=0,vmax=1)
      ax[2,tw].imshow(np.squeeze(new_input[ids_tw[2]]), cmap='gray',vmin=0,vmax=1)
      ax[3,tw].imshow(np.squeeze(new_input[ids_tw[3]]), cmap='gray',vmin=0,vmax=1)
      ax[0,tw].imshow(heatmap_0, cmap='jet', alpha=0.5) # overlay
      ax[1,tw].imshow(heatmap_1, cmap='jet', alpha=0.5) # overlay
      ax[2,tw].imshow(heatmap_2, cmap='jet', alpha=0.5) # overlay
      ax[3,tw].imshow(heatmap_3, cmap='jet', alpha=0.5) # overlay
      if tw == 0:
        ax[0,tw].set(ylabel=r'$CWT \mu$')
        ax[0,tw].yaxis.get_label().set_fontsize(15)
        ax[1,tw].set(ylabel=r'$CWT \beta$')
        ax[1,tw].yaxis.get_label().set_fontsize(15)
        ax[2,tw].set(ylabel=r'$CSP \mu$')
        ax[2,tw].yaxis.get_label().set_fontsize(15)
        ax[3,tw].set(ylabel=r'$CSP \beta$')
        ax[3,tw].yaxis.get_label().set_fontsize(15)
  for ax in ax.flat:
      ax.set_xticks([])
      ax.set_yticks([])
#-------------------------------------------------------------------------------
def cnn_network(n_fb,Nkfeats,Ntw,shape_,n_filt,units,l1p,l2p,lrate,sbj):
    #---------------------------------------------------------------------------
    keras.backend.clear_session()
    np.random.seed(123)
    tf.compat.v1.random.set_random_seed(123)
    #---------------------------------------------------------------------------
    input_  = [None]*Ntw*n_fb*Nkfeats
    conv_   = [None]*Ntw*n_fb*Nkfeats
    pool_   = [None]*Ntw*n_fb*Nkfeats
    batch0_ = [None]*Ntw*n_fb*Nkfeats
    batch2_ = [None]*Ntw*n_fb*Nkfeats
    for i in range(Ntw*n_fb*Nkfeats):
        input_[i]  = keras.layers.Input(shape=[shape_,shape_,1])
        conv_[i]   = keras.layers.Conv2D(filters=n_filt,kernel_size=3,strides=1,activation='relu',padding='SAME',input_shape=[shape_,shape_,1])(input_[i])
        #-----------------------------------------------------------------------
        batch0_[i] = keras.layers.BatchNormalization()(conv_[i])
        #-----------------------------------------------------------------------
        pool_[i]   = keras.layers.MaxPooling2D(pool_size=2)(batch0_[i])
        #-----------------------------------------------------------------------
    concat  = keras.layers.concatenate(pool_)
    flat    = keras.layers.Flatten()(concat)
    #---------------------------------------------------------------------------
    batch1  = keras.layers.BatchNormalization()(flat)
    hidden1 = keras.layers.Dense(units=units,activation='relu',kernel_regularizer=keras.regularizers.l1_l2(l1=l1p, l2=l2p), kernel_constraint=max_norm(1.))(batch1)#
    batch2  = keras.layers.BatchNormalization()(hidden1)
    output  = keras.layers.Dense(units=2, activation='softmax', kernel_constraint=max_norm(1.))(batch2)#
    model   = keras.models.Model(inputs=input_, outputs=[output])
    #---------------------------------------------------------------------------
    learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay(lrate, 4000, power=1.0,cycle=False, name=None)
    opt     = keras.optimizers.Adam(learning_rate=learning_rate_fn) 
    model.compile(loss='mean_squared_error', optimizer=opt, metrics=['accuracy'])
    return model
#-------------------------------------------------------------------------------

# Perform GradCAM

In [None]:
#attention maps wide models
#attention maps

from mpl_toolkits.axes_grid1 import make_axes_locatable

def centroid_(X):
   D = pairwise_distances(X, X.mean(axis=0).reshape(1,-1))
   inertia_ = D.mean()
   return np.argmin(D),inertia_



def plot_attention(tmpr_,rel_model_name,layer_name,list_class,figsize=(10,5), transpose=False):
    
    names_feats = [r'CWT-$\mu$-TW1',r'CWT-$\beta$-TW1',r'CWT-$\mu$-TW2',r'CWT-$\beta$-TW2',r'CWT-$\mu$-TW3',r'CWT-$\beta$-TW3',r'CWT-$\mu$-TW4',r'CWT-$\beta$-TW4',r'CWT-$\mu$-TW5',r'CWT-$\beta$-TW5',
                   r'CSP-$\mu$-TW1',r'CSP-$\beta$-TW1',r'CSP-$\mu$-TW2',r'CSP-$\beta$-TW2',r'CSP-$\mu$-TW3',r'CSP-$\beta$-TW3',r'CSP-$\mu$-TW4',r'CSP-$\beta$-TW4',r'CSP-$\mu$-TW5',r'CSP-$\beta$-TW5']
    if transpose:
      x_label_list = layer_name 
      nC = len(list_class)
      nl = len(layer_name)
      ncols,nrows = tmpr_.shape

      y_label_list = []
      for ii in range(nC):
          y_label_list += str(list_class[ii])

      dw = nrows/nl
      list_xticks = []
      for ii in range(nl):
        list_xticks += [int(dw*(0.5+ii))]
      dw = ncols/nC
      list_yticks = []
      for ii in range(nC):
        list_yticks += [int(dw*(0.5+ii))]

    else:
      y_label_list = layer_name 
      nC = len(list_class)
      nl = len(layer_name)
      nrows,ncols = tmpr_.shape

      x_label_list = []
      for ii in range(nC):
          x_label_list += str(list_class[ii])

      dw = nrows/nl
      list_yticks = []
      for ii in range(nl):
        list_yticks += [int(dw*(0.5+ii))]
      dw = ncols/nC
      list_xticks = []
      for ii in range(nC):
        list_xticks += [int(dw*(0.5+ii))]
    
    plt.figure(figsize=figsize)
    ax = plt.gca()
    im = ax.imshow(tmpr_)
    im = ax.imshow(tmpr_)
    ax.set_yticks(list_yticks)
    ax.set_yticklabels(y_label_list)
    ax.set_xticks(list_xticks)
    ax.set_xticklabels(names_feats, rotation='vertical')
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.1)
        
    plt.colorbar(im, cax=cax,extend='both',
                 ticks=[np.round(tmpr_.min(),3), np.round(0.5*(tmpr_.max()-tmpr_.min()),3), np.round(tmpr_.max(),3)])
    plt.xticks(rotation=90)
    #plt.savefig('/content/drive/MyDrive/Colab Notebooks/GradCam_Paper/GigaData/results/resulting_attention_maps/attention_map_'+str(n_sbj[sbj])+'_'+rel_model_name+'.svg', format='svg')
       
    plt.tight_layout()    
    plt.show()

import cv2
def attention_wide(modelw,rel_model_name,layer_name,X_train,y_train,
                   normalize_cam=False,norm_max_min=False,norm_c=True,
                   plot_int=False,centroid_=False,smooth_samples=20,
                   smooth_noise=0.20,transpose=False):
    #-------------------------------------------------------------------------------
    # define trial sample to visualize
    # change activations of last layer by linear
    replace2linear = ReplaceToLinear()
    #relevance model
    
    if rel_model_name == 'Weights':
      #[topo_avg_muT_cwt,topo_avg_beT_cwt,topo_avg_muT_csp,topo_avg_beT_csp]
      path='/content/drive/MyDrive/Colab Notebooks/GradCam_Paper/GigaData/results/matrix_data/WeightsRel_sbj_'+str(n_sbj[sbj])+'_fold_'+str(opt_fld[sbj])+'.pickle'
      with open(path, 'rb') as f:
         w_data = pickle.load(f)
      
      for i in range(5):
        if i ==0:
          amw_cwt = cv2.resize(w_data[0][i,:,:],(40, 40),interpolation = cv2.INTER_NEAREST)
        else: 
          amw_cwt = np.c_[amw_cwt,cv2.resize(w_data[0][i,:,:],(40, 40),interpolation = cv2.INTER_NEAREST)]
        amw_cwt = np.c_[amw_cwt,cv2.resize(w_data[1][i,:,:],(40, 40),interpolation = cv2.INTER_NEAREST)]
      
      for i in range(5):
        if i ==0:
          amw_csp = cv2.resize(w_data[2][i,:,:],(40, 40),interpolation = cv2.INTER_NEAREST)
        else: 
          amw_csp = np.c_[amw_csp,cv2.resize(w_data[2][i,:,:],(40, 40),interpolation = cv2.INTER_NEAREST)]
        amw_csp   = np.c_[amw_csp,cv2.resize(w_data[3][i,:,:],(40, 40),interpolation = cv2.INTER_NEAREST)]
  
      amw = np.concatenate((amw_cwt,amw_csp),axis=1)
      amw = np.r_[amw,amw]
      relM = [None]*len(np.unique(y_train))
      #---------------------------------------------------------------------------
      tmpr = amw/(1e-8+amw.max())
      #---------------------------------------------------------------------------

    else:
      if rel_model_name == 'Gradcam':
          gradcamw = Gradcam(modelw,
                          model_modifier=replace2linear,
                          clone=True)
      elif rel_model_name == 'Gradcam++':
          gradcamw = GradcamPlusPlus(modelw,
                                model_modifier=replace2linear,
                                clone=True) 
          
      elif rel_model_name == 'Scorecam':
          scorecamw = Scorecam(modelw)
          
      elif rel_model_name == 'Saliency':
            saliencyw = Saliency(modelw,
                                model_modifier=replace2linear,
                                clone=True)
            layer_name = [''] #saliency doesn't depend on different layers    
      nC = len(np.unique(y_train))
      relM = [None]*nC
      if type(X_train)==list:
          n_inputs = len(X_train)
          new_input = [None]*n_inputs

      for c in range(len(np.unique(y_train))):  
        id_sample = y_train == np.unique(y_train)[c]

        if (type(X_train)==list) and (rel_model_name != 'Saliency'):
          relM[c] = np.zeros((sum(id_sample),X_train[0].shape[1],X_train[0].shape[2],len(layer_name)))
          #print(1,relM[c].shape)
        elif (type(X_train)==list) and (rel_model_name == 'Saliency'):   
          relM[c] = np.zeros((sum(id_sample),X_train[0].shape[1],X_train[0].shape[2],len(X_train)))
          #print(2,relM[c].shape)        
        else:
          relM[c] = np.zeros((sum(id_sample),X_train.shape[1],X_train.shape[2],len(layer_name)))
          #print(3,relM[c].shape)
        score = CategoricalScore(list(y_train[id_sample])) #-> [0] para probar a una clase diferente
        if type(X_train)==list:
            for ni in range(n_inputs):
                new_input[ni] = X_train[ni][id_sample]
        else:
          new_input = X_train[id_sample]        
        #print('rel',rel_model_name,'layer',layer_name[l])
        for l in range(len(layer_name)):
            #print(rel_model_name,'class', np.unique(y_train)[c],'layer',layer_name[l])
        # label score -> target label accoring to the database
        #-----------------------------------------------------------------------------
        # generate heatmap with GradCAM
            if (rel_model_name == 'Gradcam') or (rel_model_name == 'Gradcam++'):
                rel = gradcamw(score,
                            new_input,
                            penultimate_layer=layer_name[l], #layer to be analized
                            expand_cam=True,
                            normalize_cam=normalize_cam)
            elif rel_model_name == 'Saliency': #saliency map is too noisy, so let’s remove noise in the saliency map using SmoothGrad!
                  rel = saliencyw(score, new_input,smooth_samples=smooth_samples,
                                  smooth_noise=smooth_noise,
                                  normalize_map=normalize_cam) #, smooth_samples=20,smooth_noise=0.20) # The number of calculating gradients iterations.
                              
            elif rel_model_name == 'Scorecam':     
                rel = scorecamw(score, new_input, penultimate_layer=layer_name[l], #layer to be analized
                            expand_cam=True,
                            normalize_cam=normalize_cam) #max_N=10 -> faster scorecam
        
            #save model

            if rel_model_name != 'Saliency':
              if type(X_train)==list: 
                tcc = rel[0]
              else: 
                tcc = rel
              dimc = tcc.shape
              tccv = tcc.ravel()
              tccv[np.isnan(tccv)] = 0
              tcc = tccv.reshape(dimc)
              if norm_max_min: #normalizing along samples
                tcc = MinMaxScaler().fit_transform(tcc.reshape(dimc[0],-1).T).T
                tcc = tcc.reshape(dimc)
              relM[c][...,l] = tcc
              if l==0: 
                tmp = np.median(relM[c][...,l],axis=0)#relM[c][...,l].mean(axis=0)
              else: 
                if transpose:
                  tmp = np.c_[tmp,np.median(relM[c][...,l],axis=0)]#np.r_[tmp,relM[c][...,l].mean(axis=0)]  #centroid
                else:  
                  tmp = np.r_[tmp,np.median(relM[c][...,l],axis=0)]#np.r_[tmp,relM[c][...,l].mean(axis=0)]  #centroid
            else: #saliency
              if type(X_train)==list: 
                tcc = np.zeros((rel[0].shape[0],rel[0].shape[1],rel[0].shape[2],len(rel)))
                for ii in range(len(rel)):
                    tcc[...,ii] = rel[ii]
              else: 
                tcc = rel
              dimc = tcc.shape
              tccv = tcc.ravel()
              tccv[np.isnan(tccv)] = 0
              tcc = tccv.reshape(dimc)
              if norm_max_min: #normalizing along samples
                tcc = MinMaxScaler().fit_transform(tcc.reshape(dimc[0],-1).T).T
                tcc = tcc.reshape(dimc)
              relM[c] = tcc
              if type(X_train)==list: 
                tmp = np.median(tcc[...,0],axis=0)
                for ii in range(len(rel)-1):
                    if transpose: 
                      tmp = np.c_[tmp,np.median(tcc[...,ii+1],axis=0)]
                    else:
                      tmp = np.r_[tmp,np.median(tcc[...,ii+1],axis=0)]
              else:
                tmp = np.median(tcc,axis=0)
                  
        if norm_c: #normalizing along layers
          tmp = tmp/(1e-8+tmp.max())
        if c==0: 
          tmpr = tmp
        else:  
          if transpose:
            tmpr = np.r_[tmpr,tmp]  
          else:
            tmpr = np.c_[tmpr,tmp]  
        #print(tmp.shape,tmp.max())    
        if plot_int: #plot every class
          plt.imshow(tmp)
          plt.colorbar(orientation='horizontal')
          plt.axis('off')
          plt.show()

      #---------------------------------------------------------------------------
      tmpr = tmpr/(1e-8+tmpr.max())
      #---------------------------------------------------------------------------
    
    list_class = np.unique(y_train)
    plot_attention(tmpr,rel_model_name,layer_name,list_class,transpose=transpose)

    return relM,tmpr

In [None]:
#-------------------------------------------------------------------------------
# define parameters
partitions    = ['train','valid','test']
names_x       = [r'-1.5s-0.5s',r'$-0.5s-1.5s$',r'$0.5s-2.5s$',r'$1.5s-3.5s$',r'$2.5s-4.5s$']
learning_rate = 1e-4 
th_name       = np.array([[-1.5, 0.5],[-0.5, 1.5],[0.5, 2.5],[1.5, 3.5],[2.5, 4.5]]) 
n_fb          = 2
Ntw           = 5                             
Nkfeats       = 2
num_classes   = 2                              
n_filt        = 2 
n_fld         = 3
n_conv_layers = 20
#-------------------------------------------------------------------------------       
n_sbj         = [41]#[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,30,31,32,33,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52]
opt_neurons   = [200]#[100,200,100,300,200,100,300,200,100,100,200,200,200,300,300,100,100,100,300,200,300,300,200,100,100,300,200,300,300,200,300,200,300,300,100,200,300,300,200,100,200,200,100,300,300,100,100,300,100,300]
opt_l1        = [0.005]#[0.0005,0.0005,0.005,0.005,0.001,0.001,0.0005,0.0005,0.0005,0.005,0.005,0.005,0.005,0.005,0.005,0.0005,0.0005,0.001,0.0005,0.0005,0.0005,0.0005,0.005,0.001,0.001,0.005,0.005,0.0005,0.0005,0.001,0.005,0.001,0.001,0.005,0.005,0.001,0.005,0.005,0.005,0.001,0.005,0.0005,0.005,0.005,0.0005,0.005,0.0005,0.005,0.005,0.005]
opt_l2        = [0.0005]#[0.005,0.001,0.0005,0.005,0.005,0.001,0.0005,0.005,0.005,0.001,0.005,0.005,0.001,0.005,0.001,0.001,0.005,0.0005,0.0005,0.0005,0.0005,0.005,0.0005,0.005,0.005,0.0005,0.001,0.0005,0.0005,0.005,0.0005,0.005,0.0005,0.005,0.005,0.001,0.0005,0.001,0.0005,0.005,0.001,0.0005,0.001,0.0005,0.005,0.001,0.001,0.005,0.0005,0.001]
opt_fld       = [1]#[3,1,1,3,3,1,2,2,3,2,1,2,1,2,1,3,2,1,2,1,1,1,3,1,1,2,1,3,3,1,1,2,2,1,1,3,2,1,1,3,2,2,3,1,2,1,1,3,1,1]
#-------------------------------------------------------------------------------
for sbj in range(len(n_sbj)):
  print('subject ', n_sbj[sbj])
  #-----------------------------------------------------------------------------
  # load data train/test trough all tw
  XF_train_cwt = []
  XF_train_csp = []
  XF_test_cwt  = []
  XF_test_csp  = []
  for i in range(th_name.shape[0]):
    X_train_re_cwt, X_train_re_csp, X_test_re_cwt, X_test_re_csp, y_trainF, y_testF = TW_data(n_sbj[sbj],th_name[i,0],th_name[i,1])
    XF_train_cwt.append(X_train_re_cwt)
    XF_train_csp.append(X_train_re_csp)
    XF_test_cwt.append(X_test_re_cwt)
    XF_test_csp.append(X_test_re_csp)
  #-----------------------------------------------------------------------------
  # partition of data
  XT_train, XT_valid, XT_test, y_train, y_valid, y_test, train_index, valid_index = norm_data(XF_train_cwt, XF_train_csp, XF_test_cwt, XF_test_csp, n_fb, Ntw, y_trainF, y_testF, opt_fld[sbj]-1)
  #-----------------------------------------------------------------------------
  # define model
  model = cnn_network(n_fb,Nkfeats,Ntw,40,n_filt,opt_neurons[sbj],opt_l1[sbj],opt_l2[sbj],learning_rate,n_sbj[sbj])
  #-----------------------------------------------------------------------------
  tf.keras.utils.plot_model(model)
  #-----------------------------------------------------------------------------
  # loading best model weights
  filepath        = '/content/drive/MyDrive/Colab Notebooks/GradCam_Paper/GigaData/results/parameter_setting/weights_sbj_'+str(n_sbj[sbj])+'_filters_2_units_'+str(int(opt_neurons[sbj]))+'_l1_'+str(opt_l1[sbj])+'_l2_'+str(opt_l2[sbj])+'_fld_'+str(opt_fld[sbj])+'.hdf5'
  checkpoint_path = filepath
  model.load_weights(checkpoint_path)
  #-----------------------------------------------------------------------------
  rel_model_name = ['Gradcam++','Scorecam','Saliency'] #,'Gradcam++','Scorecam','Saliency'
  layer_name     = ['conv2d','conv2d_1','conv2d_2','conv2d_3','conv2d_4','conv2d_5','conv2d_6','conv2d_7','conv2d_8','conv2d_9','conv2d_10',
                   'conv2d_11','conv2d_12','conv2d_13','conv2d_14','conv2d_15','conv2d_16','conv2d_17','conv2d_18','conv2d_19']
  # 
  print('norm_c = False')
  relM_ = [None]*len(rel_model_name) #relM[m] -> number classes x input image resolution x number of layers 
  tmpr_ = [None]*len(rel_model_name) 
  for m in range(len(rel_model_name)):
    relM_[m],tmpr_[m] = attention_wide(model,rel_model_name[m],layer_name,XT_train,np.argmax(y_train,axis=1),
                                      norm_c=False,norm_max_min=False,plot_int=False,transpose=True)
  #-----------------------------------------------------------------------------
  with open('/content/drive/MyDrive/Colab Notebooks/GradCam_Paper/GigaData/results/resulting_attention_maps/score_attmaps_'+str(n_sbj[sbj])+'.pickle', 'wb') as f:
            pickle.dump([relM_, tmpr_], f)
  #-----------------------------------------------------------------------------
  del model
  #-----------------------------------------------------------------------------