This code uses trained pytorch models to rank best EEG samples per class

In [1]:
import sys
import numpy as np
import scipy.io
import torch
from matplotlib import pyplot as plt
from torchvision import transforms
sys.path.append('./models/')
sys.path.append('./source/')

import utils as myUtils
from DataTypes import EEGDataset
from CNN_1 import Network

# Check does your computer support using GPU
print("GPU_available={}".format(torch.cuda.is_available()))

GPU_available=True


In [2]:
mat = scipy.io.loadmat('EEG_samples_pruned.mat') # make sure to choose the right sample set
mode = 'Kfold_pruned' # make sure to change the mode to get the correct trained models
method = 'sorting' # This code sorts/choses the 10 best EEG samples
EEG_samples = mat['EEG_samples']
trial_no = 'trial21'
no_participants = 29
model_classes = [0,1] # 0 => Sick Class; 1 => Normal Class
export_matlab = {} # To be exported for matlab analysis
none_sick_stack = [] # best non sick samples
sick_stack = [] # best sick samples
none_sick, sick = [], []

for partic_id in range(0,no_participants):

    # initiate arrays
    sick_probs = []
    none_sick_probs = []

    # Get the dataset
    full_data, train_data, test_data = myUtils.generateTrainTest(EEG_samples, partic_id, normalize = False, LOU = True)

    # Transform the dataset to standard tensor datatype
    data_transform = transforms.Compose([transforms.ToTensor()])
    test_set = EEGDataset(data_set=test_data[0], label_set=test_data[1], transform=data_transform)

    # Loading the trained model
    model = Network(num_classes=2)
    model = model.cuda() # Move the model to the GPU
    state_dict = torch.load('./trained_models/'+trial_no+'/'+mode+'/'+'Arch'+str(partic_id+1)+'.pth')  # Load the model's state dictionary.
    model.load_state_dict(state_dict) # Load the model state into the custom model
    model.eval()  # Set the model to evaluation mode.

    for sample_id in range(0, test_set.__len__()):
        # Run the model on the inputs and get scores
        with torch.no_grad():  # No need to calculate gradients for this step.
            eeg_sample, label = test_set.__getitem__(sample_id)
            eeg_sample = model(eeg_sample.unsqueeze(0).float().cuda())  # Run the model on the EEG samples.

        scores = torch.nn.functional.softmax(eeg_sample[0], dim=0)  # Calculate the softmax to get the scores.
        #print(str(scores)+" | "+str(label[0]))  # Print the scores.

        if(label[0] == 1):
            none_sick_probs.append(scores[1].item())
        else:
            sick_probs.append(scores[0].item())

    # Get the 10 best sick and non-sick sample indexes
    if len(none_sick_probs) != 0:
        best_index_non_sick = np.argsort(none_sick_probs)[-10:]
    if len(sick_probs) != 0:
        best_index_sick = np.argsort(sick_probs)[-10:]

    for i in range(0,10):
        if len(none_sick_probs) != 0:
            none_sick , __ = test_set.__getitem__(best_index_non_sick[i])
        if len(sick_probs) != 0:
            sick , __ = test_set.__getitem__(best_index_sick[i])

        if len(none_sick_stack) == 0:
            if none_sick.__len__() != 0:
                none_sick_stack = none_sick.numpy()
            else:
                none_sick_stack = []
        else:
            none_sick_stack = np.concatenate((none_sick_stack, none_sick.numpy()), axis=0)

        if len(sick_stack) == 0:
            if sick.__len__() != 0:
                sick_stack = sick.numpy()
            else:
                sick_stack = []
        else:
            sick_stack = np.concatenate((sick_stack, sick.numpy()), axis=0)
            

    field_name = f'field{partic_id+1}'
    export_matlab[field_name] = { 'sick' : sick_stack, 'normal' : none_sick_stack} #appending the activations of the 2 classes for each participant

    # Clear arrays for next model
    sick_stack, none_sick_stack = [], []
    sick, none_sick = [], []
    sick_probs, none_sick_probs = [], []

# Save results to directory
scipy.io.savemat('./Results/Plots/'+trial_no+'/'+mode+'/'+method+'/matrix_jitter_output.mat', {'export_matlab': export_matlab})


  x = self.softmax8(x)
