In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch

from dataloader import get_dataloaders, get_datasets
from utils import seed_everything
from train import load_data
from config import Config
from train import train

class CFG(Config):
    # resnet18/34/50, efficientnet_b0/b1, convnext_atto/femto/tiny, repvit_m0_9/1_0/1_1/1_5/2_3, efficientvit_b0/b1 tiny_vit_5/11/21m_224
    base_model = 'efficientnet_b0'
    batch_size = 16
    epochs = 10
    base_lr = 1e-3
    optimizer = 'AdamW'
    loss = 'KLDivLoss'
    scheduler = 'StepLR'
    scheduler_step_size = 2
    lr_gamma = 0.1
    sgd_momentum = 0.9
    freeze_epochs = 0
    spec_random_trial_num = 1
    eeg_random_trial_num = 1
    data_type = 'eeg'  # 'spec', 'eeg_tf', 'spec+eeg_tf'
    eeg_tf_data = 'eeg_tf_data_globalnorm'
    train_type = 'rater_num_split'
    init_epochs = 5
    wavenet_params = dict(eeg_ch=18, dropout=0.0, hidden_features=64)
    filt_hp = 0.5
    drop_ecg = False

    # Augmentation
    use_mixup = False
    mixup_alpha = 2.0
    coarse_dropout_args = dict(p=0.5, max_holes=8, max_height=128, max_width=128)
    xymasking_args = dict(p=0.5, num_masks_y=(2, 4), mask_y_length=(500, 2000))
    # horizontal_flip_p = 0.5

    if data_type == 'spec':
        in_channels = 1
        spec_trial_selection = 'first'
        eeg_trial_selection = 'all'
    elif data_type == 'eeg_tf' or data_type == 'eeg':
        in_channels = 1
        spec_trial_selection = 'all'
        eeg_trial_selection = 'first'
    elif data_type == 'spec+eeg_tf' or data_type == 'spec+eeg_tf+eeg':
        spec_trial_selection = 'all'
        eeg_trial_selection = 'first'

    use_wandb = False
    one_fold = True

plot_samples = False

# Show training data
if plot_samples:
    seed_everything(CFG.seed)
    df, data = load_data(CFG)
    dataloaders = get_dataloaders(CFG, get_datasets(CFG, data, df_train=df, df_validation=df))
    with torch.no_grad():
        for batch, (X, y) in enumerate(dataloaders['train']):
            plt.figure(figsize=(15, 10))
            for i in range(len(X)):
                # plt.subplot(int(np.ceil(len(X)/6)), 6, i+1)
                plt.figure(figsize=(10, 10))
                if CFG.data_type == 'eeg':
                    raw_data = X[i].cpu().numpy()[...].T
                    for k in range(len(raw_data)):
                        plt.plot(range(raw_data.shape[1]), raw_data[k]-k*(raw_data[0].max()-raw_data[0].min()))
                    plt.legend()
                    plt.yticks([])
                else:
                    img_data = X[i].permute(1, 2, 0).cpu().numpy()[...]
                    # Normalize images for plotting (since there are negative values in tensors)
                    # img_data_norm = np.clip(((img_data - img_data.mean(axis=(0, 1, 2))) / img_data.std(axis=(0, 1, 2)))/4 + 0.5, 0, 1)
                    plt.imshow(img_data, vmin=-3, vmax=3, cmap='RdBu_r')
                t = y[i].cpu().numpy()
                tars = f'[{t[0]:0.2f}'
                for s in t[1:]: tars += f', {s:0.2f}'
                tars += ']'
                plt.title(tars, fontdict={'fontsize': 8})
                if i > 0:
                    break
            if batch >= 0:
                break
    display(df.head())

In [None]:
train(CFG)