In [None]:
import numpy as np
import pandas as pd
import sklearn
import matplotlib.pyplot as plt
import matplotlib as mpl
import mne
import pathlib
import pytorch_lightning as pl
import torch
import torcheeg
import xgboost
import wandb

In [None]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('../../code')
from preprocessing import preprocess_data
from transforms import _compose, _randomcrop, totensor, \
channelwide_norm, channelwise_norm, _clamp, toimshape, \
_labelcenter, _labelnorm, _labelbin
from functools import partial
from data import BrainAgeDataset

In [None]:
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.tuner import Tuner

### Interface with data

In [None]:
data_dir = pathlib.Path("../../data/bap/")
assert data_dir.is_dir()

data_dir_healthy = data_dir / "healthy_controls/preprocessed"
assert data_dir_healthy.is_dir()

data_dir_chronic_pain = data_dir / "chronic_pain_patients"
assert data_dir_chronic_pain.is_dir()

In [None]:
eeg_files_all = list(data_dir.rglob("*.vhdr"))

eeg_files_raw = [f for f in eeg_files_all if not "preprocessed" in str(f)]
eeg_files_preprocessed = [f for f in eeg_files_all if "preprocessed" in str(f)]

eeg_files_healthy = list(data_dir_healthy.rglob("*.vhdr"))
eeg_files_chronic_pain = [f for f in eeg_files_preprocessed if "chronic_pain_patients" in str(f)]


### Analyze the subject data

In [None]:
def get_subj_id(eeg_file: pathlib.Path):
    return eeg_file.stem.split("_")[-3]

def get_age(df_subj, eeg_file: pathlib.Path):
    subj_id = get_subj_id(eeg_file)
    return df_subj[df_subj["Subject ID"] == subj_id]["Age(years)"].values[0]

#### Load clinical data

In [None]:
f_subj = data_dir / "clinical_data_updated_2020-08-04.ods"
assert f_subj.is_file()

In [None]:
df_subj_chronic_pain = pd.read_excel(f_subj, sheet_name=0, engine="odf", skiprows=1)
df_subj_healthy = pd.read_excel(f_subj, sheet_name=1, engine="odf", skiprows=0)

df_subj = pd.concat([df_subj_chronic_pain, df_subj_healthy])
df_subj["Age(years)"] = df_subj["Age(years)"].fillna(df_subj["Age (years)"])

print(f"# recorded subjects:      {len(df_subj)}")
print(f"# raw eeg files:          {len(eeg_files_raw)}")
print(f"# preprocessed eeg files: {len(eeg_files_preprocessed)}")

zfill_ints = lambda x:str(x).zfill(3) if type(x) else x
df_subj_chronic_pain["Subject ID"] = df_subj_chronic_pain["Subject ID"].astype(str).map(zfill_ints)
df_subj_healthy["Subject ID"] = df_subj_healthy["Subject ID"].astype(str).map(zfill_ints)
df_subj["Subject ID"] = df_subj["Subject ID"].astype(str).map(zfill_ints)

df_subj

#### Check the metadata and set channel types

In [None]:
example_raw = mne.io.read_raw_brainvision(eeg_files_chronic_pain[0])
example_raw

In [None]:
channel_to_channel_type = {ch_name:"eeg" for ch_name in example_raw.ch_names}
channel_to_channel_type.update({"LE":"misc", "RE":"misc"})
eeg_chs = [ch for ch in example_raw.ch_names if not (ch=="RE" or ch=="LE")]
example_raw.set_channel_types(channel_to_channel_type)


In [None]:
montage = mne.channels.make_standard_montage("standard_1020")
example_raw = example_raw.set_montage(montage)

In [None]:
channel_to_channel_type = {ch_name:"eeg" for ch_name in example_raw.ch_names}
channel_to_channel_type.update({"LE":"misc", "RE":"misc"})

In [None]:
for f in eeg_files_preprocessed:
    if not f.is_file():
        print(f)

### Neural network baseline

In [None]:
from torcheeg.models import EEGNet

In [None]:
sfreq = 128
epoch_len = sfreq * 20

epochs = []
ages = []

for f in eeg_files_preprocessed:
    
    sub_id = get_subj_id(f)
    age = get_age(df_subj, f)
    raw = mne.io.read_raw_brainvision(f, verbose=False, preload=True)
    
    raw.set_channel_types(channel_to_channel_type)
    raw = raw.crop(raw.tmin+30, raw.tmax-30)
    raw = raw.notch_filter(freqs=50, notch_widths=0.5)
    raw = raw.filter(l_freq=0.5, h_freq=sfreq//(2+1))

    raw = raw.resample(sfreq=sfreq)
    data = raw.get_data(picks=eeg_chs)
    
    sections = np.arange(0, data.shape[-1], epoch_len)
    epochs_subj = np.split(data, sections, axis=1)[1:-1]
    
    epochs.append(epochs_subj)
    ages.append(len(epochs_subj)*[age])



In [None]:
hparams_eegnet = {
    "learning_rate":1e-4,
    "batch_size":64,
    "chunk_size":int(sfreq*1),
    "dropout":0.25,
    "kernel_1": sfreq//2,
    "kernel_2": sfreq//8,
    "F1":8,
    "F2":16,
    "depth_multiplier":2
}

eegnet = EEGNet(chunk_size=hparams_eegnet["chunk_size"],
               num_electrodes=63,
               dropout=hparams_eegnet["dropout"],
               kernel_1=hparams_eegnet["kernel_1"],
               kernel_2=hparams_eegnet["kernel_2"],
               F1=hparams_eegnet["F1"],
               F2=hparams_eegnet["F2"],
               D=hparams_eegnet["depth_multiplier"],
               num_classes=1)

In [None]:
mean_age = torch.tensor(round(df_subj["Age(years)"].mean(), 3))

randomcrop = partial(_randomcrop, seq_len=hparams_eegnet["chunk_size"])
# labelcenter = partial(_labelcenter, mean_age=round(df_subj["Age(years)"].mean(), 3))
labelbin = partial(_labelbin, y_lower=mean_age)
transforms = partial(_compose, transforms=[totensor, randomcrop, channelwise_norm, toimshape])
target_transforms = partial(_compose, transforms=[labelbin, totensor])

In [None]:
np.random.seed(42), np.random.shuffle(epochs)
np.random.seed(42), np.random.shuffle(ages)

n_train = int(0.8*len(epochs))

epochs_train = epochs[:n_train]
ages_train = ages[:n_train]

epochs_val = epochs[n_train:]
ages_val = ages[n_train:]

dataset_train = BrainAgeDataset(np.concatenate(epochs_train), np.concatenate(ages_train), 
                                transforms=transforms, target_transforms=target_transforms)
dataset_val = BrainAgeDataset(np.concatenate(epochs_val), np.concatenate(ages_val), 
                              transforms=transforms, target_transforms=target_transforms)

dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=hparams_eegnet["batch_size"], shuffle=True, drop_last=True)
dataloader_val = torch.utils.data.DataLoader(dataset_val, batch_size=hparams_eegnet["batch_size"], drop_last=True)

print(len(dataset_train), len(dataset_val))

In [None]:
n_class_0 = 0
for x, y in dataset_train:
    n_class_0 += y.item()


print("Subjects above the age threshold")  
print(f"train set: {100*n_class_0 / len(dataset_train):.3}")

n_class_0 = 0
for x, y in dataset_val:
    n_class_0 += y.item()

print(f"validation set: {100*n_class_0 / len(dataset_val):.3}")

In [None]:
for eeg, age in dataloader_train.dataset:
    plt.figure(figsize=(12,5))
    print(eeg.mean(), eeg.std())
    plt.title(f"age label: {str(age.item())}")
    plt.imshow(eeg[0])
    
    plt.plot(eeg[0].T + torch.arange(eeg.shape[1]).unsqueeze(0), "r", alpha=0.25)
    plt.xlabel(f"time [@{sfreq}Hz]")

    plt.show()
    break



In [None]:
len(np.concatenate(epochs)), len(np.concatenate(ages)), \
len(np.concatenate(epochs_train)), len(np.concatenate(ages_train))

In [None]:
class BrainAgeModel(pl.LightningModule):
        
    def __init__(self, model, hparams, loss_func, metric=None):
        super().__init__()
        self.model = model
        self.hparams.update(hparams)
        self.loss_func = loss_func
        self.metric = metric
        
    def forward(self, x):
        x = self.model(x)
        return x
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams["learning_rate"])
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        val_loss = -self.loss_func(y_hat.squeeze(), y.squeeze())
        self.log("validation loss", val_loss, on_step=True, on_epoch=True, prog_bar=True)
        
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        loss = -self.loss_func(y_hat.squeeze(), y.squeeze())
        self.log("training loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        if self.metric:
            metric_val = self.metric["func"](y_hat, y)
            self.log(self.metric["name"], metric_val, on_step=True, on_epoch=True, prog_bar=True)

        return loss

In [None]:
def score_accuracy(y_hat, y):
    return (y==torch.argmax(y_hat, dim=-1)).float().mean()

def score_binary_accuracy(y_hat, y):
    return (y==torch.round(y_hat)).float().mean()

def compute_r2(model, val_dataloader):
    model.eval(), model.cpu()
    rss, tss = 0, 0
    for batch in val_dataloader:
        x, y = batch
        y_hat = model.forward(x)
        rss += torch.nn.functional.mse_loss(y.squeeze(), y_hat.squeeze())
        tss += y.var()
    return 1 - rss/tss

In [None]:
model = BrainAgeModel(
    model=list(eegnet.modules())[0], 
    hparams=hparams_eegnet, 
    loss_func=torch.nn.functional.cross_entropy, 
    metric={"name":"binary accuracy", "func":score_binary_accuracy})
print(model)
example_out = model(next(iter(dataloader_train))[0])
example_out.shape, example_out.min(), example_out.max()

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

overfit_batches = 1
wandb.login()
logger = pl.loggers.WandbLogger(project="brain-age-ssl", name=f"EEGNet baseline on {overfit_batches} batches", 
                                save_dir="/data0/practical-sose23/brain-age", log_model=False)

early_stop_callback = EarlyStopping(monitor="validation loss", min_delta=0.00, patience=500, verbose=False, mode="max")

trainer = pl.Trainer(callbacks=[early_stop_callback], overfit_batches=1, max_epochs=500, accelerator="gpu", logger=logger)
# Tuner(trainer).lr_find(model, dataloader_train)

trainer.fit(model, dataloader_train, dataloader_val)
wandb.finish()

In [None]:
compute_r2(model, dataloader_val), \
compute_r2(model, dataloader_train)

In [None]:
!nvidia-smi