In [None]:
%%capture
!pip install -U git+https://github.com/UN-GCPDS/python-gcpds.databases
!pip install -U git+https://github.com/UN-GCPDS/python-gcpds.visualizations.git >/dev/null

In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
from gcpds.visualizations.topoplots import topoplot
import matplotlib.pyplot as plt
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.model_selection import StratifiedShuffleSplit
from scipy.signal import freqz, filtfilt, resample
from scipy.signal import butter as bw

from gcpds.databases import GIGA_MI_ME
# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
channels = ['Fp1','Fpz','Fp2',
            'AF7','AF3','AFz','AF4','AF8',
            'F7','F5','F3','F1','Fz','F2','F4','F6','F8',
            'FT7','FC5','FC3','FC1','FCz','FC2','FC4','FC6','FT8',
            'T7','C5','C3','C1','Cz','C2','C4','C6','T8',
            'TP7','CP5','CP3','CP1','CPz','CP2','CP4','CP6','TP8',
            'P9','P7','P5','P3','P1','Pz','P2','P4','P6','P8','P10',
            'PO7','PO3','POz','PO4','PO8',
            'O1','Oz','O2',
            'Iz']

areas = {
    'Frontal': ['Fpz', 'AFz', 'Fz', 'FCz'],
    'Frontal Right': ['Fp2','AF4','AF8','F2','F4','F6','F8',],
    'Central Right': ['FC2','FC4','FC6','FT8','C2','C4','C6','T8','CP2','CP4','CP6','TP8',],
    'Posterior Right': ['P2','P4','P6','P8','P10','PO4','PO8','O2',],
    #'Central': ['Cz'],
    'Posterior': ['CPz','Pz', 'Cz','POz','Oz','Iz',],
    'Posterior Left': ['P1','P3','P5','P7','P9','PO3','PO7','O1',],
    'Central Left': ['FC1','FC3','FC5','FT7','C1','C3','C5','T7','CP1','CP3','CP5','TP7',],
    'Frontal Left': ['Fp1','AF3','AF7','F1','F3','F5','F7',],
}

arcs = [
    #'hemispheres',
    'areas',
    'channels',
]

In [None]:
def butterworth_digital_filter(X, N, Wn, btype, fs, axis=-1, padtype=None, padlen=0, method='pad', irlen=None):
  """
  Apply digital butterworth filter
  INPUT
  ------
  1. X: (D array)
    array with signals.
  2. N: (int+)
    The order of the filter.
  3. Wn: (float+ or 1D array)
    The critical frequency or frequencies. For lowpass and highpass filters, Wn is a scalar; for bandpass and bandstop filters, Wn is a length-2 vector.
    For a Butterworth filter, this is the point at which the gain drops to 1/sqrt(2) that of the passband (the “-3 dB point”).
    If fs is not specified, Wn units are normalized from 0 to 1, where 1 is the Nyquist frequency (Wn is thus in half cycles / sample and defined as 2*critical frequencies / fs). If fs is specified, Wn is in the same units as fs.
  4. btype: (str) {‘lowpass’, ‘highpass’, ‘bandpass’, ‘bandstop’}
    The type of filter
  5. fs: (float+)
    The sampling frequency of the digital system.
  6. axis: (int), Default=1.
    The axis of x to which the filter is applied.
  7. padtype: (str) or None, {'odd', 'even', 'constant'}
    This determines the type of extension to use for the padded signal to which the filter is applied. If padtype is None, no padding is used. The default is ‘odd’.
  8. padlen: (int+) or None, Default=0
    The number of elements by which to extend x at both ends of axis before applying the filter. This value must be less than x.shape[axis] - 1. padlen=0 implies no padding.
  9. method: (str), {'pad', 'gust'}
    Determines the method for handling the edges of the signal, either “pad” or “gust”. When method is “pad”, the signal is padded; the type of padding is determined by padtype
    and padlen, and irlen is ignored. When method is “gust”, Gustafsson’s method is used, and padtype and padlen are ignored.
  10. irlen: (int) or None, Default=nONE
    When method is “gust”, irlen specifies the length of the impulse response of the filter. If irlen is None, no part of the impulse response is ignored.
    For a long signal, specifying irlen can significantly improve the performance of the filter.
  OUTPUT
  ------
  X_fil: (D array)
    array with filtered signals.
  """
  b, a = bw(N, Wn, btype, analog=False, output='ba', fs=fs)
  return filtfilt(b, a, X, axis=axis, padtype=padtype, padlen=padlen, method=method, irlen=irlen)

class TimeFrequencyRpr(BaseEstimator, TransformerMixin):
  """
  Time frequency representation of EEG signals.

  Parameters
  ----------
    1. sfreq:  (float) Sampling frequency in Hz.
    2. f_bank: (2D array) Filter banks Frequencies. Default=None
    3. vwt:    (2D array) Interest time windows. Default=None
  Methods
  -------
    1. fit(X, y=None)
    2. transform(X, y=None)
  """
  def __init__(self, sfreq, f_bank=None, vwt=None):
    self.sfreq = sfreq
    self.f_bank = f_bank
    self.vwt = vwt
# ------------------------------------------------------------------------------

  def _validation_param(self):
    """
    Validate Time-Frequency characterization parameters.
    INPUT
    -----
      1. self
    ------
      2. None
    """
    if self.sfreq <= 0:
      raise ValueError('Non negative sampling frequency is accepted')


    if self.f_bank is None:
      self.flag_f_bank = False
    elif self.f_bank.ndim != 2:
      raise ValueError('Band frequencies have to be a 2D array')
    else:
      self.flag_f_bank = True

    if self.vwt is None:
      self.flag_vwt = False
    elif self.vwt.ndim != 2:
      raise ValueError('Time windows have to be a 2D array')
    else:
      self.flag_vwt = True

# ------------------------------------------------------------------------------
  def _filter_bank(self, X):
    """
    Filter bank Characterization.
    INPUT
    -----
      1. X: (3D array) set of EEG signals, shape (trials, channels, time_samples)
    OUTPUT
    ------
      1. X_f: (4D array) set of filtered EEG signals, shape (trials, channels, time_samples, frequency_bands)
    """
    X_f = np.zeros((X.shape[0], X.shape[1], X.shape[2], self.f_bank.shape[0])) #epochs, Ch, Time, bands
    for f in np.arange(self.f_bank.shape[0]):
      X_f[:,:,:,f] = butterworth_digital_filter(X, N=5, Wn=self.f_bank[f], btype='bandpass', fs=self.sfreq)
    return X_f

# ------------------------------------------------------------------------------
  def _sliding_windows(self, X):
    """
    Sliding Windows Characterization.
    INPUT
    -----
      1. X: (3D array) set of EEG signals, shape (trials, channels, time_samples)
    OUTPUT
    ------
      1. X_w: (4D array) shape (trials, channels, window_time_samples, number_of_windows)
    """
    window_lenght = int(self.sfreq*self.vwt[0,1] - self.sfreq*self.vwt[0,0])
    X_w = np.zeros((X.shape[0], X.shape[1], window_lenght, self.vwt.shape[0]))
    for w in np.arange(self.vwt.shape[0]):
        X_w[:,:,:,w] = X[:,:,int(self.sfreq*self.vwt[w,0]):int(self.sfreq*self.vwt[w,1])]
    return X_w

# ------------------------------------------------------------------------------
  def fit(self, X, y=None):
    """
    fit.
    INPUT
    -----
      1. X: (3D array) set of EEG signals, shape (trials, channels, time_samples)
      2. y: (1D array) target labels. Default=None
    OUTPUT
    ------
      1. None
    """
    pass

# ------------------------------------------------------------------------------
  def transform(self, X, y=None):
    """
    Time frequency representation of EEG signals.
    INPUT
    -----
      1. X: (3D array) set of EEG signals, shape (trials, channels, times)
    OUTPUT
    ------
      1. X_wf: (5D array) Time-frequency representation of EEG signals, shape (trials, channels, window_time_samples, number_of_windows, frequency_bands)
    """
    self._validation_param()     #Validate sfreq, f_freq, vwt

    #Avoid edge effects of digital filter, 1st:fbk, 2th:vwt
    if self.flag_f_bank:
        X_f = self._filter_bank(X)
    else:
        X_f = X[:,:,:,np.newaxis]

    if self.flag_vwt:
      X_wf = []
      for f in range(X_f.shape[3]):
        X_wf.append(self._sliding_windows(X_f[:,:,:,f]))
      X_wf = np.stack(X_wf, axis=-1)
    else:
      X_wf = X_f[:,:,:,np.newaxis,:]

    return X_wf

def kappa(y_true, y_pred):
    return cohen_kappa_score(np.argmax(y_true, axis = 1),np.argmax(y_pred, axis = 1))

# def minmax_norm(x):
#     x = x - np.min(x)
#     x = x/np.max(x)
#     return x

In [None]:
def load_GIGA(db,
              sbj,
              eeg_ch_names,
              fs,
              f_bank,
              vwt,
              new_fs,
              run=None):

    index_eeg_chs = db.format_channels_selectors(channels = eeg_ch_names) - 1

    tf_repr = TimeFrequencyRpr(sfreq = fs, f_bank = f_bank, vwt = vwt)

    db.load_subject(sbj)
    if run == None:
        X, y = db.get_data(classes = ['left hand mi', 'right hand mi']) #Load MI classes, all channels {EEG}, reject bad trials, uV
    else:
        X, y = db.get_run(run, classes = ['left hand mi', 'right hand mi']) #Load MI classes, all channels {EEG}, reject bad trials, uV
    X = X[:, index_eeg_chs, :] #spatial rearrangement
    X = np.squeeze(tf_repr.transform(X))
    #Resampling
    if new_fs == fs:
        pass#print('No resampling, since new sampling rate same.')
    else:
        print("Resampling from {:f} to {:f} Hz.".format(fs, new_fs))
        X = resample(X, int((X.shape[-1]/fs)*new_fs), axis = -1)

    return X, y

In [None]:
from sklearn.neighbors import KernelDensity

def shannon_entropy_kde(vector, bandwidth=1.0):
    # Reshape the vector for KDE (it expects a 2D array)
    vector = np.array(vector).reshape(-1, 1)
    
    # Perform kernel density estimation
    kde = KernelDensity(bandwidth=bandwidth)
    kde.fit(vector)
    
    # Generate a range of values to estimate the density over
    min_val, max_val = vector.min(), vector.max()
    grid = np.linspace(min_val, max_val, 1000).reshape(-1, 1)
    
    # Estimate the log density for each value in the grid
    log_density = kde.score_samples(grid)
    
    # Convert log density to actual probability density
    density = np.exp(log_density)
    
    # Normalize the density so that it sums to 1
    density /= np.sum(density)
    
    # Calculate Shannon entropy
    entropy = -np.sum(density * np.log2(density + np.finfo(float).eps))  # Add epsilon to avoid log(0)

    return entropy

def entropy_of_matrix_columns_kde(matrix, bandwidth=1.0):
    num_columns = matrix.shape[1]
    entropies = []
    
    for col in range(num_columns):
        column = matrix[:, col]
        entropy = shannon_entropy_kde(column, bandwidth=bandwidth)
        entropies.append(entropy)
    
    return entropies


In [None]:
db = GIGA_MI_ME('/kaggle/input/giga-science-gcpds/GIGA_MI_ME')
load_args = dict(db = db,
                 eeg_ch_names = channels,
                 fs = db.metadata['sampling_rate'],
                 f_bank = np.asarray([[4., 40.]]),
                 vwt = np.asarray([[2.5, 5]]),
                 new_fs = 128.)

In [None]:
#subjects groups
g1 = [0, 2, 3, 4, 5, 9, 12, 13, 22, 25, 27, 32, 34, 38, 40, 41, 43, 44, 45, 46, 47]
g2 = [7, 8, 11, 14, 17, 19, 20, 21, 23, 24, 29, 33, 36, 39, 42, 48, 49]
g3 = [1, 6, 10, 15, 16, 18, 26, 28, 30, 31, 35, 37]
g_ = [g1,g2,g3]

In [None]:
import os
plot_ = True
fs = 128
num_classes = 2
num_channels = 64
num_topo = 4

subjects = np.arange(52)+1
subjects = np.delete(subjects,[28,33]) # delete 29 and 34

num_sbj = len(subjects)

S_mean_topomaps_per_ch_ = np.zeros((num_sbj,num_channels,num_topo,num_classes))
S_mean_topomaps_false_true = np.zeros((num_sbj,num_channels,num_classes+1,num_classes))
S_entropy = np.zeros((num_sbj,num_topo,num_classes))


model_ = ['EEGNet', 'KREEGNet', 'KCS-FCNet', 'DeepConv', 'ShallowConv', 'TCFusion']



for mod_ in model_:
    try:
        os.mkdir(mod_)
    except:
        pass

    path_false = '/kaggle/input/true-false-cams/cams/False_Cams_'+mod_+'.npy'
    path_true = '/kaggle/input/true-false-cams/cams/True_Cams_'+mod_+'.npy'
    false_cams = np.load(path_false)
    true_cams = np.load(path_true)

    for ii, sbj in enumerate(subjects):
        #sbj = 43
        print(f'sbj {ii+1}/{num_sbj}: Num#: {sbj}')
        sbj_index = np.where(subjects==sbj)[0]
        X_train, y_train = load_GIGA(sbj=sbj, **load_args)
        X_train = np.expand_dims(X_train, -1)

        SSS = StratifiedShuffleSplit(n_splits = 5, test_size=0.2, random_state = 23)
        partitions = list(SSS.split(X_train, y_train))

        test_trials = partitions[4][1]
        labels = y_train[test_trials]
        left_trials = np.where(y_train[test_trials]==0)[0]
        right_trials = np.where(y_train[test_trials]==1)[0]
        left_trials, right_trials


        num_trials = np.max((left_trials.shape[0], right_trials.shape[0]))
        vfreq = np.fft.rfftfreq(X_train.shape[2],1/fs)
        num_fft = len(vfreq)
        num_time = 320
        topomaps_per_trial = np.zeros((num_trials, num_channels, num_fft, num_classes))
        topomaps_per_ch_trial = np.zeros((num_trials, num_channels, num_topo,num_classes))
        topomaps_per_false_true = np.zeros((num_trials, num_channels,num_classes+1,num_classes))

        all_ = (vfreq >=4) & (vfreq<=40) # 4 - 40
        theta_ = (vfreq >=4) & (vfreq<8)
        alpha_ = (vfreq >=8) & (vfreq<13)
        beta_ = (vfreq >=13) & (vfreq<=30)

        for i, side in enumerate([left_trials, right_trials]):
            for j, trial in enumerate(side):
                true_cam_ = true_cams[sbj_index[0], test_trials[trial]]
                false_cam_ = false_cams[sbj_index[0], test_trials[trial]]
                
                #filtro y reconstruccion
                true_cam_ = np.fft.rfft(true_cam_,axis=1)
                true_cam_[:,~all_] = 0
                true_cam_ = abs(np.fft.irfft(true_cam_,axis=1))
                false_cam_ = np.fft.rfft(false_cam_,axis=1)
                false_cam_[:,~all_] = 0
                false_cam_ = abs(np.fft.irfft(false_cam_,axis=1))
                
                max_val = np.max(np.concatenate((true_cam_, false_cam_), axis=0))

                true_cam_ /= max_val
                false_cam_ /= max_val
                diff_cam = true_cam_-false_cam_
                diff_cam /= np.max(abs(diff_cam))
                #diff_cam = 2*diff_cam + 1
                
                
           
                topomaps_per_false_true[j,:,0,i] = false_cam_.mean(axis=1)
                topomaps_per_false_true[j,:,1,i] = true_cam_.mean(axis=1)
                topomaps_per_false_true[j,:,2,i] = diff_cam.mean(axis=1)
                topomaps_per_trial[j, :, :, i] = abs(np.fft.rfft(diff_cam,axis=1))
                topomaps_per_trial[j, :, :, i] /= np.max(topomaps_per_trial[j, :, :, i])
                topomaps_per_ch_trial[j,:,0,i] = topomaps_per_trial[j, :, all_, i].mean(axis=0) 
                topomaps_per_ch_trial[j,:,1,i] = topomaps_per_trial[j, :, theta_, i].mean(axis=0) 
                topomaps_per_ch_trial[j,:,2,i] = topomaps_per_trial[j, :, alpha_, i].mean(axis=0) 
                topomaps_per_ch_trial[j,:,3,i] = topomaps_per_trial[j, :, beta_, i].mean(axis=0) 
                topomaps_per_ch_trial[j,:,:,i] /= np.max(topomaps_per_ch_trial[j,:,:,i])

        #trials mean
        mean_topo_false_true = topomaps_per_false_true.mean(axis=0)
        #mean_topo_false_true -= np.min(mean_topo_false_true)
        mean_topo_false_true /= np.max(mean_topo_false_true)
        
        mean_topomaps_per_ch_ = topomaps_per_ch_trial.mean(axis=0)
        #mean_topomaps_per_ch_ -= np.min(mean_topomaps_per_ch_)
        mean_topomaps_per_ch_ /= np.max(mean_topomaps_per_ch_)
        
        #mean_topomaps_per_ch_ = 2*mean_topomaps_per_ch_ +1
        #mean_topo_false_true = 2*mean_topo_false_true +1
        vmin_ = 0
        vmax_ = 1


        entropy_ = np.zeros((num_topo,num_classes))

        for c in range(num_classes):
            entropy_[:,c] = entropy_of_matrix_columns_kde(mean_topomaps_per_ch_[:,:,c],bandwidth=0.1)       

        S_mean_topomaps_false_true[ii] = mean_topo_false_true
        S_mean_topomaps_per_ch_[ii] = mean_topomaps_per_ch_
        S_entropy[ii] = entropy_


        if plot_:
            fig, ax = plt.subplot_mosaic([['false_l','true_l',],#'diff_l','og_l','theta_l', 'alpha_l', 'beta_l'],
                                          ['false_r','true_r',]],#'diff_r','og_r','theta_r', 'alpha_r', 'beta_r']],
                                         figsize=(4,8),gridspec_kw={'wspace': 0.02,'hspace': -0.55})
            img,_=topoplot(mean_topo_false_true[:,0,0], channels, cmap='viridis',ax=ax['false_l'], vlim=(vmin_, vmax_), show=False)
            topoplot(mean_topo_false_true[:,1,0], channels, cmap='viridis',ax=ax['true_l'], vlim=(vmin_, vmax_), show=False)
            #topoplot(mean_topo_false_true[:,2,0], channels, cmap='viridis',ax=ax['diff_l'], vlim=(vmin_, vmax_), show=False)

            #topoplot(mean_topomaps_per_ch_[:,0,0], channels, cmap='viridis',ax=ax['og_l'], vlim=(vmin_, vmax_), show=False)
            #topoplot(mean_topomaps_per_ch_[:,1,0], channels, cmap='viridis',ax=ax['theta_l'],vlim=(vmin_, vmax_), show=False)
            #topoplot(mean_topomaps_per_ch_[:,2,0], channels, cmap='viridis',ax=ax['alpha_l'],vlim=(vmin_, vmax_), show=False)
            #topoplot(mean_topomaps_per_ch_[:,3,0], channels, cmap='viridis',ax=ax['beta_l'],vlim=(vmin_, vmax_), show=False)

            ax['false_l'].set_title(f'False', size=20)
            ax['true_l'].set_title(f'True', size=20)
            #ax['diff_l'].set_title(f'Diff', size=20)
            #ax['og_l'].set_title(f'[4-40]', size=20)
            #ax['theta_l'].set_title(f'[4-8]', size=20)
            #ax['alpha_l'].set_title(f'[8-13]', size=20)
            #ax['beta_l'].set_title(f'[13-30]', size=20)

            topoplot(mean_topo_false_true[:,0,1], channels, cmap='viridis',ax=ax['false_r'], vlim=(vmin_, vmax_), show=False)
            topoplot(mean_topo_false_true[:,1,1], channels, cmap='viridis',ax=ax['true_r'], vlim=(vmin_, vmax_), show=False)
            #topoplot(mean_topo_false_true[:,2,1], channels, cmap='viridis',ax=ax['diff_r'], vlim=(vmin_, vmax_), show=False)

            #topoplot(mean_topomaps_per_ch_[:,0,1], channels, cmap='viridis',ax=ax['og_r'], vlim=(vmin_, vmax_), show=False)
            #topoplot(mean_topomaps_per_ch_[:,1,1], channels, cmap='viridis',ax=ax['theta_r'],vlim=(vmin_, vmax_), show=False)
            #topoplot(mean_topomaps_per_ch_[:,2,1], channels, cmap='viridis',ax=ax['alpha_r'],vlim=(vmin_, vmax_), show=False)
            #topoplot(mean_topomaps_per_ch_[:,3,1], channels, cmap='viridis',ax=ax['beta_r'],vlim=(vmin_, vmax_), show=False)

            #Agrego colorbar asi para que la ultima cabeza no se ponga más pequeña
            cbar_ax = fig.add_axes([0.92, 0.3, 0.015, 0.4])
            fig.colorbar(img, cax=cbar_ax)
            cbar_ax.tick_params(labelsize=20)
            plt.savefig(f'{mod_}/Sbj{sbj}_{mod_}.pdf', bbox_inches='tight')
            #plt.show()
    for g in range(3):
        #mean_ = np.median(S_mean_topomaps_per_ch_[g_[g]],axis=0)
        mean_freq = np.mean(S_mean_topomaps_per_ch_[g_[g]],axis=0)
        mean_ft = np.mean(S_mean_topomaps_false_true[g_[g]],axis=0)
        mean_freq /= np.max(mean_freq)
        mean_ft /= np.max(mean_ft)

        fig, ax = plt.subplot_mosaic([['false_l','true_l',],#'diff_l','og_l','theta_l', 'alpha_l', 'beta_l'],
                                      ['false_r','true_r',]],#'diff_r','og_r','theta_r', 'alpha_r', 'beta_r']],
                                     figsize=(4,8),gridspec_kw={'wspace': 0.02,'hspace': -0.55})
        img,_=topoplot(mean_ft[:,0,0], channels, cmap='viridis',ax=ax['false_l'], vlim=(vmin_, vmax_), show=False)
        topoplot(mean_ft[:,1,0], channels, cmap='viridis',ax=ax['true_l'], vlim=(vmin_, vmax_), show=False)
        #topoplot(mean_ft[:,2,0], channels, cmap='viridis',ax=ax['diff_l'], vlim=(vmin_, vmax_), show=False)

        #topoplot(mean_freq[:,0,0], channels, cmap='viridis',ax=ax['og_l'], vlim=(vmin_, vmax_), show=False)
        #topoplot(mean_freq[:,1,0], channels, cmap='viridis',ax=ax['theta_l'],vlim=(vmin_, vmax_), show=False)
        #topoplot(mean_freq[:,2,0], channels, cmap='viridis',ax=ax['alpha_l'],vlim=(vmin_, vmax_), show=False)
        #topoplot(mean_freq[:,3,0], channels, cmap='viridis',ax=ax['beta_l'],vlim=(vmin_, vmax_), show=False)

        ax['false_l'].set_title(f'False'+'--G'+str(g+1), size=20)
        ax['true_l'].set_title(f'True', size=20)
        #ax['diff_l'].set_title(f'Diff', size=20)
        #ax['og_l'].set_title(f'[4-40]', size=20)
        #ax['theta_l'].set_title(f'[4-8]', size=20)
        #ax['alpha_l'].set_title(f'[8-13]', size=20)
        #ax['beta_l'].set_title(f'[13-30]', size=20)

        topoplot(mean_ft[:,0,1], channels, cmap='viridis',ax=ax['false_r'], vlim=(vmin_, vmax_), show=False)
        topoplot(mean_ft[:,1,1], channels, cmap='viridis',ax=ax['true_r'], vlim=(vmin_, vmax_), show=False)
        #topoplot(mean_ft[:,2,1], channels, cmap='viridis',ax=ax['diff_r'], vlim=(vmin_, vmax_), show=False)

        #topoplot(mean_freq[:,0,1], channels, cmap='viridis',ax=ax['og_r'], vlim=(vmin_, vmax_), show=False)
        #topoplot(mean_freq[:,1,1], channels, cmap='viridis',ax=ax['theta_r'],vlim=(vmin_, vmax_), show=False)
        #topoplot(mean_freq[:,2,1], channels, cmap='viridis',ax=ax['alpha_r'],vlim=(vmin_, vmax_), show=False)
        #topoplot(mean_freq[:,3,1], channels, cmap='viridis',ax=ax['beta_r'],vlim=(vmin_, vmax_), show=False)

        #Agrego colorbar asi para que la ultima cabeza no se ponga más pequeña
        cbar_ax = fig.add_axes([0.92, 0.3, 0.015, 0.4])
        fig.colorbar(img, cax=cbar_ax)
        cbar_ax.tick_params(labelsize=20)

        plt.savefig('Topo_'+mod_+'_G'+str(g+1)+".pdf", format="pdf", bbox_inches="tight")
        #plt.show()  

In [None]:
#true_cam_ = true_cams[sbj_index[0], test_trials[trial]]
#true_cam_f = np.fft.rfft(true_cam_,axis=1)
#true_cam_f[:,~all_] = 0
#true_cam_f = abs(np.fft.irfft(true_cam_f,axis=1))
#plt.plot(true_cam_f.T)
#plt.show()

#plt.plot(true_cam_.T)
#plt.show()

In [None]:
#S_mean_topomaps_per_ch_ = np.zeros((num_sbj,num_channels,num_topo,num_classes))
#S_entropy = np.zeros((num_sbj,num_topo,num_classes))
#file_ = 'cams_diff_norm'
#np.savez(file_,S_mean_topomap=S_mean_topomaps_per_ch_,entropy_cam=S_entropy,sbjs = subjects)