In [None]:
# At the start of your notebook
from IPython.display import clear_output
import gc

# After heavy computations
clear_output(wait=True)
gc.collect()

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import torch

from utils import split_data
from datasetConstruct import combine_loaders, load_seizure_across_patients, create_dataset
from models import CNN1D, train_using_optimizer, Wavenet, LSTM, evaluate_model, output_to_probability
from plotFun import plot_time_limited_heatmap, plot_eeg_style

data_folder = "data"
MODEL_FOLDER = "checkpoints"

In [None]:
seizure_across_patients = load_seizure_across_patients(data_folder)

ml_datasets = [create_dataset(seizure, batch_size=4096) for seizure in seizure_across_patients]

train_loader, val_loader = combine_loaders(ml_datasets, batch_size=4096)

channels, time_steps = train_loader.dataset[0][0].shape

In [None]:
# Create the model
epochs = 40
checkpoint_freq = 5
lr = 0.001 # DO NOT CHANGE!
TRAIN = True

model1 = CNN1D(input_dim=channels, kernel_size=time_steps, output_dim=2, lr=lr)
model2 = Wavenet(input_dim=channels, output_dim=2, kernel_size=time_steps, lr=lr)
model3 = LSTM(input_dim=channels, output_dim=2, lr=lr)

if TRAIN:

    # Train the model
    CNNtrain_loss, CNNval_los, CNNval_accuracy = train_using_optimizer(
                                                                        model=model1,
                                                                        trainloader=train_loader,
                                                                        valloader=val_loader,
                                                                        save_location='checkpoints',
                                                                        epochs=epochs,
                                                                        device='cuda:0',
                                                                        patience=7,
                                                                        gradient_clip=1.0,
                                                                        checkpoint_freq=checkpoint_freq
                                                                    )
    
    Wavetrain_loss, Waveval_los, Waveval_accuracy = train_using_optimizer(
                                                                            model=model2,
                                                                            trainloader=train_loader,
                                                                            valloader=val_loader,
                                                                            save_location='checkpoints',
                                                                            epochs=epochs,
                                                                            device='cuda:0',
                                                                            patience=7,
                                                                            gradient_clip=1.0,
                                                                            checkpoint_freq=checkpoint_freq
                                                                        )
    # LSTMtrain_loss, LSTMval_los, LSTMval_accuracy = train_using_optimizer(
    #                                                                         model=model3,
    #                                                                         trainloader=train_loader,
    #                                                                         valloader=val_loader,
    #                                                                         save_location='checkpoints',
    #                                                                         epochs=epochs,
    #                                                                         device='cuda:0',
    #                                                                         patience=7,
    #                                                                         gradient_clip=1.0,
    #                                                                         checkpoint_freq=checkpoint_freq
    #                                                                     )
    # 
else:
    # Load the model
    model1.load_state_dict(torch.load(os.path.join(MODEL_FOLDER, "CNN1D_best.pth")))
    model2.load_state_dict(torch.load(os.path.join(MODEL_FOLDER, "Wavenet_best.pth")))
    # model3.load_state_dict(torch.load(os.path.join(MODEL_FOLDER, "LSTM_best.pth")))

# Evaluate the model
loss_CNN, acuracy_CNN = evaluate_model(model1, val_loader,'cuda:0')
# loss_LSTM, acuracy_LSTM = evaluate_model(model3, val_loader,'cuda:0')
loss_Wavenet, acuracy_Wavenet = evaluate_model(model2, val_loader,'cuda:0')

In [None]:
if TRAIN:

    # Plot the loss and accuracy
    
    plt.figure()
    plt.plot(CNNtrain_loss, label="CNN Training Loss")
    # plt.plot(CNNval_los, label="CNN Validation Loss")
    plt.plot(Wavetrain_loss, label="Wavenet Training Loss")
    # plt.plot(Waveval_los, label="Wavenet Validation Loss")
    plt.legend()
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Loss vs Epoch")
    plt.show()
    
    plt.figure()
    plt.plot(CNNval_accuracy, label="CNN Validation Accuracy")
    plt.plot(Waveval_accuracy, label="Wavenet Validation Accuracy")
    plt.legend()
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.title("Accuracy vs Epoch")
    plt.show()


In [None]:
from datasetConstruct import load_single_seizure
pat_No = 66
data_folder = os.path.join("data", f"P{pat_No}")
seizure_no = 2
# Get the probability of the all the dataset in a temporal matter
seizure1 = load_single_seizure(data_folder, seizure_no)
fs = seizure1.samplingRate

In [None]:
model = model2
model_name = model.__class__.__name__

seizure1_data = seizure1.ictal
seizure1_preictal = seizure1.preictal2

seizure1_data_combined = seizure1_data.reshape(-1, seizure1_data.shape[2])
seizure1_preictal_combined = seizure1_preictal

seizure1_total_con = np.concatenate((seizure1_preictal_combined, seizure1_data_combined), axis=0)

# Resample the data to 1s windows with 80% overlap
seizure1_total = split_data(seizure1_total_con, fs, overlap=0.8)

probabilities_matrix = np.zeros((seizure1_total.shape[0], seizure1_total.shape[2]))

# Feed the data to the model
for channel in range(seizure1_total.shape[2]):
    input_data = seizure1_total[:, :, channel].reshape(-1, 1, seizure1_total.shape[1])
    input_data = torch.tensor(input_data, dtype=torch.float32).to('cuda:0')
    probabilities_matrix[:, channel] = output_to_probability(model, input_data, 'cuda:0')
    

In [None]:
n_seconds = 80

# Create time axes
time_raw = np.arange(0, seizure1_total_con.shape[0]) / fs
time_prob = np.arange(0, probabilities_matrix.shape[0]) * 0.2

# Concatenate data
seizure_total = np.concatenate((seizure1_preictal_combined, 
                              seizure1_data_combined), axis=0)

print("Mean of seizure data:", np.mean(seizure1_data_combined))
print("Mean of preictal data:", np.mean(seizure1_preictal_combined))
print("Max of seizure data:", np.max(seizure1_data_combined))
print("Max of preictal data:", np.max(seizure1_preictal_combined))
print("Min of seizure data:", np.min(seizure1_data_combined))
print("Min of preictal data:", np.min(seizure1_preictal_combined))
print("Mean of seizure probability:", np.mean(probabilities_matrix))

# Plot probability data
plot_time_limited_heatmap(
    data=probabilities_matrix.T,
    time_axis=time_prob,
    n_seconds=n_seconds,
    preictal_boundary=50,
    title=f"{model_name} " + f"Probability of Seizure{seizure_no} (First {n_seconds}s)" if n_seconds else "Probability of Seizure",
    cmap='hot',
    save_path=f"result/Seizure{seizure_no}{model_name}Probability.png",
    flip_yaxis=True
)

In [None]:
sub_seizure_total = seizure_total[: fs*n_seconds, :]
fig = plot_eeg_style(sub_seizure_total.T, fs, spacing_factor=2, color='black', linewidth=0.5)
fig.savefig(f"result/Seizure{seizure_no}RawDataEEG.png", dpi=300, bbox_inches='tight')
plt.show()

In [None]:
sub_probability_total = probabilities_matrix[: 5*n_seconds, :]
fig = plot_eeg_style(sub_probability_total.T, 5, spacing_factor=2, color='black', linewidth=0.5)
fig.savefig(f"result/Seizure{seizure_no}{model_name}_Probability.png", dpi=300, bbox_inches='tight')
plt.show()

In [None]:
# Rerank the Probability Based on the when the channel first reach the threshold
threshold = 0.6
# Smooth the probability data over 2 seconds, each tick is 0.2s
smooth_window = 50
n_seconds = 80

# Smooth the probability data
probabilities_matrix_smoothed = np.zeros_like(probabilities_matrix)
for i in range(probabilities_matrix.shape[1]):
    probabilities_matrix_smoothed[:, i] = np.convolve(probabilities_matrix[:, i], 
                                                       np.ones(smooth_window) / smooth_window, 
                                                       mode='same')

# Find the first index where the probability is greater than the threshold
first_threshold_indices = np.argmax(probabilities_matrix_smoothed > threshold, axis=0)
if np.sum(first_threshold_indices == 0) > 0:
    first_threshold_indices[first_threshold_indices == 0] = len(probabilities_matrix_smoothed)
sorted_indices = np.argsort(first_threshold_indices)[::-1]

# # Sort the channels based on the number of ticks to reach the threshold
# number_of_ticks_each_channel = np.zeros(sub_probability.shape[1])
# for i in range(sub_probability.shape[1]):
#     number_of_ticks_each_channel[i] = np.sum(sub_probability[:, i] > threshold)
#     sorted_indices = np.argsort(number_of_ticks_each_channel)

# # Sort the channels based on the mean probability
# mean_probabilities = np.mean(sub_probability, axis=0)
# sorted_indices = np.argsort(mean_probabilities)

# Plot the probability data
plot_time_limited_heatmap(
    data=probabilities_matrix_smoothed[:, sorted_indices].T,
    time_axis=time_prob,
    n_seconds=n_seconds,
    preictal_boundary=50,
    title=f"{model_name} " + f"Probability of Seizure{seizure_no} (First {n_seconds}s) - Reranked",
    cmap='hot',
    save_path=f"result/Seizure{seizure_no}{model_name}ProbabilityReranked.png",
    flip_yaxis=False
)