In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import sklearn
from sklearn import mixture
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from mpl_toolkits.axes_grid1 import ImageGrid
import numpy as np
import time
import timedelta
import sys
import random
import cv2
import importlib
import argparse

ROOT = '' # set
sys.path.append(ROOT)

import classical
import read_experiments
import latent_space
import saliency
import augmentations
import train_model
import plotters
import models
import dataloader_umc
import dataloader_physionet
import utils

# change the width of the cells
from IPython.display import display, HTML
display(HTML("<style>.container { width:80% !important; }</style>"))

### Define paths

In [None]:
DATA = utils.check_folder(os.path.join(ROOT, 'data'))
EXPERIMENTS = utils.check_folder(os.path.join(ROOT, 'experiments'))

### Setup parameters

In [None]:
parser = argparse.ArgumentParser(description='PyTorch Training')
parser.add_argument('--dataset', default='PhysioNet', type=str)
parser.add_argument('--seed_data', default=3, type=int, help='dataset seed when selecting fraction of training set')
parser.add_argument('--valid', default=False, type=bool, help='test model against validation set (when True) or against test set (when False)')
parser.add_argument('--model', default='ResNet', type=str)
parser.add_argument('--method', default='base', type=str)
parser.add_argument('--depth', default=0, type=int)
parser.add_argument('--n_fraction', default=1.0, type=float, help='fraction of train data to be used')
parser.add_argument('--train_balance', default=True, type=bool, help='whether to balance train data')
parser.add_argument('--num_epochs', default=100, type=int)
parser.add_argument('--num_steps', default=100, type=int)
parser.add_argument('--batch_size', default=128, type=int, help='train batchsize')
parser.add_argument('--op', default='adam', type=str, help='optimizer')
parser.add_argument('--lr', default=0.001, type=float, help='initial learning rate')
parser.add_argument('--lr_max', default=0.0025, type=float, help='maximum allowed learning rate')
parser.add_argument('--use_sched', default=True, type=bool, help='whether to use learning rate scheduler')
parser.add_argument('--weight_decay', default=1e-4, type=float, help='weight decay (L2 penalty)')
parser.add_argument('--grad_clip', default=0.1, type=float, help='gradient clipping to prevent exploding gradients')
parser.add_argument('--seed', default=4)
parser.add_argument('--num_classes', default=2, type=int, help='number of classes')
parser.add_argument('--sample_rate', default=1000, type=int, help='signal sample rate')
parser.add_argument('--num_channels', default=4, type=int, help='signal channel number')
parser.add_argument('--sig_len', default=2500, type=int, help='signal length')
parser.add_argument('--latent_space', default=False, type=bool, help='whether to calculate (and plot) latent space hidden features')
parser.add_argument('--classical_space', default=False, type=bool, help='whether to calculate (and plot) classical features')
parser.add_argument('--EXPERIMENTS', default = EXPERIMENTS, type=str, help='path to experiment results')
parser.add_argument('-f') # dummy argument to prevent an error, since argparse is a module designed to parse the arguments passed from the command line
args = parser.parse_args()

if torch.cuda.is_available():
    device = torch.device('cuda')
    print('Device set to: cuda')
else: 
    device = torch.device('cpu')
    print('Device set to: cpu')

## PhysioNet dataset

### Load dataset

In [None]:
%%time
args.dataset = 'PhysioNet'
args.sig_len = 2500
DATASET = os.path.join(DATA, 'physionet', f'zbytes_physionet_dataset.dat')
dataset = utils.file2dict(DATASET)
print(f'Dataset {DATASET} has been loaded')

### Run the baseline and all augmentation methods across all training-data fractions

In [None]:
importlib.reload(read_experiments)
importlib.reload(latent_space)
importlib.reload(saliency)
importlib.reload(augmentations)
importlib.reload(train_model)
importlib.reload(models)
importlib.reload(plotters)
importlib.reload(dataloader_physionet)
importlib.reload(utils)

models_arr = [
#     'ResNet', 
#     'ResNetPlus',
#     'XResNet1d18', 
#     'XResNet1d18Plus', 
#     'ResCNN', 
#     'InceptionTime', 
#     'InceptionTimePlus', 
#     'XceptionTime', 
#     'XceptionTimePlus', 
#     'gMLP', 
#     'XCM', 
#     'XCMPlus', 
#     'FCN', 
    'resnet9',
    'Potes',
#     'FCNPlus',
#     'RNN', 
#     'LSTM', 
#     'GRU', 
#     'mWDN', 
#     'OmniScaleCNN'
    ]

aug_methods = [
      'gaussiannoise(25,40)',
      'timemask(0.2)',
      'timewarp(0.05,4)',
      'magnitudewarp(0.2,4)',
      'latentmixup', # manifold mixup
      'mixup(same)',
      'respiratoryscale(12,20)'
      'durratiomixup', #PCGmix
      'durmixmagwarp(0.2,4)', #PCGmix+
        ]

# Fix parameters
args.valid = False
args.num_epochs = 50
args.batch_size = 64
args.lr_max = 0.01

n_fractions = [0.015, 0.052, 0.1, 0.2, 0.3, 0.4, 0.6, 0.8, 1.0]
n_fractions = [0.015]

for model_ar in models_arr:
    args.model = model_ar
    for cm in aug_methods:   
        #args.num_epochs = epoch
        for n_fraction in n_fractions:
            args.n_fraction = n_fraction
            if n_fraction == 0.015:
                seed_datas = np.arange(1001001, 1001334, 1)
            if n_fraction == 0.052:
                seed_datas = np.arange(1005001, 1005101, 1)
            if n_fraction == 0.1:
                seed_datas = np.arange(1010001, 1010051, 1)
            if n_fraction == 0.2:
                seed_datas = np.arange(1020001, 1020026, 1)
            if n_fraction == 0.3:
                seed_datas = np.arange(1030001, 1030017, 1)
            if n_fraction == 0.4:
                seed_datas = np.arange(1040001, 1040013, 1)
            if n_fraction == 0.6:
                seed_datas = np.arange(1060001, 1060009, 1)
            if n_fraction == 0.8:
                seed_datas = np.arange(1080001, 1080007, 1)
            if n_fraction == 1.0:
                seed_datas = [1100001]
            if n_fraction == 1.0:
                seeds_test = [1, 2, 3, 4, 5]
            else:
                seeds_test = [1]
            for seed_data in seed_datas:
                args.seed_data = seed_data
                # First run no-agumentation baseline
                for seed in seeds_test:
                    args.seed = seed
                    args.method = 'base'
                    if utils.experiment_already_done(args):
                        print(f'Already done: {args.seed_data=}, {args.seed=}, {args.valid=}, {args.method=}')
                        continue
                    train_model.train_model(args, dataset, device)
                # Run method
                for seed in seeds_test:
                    args.seed = seed
                    args.method = cm
                    args = read_experiments.hyperparameters_robust(args)
                    if utils.experiment_already_done(args):
                        print(f'Already done: {args.seed_data=}, {args.seed=}, {args.valid=}, {args.method=}')
                        continue
                    train_model.train_model(args, dataset, device)

### Different train balance "true data sampling seeds" for n_fraction=1.0

In [None]:
importlib.reload(read_experiments)
importlib.reload(latent_space)
importlib.reload(saliency)
importlib.reload(augmentations)
importlib.reload(train_model)
importlib.reload(models)
importlib.reload(plotters)
importlib.reload(dataloader_physionet)
importlib.reload(utils)

args.model='resnet9'

args.valid = False
args.num_epochs = 50
args.batch_size = 64
args.lr_max = 0.01

seeds_test = [1]

args.n_fraction = 1.0

seeds_true = [19, 20, 21, 22, 23]
for m in ['resnet9']:
    args.model = m
    seed_datas = [1100001]
    seeds_test = [1, 2, 3, 4, 5]
    for seed_true in seeds_true:
        args.true_seed = seed_true
        for seed_data in seed_datas:
            args.seed_data = seed_data
            for seed in seeds_test:
                args.seed = seed
                args.method = f'base-trueseed={args.true_seed}'
                if utils.experiment_already_done(args):
                    print(f'Already done: {args.seed_data=}, {args.seed=}, {args.valid=}, {args.method=}')
                    continue
                train_model.train_model(args, dataset, device)
                
args.true_seed = 18

### Out-of-manifold intrusion

In [None]:
importlib.reload(read_experiments)
importlib.reload(latent_space)
importlib.reload(saliency)
importlib.reload(augmentations)
importlib.reload(train_model)
importlib.reload(models)
importlib.reload(plotters)
importlib.reload(dataloader_physionet)
importlib.reload(utils)

# First train latent space model
args.model = 'ResCNN'
args.num_epochs = 10
args.batch_size = 32
args.lr_max = 0.00089
args.n_fraction = 1.0
args.seed_data = 3
args.seed = 1
args.method = 'base'

if not utils.experiment_already_done(args):
    train_model.train_model(args, dataset, device)

# Then, run out-of-manifold intrusion 
# set True in latent_space.py

args.model='resnet9'

aug_methods = [
      #'(closestknn=1)durmixmagwarp(0.2,4)',
      #'(closestknn=2)durmixmagwarp(0.2,4)',
       '(closestknn=4)durmixmagwarp(0.2,4)',
      #'(closestknn=6)durmixmagwarp(0.2,4)',
       '(closestknn=8)durmixmagwarp(0.2,4)',
      #'(closestknn=12)durmixmagwarp(0.2,4)',
       '(closestknn=16)durmixmagwarp(0.2,4)',
      #'(closestknn=20)durmixmagwarp(0.2,4)',
       '(closestknn=26)durmixmagwarp(0.2,4)',
      #'(closestknn=32)durmixmagwarp(0.2,4)',
       '(closestknn=64)durmixmagwarp(0.2,4)',
        ]

args.valid = False
args.num_epochs = 50
args.batch_size = 64
args.lr_max = 0.01

n_fractions = [0.015, 0.052, 0.1, 0.2, 0.3, 0.4, 0.6, 0.8, 1.0]
n_fractions = [0.1]
seeds_test = [1]

for n_fraction in n_fractions:
    for cm in aug_methods:   
    #args.num_epochs = epoch
        args.n_fraction = n_fraction
        # select a cutmix probability grid that is limited based on the n_fraction
        if n_fraction == 0.015:
            aug_probas =  [1.0]
            seed_datas = np.arange(1001001, 1001334, 1)
        if n_fraction == 0.052:
            aug_probas =  [1.0]
            seed_datas = np.arange(1005001, 1005101, 1)
            #seed_datas = np.arange(1005001, 1005201, 1)
        if n_fraction == 0.1:
            aug_probas =  [1.0]
            seed_datas = np.arange(1010001, 1010051, 1)
            #seed_datas = np.arange(1010001, 1010101, 1)
        if n_fraction == 0.2:
            aug_probas =  [0.8]
            seed_datas = np.arange(1020001, 1020026, 1)
            #seed_datas = np.arange(1020001, 1020051, 1)
        if n_fraction == 0.3:
            aug_probas =  [0.6]
            seed_datas = np.arange(1030001, 1030017, 1)
            #seed_datas = np.arange(1040001, 1040026, 1)
        if n_fraction == 0.4:
            aug_probas =  [0.6]
            seed_datas = np.arange(1040001, 1040013, 1)
            #seed_datas = np.arange(1040001, 1040026, 1)
        if n_fraction == 0.6:
            aug_probas =  [0.4]
            seed_datas = np.arange(1060001, 1060009, 1)
            #seed_datas = np.arange(1060001, 1060017, 1)
        if n_fraction == 0.8:
            aug_probas =  [0.2]
            seed_datas = np.arange(1080001, 1080007, 1)
            #seed_datas = np.arange(1080001, 1080013, 1)
        if n_fraction == 1.0:
            aug_probas =  [0.2]
            seed_datas = [1100001]
        if n_fraction == 1.0:
            seeds_test = [1, 2, 3, 4, 5]
            #seeds_test = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
        else:
            seeds_test = [1]
        for seed_data in seed_datas:
            args.seed_data = seed_data
            for seed in seeds_test:
                args.seed = seed
                args.method = 'base'
                if utils.experiment_already_done(args):
                    print(f'Already done: {args.seed_data=}, {args.seed=}, {args.valid=}, {args.method=}')
                    continue
                #train_model.train_model(args, dataset, device)
            for cp in aug_probas:
                args.method = f'{cm}+{cp}'
                for seed in seeds_test:
                    args.seed = seed
                    if utils.experiment_already_done(args):
                        print(f'Already done: {args.seed_data=}, {args.seed=}, {args.valid=}, {args.method=}')
                        continue
                    train_model.train_model(args, dataset, device)

### Different mapping functions

In [None]:
importlib.reload(read_experiments)
importlib.reload(latent_space)
importlib.reload(saliency)
importlib.reload(augmentations)
importlib.reload(train_model)
importlib.reload(models)
importlib.reload(plotters)
importlib.reload(dataloader_physionet)
importlib.reload(utils)

args.model='resnet9'

aug_methods = [
                 '(samePCG)durmixmagwarp(0.2,4)',
                 '(sameCVD)durmixmagwarp(0.2,4)',
                 '(sameDataset)durmixmagwarp(0.2,4)',
                 '(mixAll)durmixmagwarp(0.2,4)',
                ]

args.valid = False
args.num_epochs = 50
args.batch_size = 64
args.lr_max = 0.01

n_fractions = [0.015, 0.052, 0.1, 0.2, 0.3, 0.4, 0.6, 0.8, 1.0]

seeds_test = [1]

for n_fraction in n_fractions:
    for cm in aug_methods:   
    #args.num_epochs = epoch
        args.n_fraction = n_fraction
        # select a cutmix probability grid that is limited based on the n_fraction
        if n_fraction == 0.015:
            aug_probas =  [1.0]
            seed_datas = np.arange(1001001, 1001334, 1)
        if n_fraction == 0.052:
            aug_probas =  [1.0]
            seed_datas = np.arange(1005001, 1005101, 1)
            #seed_datas = np.arange(1005001, 1005201, 1)
        if n_fraction == 0.1:
            aug_probas =  [1.0]
            seed_datas = np.arange(1010001, 1010051, 1)
            #seed_datas = np.arange(1010001, 1010101, 1)
        if n_fraction == 0.2:
            aug_probas =  [0.8]
            seed_datas = np.arange(1020001, 1020026, 1)
            #seed_datas = np.arange(1020001, 1020051, 1)
        if n_fraction == 0.3:
            aug_probas =  [0.6]
            seed_datas = np.arange(1030001, 1030017, 1)
            #seed_datas = np.arange(1040001, 1040026, 1)
        if n_fraction == 0.4:
            aug_probas =  [0.6]
            #aug_probas =  [0.2, 0.4, 0.6, 0.8, 1.0]
            seed_datas = np.arange(1040001, 1040013, 1)
            #seed_datas = np.arange(1040001, 1040026, 1)
        if n_fraction == 0.6:
            aug_probas =  [0.4]
            seed_datas = np.arange(1060001, 1060009, 1)
            #seed_datas = np.arange(1060001, 1060017, 1)
        if n_fraction == 0.8:
            aug_probas =  [0.2]
            seed_datas = np.arange(1080001, 1080007, 1)
            #seed_datas = np.arange(1080001, 1080013, 1)
        if n_fraction == 1.0:
            aug_probas =  [0.2]
            seed_datas = [1100001]
        if n_fraction == 1.0:
            seeds_test = [1, 2, 3, 4, 5]
            #seeds_test = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
        else:
            seeds_test = [1]
        for seed_data in seed_datas:
            args.seed_data = seed_data
            for seed in seeds_test:
                args.seed = seed
                args.method = 'base'
                if utils.experiment_already_done(args):
                    print(f'Already done: {args.seed_data=}, {args.seed=}, {args.valid=}, {args.method=}')
                    continue
                train_model.train_model(args, dataset, device)
            for cp in aug_probas:
                args.method = f'{cm}+{cp}'
                for seed in seeds_test:
                    args.seed = seed
                    if utils.experiment_already_done(args):
                        print(f'Already done: {args.seed_data=}, {args.seed=}, {args.valid=}, {args.method=}')
                        continue
                    train_model.train_model(args, dataset, device)

### Saliency utilization optimization

In [None]:
importlib.reload(read_experiments)
importlib.reload(latent_space)
importlib.reload(saliency)
importlib.reload(augmentations)
importlib.reload(train_model)
importlib.reload(models)
importlib.reload(plotters)
importlib.reload(dataloader_physionet)
importlib.reload(utils)

args.model='resnet9'

aug_methods = [
                  '(rand)durratiomixup', # test rand to see if random oving also helps
                  '(rand)durmixmagwarp(0.2,4)', # test rand to see if random oving also helps
                  '(saloptsum)durratiomixup',
                  '(saloptsum)durmixmagwarp(0.2,4)',
                  '(saloptenv)durratiomixup',
                  '(saloptenv)durmixmagwarp(0.2,4)',
                  '(saloptenv-1)durratiomixup',
                  '(saloptenv-2)durmixmagwarp(0.2,4)',
                ]

args.valid = False
args.num_epochs = 50
args.batch_size = 64
args.lr_max = 0.01

n_fractions = [0.015, 0.052, 0.1, 0.2, 0.4, 0.6, 0.8, 1.0]
n_fractions = [0.1]

seeds_test = [1]
for n_fraction in n_fractions:
    for cm in aug_methods:   
    #args.num_epochs = epoch
        args.n_fraction = n_fraction
        if n_fraction == 0.015:
            aug_probas =  [1.0]
            seed_datas = np.arange(1001001, 1001334, 1)
        if n_fraction == 0.052:
            aug_probas =  [1.0]
            seed_datas = np.arange(1005001, 1005101, 1)
            #seed_datas = np.arange(1005001, 1005201, 1)
        if n_fraction == 0.1:
            aug_probas =  [1.0]
            seed_datas = np.arange(1010001, 1010051, 1)
            #seed_datas = np.arange(1010001, 1010101, 1)
        if n_fraction == 0.2:
            aug_probas =  [0.6, 0.7, 0.8, 0.9, 1.0]
            aug_probas =  [0.8]
            seed_datas = np.arange(1020001, 1020026, 1)
            #seed_datas = np.arange(1020001, 1020051, 1)
        if n_fraction == 0.3:
            aug_probas =  [0.2, 0.4, 0.6, 0.8, 1.0]
            aug_probas =  [0.6]
            seed_datas = np.arange(1030001, 1030017, 1)
            #seed_datas = np.arange(1040001, 1040026, 1)
        if n_fraction == 0.4:
            aug_probas =  [0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
            aug_probas =  [0.6]
            #aug_probas =  [0.2, 0.4, 0.6, 0.8, 1.0]
            seed_datas = np.arange(1040001, 1040013, 1)
            #seed_datas = np.arange(1040001, 1040026, 1)
        if n_fraction == 0.6:
            aug_probas =  [0.1, 0.2, 0.3, 0.4, 0.5]
            aug_probas =  [0.2, 0.4, 0.6, 0.8, 1.0]
            aug_probas =  [0.4]
            seed_datas = np.arange(1060001, 1060009, 1)
            #seed_datas = np.arange(1060001, 1060017, 1)
        if n_fraction == 0.8:
            aug_probas =  [0.1, 0.2, 0.3, 0.4, 0.5]
            aug_probas =  [0.2]
            seed_datas = np.arange(1080001, 1080007, 1)
            #seed_datas = np.arange(1080001, 1080013, 1)
        if n_fraction == 1.0:
            aug_probas =  [0.1, 0.2, 0.3, 0.4, 0.5]
            aug_probas =  [0.2]
            seed_datas = [1100001]
        if n_fraction == 1.0:
            seeds_test = [1, 2, 3, 4, 5]
        else:
            seeds_test = [1]
        for seed_data in seed_datas:
            args.seed_data = seed_data
            for seed in seeds_test:
                args.seed = seed
                args.method = 'base'
                if utils.experiment_already_done(args):
                    print(f'Already done: {args.seed_data=}, {args.seed=}, {args.valid=}, {args.method=}')
                    continue
                train_model.train_model(args, dataset, device)
            for cp in aug_probas:
                args.method = f'{cm}+{cp}'
                for seed in seeds_test:
                    args.seed = seed
                    if utils.experiment_already_done(args):
                        print(f'Already done: {args.seed_data=}, {args.seed=}, {args.valid=}, {args.method=}')
                        continue
                    train_model.train_model(args, dataset, device)

### Different $\alpha$ values in beta function

In [None]:
importlib.reload(read_experiments)
importlib.reload(latent_space)
importlib.reload(saliency)
importlib.reload(augmentations)
importlib.reload(train_model)
importlib.reload(models)
importlib.reload(plotters)
importlib.reload(dataloader_physionet)
importlib.reload(utils)

args.model = 'resnet9'

cutmix_methods = [
         '(alpha=0.05)durmixmagwarp(0.2,4)+1.0',
         '(alpha=0.25)durmixmagwarp(0.2,4)+1.0',
         '(alpha=0.5)durmixmagwarp(0.2,4)+1.0',
         '(alpha=0.75)durmixmagwarp(0.2,4)+1.0',
         '(alpha=1.25)durmixmagwarp(0.2,4)+1.0',
         '(alpha=1.5)durmixmagwarp(0.2,4)+1.0',
         '(alpha=1.75)durmixmagwarp(0.2,4)+1.0',
         '(alpha=2)durmixmagwarp(0.2,4)+1.0',
         '(alpha=3)durmixmagwarp(0.2,4)+1.0',
         '(alpha=4)durmixmagwarp(0.2,4)+1.0',
         '(alpha=5)durmixmagwarp(0.2,4)+1.0',
         '(alpha=6)durmixmagwarp(0.2,4)+1.0',
         '(alpha=7)durmixmagwarp(0.2,4)+1.0',
         '(alpha=8)durmixmagwarp(0.2,4)+1.0',
         '(alpha=9)durmixmagwarp(0.2,4)+1.0',
         '(alpha=10)durmixmagwarp(0.2,4)+1.0',       
        ]

args.valid = False
args.num_epochs = 50
args.batch_size = 64
args.lr_max = 0.01

n_fractions = [0.015, 0.052, 0.1, 0.2, 0.3, 0.4, 0.6, 0.8, 1.0]
n_fractions = [0.1]

for cm in cutmix_methods:   
    #args.num_epochs = epoch
    for n_fraction in n_fractions:
        args.n_fraction = n_fraction
        if n_fraction == 0.015:
            seed_datas = np.arange(1001001, 1001334, 1)
        if n_fraction == 0.052:
            seed_datas = np.arange(1005001, 1005101, 1)
        if n_fraction == 0.1:
            seed_datas = np.arange(1010001, 1010051, 1)
        if n_fraction == 0.2:
            seed_datas = np.arange(1020001, 1020026, 1)
        if n_fraction == 0.3:
            seed_datas = np.arange(1030001, 1030017, 1)
        if n_fraction == 0.4:
            seed_datas = np.arange(1040001, 1040013, 1)
        if n_fraction == 0.6:
            seed_datas = np.arange(1060001, 1060009, 1)
        if n_fraction == 0.8:
            seed_datas = np.arange(1080001, 1080007, 1)
        if n_fraction == 1.0:
            seed_datas = [1100001]
        if n_fraction == 1.0:
            seeds_test = [1, 2, 3, 4, 5]
        else:
            seeds_test = [1]
        for seed_data in seed_datas:
            args.seed_data = seed_data
            for seed in seeds_test:
                args.seed = seed
                args.method = 'base'
                if utils.experiment_already_done(args):
                    print(f'Already done: {args.seed_data=}, {args.seed=}, {args.valid=}, {args.method=}')
                    continue
                train_model.train_model(args, dataset, device)
            for seed in seeds_test:
                args.seed = seed
                args.method = cm
                if utils.experiment_already_done(args):
                    print(f'Already done: {args.seed_data=}, {args.seed=}, {args.valid=}, {args.method=}')
                    continue
                train_model.train_model(args, dataset, device)

## UMC dataset (proprietary chronic heart failure dataset)

### Load dataset

In [None]:
%%time
args.dataset = 'UMC'
args.sig_len = 2000
DATASET = os.path.join(DATA, 'UMC', f'zbytes_UMC_dataset.dat')
dataset = utils.file2dict(DATASET)
print(f'Dataset {DATASET} has been loaded')

### Run base

In [None]:
importlib.reload(latent_space)
importlib.reload(saliency)
importlib.reload(augmentations)
importlib.reload(train_model)
importlib.reload(models)
importlib.reload(plotters)
importlib.reload(dataloader_umc)
importlib.reload(utils)

args.train_balance = True
args.seed_data = 0
args.model='resnet9'
args.latent_space=False
args.aug = False

args.n_fraction = 1.0

lr_arr = [0.01, 0.001]
epochs_arr = [50, 25, 10]
bs_arr = [128, 64, 32]

seed_datas = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
args.seed = 1

args.method = 'base'
for batch_size in bs_arr:
    args.batch_size = batch_size
    for num_epochs in epochs_arr:
        args.num_epochs = num_epochs
        for lr_max in lr_arr:
            args.lr_max = lr_max  
            for seed_data in seed_datas:
                args.seed_data = seed_data
                if utils.experiment_already_done(args):
                    continue
                train_model.train_model(args, dataset, device)

### Train with augmentations and extract classical features

In [None]:
importlib.reload(latent_space)
importlib.reload(saliency)
importlib.reload(augmentations)
importlib.reload(train_model)
importlib.reload(models)
importlib.reload(plotters)
importlib.reload(dataloader_umc)
importlib.reload(utils)

args.latent_space=False
args.classical_space=True
args.valid = False
args.num_epochs = 20
args.num_epochs = 50
args.batch_size = 512
args.lr_max = 0.01

args.seed_data = 0
args.model='resnet9'
args.aug = False

args.n_fraction = 1.0

seed_datas = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
args.seed = 1

aug_methods = [
      '(samePCG)durratiomixup+1.0',
      '(samePCG)durmixmagwarp(0.2,4)+1.0',
        ]

for am in aug_methods:
    for seed_data in seed_datas:
        args.seed_data = seed_data
        args.method = 'base'
        if utils.experiment_already_done(args):
            print(f'Already done: {args.seed_data=}, {args.seed=}, {args.valid=}, {args.method=}')
            continue
        train_model.train_model(args, dataset, device)
        args.method = am
        if utils.experiment_already_done(args):
            print(f'Already done: {args.seed_data=}, {args.seed=}, {args.valid=}, {args.method=}')
            continue
        train_model.train_model(args, dataset, device)