# Import Packages

In [None]:
# Set Environment to avoid using too many resources.
# Just in case slurm isn't available
import os
os.environ["OMP_NUM_THREADS"]         = "12"
os.environ["OPENBLAS_NUM_THREADS"]    = "12"
os.environ["MKL_NUM_THREADS"]         = "12"
os.environ["VECLIB_MAXIMUM_THREADS"]  = "12"
os.environ["NUMEXPR_NUM_THREADS"]     = "12"
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
import copy
import random
import numpy as np
import pandas as pd

# IMPORT TORCH
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

# IMPORT SELFEEG 
import selfeeg
import selfeeg.models as zoo
import selfeeg.dataloading as dl
import selfeeg.augmentation as aug
from selfeeg.ssl import fine_tune as train_model

# IMPORT REPOSITORY FUNCTIONS
from AllFnc import split
from AllFnc.models import TransformEEG, TransformeegEncoder
from AllFnc.pretraining import pretrain_model, VICReg, vicreg_loss
from AllFnc.training import loadEEG, load_eeg_pretrain, set_augmenter

import warnings
warnings.filterwarnings(
    "ignore",
    message= "numpy.core.numeric is deprecated",
    category=DeprecationWarning
)

In [None]:
def _reset_seed(seed):
    if seed is not None:
        torch.manual_seed(seed)
        np.random.seed(seed)
        random.seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(seed)
            torch.cuda.manual_seed_all(seed)

# Set Parameters

In [None]:
dataPath       = '/data/datasets/eegpickle/'
pipelineToEval = 'ica'
modelToEval    = 'transformeeg'
augment_list   = ['masking', 'masking'] # masking will be done only once
chans_reduced  = False
downsample     = True
z_score        = True
rem_interp     = True
rem_noise      = False
batchsize      = 1024
overlap        = 0.25
window         = 16
workers        = 0
verbose        = True
lr             = 2.5e-4
weight_decay   = None
device         = "cuda:0"
seed           = 42

In [None]:
# set torch parameters
torch.use_deterministic_algorithms(True, warn_only=False)
torch.backends.cudnn.deterministic = True

# Define the device to use
if device is None:
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
else:
    device = torch.device(device)

In [None]:
if dataPath[-1] != os.sep:
    dataPath += os.sep
if pipelineToEval[-1] != os.sep:
    eegpath = dataPath + pipelineToEval + os.sep
else:
    eegpath = dataPath + pipelineToEval

if rem_interp:
    if chans_reduced:
        Chan = 19
    else:
        Chan = 32
else:
    Chan = 61

freq = 125 if downsample else 250
Samples = int(freq*window)

# Data loading

## Load and split data

In [None]:
load_eeg_args = {
    'downsample':         downsample, 
    'use_only_original':  rem_interp,
    'reduce_to_nineteen': chans_reduced,
    'apply_zscore':       z_score,
    'detect_noise':       rem_noise,
    'winlen':             window,
    'overlap':            overlap
}

# See BIDSAlign info table at
# https://github.com/MedMaxLab/BIDSAlign/blob/main/DATASET_INFO.tsv
glob_input = [
    '2_*.pickle',  # ds004148 - Eyes Open Closed    - 60 Subj
    '5_*.pickle',  # ds003490 - PD vs CTL           - 50 Subj
    '8_*.pickle',  # ds002778 - PD vs CTL           - 31 Subj
    '6_*.pickle',  # tdbrain  - Multiple conditions - 1274 Subj
    '10_*.pickle', # ds004504 - AD vs FTD vs CTL    - 88 Subj
    '19_*.pickle', # ds004584 - CTL vs PD           - 149 Subj
    '21_*.pickle', # CAUEEG   - Multiple neurodeg   - 1379 subjects
    '26_*.pickle', # BrainLat - Multiple neurodeg   - 125 subjects
]
try: 
    EEGlen = pd.read_csv('pretrain_dataset_length.csv', index_col=0)
except Exception:
    EEGlen = dl.get_eeg_partition_number(
        eegpath, freq, window, overlap, 
        file_format             = glob_input,
        load_function           = load_eeg_pretrain,
        optional_load_fun_args  = load_eeg_args,
        includePartial          = True,
        verbose                 = verbose,
        save                    = True,
        save_path               = 'pretrain_dataset_length.csv'
    )

In [None]:
dataset_id_ex  = lambda x: int(x.split(os.sep)[-1].split('_')[0])
subject_id_ex  = lambda x: int(x.split(os.sep)[-1].split('_')[1])

EEGsplit= dl.get_eeg_split_table(
    partition_table      = EEGlen,
    exclude_data_id      = None,
    val_data_id          = {19: None}, #{5:None, 8: None},
    val_ratio            = 0.0,
    test_ratio           = 0.0, 
    split_tolerance      = 0.0001,
    dataset_id_extractor = dataset_id_ex,
    subject_id_extractor = subject_id_ex,
    perseverance         = 10000,
    seed                 = seed
)

## Build dataloader

In [None]:
trainset = dl.EEGDataset(
    EEGlen, EEGsplit, [freq, window, overlap], 'train', 
    supervised             = False, 
    label_on_load          = False,
    load_function          = load_eeg_pretrain,
    optional_load_fun_args = load_eeg_args
)
if verbose:
    print("Loading training set")
trainset.preload_dataset()

valset = dl.EEGDataset(
    EEGlen, EEGsplit, [freq, window, overlap], 'validation',
    supervised             = False, 
    label_on_load          = False,
    load_function          = load_eeg_pretrain,
    optional_load_fun_args = load_eeg_args,
)
if verbose:
    print("Loading validation set")
valset.preload_dataset()

In [None]:
trainsampler = dl.EEGSampler(data_source=trainset, BatchSize=batchsize, Workers=1)
trainloader = DataLoader(dataset=trainset, batch_size=batchsize, sampler=trainsampler)
valloader = DataLoader(dataset=valset, batch_size=batchsize, shuffle=False)

# Model pretraining

## Define Model

In [None]:
lossFnc = vicreg_loss
lossVal = vicreg_loss

preaugmenter = set_augmenter(['phase_swap', 'phase_swap'], fs=freq, winlen=window)

# Set data augmentation
augment_list   = ['masking', 'masking']
if augment_list is None:
    augmenter = None
else:
    augmenter = set_augmenter(augment_list, fs=freq, winlen=window, p=None)


_reset_seed(seed)
mdl_encoder = TransformeegEncoder(Chan, seed=seed)
if chans_reduced:
    mdl_siamese = VICReg(mdl_encoder, [76, 128])
else:
    mdl_siamese = VICReg(mdl_encoder, [128, 128])

In [None]:
mdl_initialized = copy.deepcopy(mdl_siamese).to(device='cpu')
mdl_siamese.to(device = device)
mdl_siamese.train()
if verbose:
    print(' ')
    ParamTab = selfeeg.utils.count_parameters(mdl_siamese, False, True, True)
    print(' ')
    
if verbose:
    print(' ')
    print('used learning rate', lr)

gamma = 0.995
optimizer = torch.optim.Adam(
    mdl_siamese.parameters(),
    betas = (0.75, 0.999),
    lr = lr,
    weight_decay = 0
)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma = gamma)
mask_func = lambda x: 0.5  - 0.3**(x/25 + 1)

# Define selfEEG's EarlyStopper with large patience to act as a model checkpoint
earlystop = selfeeg.ssl.EarlyStopping(
    patience  = 500, 
    min_delta = 1e-04, 
    record_best_weights = True
)

In [None]:
lossVal = selfeeg.losses.vicreg_loss
validation_loss_args = []
#_reset_seed(seed)

loss_summary = pretrain_model(
    model                 = mdl_siamese,
    train_dataloader      = trainloader,
    epochs                = 50,
    optimizer             = optimizer,
    loss_func             = lossFnc,
    preaugmenter          = preaugmenter,
    augmenter             = augmenter,
    lr_scheduler          = scheduler,
    EarlyStopper          = earlystop,
    validation_dataloader = valloader,
    validation_loss_func  = lossVal,
    validation_loss_args  = validation_loss_args,
    verbose               = verbose,
    device                = device,
    return_loss_info      = True,
    mask_tokens           = True,
    both_mask_and_aug     = True,
    mask_percentage       = 0.2,
    token_num             = 498
)

In [None]:
loss_summary

## Save model

In [None]:
earlystop.restore_best_weights(mdl_siamese)

In [None]:
savepath = 'Results/Pretraining/Models/vicreg_32ch_128emb_50lam_75mu_10nu_100ep.pt'
mdl_siamese.to(device='cpu')
mdl_siamese.eval()
torch.save(mdl_siamese.state_dict(), savepath)