In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [16]:
import numpy as np
import pandas as pd
import sklearn
import matplotlib.pyplot as plt
import matplotlib as mpl
import mne
import pathlib
# import pytorch_lightning as pl
import lightning.pytorch as pl
import torch
import torcheeg
import xgboost
import wandb
import autoreject
from tqdm.notebook import tqdm
from torcheeg.models import EEGNet
from functools import partial
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import os
import csv
import wandb
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor

In [17]:
import sys
sys.path.append('../../src/utils')
from transforms import channelwide_norm, channelwise_norm, _clamp, _randomcrop, _compose
# from transforms import _compose, _randomcrop, totensor, \
# channelwide_norm, channelwise_norm, _clamp, toimshape, \
# _labelcenter, _labelnorm, _labelbin

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"

In [None]:
config = {
    'seed': 42
}

In [None]:
torch.set_float32_matmul_precision('medium')

torch.manual_seed(config['seed'])
np.random.seed(config['seed'])

In [15]:
class EEGDataset(Dataset):
    def __init__(self, dataset_names, splits, transforms, sfreq=135, len_in_sec=30, oversample=False):
        self.sfreq = sfreq
        self.len_in_sec = len_in_sec
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.transforms = transforms

        assert all(split in ['train', 'val', 'test'] for split in splits)
        assert all(dataset_name in ['hbn', 'bap', 'lemon'] for dataset_name in dataset_names)

        file_paths = {}
        age_dict = {}

        for dataset_name in dataset_names:
            dataset_path = os.path.join('/data0/practical-sose23/brain-age/data', dataset_name, 'preprocessed/v2.0')
            file_paths[dataset_name] = {}
            age_dict[dataset_name] = {}
            for split in splits:
                split_path = os.path.join(dataset_path, dataset_name + '_{}'.format(split) + '.csv')
                data = np.loadtxt(split_path, dtype=str ,delimiter=',',skiprows=1)
                file_paths[dataset_name][split] = data[:, 0]
                age_dict[dataset_name][split] = data[:, 1]
                
                # file_paths.extend(data[:, 0])

                # age = np.concatenate((age, data[:, 1]))
                # file_paths.extend([line.strip() for line in lines])

        if oversample and 'train' in split and len(dataset_names) > 1: 
            minority_len = len(file_paths['bap']['train'])
            majority_len = len(file_paths['hbn']['train'])

            ratio = majority_len // minority_len
            num_to_oversample = int(minority_len * (ratio - 1))

            oversample_indices = np.random.choice(np.arange(0, minority_len, 1), size=num_to_oversample)
            oversampled_data_path = file_paths['bap']['train'][oversample_indices]
            oversampled_age = age_dict['bap']['train'][oversample_indices]

            file_paths['bap']['train'] = np.concatenate((file_paths['bap']['train'], oversampled_data_path))
            age_dict['bap']['train'] = np.concatenate((age_dict['bap']['train'], oversampled_age),)
        
        age = np.array([])
        data_paths = np.array([])
        for data_name, splits in file_paths.items():
            for split, data in splits.items(): 
                data_paths = np.concatenate((data_paths, file_paths[data_name][split]))
                age = np.concatenate((age, age_dict[data_name][split]))

        self.target = torch.tensor(np.round((np.array(age)).astype(float)).astype(int))
        self.data_paths = data_paths


    def __len__(self):
        return len(self.data_paths)

    def __getitem__(self, index):
        data_path = self.data_paths[index]
        target = self.target[index]

        with open(data_path, 'rb') as in_file:
            eeg_npy = np.load(in_file)

        # data = eeg_npy[:, :self.sfreq * self.len_in_sec].astype(np.float32)
        data = eeg_npy.astype(np.float32)
        # data_with_channel = torch.unsqueeze(torch.tensor(data), 0)
        data = torch.unsqueeze(torch.tensor(data), 0)

        return self.transforms(data), target

In [None]:
# train_dataset = EEGDataset(args.train_dataset, ['train'], transforms=composed_transforms, oversample=args.oversample)
# autoencoder_train_dataset = EEGDataset(['lemon'], ['train'], transforms=composed_transforms, oversample=False)
# autoencoder_train_dataloader = DataLoader(autoencoder_train_dataset, 
#                             batch_size=args.batch_size, 
#                             num_workers=args.num_workers, 
#                             pin_memory=True, 
#                             shuffle=True)
train_dataset = EEGDataset(['lemon'], ['train'], transforms=composed_transforms, oversample=False)
train_dataloader = DataLoader(train_dataset, 
                            batch_size=args.batch_size, 
                            num_workers=args.num_workers, 
                            pin_memory=True, 
                            shuffle=True)

val_dataset = EEGDataset(['lemon'], ['val'], transforms=composed_transforms)
validation_dataloader =  DataLoader(val_dataset, 
                                    batch_size=args.batch_size, 
                                    num_workers=args.num_workers, 
                                    pin_memory=True, 
                                    # shuffle=True
                                    )

In [13]:
# class BrainAgeDataset(Dataset):
#     def __init__(self, epochs, ages, transforms=lambda x:x, target_transforms=lambda x:x):
#         self.epochs = epochs
#         self.ages = ages
#         self.transforms = transforms
#         self.target_transforms = target_transforms

#     def __getitem__(self, idx):
#         return self.transforms(self.epochs[idx]), self.target_transforms(self.ages[idx])
    
#     def __len__(self):
#         return len(self.ages)

In [None]:
hparams_eegnet = {
    "learning_rate":1e-3,
    "batch_size":128,
    "chunk_size":int(sfreq*4),
    "dropout":0.2,
    "kernel_1": int(sfreq//2),
    "kernel_2": int(sfreq//8),
    "F1":16,
    "F2":32,
    "depth_multiplier":2
}

In [3]:
eegnet = EEGNet(chunk_size=hparams_eegnet["chunk_size"],
               num_electrodes=63,
               dropout=hparams_eegnet["dropout"],
               kernel_1=hparams_eegnet["kernel_1"],
               kernel_2=hparams_eegnet["kernel_2"],
               F1=hparams_eegnet["F1"],
               F2=hparams_eegnet["F2"],
               D=hparams_eegnet["depth_multiplier"],
               num_classes=1)

NameError: name 'sfreq' is not defined

In [None]:
# mean_age = torch.tensor(round(df_subj["Age"].mean(), 3))

# randomcrop = partial(_randomcrop, seq_len=hparams_eegnet["chunk_size"])
# clamp = partial(_clamp, dev_val=20.0)
# labelcenter = partial(_labelcenter, mean_age=round(df_subj["Age(years)"].mean(), 3))
# labelbin = partial(_labelbin, y_lower=mean_age)
# transforms = partial(_compose, transforms=[totensor, randomcrop, channelwise_norm, clamp, toimshape])
# target_transforms = partial(_compose, transforms=[labelcenter, totensor])

In [None]:
if args.standardization == "channelwise":
    norm = channelwise_norm
elif args.standardization == "channelwide":
    norm = channelwide_norm 
randomcrop = partial(_randomcrop, seq_len=args.crop_len)
clamp = partial(_clamp, dev_val=args.clamp_val)
composed_transforms = partial(_compose, transforms=[randomcrop, norm, clamp])

In [None]:
wandb.login()
logger = pl.loggers.WandbLogger(project="brain-age", name=args.experiment_name, 
                                save_dir="wandb/", log_model=True)

lr_monitor = LearningRateMonitor(logging_interval='epoch')


trainer = pl.Trainer(
                    # overfit_batches=1,
                    deterministic=True, # to ensure reproducibility 
                    devices=[0], 
                    callbacks=[lr_monitor, 
                    # checkpoint_callback, 
                    # early_stop_callback
                    ], 
                    max_epochs=args.epochs, 
                    accelerator="gpu", 
                    logger=logger,
                    precision="bf16-mixed", 
                    # fast_dev_run=True, 
                    )

trainer.fit(model=model, 
            train_dataloaders=autoencoder_train_dataloader, 
            val_dataloaders=[autoencoder_train_dataloader, validation_dataloader])

wandb.finish()

In [18]:
!nvidia-smi

Sun Jul  2 12:30:18 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 515.105.01   Driver Version: 515.105.01   CUDA Version: 11.7     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A40          On   | 00000000:01:00.0 Off |                    0 |
|  0%   64C    P0   276W / 300W |  38310MiB / 46068MiB |     99%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA A40          On   | 00000000:25:00.0 Off |                    0 |
|  0%   74C    P0   287W / 300W |  12335MiB / 46068MiB |    100%      Default |
|       