In [None]:
import numpy as np
# At the start of your notebook
from IPython.display import clear_output
import gc

# After heavy computations
clear_output(wait=True)
gc.collect()

In [None]:
from steps import setup_and_train_models, analyze_seizure_propagation
import torch
from datasetConstruct import construct_channel_recognition_dataset
from models import Wavenet, train_using_optimizer
import pickle
import os
import matplotlib.pyplot as plt
import pandas as pd

In [None]:
RESULT_FOLDER = "result"
MODEL_FOLDER = "model"
model_names = ['ResNet']  # 'CNN1D', 'Wavenet', 'LSTM', 'S4', 'ResNet'

In [None]:
# Do batch analysis to find the best hyperparameters
seizures = [1, 2, 3, 5, 7]
thresholds = [0.8]
smooth_windows = [80]

In [None]:
from steps import extract_sEEG_features
from datasetConstruct import load_seizure_across_patients

dataset = load_seizure_across_patients(data_folder='data')

for seizure in dataset:
    seizure_new = extract_sEEG_features(seizure, sampling_rate=seizure.samplingRate)

In [None]:
results, models = setup_and_train_models(
    data_folder="data",
    model_folder="checkpoints",
    model_names=model_names,  # Only use CNN1D and Wavenet
    train=False,
    input_type='transformed',
    params={'epochs': 100, 'batch_size': 4096, 'checkpoint_freq': 20}  # params: epochs, checkpoint_freq, lr, batch_size, device, patience, gradient_clip
)

In [None]:
from typing import List, Tuple, Dict
from utils import split_data, find_seizure_related_channels
from datasetConstruct import load_single_seizure
from models import output_to_probability

marking_file = 'data/Seizure_Onset_Type_ML_USC.xlsx'
patient_no = 66
seizure_no = 1
data_folder = 'data'
# Set up paths
single_seizure_folder = os.path.join(data_folder, f"P{patient_no}")
save_folder = os.path.join("result", f"P{patient_no}", f"Seizure{seizure_no}")
os.makedirs(save_folder, exist_ok=True)
model_name = model_names[0]
model = models[model_name]
params = {
    'threshold': 0.8,
    'smooth_window': 10,
    'n_seconds': 60,
    'seizure_start': 10,
    'overlap': 0.9,
    'device': 'cuda:0'
}

def load_seizure_data() -> Tuple[object, List[str], List[str]]:
    """Load seizure data and channel information"""
    # Load seizure marking data
    seizure_marking = pd.read_excel(marking_file)

    # Find seizure-related channels
    seizure_channels, seizure_onset_channels = find_seizure_related_channels(
        seizure_marking, seizure_no, patient_no
    )

    # Load seizure data
    seizure_obj = load_single_seizure(single_seizure_folder, seizure_no)

    return seizure_obj, seizure_channels, seizure_onset_channels


def process_data(seizure_obj) -> Tuple[np.ndarray, np.ndarray, float]:
    """Process raw seizure data"""
    fs = seizure_obj.samplingRate
    ictal_data = seizure_obj.ictal
    preictal_data = seizure_obj.preictal2

    # Reshape and combine data
    ictal_combined = ictal_data.reshape(-1, ictal_data.shape[2])
    total_data = np.concatenate((preictal_data, ictal_combined), axis=0)

    # Split data into windows
    total_windows = split_data(total_data, fs, overlap=params['overlap'])

    return total_data, total_windows, fs


def compute_probabilities(data: np.ndarray, model, device: str) -> np.ndarray:
    """Compute seizure probabilities for each channel"""
    prob_matrix = np.zeros((data.shape[0], data.shape[2]))

    for channel in range(data.shape[2]):
        input_data = data[:, :, channel].reshape(-1, 1, data.shape[1])
        input_data = torch.tensor(input_data, dtype=torch.float32).to(device)
        prob_matrix[:, channel] = output_to_probability(model, input_data, device)

    return prob_matrix


In [None]:
# Test and debug single seizure data
# Load data
seizure_obj, seizure_channels, seizure_onset_channels = load_seizure_data()

# Process data
total_data, windowed_data, fs = process_data(seizure_obj)

# Compute probabilities
probabilities = compute_probabilities(windowed_data, model, params['device'])

In [None]:
channel = 51
seconds = 100
nsamples = int(seconds/(1-params['overlap']))
# Plot the total data and seizure probability
raw_data = np.mean(windowed_data, axis=1)[:, channel][:nsamples]
probability = probabilities[:, channel][:nsamples]

# Smooth the probability by using a moving average
probability = np.convolve(probability, np.ones(params['smooth_window']) / params['smooth_window'], mode='same')

fig, ax = plt.subplots(2, 1, figsize=(10, 6))
ax[0].plot(raw_data)
ax[1].plot(probability)
ax[0].set_title(f'Channel {channel} - Raw Data')
ax[1].set_title(f'Channel {channel} - Seizure Probability')
# Set x-axis labels
ax[1].set_xlabel('Time (s)')
# Change x-ticks to seconds
x_ticks = np.arange(0, nsamples, 100)
x_labels = np.arange(0, seconds, 10)
ax[1].set_xticks(x_ticks)
ax[1].set_xticklabels(x_labels)

# Delete ax[0] x-axis labels
ax[0].set_xticks([])

plt.tight_layout()
plt.show()

In [None]:
# Calculate probability correlation to check the contamination of common noise
from scipy.stats import pearsonr
correlation = np.zeros((probabilities.shape[1], probabilities.shape[1]))
for i in range(probabilities.shape[1]):
    for j in range(probabilities.shape[1]):
        correlation[i, j] = np.abs(pearsonr(probabilities[:, i], probabilities[:, j])[0])
        
plt.imshow(correlation)
plt.colorbar()
plt.show()

In [None]:
results_propagation_total = []
model_name = model_names[0]
model = models[model_name]
threshold = 0.8
smooth_window = 10
LOAD=False
filename = f'{RESULT_FOLDER}/results_propagation_{model_name}_{threshold}_{smooth_window}.pkl'

# if file exists
if os.path.exists(filename) and LOAD:
    with open(filename, 'rb') as f:
        results_propagation_total = pickle.load(f)
        
if len(results_propagation_total) == 0:
    for PAT_NO in [65, 66]:
        for seizure_no in seizures:
            if PAT_NO == 66 and seizure_no > 3:
                continue
            params = {
                'threshold': threshold,
                'smooth_window': smooth_window,
                'n_seconds': 60,
                'seizure_start': 10,
            }
            results_propagation = analyze_seizure_propagation(
                patient_no=PAT_NO,
                seizure_no=seizure_no,
                model=model,
                data_folder='data',
                params=params,
                save_results_ind=True
            )
            results_propagation_total.append(results_propagation)
            
    with open(filename, 'wb') as f:
        pickle.dump(results_propagation_total, f)

In [None]:
# Examine the result:
from plotFun import plot_eeg_style
# 1. Plot the smoothed result:
sample_result = results_propagation_total[5]['smoothed_probabilities'][20:350]
sample_result2 = results_propagation_total[5]['probabilities'][20:300]
fig = plot_eeg_style(sample_result.T, sampling_rate=5)
plt.show()

In [None]:
# # Load and test the augmented data
# augdata = pd.read_csv('data/clips.tar.gz', compression='gzip', header=0, sep=' ;', encoding='ISO-8859-2', quotechar='"', engine='python')


In [None]:
seizure_channels_dataset_train, seizure_channels_dataset_val, seizure_onset_dataset_train, seizure_onset_dataset_val = construct_channel_recognition_dataset(results_propagation_total, 50, batch_size=128, data_aug=False)

In [None]:
# Define the model
model_seizure_channel = Wavenet(input_dim=1, output_dim=2, lr=0.001)

# Train the model
train_loss, val_loss, val_accuracy = train_using_optimizer(model_seizure_channel, seizure_channels_dataset_train, seizure_channels_dataset_val, epochs=200, checkpoint_freq=20)

In [None]:
# Plot the training and validation loss

x_ticks = range(0, 200, 20)

plt.plot(train_loss, label='Train')
plt.plot(x_ticks, val_loss, label='Validation')
# Twin the y-axis for accuracy of validation
plt.twinx()
plt.plot(x_ticks, val_accuracy, label='Validation Accuracy', color='red')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.savefig('result/loss_seizure_channels.png')

plt.show()

In [None]:
predicted_labels_total = []

time_lengths = [10, 20, 30, 40, 50, 60]

for i in range(len(results_propagation_total)):

    sample_seizure = results_propagation_total[i]['smoothed_probabilities'][50:300, :]
    
    # Feed the data to the model and get predictions
    sample_seizure = sample_seizure.T
    sample_seizure = np.expand_dims(sample_seizure, axis=1)
    
    # Convert to tensor
    sample_seizure = torch.tensor(sample_seizure, dtype=torch.float32)
    
    # Get the predictions
    predictions = model_seizure_channel(sample_seizure)
    
    # Get the predicted labels, where predicted_labels = 1 when chance is more than 80%
    predicted_labels = predictions.detach().to('cpu').numpy()
    
    predicted_labels = predicted_labels[:, 1] > 0.5
    
    predicted_labels_total.append(predicted_labels)

In [None]:
# Convert the channel from results_propagation to the y_true as 1s and 0s
gound_truth_total = []
for i in range(len(results_propagation_total)):
    y_true = np.zeros(results_propagation_total[i]['smoothed_probabilities'].shape[1])
    y_true[results_propagation_total[i]['true_seizure_channels']] = 1
    gound_truth_total.append(y_true)
    
# Convert the gound_truth_total to a single array
gound_truth_total = np.concatenate(gound_truth_total)

# Convert the predicted_labels_total to a single array
predicted_labels_total = np.concatenate(predicted_labels_total)

In [None]:
# Plot the confusion matrix
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix

conf_matrix = confusion_matrix(gound_truth_total.flatten(), predicted_labels_total.flatten())
sns.heatmap(conf_matrix, annot=True, fmt='d')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.savefig('result/confusion_matrix_seizure_channels.png')

plt.show()


In [None]:
# Calculate the accuracy, precision, recall, and F1 score
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

accuracy = accuracy_score(gound_truth_total.flatten(), predicted_labels_total.flatten())
precision = precision_score(gound_truth_total.flatten(), predicted_labels_total.flatten())
recall = recall_score(gound_truth_total.flatten(), predicted_labels_total.flatten())
f1 = f1_score(gound_truth_total.flatten(), predicted_labels_total.flatten())

print(f'Accuracy: {accuracy}')
print(f'Precision: {precision}')
print(f'Recall: {recall}')
print(f'F1: {f1}')