In [None]:
import argparse
import os
gpus = [1]
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, gpus))
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
import numpy as np
import math
import glob
import random
import itertools
import datetime
import time
import datetime
import sys
import scipy.io

import torchvision.transforms as transforms
from torchvision.utils import save_image, make_grid

import torch.utils.data as Data
from torch.utils.data import DataLoader
from torch.autograd import Variable
from torchsummary import summary
import torch.autograd as autograd
from torchvision.models import vgg19

import torch.nn as nn
import torch.nn.functional as F
import torch
import torch.nn.init as init

from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms
from sklearn.decomposition import PCA

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

from torch import nn
from torch import Tensor
from PIL import Image
from torchvision.transforms import Compose, Resize, ToTensor
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce
# from common_spatial_pattern import csp

import matplotlib.pyplot as plt
import mne
from matplotlib import mlab as mlab
from torch.backends import cudnn
from utils import GradCAM, show_cam_on_image

cudnn.benchmark = False
cudnn.deterministic = True

import myimporter
from BCI_functions import *

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import torch.utils.data as Data
import torch.nn.functional as F

import pandas as pd

In [None]:
# keep the overall model class, omitted here
class ViT(nn.Sequential):
    def __init__(self, emb_size=40, depth=2, n_classes=2, **kwargs):
        super().__init__(
            # ... the model
        )
        
# ! A crucial step for adaptation on Transformer
# reshape_transform  b 61 40 -> b 40 1 61
def reshape_transform(tensor):
    result = rearrange(tensor, 'b (h w) e -> b e (h) (w)', h=1)
    return result

In [None]:
# Convolution module
# use conv to capture local features, instead of postion embedding.
class PatchEmbedding(nn.Module):
    def __init__(self, emb_size=40):
        # self.patch_size = patch_size
        super().__init__()

        self.shallownet = nn.Sequential(
            nn.Conv2d(1, 40, (1, 25), (1, 1)),
            nn.Conv2d(40, 40, (7, 1), (1, 1)), # 22 when using 64 channels
            nn.BatchNorm2d(40),
            nn.ELU(),
            nn.AvgPool2d((1, 75), (1, 15)),  # pooling acts as slicing to obtain 'patch' along the time dimension as in ViT
            nn.Dropout(0.5),
        )

        self.projection = nn.Sequential(
            nn.Conv2d(40, emb_size, (1, 1), stride=(1, 1)),  # transpose, conv could enhance fiting ability slightly
            Rearrange('b e (h) (w) -> b (h w) e'),
        )


    def forward(self, x: Tensor) -> Tensor:
        b, _, _, _ = x.shape
        x = self.shallownet(x.float())
        x = self.projection(x)
        return x


class MultiHeadAttention(nn.Module):
    def __init__(self, emb_size, num_heads, dropout):
        super().__init__()
        self.emb_size = emb_size
        self.num_heads = num_heads
        self.keys = nn.Linear(emb_size, emb_size)
        self.queries = nn.Linear(emb_size, emb_size)
        self.values = nn.Linear(emb_size, emb_size)
        self.att_drop = nn.Dropout(dropout)
        self.projection = nn.Linear(emb_size, emb_size)

    def forward(self, x: Tensor, mask: Tensor = None) -> Tensor:
        queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.num_heads)
        keys = rearrange(self.keys(x), "b n (h d) -> b h n d", h=self.num_heads)
        values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads)
        energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)  
        if mask is not None:
            fill_value = torch.finfo(torch.float32).min
            energy.mask_fill(~mask, fill_value)

        scaling = self.emb_size ** (1 / 2)
        att = F.softmax(energy / scaling, dim=-1)
        att = self.att_drop(att)
        out = torch.einsum('bhal, bhlv -> bhav ', att, values)
        out = rearrange(out, "b h n d -> b n (h d)")
        out = self.projection(out)
        return out


class ResidualAdd(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        res = x
        x = self.fn(x, **kwargs)
        x += res
        return x


class FeedForwardBlock(nn.Sequential):
    def __init__(self, emb_size, expansion, drop_p):
        super().__init__(
            nn.Linear(emb_size, expansion * emb_size),
            nn.GELU(),
            nn.Dropout(drop_p),
            nn.Linear(expansion * emb_size, emb_size),
        )


class GELU(nn.Module):
    def forward(self, input: Tensor) -> Tensor:
        return input*0.5*(1.0+torch.erf(input/math.sqrt(2.0)))


class TransformerEncoderBlock(nn.Sequential):
    def __init__(self,
                 emb_size,
                 num_heads=10,
                 drop_p=0.5,
                 forward_expansion=4,
                 forward_drop_p=0.5):
        super().__init__(
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                MultiHeadAttention(emb_size, num_heads, drop_p),
                nn.Dropout(drop_p)
            )),
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                FeedForwardBlock(
                    emb_size, expansion=forward_expansion, drop_p=forward_drop_p),
                nn.Dropout(drop_p)
            )
            ))


class TransformerEncoder(nn.Sequential):
    def __init__(self, depth, emb_size):
        super().__init__(*[TransformerEncoderBlock(emb_size) for _ in range(depth)])


class ClassificationHead(nn.Sequential):
    def __init__(self, emb_size, n_classes):
        super().__init__()
        
        # global average pooling
        self.clshead = nn.Sequential(
            Reduce('b n e -> b e', reduction='mean'),
            nn.LayerNorm(emb_size),
            nn.Linear(emb_size, n_classes)
        )
        self.fc = nn.Sequential(
            nn.Linear(2200, 256), # 25800 for 2s, 8600 for 1s # for 64 #2200 for 17 # 3000 for 21 
            nn.ELU(),
            nn.Dropout(0.5),
            nn.Linear(256, 32),
            nn.ELU(),
            nn.Dropout(0.3),
            nn.Linear(32, n_classes) #4 # change here for classes
        )

    def forward(self, x):
        x = x.contiguous().view(x.size(0), -1)
        out = self.fc(x)
        return out


class Conformer(nn.Sequential):
    def __init__(self, emb_size=40, depth=2, n_classes=2, **kwargs):
        super().__init__(

            PatchEmbedding(emb_size),
            TransformerEncoder(depth, emb_size),
            ClassificationHead(emb_size, n_classes)
        )

In [None]:
# TODO: This class if has list of subject id can later support combination of sub ids
# TODO: add a function transform to convert dataset to train test, avoiding repetition of same code

class EEGMMIDTrSet(Data.Dataset):
    def __init__(self, subject_id, transform=None):
        root_dir = "../../Deep-Learning-for-BCI/dataset/"
        dataset_raw = np.load(root_dir + str(subject_id) + '.npy')
        dataset=[]  # feature after filtering

        # EEG Gamma pattern decomposition
        for i in range(dataset_raw[:,:-1].shape[1]):
            x = dataset_raw[:, i]
            fs = 160.0
            lowcut = 8.0
            highcut = 30.0
            y = butter_bandpass_filter(x, lowcut, highcut, fs, order=3)
            dataset.append(y)
        dataset=np.array(dataset).T
        dataset=np.hstack((dataset,dataset_raw[:,-1:]))
        print(dataset.shape)
        # keep 4,5 which are left and right fist open close imagery classes, remove rest
        # refer 1-Data.ipynb for the details
        removed_label = [0,1,6,7,8,9,10]  # [0,1,2,3,4,5,10] for hf # [0,1,6,7,8,9,10] for lr
        for ll in removed_label:
            id = dataset[:, -1]!=ll
            dataset = dataset[id]

        # Pytorch needs labels to be sequentially ordered starting from 0
        dataset[:, -1][dataset[:, -1] == 2] = 0
        dataset[:, -1][dataset[:, -1] == 4] = 0
        dataset[:, -1][dataset[:, -1] == 3] = 1
        dataset[:, -1][dataset[:, -1] == 5] = 1
        
        # data segmentation
        n_class = 2 #int(11-len(removed_label))  # 0~9 classes ('10:rest' is not considered)
        no_feature = 64  # the number of the features
        segment_length = 160 #160  # selected time window; 16=160*0.1
        
        #Overlapping is removed to avoid training set overlap with test set
        data_seg = extract(dataset, n_classes=n_class, n_fea=no_feature, 
                           time_window=segment_length, moving=(segment_length))  # /2 for 50% overlapping
        print('After segmentation, the shape of the data:', data_seg.shape)

        # split training and test data
        no_longfeature = no_feature*segment_length
        data_seg_feature = data_seg[:, :no_longfeature]
        self.data_seg_label = data_seg[:, no_longfeature:no_longfeature+1]
        
        # Its important to have random state set equal for Training and test dataset
        train_feature, test_feature, train_label, test_label = train_test_split(
            data_seg_feature, self.data_seg_label,random_state=0, shuffle=True,stratify=self.data_seg_label)

        # Check the class label splits to maintain balance
        unique, counts = np.unique(self.data_seg_label, return_counts=True)
        left_perc = counts[0]/sum(counts)
        if left_perc < 0.4 or left_perc > 0.6:
            print("Imbalanced dataset with split of: ",left_perc,1-left_perc)
        else:
            print("Classes balanced.")
        unique, counts = np.unique(train_label, return_counts=True)
        print("Class label splits in training set \n ",np.asarray((unique, counts)).T)
        unique, counts = np.unique(test_label, return_counts=True)
        print("Class label splits in test set\n ",np.asarray((unique, counts)).T)



        # normalization
        # before normalize reshape data back to raw data shape
        train_feature_2d = train_feature.reshape([-1, no_feature])
        test_feature_2d = test_feature.reshape([-1, no_feature])

        scaler1 = StandardScaler().fit(train_feature_2d)
        train_fea_norm1 = scaler1.transform(train_feature_2d) # normalize the training data
        test_fea_norm1 = scaler1.transform(test_feature_2d) # normalize the test data
        print('After normalization, the shape of training feature:', train_fea_norm1.shape,
              '\nAfter normalization, the shape of test feature:', test_fea_norm1.shape)
        
#         # This is to select only MI related channels
#         train_fea_norm1 = train_fea_norm1[:,:21]
#         test_fea_norm1 = test_fea_norm1[:,:21]
#         no_feature = 21
#         [20, 21, 22, 23, 27, 28, 29, 37, 43, 53, 54, 56, 58, 59, 61, 62, 63] # top fr for 109 _top16
        train_fea_norm1 = train_fea_norm1[:,[20, 21, 22, 23, 27, 28, 29, 37, 43, 53, 54, 56, 58, 59, 61, 62, 63]]
        test_fea_norm1 = test_fea_norm1[:,[20, 21, 22, 23, 27, 28, 29, 37, 43, 53, 54, 56, 58, 59, 61, 62, 63]]
        no_feature = 17
        
        # after normalization, reshape data to 3d
        train_fea_norm1 = train_fea_norm1.reshape([-1, segment_length, no_feature])
        test_fea_norm1 = test_fea_norm1.reshape([-1, segment_length, no_feature])
        print('After reshape, the shape of training feature:', train_fea_norm1.shape,
              '\nAfter reshape, the shape of test feature:', test_fea_norm1.shape)
        
        # reshape for data shape: (trial, conv channel, electrode channel, time samples)
        # earlier it was (trial,timesamples,electrode_channel)
        train_fea_reshape1 = np.swapaxes(np.expand_dims(train_fea_norm1,1),2,3)
        test_fea_reshape1 = np.swapaxes(np.expand_dims(test_fea_norm1,1),2,3)
        print('After expand dims, the shape of training feature:', train_fea_reshape1.shape,
              '\nAfter expand dims, the shape of test feature:', test_fea_reshape1.shape)
        
        self.data = torch.tensor(train_fea_reshape1)
        self.targets = torch.tensor(train_label.flatten()).long()
        
        print("data and target type:",type(self.data),type(self.targets))


    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        data, target = self.data[idx], self.targets[idx]
        return data, target
    
    def get_class_weights(self):
        class_weights=class_weight.compute_class_weight('balanced',np.unique(self.data_seg_label),
                                                        self.data_seg_label[:,0])
        return class_weights



## Get topo plots for visualisation

In [None]:
from mpl_toolkits.axes_grid1 import make_axes_locatable
device = torch.device("cpu") #cpu
model = Conformer(n_classes=2)
cat_dict = {0:'left_hand',1:'right_hand'}
biosemi_montage = mne.channels.make_standard_montage('biosemi64')
index = [20, 21, 22, 23, 27, 28, 29, 37, 43, 53, 54, 56, 58, 59, 61, 62, 63] # top17 FR for 16_109 subs 
biosemi_montage.ch_names = [biosemi_montage.ch_names[i] for i in index]
biosemi_montage.dig = [biosemi_montage.dig[i+3] for i in index]
info = mne.create_info(ch_names=biosemi_montage.ch_names, sfreq=160,ch_types='eeg')


for sub_id in [42]:
    fig, axes = plt.subplots(1, 3, figsize=(10, 25)) #figsize=(15, 5)
    model.load_state_dict(torch.load("../transformer_results/eegmmid_ws_offline_"+ str(sub_id)+ "_model1" + '.pth', map_location=device)) 
    target_layers = [model[1]]  # set the target layer 
    print(target_layers)
    cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False, reshape_transform=reshape_transform)

    train_ds = EEGMMIDTrSet(subject_id=sub_id)


    data = train_ds.data#[:, :, :22, :]allData#
    
    
    for target_category in range(2):
        
        all_cam = []
        # this loop is used to obtain the cam of each trial/sample
        for i in range(data.shape[0]):
            test = torch.as_tensor(data[i:i+1, :, :, :], dtype=torch.float32)
            test = torch.autograd.Variable(test, requires_grad=True)

            grayscale_cam = cam(input_tensor=test,target_category=target_category) #,target_category=2
            print(grayscale_cam.shape)
            grayscale_cam = grayscale_cam[0, :]
            all_cam.append(grayscale_cam)

        # the mean of all data
        test_all_data = np.squeeze(np.mean(data.detach().cpu().numpy(), axis=0)) #.detach().cpu().numpy()
        test_all_data = (test_all_data - np.mean(test_all_data)) / np.std(test_all_data)
        mean_all_test = np.mean(test_all_data, axis=1)

        # the mean of all cam
        test_all_cam = np.mean(all_cam, axis=0)
        mean_all_cam = np.mean(test_all_cam, axis=1)

        # apply cam on the input data
        hyb_all = test_all_data * test_all_cam
        hyb_all = (hyb_all - np.mean(hyb_all)) / np.std(hyb_all)
        mean_hyb_all = np.mean(hyb_all, axis=1)

        evoked = mne.EvokedArray(test_all_data, info)
        evoked.set_montage(biosemi_montage)
        if target_category == 0:
            # Create a topomap for the current oscillation band
            im,cn = mne.viz.plot_topomap(mean_all_test, evoked.info,axes=axes[target_category], 
                                 show=False, res=600, size=5);

            # Set the plot title
            axes[target_category].set_title("Raw", {'fontsize' : 10})
            divider = make_axes_locatable(axes[target_category])
            cax = divider.append_axes('right', size='5%', pad=0.05)
            fig.colorbar(im, cax=cax, orientation='vertical')


        # Create a topomap for the current oscillation band
        im,cn = mne.viz.plot_topomap(mean_hyb_all, evoked.info,axes=axes[target_category+1], 
                             show=False, res=600, size=5);

        # Set the plot title
        axes[target_category+1].set_title("S" + str(sub_id) + "_" + cat_dict[target_category], {'fontsize' : 10})

        divider = make_axes_locatable(axes[target_category+1])
        cax = divider.append_axes('right', size='5%', pad=0.05)
        fig.colorbar(im, cax=cax, orientation='vertical')
    
    plt.savefig('./outimgs/'+str(sub_id) + '_topo.png')
    plt.close(fig)



## Find most relevant channels for prediction

In [None]:
import pandas as pd
device = torch.device("cpu")
model = Conformer(n_classes=2)
cat_dict = {0:'left hand movement',1:'right hand movement'}
biosemi_montage = mne.channels.make_standard_montage('biosemi64')
index = [8, 9, 10, 46, 45, 44, 43, 13, 12, 11, 47, 48, 49, 50, 16, 17, 18, 
         31, 55, 54, 53]
biosemi_montage.ch_names = [biosemi_montage.ch_names[i] for i in index]
biosemi_montage.dig = [biosemi_montage.dig[i+3] for i in index]
info = mne.create_info(ch_names=biosemi_montage.ch_names, sfreq=160,ch_types='eeg')

top_channel_dict = {}

for sub_id in [7,15,29,32,35,42,43,46,48,49,54,56,62,93,94,108]:
    
    model.load_state_dict(torch.load("../transformer_results/eegmmid_ws_offline_"+ str(sub_id)+ "_model0" + '.pth', map_location=device)) 
    target_layers = [model[1]]  # set the target layer 
    print(target_layers)
    cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False, reshape_transform=reshape_transform)

    train_ds = EEGMMIDTrSet(subject_id=sub_id)


    data = train_ds.data
    
    
    for target_category in range(2):
        
        all_cam = []
        # this loop is used to obtain the cam of each trial/sample
        for i in range(data.shape[0]):
            test = torch.as_tensor(data[i:i+1, :, :, :], dtype=torch.float32)
            test = torch.autograd.Variable(test, requires_grad=True)

            grayscale_cam = cam(input_tensor=test,target_category=target_category) #,target_category=2
            grayscale_cam = grayscale_cam[0, :]
            all_cam.append(grayscale_cam)

        # the mean of all data
        test_all_data = np.squeeze(np.mean(data.detach().cpu().numpy(), axis=0)) #.detach().cpu().numpy()
        test_all_data = (test_all_data - np.mean(test_all_data)) / np.std(test_all_data)
        mean_all_test = np.mean(test_all_data, axis=1)

        # the mean of all cam
        test_all_cam = np.mean(all_cam, axis=0)
        mean_all_cam = np.mean(test_all_cam, axis=1)

        # apply cam on the input data
        hyb_all = test_all_data * test_all_cam
        hyb_all = (hyb_all - np.mean(hyb_all)) / np.std(hyb_all)
        mean_hyb_all = np.mean(hyb_all, axis=1)
        
        # get top 3 channel index -> feature relevance
        nind = np.argsort(mean_hyb_all)[-10:]
        print([biosemi_montage.ch_names[i] for i in nind])
        
        imp_ind = [mean_hyb_all[i] for i in nind]

        ch_names = [biosemi_montage.ch_names[i] for i in nind]
        print(imp_ind)
        if not math.isnan(imp_ind[0]):
            top_channel_dict[(sub_id,target_category)] = ch_names 


In [None]:
left_list = []
right_list = []
for key in top_channel_dict:
    if key[1] == 0:
        left_list.extend(top_channel_dict[key])
    elif key[1] == 1:
        right_list.extend(top_channel_dict[key])

In [None]:
from collections import Counter,OrderedDict
freq_left = Counter(left_list)
freq_right = Counter(right_list)

In [None]:
f, ax = plt.subplots(figsize=(15, 6))
plt.xticks(rotation = 45)
ax.tick_params(axis='x', which='major')
ax.set_xlabel('Channel Names', fontsize=10)
ax.set_ylabel('Frequency of appearances in Top 10 features from GradCAM', fontsize=10)
plt.bar(OrderedDict(freq_left.most_common()).keys(), OrderedDict(freq_left.most_common()).values())
plt.tight_layout()

In [None]:
f, ax = plt.subplots(figsize=(15, 6))
# fig = plt.figure(figsize=(15,6))
plt.xticks(rotation = 45)
ax.tick_params(axis='x', which='major')
ax.set_xlabel('Channel Names', fontsize=10)
ax.set_ylabel('Frequency of appearances in Top 10 features from GradCAM', fontsize=10)
plt.bar(OrderedDict(freq_right.most_common()).keys(), OrderedDict(freq_right.most_common()).values())
plt.tight_layout()

In [None]:
print("left: ",freq_left.most_common()[:10])
print("right: ",freq_right.most_common()[:10])



## Code to generate TF plots

In [None]:
# Nz, F9, F10, FT9, FT10, A1, A2, TP9, TP10, P9, and P10
# [22,25,26,33,32,31,30,39,1,2,3,10,9,8,41,45,15,16,17,50,49,48,47]
print(pre_montage.ch_names)

final_index = []
data_seq_ch = ['FC5','FC3','FC1','FCz','FC2','FC4','FC6','C5',
              'C3','C1','Cz','C2','C4','C6','CP5','CP3','CP1',
              'CPz','CP2','CP4','CP6','Fp1','Fpz','Fp2','AF7',
              'AF3','AFz','AF4','AF8','F7','F5','F3','F1','Fz',
              'F2','F4','F6','F8','FT7','FT8','T7','T8','P9','P10','TP7','TP8','P7','P5','P3','P1','Pz','P2',
              'P4','P6','P8','PO7','PO3','POz','PO4','PO8','O1',
              'Oz','O2','Iz'] #'T9','T10' removed and using P9,P10 instead

for ch_n in data_seq_ch:
    final_index.append(pre_montage.ch_names.index(ch_n))
print(final_index)

In [None]:
len(final_index)

In [None]:
import matplotlib.pyplot as plt
import numpy as np

import mne
from mne.datasets import eegbci
from mne.time_frequency import tfr_morlet

device = torch.device("cpu")
model = Conformer(n_classes=2)
cat_dict = {0:'left hand movement',1:'right hand movement'}
biosemi_montage = mne.channels.make_standard_montage('biosemi64')
index = [8, 9, 10, 11, 12, 13, 16, 17, 18, 31, 43, 44, 45, 46, 47, 48, 49, 50, 53, 54, 55]
#[20, 21, 22, 23, 27, 28, 29, 37, 43, 53, 54, 56, 58, 59, 61, 62, 63] # For 109 subs all channels topfr 
# [18, 20, 21, 22, 23, 27, 28, 29, 34, 38, 44, 53, 54, 55, 59, 60, 61, 63] # old
# [8, 9, 10, 11, 12, 13, 16, 17, 18, 31, 43, 44, 45, 46, 47, 48, 49, 50, 53, 54, 55] 21 MI channels
#range(64)#[37, 9, 10, 46, 45, 44, 13, 12, 11, 47, 48, 49, 50, 17, 18, 31, 55, 54, 19, 30, 56, 29]  # for bci competition iv 2a

biosemi_montage.ch_names = [biosemi_montage.ch_names[i] for i in index]
biosemi_montage.dig = [biosemi_montage.dig[i+3] for i in index]
info = mne.create_info(ch_names=biosemi_montage.ch_names, sfreq=160,ch_types='eeg')

top_channel_dict = {}
# fig, axes = plt.subplots(1, 2) #figsize=(15, 5)
# print(axes)
for sub_id in [42]:#range(7,8):
    
    model.load_state_dict(torch.load("../transformer_results/eegmmid_ws_offline_"+ str(sub_id)+ "_model0" + '.pth', map_location=device)) 
    #./eegmmid_ws_offline_1_model0.pth ../model.pth
    target_layers = [model[1]]  # set the target layer 
    print(target_layers)
    cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False, reshape_transform=reshape_transform)

    train_ds = EEGMMIDTrSet(subject_id=sub_id)


    data = train_ds.data#[:, :, :22, :]allData#
#     print(data.shape)
    
    
    for target_category in range(2):
        # # used for cnn model without transformer
# model.load_state_dict(torch.load('./model/model_cnn.pth', map_location=device))
# target_layers = [model[0].projection]  # set the layer you want to visualize, you can use torchsummary here to find the layer index
# cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False)
        
        all_cam = []
        # this loop is used to obtain the cam of each trial/sample
        for i in range(data.shape[0]):
            test = torch.as_tensor(data[i:i+1, :, :, :], dtype=torch.float32)
            test = torch.autograd.Variable(test, requires_grad=True)

            grayscale_cam = cam(input_tensor=test,target_category=target_category)
            grayscale_cam = grayscale_cam[0, :]
            all_cam.append(grayscale_cam)

        # the mean of all data
        test_all_data = np.squeeze(np.mean(data.detach().cpu().numpy(), axis=0)) #.detach().cpu().numpy()
        test_all_data = (test_all_data - np.mean(test_all_data)) / np.std(test_all_data)
        mean_all_test = np.mean(test_all_data, axis=1)

        # the mean of all cam
        test_all_cam = np.mean(all_cam, axis=0)
#         test_all_cam = (test_all_cam - np.mean(test_all_cam)) / np.std(test_all_cam)
        mean_all_cam = np.mean(test_all_cam, axis=1)

        # apply cam on the input data
        hyb_all = test_all_data * test_all_cam
        hyb_all = (hyb_all - np.mean(hyb_all)) / np.std(hyb_all)
        mean_hyb_all = np.mean(hyb_all, axis=0)
        print(mean_hyb_all.shape)
        print(mean_hyb_all)
        
        nind = np.argsort(mean_hyb_all)[-3:]
        print("final_timespan",nind/160)
        time_freq = list(itertools.product(nind/160,[10,21]))
        freqs = np.logspace(*np.log10([8, 30]), num=8)
        n_cycles = freqs / 2.0  # different number of cycle per frequency
#         print(test_all_data.shape)
#         single_info = mne.create_info(ch_names=[biosemi_montage.ch_names], sfreq=160,ch_types='eeg')
        evoked = mne.EvokedArray(test_all_data, info)
        evoked.set_montage(biosemi_montage)
        power = tfr_morlet(
            evoked,
            freqs=freqs,
            n_cycles=n_cycles,
            use_fft=True,
            return_itc=False,
            decim=3,
            n_jobs=None,
        )
        power.plot_joint(baseline=(0, 0), mode="mean", tmin=0, tmax=1, 
                         timefreqs=time_freq)
        
        