<a href="https://colab.research.google.com/github/alexandertaoadams/AlexanderAdamsMastersThesis/blob/main/NB_EEG_Neonatal_Stoch_PerPatient.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install gpjax
!pip install sktime
!pip install mne

In [67]:
# jax libraries
import numpy as np
import jax
import jax.numpy as jnp

# gpjax libraries
import gpjax as gpx

# core libraries
from flax import nnx
import optax as ox

# data manipulation and visualisation libraries
import mne
import pandas as pd
from sktime.datasets import load_from_tsfile
from matplotlib import pyplot as plt
import seaborn as sns

In [68]:
!git clone https://github.com/alexandertaoadams/AlexanderAdamsMastersThesis.git

import sys
sys.path.insert(0, '/content/AlexanderAdamsMastersThesis')
import AlexanderAdamsMastersThesis.src as src

from src.kernels import SignatureKernel
from src.inducing_variables import initial_inducing_variables
from src.utils import normalise

fatal: destination path 'AlexanderAdamsMastersThesis' already exists and is not an empty directory.


### **Data loading and preprocessing**

In [69]:
file_path_A = "/content/drive/MyDrive/DATA_EEG_Neonatal/annotations_2017_A_fixed.csv"
df_A = pd.read_csv(file_path_A)

file_path_B = "/content/drive/MyDrive/DATA_EEG_Neonatal/annotations_2017_B.csv"
df_B = pd.read_csv(file_path_B)

file_path_C = "/content/drive/MyDrive/DATA_EEG_Neonatal/annotations_2017_C.csv"
df_C = pd.read_csv(file_path_C)

valid_patients = []
for patient_index in range(1,80):
    patient_index  = str(patient_index)
    df_sub_A_shape = df_A.loc[df_A[patient_index].notna(), patient_index].shape
    df_sub_B_shape = df_B.loc[df_B[patient_index].notna(), patient_index].shape
    df_sub_C_shape = df_C.loc[df_C[patient_index].notna(), patient_index].shape
    if df_sub_A_shape == df_sub_B_shape == df_sub_C_shape:
        valid_patients.append(patient_index)

all_patients = [str(i) for i in range(1,80)]
invalid_patients = list(set(all_patients).difference(valid_patients))

print("The following patients are invalid, because the annotation lengths between annotators are not equal:", invalid_patients)

def find_concensus(patient_index):
    patient_index = str(patient_index)
    df_A_sub = df_A.loc[df_A[patient_index].notna(), patient_index]
    df_B_sub = df_B.loc[df_B[patient_index].notna(), patient_index]
    df_C_sub = df_C.loc[df_C[patient_index].notna(), patient_index]

    agree1 = (df_A_sub == df_B_sub)
    agree2 = (df_B_sub == df_C_sub)
    agree = agree1 & agree2

    df = pd.DataFrame(df_A_sub[agree])

    return df

def get_eegs_and_labels_as_list(patient_index, length=1):

    annotations = find_concensus(str(patient_index))

    file = f'/content/drive/MyDrive/DATA_EEG_Neonatal/eeg{patient_index}.edf'
    data = mne.io.read_raw_edf(file)
    raw_data = data.get_data()
    eeg_signal = jnp.array(raw_data)

    train_data = []
    train_labels = []
    for value in annotations.index:
        eeg_start = 256*(value-0.5*(length-1))
        eeg_end = 256*(value+0.5*(length+1))
        eeg_slices = eeg_signal[:, int(eeg_start):int(eeg_end)]
        train_data.append(eeg_slices)
        train_labels.append(annotations.loc[value, str(patient_index)])

    return train_data, train_labels

The following patients are invalid, because the annotation lengths between annotators are not equal: []


In [70]:
def bipolar_montage(data):

    pairs = [ [2,6], [6,11], [11,16], [16,19],
              [1,4], [4,9], [9,14], [14,18],
              [2,7], [7,12], [12,17], [17,19],
              [1,3], [3,8], [8,13], [13,18] ]

    all_leads = []

    for i in range(0,16):
        lead_diff = data[:, pairs[i][0]-1, :] - data[:, pairs[i][1]-1, :]
        all_leads.append(lead_diff)


    features = jnp.stack(all_leads)  # Shape: (n_features, n_sequences, seq_length)
    features = features.transpose(1, 0, 2)  # Shape: (n_sequences, n_features, seq_length)
    return features

In [71]:
def train_test_split(sequences, labels, key=None):

    if key is None:
      key = jax.random.PRNGKey(0)

    positive_idx = jnp.where(labels == 1)[0]
    negative_idx = jnp.where(labels == 0)[0]

    key_pos, key_neg = jax.random.split(key)

    pos_selected = jax.random.choice(
        key_pos, positive_idx, shape=(100,), replace=False
    )
    neg_selected = jax.random.choice(
        key_neg, negative_idx, shape=(100,), replace=False
    )
    train_idx = jnp.concatenate([pos_selected, neg_selected])
    test_idx  = jnp.setdiff1d(jnp.arange(sequences.shape[0]), train_idx)

    train_data = sequences[train_idx]
    train_labels = labels[train_idx]
    test_data = sequences[test_idx]
    test_labels = labels[test_idx]

    return train_data, train_labels, test_data, test_labels


In [72]:
eegs_list = []
labels_list = []
for patient_id in (14,):
    eegs_next, labels_next = get_eegs_and_labels_as_list(patient_id)
    eegs_list.extend(eegs_next)
    labels_list.extend(labels_next)

eegs_arr = jnp.stack(eegs_list)
labels_arr = jnp.array(labels_list)

Extracting EDF parameters from /content/drive/MyDrive/DATA_EEG_Neonatal/eeg14.edf...
Setting channel info structure...
Creating raw.info structure...


  data = mne.io.read_raw_edf(file)
  data = mne.io.read_raw_edf(file)


In [73]:
train_data, train_labels, test_data, test_labels = train_test_split(eegs_arr, labels_arr)

xtrain, train_mean, train_std = normalise(bipolar_montage(train_data))
ytrain = train_labels
xtest = (bipolar_montage(test_data) - train_mean) / train_std
ytest = test_labels

### **Training**

In [None]:
# Initialising model
q_kernel = SignatureKernel(16, 256, 2)
q_mean_function = gpx.mean_functions.Constant()
q_prior = gpx.gps.Prior(mean_function=q_mean_function, kernel=q_kernel)
q_likelihood = gpx.likelihoods.Bernoulli(xtrain.shape[0])
q_posterior = q_likelihood * q_prior

D = gpx.dataset.Dataset(jnp.reshape(xtrain, (xtrain.shape[0], -1)), jnp.expand_dims(ytrain, axis=1))
Z = initial_inducing_variables(xtrain, ytrain, 64)

# Model
q = gpx.variational_families.VariationalGaussian(
    posterior=q_posterior,
    inducing_inputs=Z
)

In [None]:
# Training
optimised_model, history = gpx.fit(
    model=q,
    objective= lambda model, data: -gpx.objectives.elbo(model, data),
    train_data=D,
    optim=ox.adam(learning_rate=1e-3),
    trainable=gpx.parameters.Parameter,
    num_iters=500,
    batch_size=25,
    verbose=True
)

### **Model Evaluation**

In [76]:
@jax.jit
def predict_batch(model, batch):
    def predict_single(x):
        latent_dist = model.predict(x[None, :])
        predicted_dist = model.posterior.likelihood(latent_dist)
        return predicted_dist.mean.squeeze()
    return jax.vmap(predict_single)(batch)

def batched_predict(xtest, model, batch_size=64):
    num_points = xtest.shape[0]
    num_batches = (num_points + batch_size - 1) // batch_size

    results = []
    for i in range(num_batches):
        batch = xtest[i * batch_size : (i + 1) * batch_size]
        preds = predict_batch(model, batch)
        results.append(preds)

    return jnp.concatenate(results, axis=0)

In [77]:
import sklearn.metrics as skm
def display_results(pred_labels, true_labels):

    y_pred = np.array(pred_labels)
    y_true = np.array(true_labels)

    total = int(len(y_true))
    num_neg = int(np.sum(y_true == 0))
    num_pos = int(np.sum(y_true == 1))

    tp = int(np.sum((y_true == 1) & (y_pred == 1)))
    tn = int(np.sum((y_true == 0) & (y_pred == 0)))
    fp = int(np.sum((y_true == 0) & (y_pred == 1)))
    fn = int(np.sum((y_true == 1) & (y_pred == 0)))

    mcc = float(skm.matthews_corrcoef(y_true, y_pred))
    f1 = float(skm.f1_score(y_true, y_pred))

    data = {
        "Metric": [
            "Test Size",
            "Negative Samples",
            "Positive Samples",
            "True Positives (TP)",
            "True Negatives (TN)",
            "False Positives (FP)",
            "False Negatives (FN)",
            "MCC",
            "F1 score",
        ],
        "Value": [
            total,
            num_neg,
            num_pos,
            tp,
            tn,
            fp,
            fn,
            round(mcc, 3),
            round(f1, 3),
        ]
    }

    results_table = pd.DataFrame(data).set_index("Metric")
    return results_table

In [78]:
# Get predicted means
predicted_mean = batched_predict(jnp.reshape(xtest, (xtest.shape[0], -1)), model=optimised_model)

# Get predicted class labels (0 or 1)
predicted_labels = jnp.round(predicted_mean)

In [79]:
# Display results
results = display_results(predicted_labels, ytest)
results

Unnamed: 0_level_0,Value
Metric,Unnamed: 1_level_1
Test Size,3105.0
Negative Samples,1121.0
Positive Samples,1984.0
True Positives (TP),1309.0
True Negatives (TN),844.0
False Positives (FP),277.0
False Negatives (FN),675.0
MCC,0.397
F1 score,0.733
