# Setup and environment

## Import Packages

In [341]:
import os

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.neural_network import MLPClassifier
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import classification_report
from tqdm import tqdm

import torch
from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
from nn_architecture.ae_networks import TransformerAutoencoder, TransformerDoubleAutoencoder, TransformerFlattenAutoencoder
from helpers.dataloader import Dataloader

## Define functions

In [397]:
def norm(data):
    return (data-np.min(data)) / (np.max(data) - np.min(data))

def load_data(data_checkpoint = 'data/seeg_ae_training_data.csv'):
    dataloader = Dataloader(data_checkpoint, col_label='Condition', channel_label='Electrode')
    label_and_dataset = dataloader.get_data(shuffle=False)
    labels = label_and_dataset[:,0,:].unsqueeze(1).detach().numpy()
    
    dataset = label_and_dataset[:,1:,:]
    dataset = norm(dataset.detach().numpy())

    return labels, dataset

def initiate_autoencoder(ae_dict, dataset):

    n_channels = dataset.shape[-1]
    sequence_length = dataset.shape[1] - 1

    if ae_dict['configuration']['target'] == 'channels':
        autoencoder = TransformerAutoencoder(input_dim=n_channels,
                                       output_dim=ae_dict['configuration']['channels_out'],
                                       output_dim_2=sequence_length,
                                       target=TransformerAutoencoder.TARGET_CHANNELS,
                                       hidden_dim=ae_dict['configuration']['hidden_dim'],
                                       num_layers=ae_dict['configuration']['num_layers'],
                                       num_heads=ae_dict['configuration']['num_heads'],).to('cpu')
    elif ae_dict['configuration']['target'] == 'time':
        autoencoder = TransformerAutoencoder(input_dim=sequence_length,
                                       output_dim=ae_dict['configuration']['timeseries_out'],
                                       output_dim_2=n_channels,
                                       target=TransformerAutoencoder.TARGET_TIMESERIES,
                                       hidden_dim=ae_dict['configuration']['hidden_dim'],
                                       num_layers=ae_dict['configuration']['num_layers'],
                                       num_heads=ae_dict['configuration']['num_heads'],).to('cpu')
    elif ae_dict['configuration']['target'] == 'full':
        autoencoder = TransformerDoubleAutoencoder(input_dim=n_channels,
                                             output_dim=ae_dict['configuration']['output_dim'],
                                             output_dim_2=ae_dict['configuration']['output_dim_2'],
                                             sequence_length=sequence_length,
                                             hidden_dim=ae_dict['configuration']['hidden_dim'],
                                             num_layers=ae_dict['configuration']['num_layers'],
                                             num_heads=ae_dict['configuration']['num_heads'],).to('cpu')
    else:
        raise ValueError(f"Encode target '{ae_dict['configuration']['target']}' not recognized, options are 'channels', 'time', or 'full'.")
    consume_prefix_in_state_dict_if_present(ae_dict['model'],'module.')
    autoencoder.load_state_dict(ae_dict['model'])
    autoencoder.device = torch.device('cpu')

    return autoencoder
    
def norm(data):
    return (data-np.min(data)) / (np.max(data) - np.min(data))

def encode_data(data_checkpoint, ae_dict, type='flatten'):

    dataloader = Dataloader(data_checkpoint, col_label='Condition', channel_label='Electrode')
    dataset = dataloader.get_data()
    labels = dataset[:,0,0]
    dataset = norm(dataset.detach().numpy())

    autoencoder = initiate_autoencoder(ae_dict, dataset)

    encoded_samples = []
    for sample_index in range(len(dataset)):
        encoded_sample = autoencoder.encode(torch.from_numpy(dataset[sample_index,1:, :]).unsqueeze(0))
        if type == 'mean':
            flattened_encoded_sample = np.mean(encoded_sample.detach().numpy(),axis=2)[0,:]
        else:
            flattened_encoded_sample = encoded_sample.flatten().detach().numpy()
        encoded_samples.append(flattened_encoded_sample)
        
    return np.array(labels), np.array(encoded_samples)

# Data and Autoencoder

## Load Autoencoder

In [398]:
ae_checkpoint = 'trained_ae/ae_ddp_5000ep_20240119_023414.pt'
ae_dict = torch.load(ae_checkpoint, map_location=torch.device('cpu'))

#Report
for key in ae_dict['configuration'].keys():
    if key != 'dataloader' and key != 'history':
        print(f"{key}: {ae_dict['configuration'][key]}")

device: cuda
model_class: TransformerDoubleAutoencoder
batch_size: 4
n_epochs: 5000
sample_interval: 100
learning_rate: 0.0001
hidden_dim: 256
path_dataset: data/seeg_ae_training_data.csv
path_checkpoint: trained_ae/checkpoint.pt
timeseries_out: 320
channels_out: 6
target: full
channel_label: Electrode
trained_epochs: 5000
input_dim: 12
output_dim: 6
output_dim_2: 320
num_layers: 2
num_heads: 8
activation: sigmoid


## Load and Encode Data

In [399]:
print('Encoding training set...')
y_train, x_train = encode_data('data/seeg_ae_training_data.csv', ae_dict)

print('Encoding validation set...')
y_test, x_test = encode_data('data/seeg_ae_validation_data.csv', ae_dict)


Encoding training set...
Encoding validation set...


# Classifier

In [400]:
# Create a random generator with a seed for reproducible results
rng = np.random.default_rng(42)

# convert the list to a numpy array
x_train = np.array(x_train)
y_train = np.array(y_train)
x_test = np.array(x_test)
y_test = np.array(y_test)

# Generate a shuffled set of indices
indices = np.arange(len(y_train))
rng.shuffle(indices)

# Use the shuffled indices to reorder both arrays
x_train = x_train[indices]
y_train = y_train[indices]

# 4.2 Train the Classifier

A classifier has a lot of different settings and hyperparameter configurations. Instead of manually try different settings, we leverage the `GridSearchCV` functionality from the sklearn package.

Now, we train the network

In [402]:
neuralNetOutput = MLPClassifier(hidden_layer_sizes=(20,),
                             activation='identity',
                             solver = 'adam',
                             alpha = 0.0001,
                             learning_rate = 'constant',
                             max_iter = 1000,
                             random_state = 1251)

neuralNetOutput.fit(x_train, y_train)

# 4.3 Validation/Prediction

We predict data from the 'NextDayRecall'

In [403]:
y_test, y_pred = y_test, neuralNetOutput.predict(x_test)
predictResults = classification_report(y_test, y_pred, output_dict=True)

print('Prediction: ' + str(round(predictResults['macro avg']['f1-score']*100)) + '%')
print(y_test)
print(y_pred)

Prediction: 12%
[1. 2. 2. 1. 1. 1. 2.]
[1. 1. 1. 2. 2. 2. 1.]
