In [1]:
import numpy as np
import pandas as pd
import sklearn
import matplotlib.pyplot as plt
import matplotlib as mpl
import mne
import pathlib
import pytorch_lightning as pl
import torch
import torcheeg
import xgboost
import wandb

In [2]:
pathlib.Path().resolve()

PosixPath('/u/home/swth/brain-age/drafts/schwarz')

In [3]:
import urllib

In [4]:
urllib.request.urlretrieve("http://fcon_1000.projects.nitrc.org/indi/cmi_healthy_brain_network/File/_pheno/HBN_R10_Pheno.csv", "./HBN_R10_Pheno.csv")


('./HBN_R10_Pheno.csv', <http.client.HTTPMessage at 0x7f87aae50ca0>)

In [5]:
pd.read_csv("./HBN_R10_Pheno.csv").EID

0      NDARJK487UCN
1      NDARFW670TY2
2      NDARHU395FP0
3      NDARHV885JFU
4      NDARCP753UEW
           ...     
847    NDARUP748AXG
848    NDARYE131DBX
849    NDARMG355BP1
850    NDARZC445DDK
851    NDAREK130FBX
Name: EID, Length: 852, dtype: object

In [6]:
class InterpolateElectrodes(object):
    """
    interpolates between electrodes recomputing the interpolation matrix for each sample
    """
    
    def __init__(self, from_montage, to_montage):

        ### Get interpolation matrix given several mne montage (covering all channels of interest)

        self.from_ch_pos = np.stack(
            [value for key, value in from_montage.get_positions()["ch_pos"].items()]
        )
        self.to_ch_pos = np.stack(
            [value for key, value in to_montage.get_positions()["ch_pos"].items()]
        )   
    def __call__(self, x):
        interpolation_matrix = mne.channels.interpolation._make_interpolation_matrix(
                self.from_ch_pos, self.to_ch_pos
                )
        x_interpolated = np.matmul(interpolation_matrix, x)
        return x_interpolated

### Interface with data

In [7]:
data_dir = pathlib.Path("../../data/")
assert data_dir.is_dir()

data_dir_healthy = data_dir / "healthy_controls/preprocessed"
assert data_dir_healthy.is_dir()

data_dir_chronic_pain = data_dir / "chronic_pain_patients"
assert data_dir_chronic_pain.is_dir()

AssertionError: 

In [None]:
eeg_files_all = list(data_dir.rglob("*.vhdr"))

eeg_files_raw = [f for f in eeg_files_all if not "preprocessed" in str(f)]
eeg_files_preprocessed = [f for f in eeg_files_all if "preprocessed" in str(f)]

eeg_files_healthy = list(data_dir_healthy.rglob("*.vhdr"))
eeg_files_chronic_pain = [f for f in eeg_files_preprocessed if "chronic_pain_patients" in str(f)]


### Analyze the subject data

#### Load clinical data

In [None]:
f_subj = data_dir / "clinical_data_updated_2020-08-04.ods"
assert f_subj.is_file()

In [None]:
df_subj_chronic_pain = pd.read_excel(f_subj, sheet_name=0, engine="odf", skiprows=1)
df_subj_healthy = pd.read_excel(f_subj, sheet_name=1, engine="odf", skiprows=0)

df_subj = pd.concat([df_subj_chronic_pain, df_subj_healthy])
df_subj["Age(years)"] = df_subj["Age(years)"].fillna(df_subj["Age (years)"])

print(f"# recorded subjects:      {len(df_subj)}")
print(f"# raw eeg files:          {len(eeg_files_raw)}")
print(f"# preprocessed eeg files: {len(eeg_files_preprocessed)}")

zfill_ints = lambda x:str(x).zfill(3) if type(x) else x
df_subj_chronic_pain["Subject ID"] = df_subj_chronic_pain["Subject ID"].astype(str).map(zfill_ints)
df_subj_healthy["Subject ID"] = df_subj_healthy["Subject ID"].astype(str).map(zfill_ints)
df_subj["Subject ID"] = df_subj["Subject ID"].astype(str).map(zfill_ints)

df_subj

#### Check the metadata and set channel types

In [None]:
example_raw = mne.io.read_raw_brainvision(eeg_files_chronic_pain[0])
example_raw

In [None]:
channel_to_channel_type = {ch_name:"eeg" for ch_name in example_raw.ch_names}
channel_to_channel_type.update({"LE":"misc", "RE":"misc"})
eeg_chs = [ch for ch in example_raw.ch_names if not (ch=="RE" or ch=="LE")]
example_raw.set_channel_types(channel_to_channel_type)


In [None]:
montage = mne.channels.make_standard_montage("standard_1020")
example_raw = example_raw.set_montage(montage)

In [None]:
for f in eeg_files_preprocessed:
    if not f.is_file():
        print(f)

### Extract PSD dataset

In [None]:
def get_subj_id(eeg_file: pathlib.Path):
    return eeg_file.stem.split("_")[-3]

def get_age(df_subj, eeg_file: pathlib.Path):
    subj_id = get_subj_id(eeg_file)
    return df_subj[df_subj["Subject ID"] == subj_id]["Age(years)"].values[0]
    

In [None]:
ages = []
psds = []

for f in eeg_files_preprocessed:
    
    age = get_age(df_subj, f)
    raw = mne.io.read_raw_brainvision(f, verbose=False, preload=True)
    raw = raw.set_channel_types(channel_to_channel_type, verbose=False)
    raw = raw.notch_filter(freqs=50, notch_widths=0.5)
    raw = raw.filter(l_freq=1, h_freq=100)
    ages.append(age)
    psds.append(np.log(raw.compute_psd().get_data()))
    
psds = np.stack(psds)
ages = np.array(ages)

In [None]:
from sklearn.ensemble import RandomForestRegressor
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import RandomizedSearchCV, RepeatedKFold, cross_val_score, KFold, train_test_split
from sklearn.pipeline import Pipeline

### Baselines with train-validation split

In [None]:
ch_name_to_ch_pos = montage.get_positions()["ch_pos"]

pos = [ch_name_to_ch_pos[ch_name] for ch_name in example_raw.ch_names if channel_to_channel_type[ch_name]=="eeg"]
pos = np.stack(pos)

In [None]:
X, y = psds.reshape(len(psds), -1), ages
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.33)
X_val, X_test, y_val, y_test = train_test_split(
    X_test, y_test, test_size=0.66)
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)

rfg = RandomForestRegressor()
rfg.fit(X_train, y_train)
r_squared_rfg_train = rfg.score(scaler.transform(X_train), y_train)
r_squared_rfg_val = rfg.score(scaler.transform(X_val), y_val)

In [None]:
# percent_feat = 0.5
# feat_mask = rfg.feature_importances_ > np.percentile(rfg.feature_importances_, 100-percent_feat)
# X_train = X_train[:, feat_mask]
# X_val = X_val[:, feat_mask]
# scaler = StandardScaler()
# X_train = scaler.fit_transform(X_train)

gbr = GradientBoostingRegressor(n_estimators=400)
gbr.fit(X_train, y_train)
r_squared_gbr_train = gbr.score(scaler.transform(X_train), y_train)
r_squared_gbr_val = gbr.score(scaler.transform(X_val), y_val)

print("\nTraining performance (R²) \n")
print(f"random forest regressor:     {r_squared_rfg_train:.3}")
print(f"gradient boosting regressor: {r_squared_gbr_train:.3}")
print("\nValidation performance (R²) \n")
print(f"random forest regressor:     {r_squared_rfg_val:.3}")
print(f"gradient boosting regressor: {r_squared_gbr_val:.3}")

### KFold Crossvalidation

In [None]:
X, y = psds.reshape(len(psds), -1), ages
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.33)

scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)

In [None]:
rfg = RandomForestRegressor()
rfg_pipe = Pipeline(steps=[("scaler", scaler), ("rfg", rfg)])

cv = KFold(n_splits=10)

r_squared_rfg_val = cross_val_score(rfg_pipe, X=X_train, y=y_train, scoring='r2', cv=cv, n_jobs=-1)
r_squared_rfg_val_mean, r_squared_rfg_val_std = r_squared_rfg_val.mean(), r_squared_rfg_val.std()

print(f"{r_squared_rfg_val_mean:.3} +/- {r_squared_rfg_val_std:.3}")


In [None]:
gbr = GradientBoostingRegressor()
gbr_pipe = Pipeline(steps=[("scaler", scaler), ("gbr", gbr)])

cv = KFold(n_splits=10)

r_squared_gbr_val = cross_val_score(gbr_pipe, X=X_train, y=y_train, scoring='r2', cv=cv, n_jobs=-1)
r_squared_gbr_val_mean, r_squared_gbr_val_std = r_squared_gbr_val.mean(), r_squared_gbr_val.std()

print(f"{r_squared_gbr_val_mean:.3} +/- {r_squared_gbr_val_std:.3}")


In [None]:
xgbr = XGBRegressor(n_estimators=160, max_depth=2, learning_rate=0.1)
xgbr_pipe = Pipeline(steps=[("scaler", scaler), ("xgbr", xgbr)])

cv = KFold(n_splits=10)

r_squared_xgbr_val = cross_val_score(xgbr_pipe, X=X_train, y=y_train, scoring='r2', cv=cv, n_jobs=-1)
r_squared_xgbr_val_mean, r_squared_xgbr_val_std = r_squared_xgbr_val.mean(), r_squared_xgbr_val.std()

print(f"{r_squared_xgbr_val_mean:.3} +/- {r_squared_xgbr_val_std:.3}")

### XGBoost

In [None]:
from xgboost import XGBRegressor

In [None]:
param_distributions = {
    "xg_gbr__n_estimators": [20, 80, 160, 320, 640, 1280],
    "xg_gbr__max_depth": [2, 3, 4, 5, 6],
    "xg_gbr__learning_rate": [1e-3, 1e-2, 1e-1, 0.2]
}

xgb_regr = XGBRegressor()
xgb_pipe = Pipeline(steps=[("scaler", scaler), ("xg_gbr", xgb_regr)])

inner_cv = KFold(n_splits=10)
outer_cv = KFold(n_splits=10)

xgb_regr_cv = RandomizedSearchCV(xgb_pipe, param_distributions, cv=inner_cv, n_iter=10)
search = xgb_regr_cv.fit(X_train, y_train)

r_squared_xgbr_train = xgb_regr_cv.best_score_
r_squared_xgbr_val = cross_val_score(xgb_regr_cv, X=X_train, y=y_train, scoring='r2', cv=outer_cv)

print(
    r_squared_xgbr_train,
    r_squared_xgbr_val.mean(),
    r_squared_xgbr_val.std()
)

### Neural network baseline

In [None]:
from torcheeg.models import EEGNet

In [None]:
channel_to_channel_type = {ch_name:"eeg" for ch_name in raw.ch_names}
channel_to_channel_type.update({"LE":"misc", "RE":"misc"})

In [None]:
sfreq = 300
epoch_len = sfreq * 10

epochs = []
ages = []

for f in eeg_files_preprocessed:
    
    sub_id = get_subj_id(f)
    age = get_age(df_subj, f)
    raw = mne.io.read_raw_brainvision(f, verbose=False, preload=True)
    
    raw.set_channel_types(channel_to_channel_type)
    raw = raw.crop(raw.tmin+30, raw.tmax-30)
    raw = raw.filter(l_freq=0.5, h_freq=100)
    raw = raw.notch_filter(freqs=50, notch_widths=0.5)

    raw = raw.resample(sfreq=sfreq)
    data = raw.get_data(picks=eeg_chs)
    
    sections = np.arange(0, data.shape[-1], epoch_len)
    epochs_subj = np.split(data, sections, axis=1)[1:-1]
    
    epochs.append(epochs_subj)
    ages.append(len(epochs_subj)*[age])



In [None]:
hparams_eegnet = {
    "learning_rate":1e-3,
    "batch_size":64,
    "dropout":0.5,
    "kernel_1":150,
    "kernel_2":16,
    "F1":8,
    "F2":8,
    "depth_multiplier":2
}

eegnet = EEGNet(chunk_size=epoch_len,
               num_electrodes=63,
               dropout=hparams_eegnet["dropout"],
               kernel_1=hparams_eegnet["kernel_1"],
               kernel_2=hparams_eegnet["kernel_2"],
               F1=hparams_eegnet["F1"],
               F2=hparams_eegnet["F2"],
               D=hparams_eegnet["depth_multiplier"],
               num_classes=1)

In [None]:
from functools import partial

In [None]:
def _totensor(x):
    return torch.tensor(x).float()

def _channelwise_norm(x):
    return (x - x.mean(-1, keepdims=True)) / x.std(-1, keepdims=True)

def _toimshape(x):
    return x.unsqueeze(0)

def _compose(x, transforms):
    for transform in transforms:
        x = transform(x)
    return x

def _labelcenter(x, mean_age):
    return x - mean_age

_labelcenter = partial(_labelcenter, mean_age=df_subj["Age(years)"].mean())
transforms = partial(_compose, transforms=[_totensor, _channelwise_norm, _toimshape])

In [None]:
class BrainAgeDataset(torch.utils.data.Dataset):
    def __init__(self, epochs, ages, transforms=lambda x:x, target_transforms=lambda x:x):
        self.epochs = epochs
        self.ages = ages
        self.transforms = transforms
        self.target_transforms = target_transforms

    def __getitem__(self, idx):
        return self.transforms(self.epochs[idx]), self.target_transforms(self.ages[idx])
    
    def __len__(self):
        return len(self.ages)

In [None]:
np.random.seed(42)
np.random.shuffle(epochs)
np.random.seed(42)
np.random.shuffle(ages)

n_train = int(0.8*len(epochs))

epochs_train = epochs[:n_train]
ages_train = ages[:n_train]

epochs_val = epochs[n_train:]
ages_val = ages[n_train:]

dataset_train = BrainAgeDataset(np.concatenate(epochs_train), np.concatenate(ages_train), 
                                transforms=transforms, target_transforms=_labelcenter)
dataset_val = BrainAgeDataset(np.concatenate(epochs_val), np.concatenate(ages_val), 
                              transforms=transforms, target_transforms=_labelcenter)

dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=hparams_eegnet["batch_size"], drop_last=True)
dataloader_val = torch.utils.data.DataLoader(dataset_val, batch_size=hparams_eegnet["batch_size"], drop_last=True)

In [None]:
len(np.concatenate(epochs)), len(np.concatenate(ages)), \
len(np.concatenate(epochs_train)), len(np.concatenate(ages_train))

In [None]:
class BrainAgeModel(pl.LightningModule):
    
    def __init__(self, model, hparams):
        super().__init__()
        self.model = model
        self.save_hyperparameters(ignore=['model'])
        
    def forward(self, x):
        x = self.model(x)
        return x
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.hparams["learning_rate"])
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        val_loss = torch.nn.functional.l1_loss(y.squeeze(), y_hat)
        self.log("validation loss", val_loss, on_step=True, on_epoch=True, prog_bar=True)
        
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        loss = torch.nn.functional.l1_loss(y.squeeze(), y_hat)
        self.log("training loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        return loss

In [None]:
model = BrainAgeModel(list(eegnet.modules())[0], hparams_eegnet)
print(model)

In [None]:
len(dataloader_train)

In [None]:
def compute_r2(model, val_dataloader):
    model.eval()
    model.cpu()
    rss = 0
    tss = 0
    n = 0
    for batch in val_dataloader:
        x, y = batch
        y_hat = model.forward(x)
        rss += torch.nn.functional.mse_loss(y.squeeze(), y_hat.squeeze())
        tss += y.var()
#         n += len(y)
    return 1 - rss/tss    
    

In [None]:
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

wandb.login()
logger = pl.loggers.WandbLogger(project="brain-age-ssl", name="EEGNet baseline on 1.0 batches", 
                                save_dir="/data0/practical-sose23/brain-age", log_model=False)

early_stop_callback = EarlyStopping(monitor="validation loss", min_delta=0.00, patience=3, verbose=False, mode="max")

trainer = pl.Trainer(callbacks=[early_stop_callback], overfit_batches=1.0, max_epochs=200, accelerator="gpu", logger=logger)
trainer.fit(model, dataloader_train, dataloader_val)
wandb.finish()

In [None]:
compute_r2(model, dataloader_val), \
compute_r2(model, dataloader_train)

In [None]:
!nvidia-smi