This code uses trained pytorch models to rank best EEG samples per class

In [1]:
import sys
import numpy as np
import scipy.io
import torch
from matplotlib import pyplot as plt
from torchvision import transforms
sys.path.append('./models/')
sys.path.append('./source/')

import utils as myUtils
from DataTypes import EEGDataset
from CNN_1 import Network

# Check does your computer support using GPU
print("GPU_available={}".format(torch.cuda.is_available()))

GPU_available=True


In [2]:
mat = scipy.io.loadmat('EEG_samples_roi.mat') # make sure to choose the right sample set
mode = 'Kfold_roi' # make sure to change the mode to get the correct trained models
method = 'sorting' # This code sorts/choses the 10 best EEG samples
EEG_samples = mat['EEG_samples']
trial_no = 'trial21'
no_participants = 29
model_classes = [0,1] # 0 => Sick Class; 1 => Normal Class
export_matlab = {} # To be exported for matlab analysis

best_none_sick_stack, worst_none_sick_stack = [], [] # best and worst non sick samples
best_sick_stack, worst_sick_stack = [], [] # best and worst sick samples
best_none_sick, best_sick = [], []
worst_none_sick, worst_sick = [], []

for partic_id in range(0,no_participants):

    # initiate arrays
    
    # True positive probability
    sick_probs = []
    none_sick_probs = []

    # Get the dataset
    full_data, train_data, test_data = myUtils.generateTrainTest(EEG_samples, partic_id, normalize = False, LOU = True)

    # Transform the dataset to standard tensor datatype
    data_transform = transforms.Compose([transforms.ToTensor()])
    test_set = EEGDataset(data_set=test_data[0], label_set=test_data[1], transform=data_transform)

    # Loading the trained model
    model = Network(num_classes=2)
    model = model.cuda() # Move the model to the GPU
    state_dict = torch.load('./trained_models/'+trial_no+'/'+mode+'/'+'Arch'+str(partic_id+1)+'.pth')  # Load the model's state dictionary.
    model.load_state_dict(state_dict) # Load the model state into the custom model
    model.eval()  # Set the model to evaluation mode.

    for sample_id in range(0, test_set.__len__()):
        # Run the model on the inputs and get scores
        with torch.no_grad():  # No need to calculate gradients for this step.
            eeg_sample, label = test_set.__getitem__(sample_id)
            eeg_sample = model(eeg_sample.unsqueeze(0).float().cuda())  # Run the model on the EEG samples.

        scores = torch.nn.functional.softmax(eeg_sample[0], dim=0)  # Calculate the softmax to get the scores.
        #print(str(scores)+" | "+str(label[0]))  # Print the scores.

        if(label[0] == 1):
            none_sick_probs.append(scores[1].item())
        else:
            sick_probs.append(scores[0].item())

    # Get the 10 best and worst sick and non-sick sample indexes
    if len(none_sick_probs) != 0:
        best_index_non_sick = np.argsort(none_sick_probs)[-10:]
        worst_index_non_sick = np.argsort(none_sick_probs)[:10]
    if len(sick_probs) != 0:
        best_index_sick = np.argsort(sick_probs)[-10:]
        worst_index_sick = np.argsort(sick_probs)[:10]

    for i in range(0,10):
        if len(none_sick_probs) != 0:
            best_none_sick , __ = test_set.__getitem__(best_index_non_sick[i])
            worst_none_sick , __ = test_set.__getitem__(worst_index_non_sick[i])
        if len(sick_probs) != 0:
            best_sick , __ = test_set.__getitem__(best_index_sick[i])
            worst_sick , __ = test_set.__getitem__(worst_index_sick[i])

        # Create a stack of the BEST sick and non-sick samples
        if len(best_none_sick_stack) == 0:
            if best_none_sick.__len__() != 0:
                best_none_sick_stack = best_none_sick.numpy()
                best_none_sick_stack = np.expand_dims(np.squeeze(best_none_sick_stack), axis=2)
            else:
                best_none_sick_stack = []
        else:
            best_none_sick_stack = np.concatenate((best_none_sick_stack, np.transpose(best_none_sick.numpy(), (1, 2, 0))), axis=2)

        if len(best_sick_stack) == 0:
            if best_sick.__len__() != 0:
                best_sick_stack = best_sick.numpy()
                best_sick_stack = np.expand_dims(np.squeeze(best_sick_stack), axis=2)
            else:
                best_sick_stack = []
        else:
            best_sick_stack = np.concatenate((best_sick_stack, np.transpose(best_sick.numpy(), (1, 2, 0))), axis=2)

        # create a stack for the WORST sick and non-sick samples
        if len(worst_none_sick_stack) == 0:
            if worst_none_sick.__len__() != 0:
                worst_none_sick_stack = worst_none_sick.numpy()
                worst_none_sick_stack = np.expand_dims(np.squeeze(worst_none_sick_stack), axis=2)
            else:
                worst_none_sick_stack = []
        else:
            worst_none_sick_stack = np.concatenate((worst_none_sick_stack, np.transpose(worst_none_sick.numpy(), (1, 2, 0))), axis=2)

        if len(worst_sick_stack) == 0:
            if worst_sick.__len__() != 0:
                worst_sick_stack = worst_sick.numpy()
                worst_sick_stack = np.expand_dims(np.squeeze(worst_sick_stack), axis=2)
            else:
                worst_sick_stack = []
        else:
            worst_sick_stack = np.concatenate((worst_sick_stack, np.transpose(worst_sick.numpy(), (1, 2, 0))), axis=2)
            

    #appending the activations of the 2 classes for each participant
    field_name = f'field{partic_id+1}'
    export_matlab[field_name] = { 
        'true_sick' : best_sick_stack, 
        'true_normal' : best_none_sick_stack,
        'false_sick' : worst_sick_stack,
        'false_normal' : worst_none_sick_stack
    } 

    # Clear arrays for next model
    best_sick_stack, best_none_sick_stack = [], []
    best_sick, best_none_sick = [], []
    worst_sick_stack, worst_none_sick_stack = [], []
    worst_sick, worst_none_sick = [], []
    sick_probs, none_sick_probs = [], []

# Save results to directory
scipy.io.savemat('./Results/Plots/'+trial_no+'/'+mode+'/'+method+'/matrix_jitter_output.mat', {'export_matlab': export_matlab})


  x = self.softmax8(x)
