In [1]:
# %%
import os
import argparse 

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F


import json as json

import torch.optim as optim

import matplotlib.pyplot as plt
plt.style.use('dark_background')

import models as models

import wandb
# from os import Path

import models 
import datasets
import dataset

import numpy as np
import time as time 
import util.misc as misc
# from util.misc import NativeScalerWithGradNormCount as NativeScaler
from util.callbacks import EarlyStop

from util.engine_train import train_one_epoch, evaluate # evaluate_online


wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33madrian-dendorfer[0m ([33madrian_s_playground[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [2]:
def get_args_parser():
    parser = argparse.ArgumentParser("NN training")

    parser.add_argument('--batch_size', default=64, type=int)
    parser.add_argument('--epochs', default=400, type=int)
    parser.add_argument('--acum_iter', default=1, type=int) 

    parser.add_argument('--model', default='shallow_conv_net', type=str, metavar='MODEL',
                        help='Name of model to train')
    
    # Model parameters
    parser.add_argument('--input_channels', type=int, default=1, metavar='N',
                        help='input channels')
    parser.add_argument('--input_electrodes', type=int, default=61, metavar='N',
                        help='input electrodes')
    parser.add_argument('--time_steps', type=int, default=100, metavar='N',
                        help='input length')
    # parser.add_argument('--length_samples', default=200, 
    #                     help='length of samples') 

    # Optimizer parameters
    parser.add_argument('--optimizer', type=str, default="adam_w", 
                        help='optimizer type') 
    parser.add_argument('--lr', type=float, default=1e-3, metavar='LR',
                        help='learning rate') 

    # Callback parameters
    parser.add_argument('--patience', default=-1, type=float,
                        help='Early stopping whether val is worse than train for specified nb of epochs (default: -1, i.e. no early stopping)')
    parser.add_argument('--max_delta', default=0, type=float,
                        help='Early stopping threshold (val has to be worse than (train+delta)) (default: 0)')


    # Dataset parameters
    parser.add_argument('--data_path', 
                        # default='_.pt',
                        default="/vol/aimspace/users/dena/Documents/mae/data/lemon/data_raw_train.pt",
                        type=str,
                        help='train dataset path')

    parser.add_argument('--labels_path', 
                        # default='_.pt', 
                        default="/vol/aimspace/users/dena/Documents/ad_benchmarking/ad_benchmarking/data/labels_bin_train.pt", #labels_raw_train.pt",
                        type=str,
                        help='train labels path')
    parser.add_argument('--val_data_path', 
                        # default='', 
                        default="/vol/aimspace/users/dena/Documents/mae/data/lemon/data_raw_val.pt",
                        type=str,
                        help='validation dataset path')
    parser.add_argument('--val_labels_path', 
                        # default='_.pt', 
                        default="/vol/aimspace/users/dena/Documents/ad_benchmarking/ad_benchmarking/data/labels_bin_val.pt", # "labels_raw_val.pt"
                        type=str,
                        help='validation labels path')
    parser.add_argument('--number_samples', default=1, type=int, # | str, 
                        help='number of samples on which network should train on. "None" means all samples.')
    
    
    # Wandb parameters
    parser.add_argument('--wandb', action='store_true', default=False)
    parser.add_argument('--wandb_project', default='',
                        help='project where to wandb log')
    parser.add_argument('--wandb_id', default='', type=str,
                        help='id of the current run')
    parser.add_argument('--device', default='cuda',
                        help='device to use for training / testing')
    parser.add_argument('--seed', default=0, type=int)

    # Saving Parameters
    parser.add_argument('--output_dir', default='',
                        help='path where to save, empty for no saving')
    
    # parser.add_argument('--mode', type=str, default="train")

    return parser

In [3]:
class Namespace:
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)

In [44]:
args = Namespace(batch_size=64,
    epochs=100,
    acum_iter=1,
    model='shallow_conv_net',
    input_channels=1,
    input_electrodes=61,
    time_steps=100,
    optimizer='sdg', #'adam_w',
    criterion='bce',    #####
    lr=0.01,
    patience=-1,
    max_delta=0,
    data_path='/vol/aimspace/users/dena/Documents/mae/data/lemon/data_raw_train.pt',
    labels_path='/vol/aimspace/users/dena/Documents/ad_benchmarking/ad_benchmarking/data/labels_bin_train.pt',
    val_data_path='/vol/aimspace/users/dena/Documents/mae/data/lemon/data_raw_val.pt',
    val_labels_path='/vol/aimspace/users/dena/Documents/ad_benchmarking/ad_benchmarking/data/labels_bin_val.pt',
    number_samples=64,
    num_workers=4,
    wandb=False,
    wandb_project='',
    wandb_id='',
    device='cuda',
    seed=0,
    output_dir='')


# Training set size:  1
# Validation set size:  1

In [45]:
device = torch.device(args.device)

# Fix the seed for reproducibility
seed = args.seed 
torch.manual_seed(seed)
np.random.seed(seed)

dataset_train = dataset.EEGDataset(data_path=args.data_path, labels_path=args.labels_path, 
                            train=True, number_samples=args.number_samples, length_samples=args.time_steps,
                            args=args)
dataset_val = dataset.EEGDataset(data_path=args.data_path, labels_path=args.labels_path, 
                            train=True, number_samples=args.number_samples, length_samples=args.time_steps,
                            args=args)

print("Training set size: ", len(dataset_train))
print("Validation set size: ", len(dataset_val))

sampler_val = torch.utils.data.SequentialSampler(dataset_val)
sampler_train = torch.utils.data.RandomSampler(dataset_train) 

data_loader_train = torch.utils.data.DataLoader(
    dataset_train, 
    sampler=sampler_train,
    # shuffle=True,
    batch_size=args.batch_size,
    num_workers=args.num_workers,
    drop_last=False,
)

data_loader_val = torch.utils.data.DataLoader(
    dataset_val, 
    sampler=sampler_val,
    # shuffle=False,
    batch_size=args.batch_size,
    num_workers=args.num_workers,
    drop_last=False,
)

model = models.__dict__[args.model](
    n_channels=args.input_electrodes, 
    input_time_length=args.time_steps,
)

Training set size:  64
Validation set size:  64


In [46]:
# for now: 
wandb.init(project=args.wandb_project, config=vars(args))
data_loader_train = data_loader_val



0,1
batch train loss,▇▇▇█▆▄▆▃▅▄▄▄▄▃▄▄▃▃▃▃▄▃▃▃▃▃▃▂▃▃▂▂▃▃▂▁▃▂▃▁
batch val loss,▇▄█▅▇▄▄▄▃▅▃▃▃▃▃▃▃▃▃▃▂▂▂▃▂▂▂▂▂▁▂▂▂▁▂▂▂▂▁▂

0,1
batch train loss,0.34903
batch val loss,0.47987


In [47]:
model.to(device)

# eval_criterion = "bce"
# criterion = nn.BCELoss()  # !!!! 

optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.95))

# Define callbacks
early_stop = EarlyStop(patience=args.patience, max_delta=args.max_delta)

print(f"Start training for {args.epochs} epochs")


for epoch in range(args.epochs):
    start_time = time.time()
    
    mean_loss_epoch_train = train_one_epoch(model, data_loader_train, optimizer, device, epoch, args=args) #loss_scaler, criterion
    print(f"Loss / BCE on {len(dataset_train)} train samples: {mean_loss_epoch_train}")

    mean_loss_epoch_val = evaluate(model, data_loader_val, device, epoch, args=args) 
    print(f"Loss / BCE on {len(dataset_val)} val samples: {mean_loss_epoch_val}")


Start training for 100 epochs


Loss / BCE on 64 train samples: 0.8654236197471619
Loss / BCE on 64 val samples: 0.6264200806617737
Loss / BCE on 64 train samples: 0.6264022588729858
Loss / BCE on 64 val samples: 0.6480107307434082
Loss / BCE on 64 train samples: 0.6480274796485901
Loss / BCE on 64 val samples: 0.6613976955413818
Loss / BCE on 64 train samples: 0.6618583798408508
Loss / BCE on 64 val samples: 0.6617220640182495
Loss / BCE on 64 train samples: 0.661530077457428
Loss / BCE on 64 val samples: 0.6527092456817627
Loss / BCE on 64 train samples: 0.6527494192123413
Loss / BCE on 64 val samples: 0.6409313678741455
Loss / BCE on 64 train samples: 0.6409450173377991
Loss / BCE on 64 val samples: 0.6285372972488403
Loss / BCE on 64 train samples: 0.628294825553894
Loss / BCE on 64 val samples: 0.621527910232544
Loss / BCE on 64 train samples: 0.6233018636703491
Loss / BCE on 64 val samples: 0.6199936866760254
Loss / BCE on 64 train samples: 0.6207655668258667
Loss / BCE on 64 val samples: 0.6205404996871948
Los