In [None]:

# Experiment setup:
# First, train a model on the original training data
# Second, train a model with noise on the 0-250ms region
# Third, train a model with noise on the 250-500ms regions
# Lastly, take 5-10 SHAP approximation per model (depends on the computing time)
# And make sure the SHAP per participant uses the same fixed sets for the approximation

In [None]:
from src.model.eegnet_variance import EEGNetMultiHeaded
from pytorch_lightning.loggers import CometLogger
from src.datamodule import DataModule, predictDataSet
from src.preprocessing import create_dataset
import pytorch_lightning as pl
import numpy as np
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from src.preprocessing import *
from src.noise_addition import zero_signal, add_gaussian_noise
from src.model.variance_wrapper import EEGVarianceWrapper
import copy
import random
import pickle
import torch
import shap

In [None]:
comet_logger = CometLogger(
    api_key="WSATCNWE43zphHslQCTsJKcgk",
    workspace="marwo22",  # Optional
    project_name="bachelors-project"  # Optional
)

In [None]:
def add_guassian_noise_to_dataset(dataset, severity_index, low: int = 0, high: int = 308):
    if (severity_index == 0):
        return dataset

    length = len(dataset[0])
        # Add noise to 7.5%  * severityIndex of samples in the dataset. It can thus range from 5% to 50%
    episodes_to_corrupt = random.sample(range(0, length), int(0.1 * severity_index * length))
        # Add noise to the samples
    for episode in episodes_to_corrupt:
        channels_to_corrupt = int(64 / 10 * severity_index)
        dataset[0][episode] = add_gaussian_noise(dataset[0][episode], 2 * severity_index, channels_to_corrupt, low, high)

    return dataset

In [None]:
def add_zero_to_dataset(dataset, severity_index, low: int = 0, high: int = 308):
    if (severity_index == 0):
        return dataset
    
    length = len(dataset[0])
        # Add noise to 7.5%  * severityIndex of samples in the dataset. It can thus range from 5% to 50%
    episodes_to_corrupt = random.sample(range(0, length), int(0.1 * severity_index * length))
        # Add noise to the samples
    for episode in episodes_to_corrupt:
            # Can range from 5-50%
        channels_to_corrupt = int(64 / 20 * severity_index)
            # Zeroes the entire signal
        dataset[0][episode] = zero_signal(dataset[0][episode], channels_to_corrupt, low, high, 100)

    return dataset

In [None]:
shap_values = [[] for _ in range(3)]

for j in range(6):  
    # Load in the datasets with the current participant as test set
    train, val, test = create_dataset('./src/pickle_df', j + 1)

    for i in range(3):
        # Take deep copies
        train_copy = copy.deepcopy(train)
        val_copy = copy.deepcopy(val)
        test_copy = copy.deepcopy(test)
        
        # Add noise to the deep copies. If i is 0, noise is added to the lower end, when it is 1, it is added to the higher end
        # For this experiment, the intensity value is 5, resulting in a 'medium' level of noise
        if i != 0:
            train_copy = add_guassian_noise_to_dataset(train_copy, 10, low = (i - 1) * 128, high = i * 128 - 1)
            val_copy = add_guassian_noise_to_dataset(val_copy, 10, low = (i - 1) * 128, high = i * 128 - 1)
            test_copy = add_guassian_noise_to_dataset(test_copy, 10, low = (i - 1) * 128, high = i * 128 - 1)

        dm = DataModule(train=train_copy, val=val_copy, test=test_copy, batch_size=16)

        model = EEGNetMultiHeaded(chunk_size=308,
                                num_electrodes=64,
                                dropout=0.5,
                                kernel_1=64,
                                kernel_2=16,
                                F1=8,
                                F2=16,
                                D=2,
                                num_classes=2)
        # Train for 25 epochs for this example
        # Final one for results wil run for 50 most likely
        trainer = pl.Trainer(
            max_epochs=25,
            logger=comet_logger
        )

        # Fit and test model
        trainer.fit(model, datamodule=dm)

        test_shap = list()

        shap_test_tensor = torch.empty((0, 1, 64, 308), dtype=torch.float32)

        for (data, _) in dm.test_dataloader():
            shap_test_tensor = torch.cat((shap_test_tensor, data), 0)

        background = shap_test_tensor[:-10]
        test_episodes = shap_test_tensor[-10:]

        variance_model = EEGVarianceWrapper(model)

        e_variance = shap.DeepExplainer(variance_model, background)

        shap_values_variance = e_variance.shap_values(test_episodes)

        shap_values[i].append(shap_values_variance)


        
    
    # Just to visualize progress easier, the output is very messy
    print("\n\n\n\n\nFinished" + str(j) + "\n\n\n\n\n")