In [1]:
# %%
import os
import argparse 

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F


import json as json

import torch.optim as optim

import matplotlib.pyplot as plt
plt.style.use('dark_background')

import models as models

import wandb
# from os import Path

import models 
import datasets
import dataset

import numpy as np
import time as time 
import util.misc as misc
# from util.misc import NativeScalerWithGradNormCount as NativeScaler
from util.callbacks import EarlyStop

from util.engine_train import train_one_epoch, evaluate # evaluate_online


wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33madrian-dendorfer[0m ([33madrian_s_playground[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [2]:
def get_args_parser():
    parser = argparse.ArgumentParser("NN training")

    parser.add_argument('--batch_size', default=64, type=int)
    parser.add_argument('--epochs', default=400, type=int)
    parser.add_argument('--acum_iter', default=1, type=int) 

    parser.add_argument('--model', default='shallow_conv_net', type=str, metavar='MODEL',
                        help='Name of model to train')
    
    # Model parameters
    parser.add_argument('--input_channels', type=int, default=1, metavar='N',
                        help='input channels')
    parser.add_argument('--input_electrodes', type=int, default=61, metavar='N',
                        help='input electrodes')
    parser.add_argument('--time_steps', type=int, default=100, metavar='N',
                        help='input length')
    # parser.add_argument('--length_samples', default=200, 
    #                     help='length of samples') 

    # Optimizer parameters
    parser.add_argument('--optimizer', type=str, default="adam_w", 
                        help='optimizer type') 
    parser.add_argument('--lr', type=float, default=1e-3, metavar='LR',
                        help='learning rate') 

    # Callback parameters
    parser.add_argument('--patience', default=-1, type=float,
                        help='Early stopping whether val is worse than train for specified nb of epochs (default: -1, i.e. no early stopping)')
    parser.add_argument('--max_delta', default=0, type=float,
                        help='Early stopping threshold (val has to be worse than (train+delta)) (default: 0)')


    # Dataset parameters
    parser.add_argument('--data_path', 
                        # default='_.pt',
                        default="/vol/aimspace/users/dena/Documents/mae/data/lemon/data_raw_train.pt",
                        type=str,
                        help='train dataset path')

    parser.add_argument('--labels_path', 
                        # default='_.pt', 
                        default="/vol/aimspace/users/dena/Documents/ad_benchmarking/ad_benchmarking/data/labels_bin_train.pt", #labels_raw_train.pt",
                        type=str,
                        help='train labels path')
    parser.add_argument('--val_data_path', 
                        # default='', 
                        default="/vol/aimspace/users/dena/Documents/mae/data/lemon/data_raw_val.pt",
                        type=str,
                        help='validation dataset path')
    parser.add_argument('--val_labels_path', 
                        # default='_.pt', 
                        default="/vol/aimspace/users/dena/Documents/ad_benchmarking/ad_benchmarking/data/labels_bin_val.pt", # "labels_raw_val.pt"
                        type=str,
                        help='validation labels path')
    parser.add_argument('--number_samples', default=1, type=int, # | str, 
                        help='number of samples on which network should train on. "None" means all samples.')
    
    
    # Wandb parameters
    parser.add_argument('--wandb', action='store_true', default=False)
    parser.add_argument('--wandb_project', default='',
                        help='project where to wandb log')
    parser.add_argument('--wandb_id', default='', type=str,
                        help='id of the current run')
    parser.add_argument('--device', default='cuda',
                        help='device to use for training / testing')
    parser.add_argument('--seed', default=0, type=int)

    # Saving Parameters
    parser.add_argument('--output_dir', default='',
                        help='path where to save, empty for no saving')
    
    # parser.add_argument('--mode', type=str, default="train")

    return parser

In [3]:
class Namespace:
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)

In [4]:
args = Namespace(batch_size=64,
    epochs=200,
    acum_iter=1,
    model='first_shallow_conv_net_regression', #shallow_conv_net',  # deep_conv_net, simple_classifier
    input_channels=1,
    input_electrodes=61,
    time_steps=100,
    optimizer='adamw', #'adam_w',
    criterion='mse',   
    lr=0.012,
    patience=100,
    sufficient_accuracy=0.001, #-np.inf, 
    max_delta=0,
    data_path='/vol/aimspace/users/dena/Documents/mae/data/lemon/data_raw_train.pt',
    # Classification
    # labels_path='/vol/aimspace/users/dena/Documents/ad_benchmarking/ad_benchmarking/data/labels_bin_train.pt',
    #Regression
    labels_path='/u/home/dena/Documents/mae/data/lemon/labels_raw_train.pt',
    val_data_path='/vol/aimspace/users/dena/Documents/mae/data/lemon/data_raw_val.pt',
    val_labels_path='/vol/aimspace/users/dena/Documents/ad_benchmarking/ad_benchmarking/data/labels_bin_val.pt',
    number_samples=1, #16, #64,
    num_workers=4,
    wandb=False,
    wandb_project='',
    wandb_id='',
    device='cpu', #cuda',
    seed=0,
    output_dir='')


# Training set size:  1
# Validation set size:  1

In [5]:
X = torch.load(args.labels_path, map_location=torch.device('cpu')) # load to ram
X

tensor([[62.5000],
        [22.5000],
        [22.5000],
        [32.5000],
        [27.5000],
        [22.5000],
        [62.5000],
        [22.5000],
        [27.5000],
        [67.5000],
        [22.5000],
        [22.5000],
        [27.5000],
        [72.5000],
        [22.5000],
        [72.5000],
        [67.5000],
        [22.5000],
        [27.5000],
        [22.5000],
        [67.5000],
        [22.5000],
        [27.5000],
        [27.5000],
        [22.5000],
        [22.5000],
        [27.5000],
        [22.5000],
        [27.5000],
        [62.5000],
        [22.5000],
        [22.5000],
        [72.5000],
        [22.5000],
        [32.5000],
        [22.5000],
        [62.5000],
        [22.5000],
        [22.5000],
        [27.5000],
        [22.5000],
        [67.5000],
        [27.5000],
        [62.5000],
        [22.5000],
        [27.5000],
        [22.5000],
        [72.5000],
        [27.5000],
        [67.5000],
        [77.5000],
        [27.5000],
        [22.

In [6]:
# print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
# print("{}".format(args).replace(', ', ',\n'))

device = torch.device(args.device)

# Fix the seed for reproducibility
seed = args.seed 
torch.manual_seed(seed)
np.random.seed(seed)

dataset_train = dataset.EEGDataset(data_path=args.data_path, labels_path=args.labels_path, 
                            train=True, number_samples=args.number_samples, length_samples=args.time_steps,
                            args=args)
dataset_val = dataset.EEGDataset(data_path=args.data_path, labels_path=args.labels_path, 
                            train=True, number_samples=args.number_samples, length_samples=args.time_steps,
                            args=args)

print("Training set size: ", len(dataset_train))
print("Validation set size: ", len(dataset_val))

sampler_val = torch.utils.data.SequentialSampler(dataset_val)
sampler_train = torch.utils.data.RandomSampler(dataset_train) 

# # wandb logging
# if args.wandb == True:
#     config = vars(args)
#     if args.wandb_id:
#         wandb.init(project=args.wandb_project, id=args.wandb_id, config=config)
#     else:
#         wandb.init(project=args.wandb_project, config=config)
wandb.init(project=args.wandb_project, config=vars(args))

data_loader_train = torch.utils.data.DataLoader(
    dataset_train, 
    sampler=sampler_train,
    # shuffle=True,
    batch_size=args.batch_size,
    num_workers=args.num_workers,
    # pin_memory=args.pin_mem,
    drop_last=False,
)

data_loader_val = torch.utils.data.DataLoader(
    dataset_val, 
    sampler=sampler_val,
    # shuffle=False,
    batch_size=args.batch_size,
    num_workers=args.num_workers,
    # pin_memory=args.pin_mem,
    drop_last=False,
)

model = models.__dict__[args.model](
    n_channels=args.input_electrodes, 
    input_time_length=args.time_steps, 
)

model.to(device)

# eval_criterion = "bce"
if args.criterion == "bce": 
    criterion = torch.nn.BCELoss() # For classification

elif args.criterion == "mae": 
    criterion = torch.nn.L1Loss() # For regression 

elif args.criterion == "mse": 
    criterion = torch.nn.MSELoss() # For regression 



if args.optimizer == "sgd":
    optimizer = optim.SGD(model.parameters(),
                            lr=args.lr, momentum=0.9)
elif args.optimizer == "adam":
    optimizer = optim.Adam(model.parameters(),
                            lr=args.lr)
elif args.optimizer == "adamw": 
    optimizer = optim.AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.95))

else: 
    print("Attention: No optimier chosen.")

Training set size:  1
Validation set size:  1


In [7]:
for i in enumerate(data_loader_val): 
    break



In [8]:
i[1][1]

tensor([[62.5000]])

In [9]:
# for now: 
data_loader_train = data_loader_val

In [10]:
# CLASSIFICATION

# # Define callbacks
# # early_stop = EarlyStop(patience=args.patience, max_delta=args.max_delta)

# print(f"Start training for {args.epochs} epochs")

# min_val_metric = np.inf
# counter = 0 

# for epoch in range(args.epochs): 
    
#     mean_loss_epoch_train_bce, mean_loss_epoch_train_L1 = train_one_epoch(model, data_loader_train, optimizer, criterion, device, epoch, args=args) #loss_scaler, criterion
#     print(f"Loss / BCE on {len(dataset_train)} train samples: {mean_loss_epoch_train_bce}")

#     # mean_loss_epoch_val_bce, mean_loss_epoch_val_L1 = evaluate(model, data_loader_val, criterion, device, epoch, args=args) 
#     target, output, mean_loss_epoch_val_bce, mean_loss_epoch_val_L1 = evaluate(model, data_loader_val, criterion, device, epoch, args=args) 
#     print(target, output) 
#     print(f"Loss / BCE on {len(dataset_val)} val samples BCE: {mean_loss_epoch_val_bce}, val samples MAE: {mean_loss_epoch_val_L1}")
#     wandb.log({"mean train BCE loss": mean_loss_epoch_train_bce,
#                "mean train MAE loss": mean_loss_epoch_train_L1, 
#                "mean val BCE loss": mean_loss_epoch_val_bce, 
#                "mean val MAE loss": mean_loss_epoch_val_L1, 
#                "epoch": epoch})
    
#     # Early Stopping
#     print(f"Sufficient accuracy: {args.sufficient_accuracy}.")
#     print(f"patience: {args.patience > -1}.")
#     print(f"stuff: {mean_loss_epoch_train_L1 < args.sufficient_accuracy}.")
#     if args.patience > -1: 
#         if mean_loss_epoch_train_L1 < args.sufficient_accuracy: 
#             break
#         elif mean_loss_epoch_train_L1 < min_val_metric: 
#             min_val_metric = mean_loss_epoch_train_L1
#             counter == 0
#         elif mean_loss_epoch_train_L1 > min_val_metric: 
#             counter += 1
#             if counter > args.patience:
#                 print(f"stopped early at epoch {epoch}.")
#                 break 



In [11]:
# REGRESSION 

# Define callbacks
# early_stop = EarlyStop(patience=args.patience, max_delta=args.max_delta)

print(f"Start training for {args.epochs} epochs")

min_val_metric = np.inf
counter = 0 

for epoch in range(args.epochs): 
    
    mean_loss_epoch_train_mae = train_one_epoch(model, data_loader_train, optimizer, criterion, device, epoch, args=args) #loss_scaler, criterion
    print(f"Loss / MAE on {len(dataset_train)} train samples: {mean_loss_epoch_train_mae}")

    # mean_loss_epoch_val_bce, mean_loss_epoch_val_L1 = evaluate(model, data_loader_val, criterion, device, epoch, args=args) 
    target, output, mean_loss_epoch_val_mae = evaluate(model, data_loader_val, criterion, device, epoch, args=args) 
    print(target, output) 
    print(f"Loss / MAE on {len(dataset_val)} val samples: {mean_loss_epoch_val_mae}")
    wandb.log({"mean train BCE loss": mean_loss_epoch_train_mae,
               "mean val BCE loss": mean_loss_epoch_val_mae, 
               "epoch": epoch})
    
    # Early Stopping
    print(f"Sufficient accuracy: {args.sufficient_accuracy}.")
    print(f"patience: {args.patience > -1}.")
    if args.patience > -1: 
        if mean_loss_epoch_val_mae < args.sufficient_accuracy: 
            break
        elif mean_loss_epoch_val_mae < min_val_metric: 
            min_val_metric = mean_loss_epoch_val_mae
            counter == 0
        elif mean_loss_epoch_val_mae > min_val_metric: 
            counter += 1
            if counter > args.patience:
                print(f"stopped early at epoch {epoch}.")
                break 



Start training for 200 epochs
Loss / MAE on 1 train samples: 61.94758605441176


  return F.conv2d(input, weight, bias, self.stride,


Target: tensor([[62.5000]]), output: tensor([[0.9971]])
tensor([[62.5000]]) tensor([[0.9971]])
Loss / MAE on 1 val samples: 61.50289190672765
Sufficient accuracy: 0.001.
patience: True.


Loss / MAE on 1 train samples: 61.96462298725883
Target: tensor([[62.5000]]), output: tensor([[0.9998]])
tensor([[62.5000]]) tensor([[0.9998]])
Loss / MAE on 1 val samples: 61.50017069971229
Sufficient accuracy: 0.001.
patience: True.
Loss / MAE on 1 train samples: 61.92935980186417
Target: tensor([[62.5000]]), output: tensor([[1.0000]])
tensor([[62.5000]]) tensor([[1.0000]])
Loss / MAE on 1 val samples: 61.500011909297626
Sufficient accuracy: 0.001.
patience: True.
Loss / MAE on 1 train samples: 61.88625805093466
Target: tensor([[62.5000]]), output: tensor([[1.0000]])
tensor([[62.5000]]) tensor([[1.0000]])
Loss / MAE on 1 val samples: 61.5
Sufficient accuracy: 0.001.
patience: True.
Loss / MAE on 1 train samples: 61.816974844343925
Target: tensor([[62.5000]]), output: tensor([[1.0000]])
tensor([[62.5000]]) tensor([[1.0000]])
Loss / MAE on 1 val samples: 61.5
Sufficient accuracy: 0.001.
patience: True.
Loss / MAE on 1 train samples: 61.85598666206449
Target: tensor([[62.5000]]), output

Notes: 
- SDG works increadibly bad
- first_shallow_conv_net and deep_conv_net have currently hardcoded eeg channels (n_channels = 61) and input time lengths (input_time_length == 100)
- Created ShallowConvNet_Regression for regression, that does not have a final activation function (the one in the paper has a softmax) 

In [None]:
block1_conv = 534-25+1 # same
block1_conv 

510

In [None]:
block1_conv = 510
block1_max = block1_conv/3
block1_max

170.0

In [None]:
block1_max = 171
block2_conv = (block1_max-10)+1 # valid
print(block2_conv)
block2_max = block2_conv/3 # valid
block2_max

162


54.0

In [None]:
block3_conv = block2_max-10+1
print(block3_conv)
block3_max = block3_conv/3
print(block3_max)

45.0
15.0
