In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import sklearn
from sklearn import mixture
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from mpl_toolkits.axes_grid1 import ImageGrid
import numpy as np
import time
import timedelta
import sys
import random
import cv2
import importlib
import argparse

ROOT = '' # set
sys.path.append(EOL)

import classical
import read_experiments
import latent_space
import saliency
import augmentations2d
import train_model
import plotters
import models2d
import dataloader_umc2d
import dataloader_physionet2d
import utils

# change the width of the cells
from IPython.display import display, HTML
display(HTML("<style>.container { width:80% !important; }</style>"))

### Define paths

In [None]:
DATA = utils.check_folder(os.path.join(ROOT, 'data'))
EXPERIMENTS = utils.check_folder(os.path.join(ROOT, 'experiments'))

### Setup parameters

In [None]:
parser = argparse.ArgumentParser(description='PyTorch Training')
parser.add_argument('--dataset', default='PhysioNet(spec)', type=str)
parser.add_argument('--seed_data', default=3, type=int, help='dataset seed when selecting fraction of training set')
parser.add_argument('--valid', default=False, type=bool, help='test model against validation set (when True) or against test set (when False)')
parser.add_argument('--model', default='ResNet', type=str)
parser.add_argument('--method', default='base', type=str)
parser.add_argument('--depth', default=0, type=int)
parser.add_argument('--n_fraction', default=1.0, type=float, help='fraction of train data to be used')
parser.add_argument('--train_balance', default=True, type=bool, help='whether to balance train data')
parser.add_argument('--num_epochs', default=100, type=int)
parser.add_argument('--num_steps', default=100, type=int)
parser.add_argument('--batch_size', default=128, type=int, help='train batchsize')
parser.add_argument('--op', default='adam', type=str, help='optimizer')
parser.add_argument('--lr', default=0.001, type=float, help='initial learning rate')
parser.add_argument('--lr_max', default=0.0025, type=float, help='maximum allowed learning rate')
parser.add_argument('--use_sched', default=True, type=bool, help='whether to use learning rate scheduler')
parser.add_argument('--weight_decay', default=1e-4, type=float, help='weight decay (L2 penalty)')
parser.add_argument('--grad_clip', default=0.1, type=float, help='gradient clipping to prevent exploding gradients')
parser.add_argument('--seed', default=4)
parser.add_argument('--num_classes', default=2, type=int, help='number of classes')
parser.add_argument('--sample_rate', default=1000, type=int, help='signal sample rate')
parser.add_argument('--num_channels', default=4, type=int, help='signal channel number')
parser.add_argument('--sig_len', default=2500, type=int, help='signal length')
parser.add_argument('--latent_space', default=False, type=bool, help='whether to calculate (and plot) latent space hidden features')
parser.add_argument('--classical_space', default=False, type=bool, help='whether to calculate (and plot) classical features')
parser.add_argument('--EXPERIMENTS', default = EXPERIMENTS, type=str, help='path to experiment results')
parser.add_argument('-f') # dummy argument to prevent an error, since argparse is a module designed to parse the arguments passed from the command line
args = parser.parse_args()

if torch.cuda.is_available():
    device = torch.device('cuda')
    print('Device set to: cuda')
else: 
    device = torch.device('cpu')
    print('Device set to: cpu')

## PhysioNet dataset

### Load dataset

In [None]:
%%time
args.dataset = 'PhysioNet(spec128)'
DATASET = os.path.join(DATA, 'physionet', f'zbytes_physionet_spectrograms128_dataset_selection.dat')
dataset = utils.file2dict(DATASET)
print(f'Dataset {DATASET} has been loaded')

### Train model

In [None]:
importlib.reload(latent_space)
importlib.reload(saliency)
importlib.reload(augmentations2d)
importlib.reload(train_model)
importlib.reload(models2d)
importlib.reload(plotters)
importlib.reload(dataloader_physionet2d)
importlib.reload(utils)

args.model = 'resnet9'
args.method = 'base'

args.seed_data = 10005
args.valid = False
args.num_epochs = 10
args.batch_size = 64
args.lr_max = 0.01


args.method = 'base'

train_model.train_model(args, dataset, device)

### Run baseline and augmentation methods

In [None]:
importlib.reload(latent_space)
importlib.reload(saliency)
importlib.reload(augmentations2d)
importlib.reload(train_model)
importlib.reload(models2d)
importlib.reload(plotters)
importlib.reload(dataloader_physionet2d)
importlib.reload(utils)

args.model='resnet9'

aug_methods = [
                 'durratiocutmix',
                 'mixup(same)',
                 'durratiomixup', # PCGmix
                 'cutmix',
                 'freqmask(0.1)',
                 'timemask(0.1)',
                 'cutout(0.25,0.25)',
                 'latentmixup',
                 'durmixfreqmask(0.1)',
                 'durmixtimemask(0.1)',
                 'durmixcutout(0.25,0.25)',
                 '(saloptsum)durratiomixup',
                 '(saloptenv)durratiomixup',
                 '(saloptsum-1)durratiomixup',
                 '(saloptenv-1)durratiomixup',
                ]


args.valid = False
args.num_epochs = 50
args.batch_size = 64
args.lr_max = 0.01

n_fractions = [0.015, 0.052, 0.1, 0.2, 0.3, 0.4, 0.6, 0.8, 1.0]
n_fractions = [0.1]

seeds_test = [1]
for cm in aug_methods:   
    #args.num_epochs = epoch
    for n_fraction in n_fractions:
        args.n_fraction = n_fraction
        # select a cutmix probability grid that is limited based on the n_fraction
        if n_fraction == 0.015:
            aug_probas =  [1.0]
            seed_datas = np.arange(1001001, 1001201, 1)
        if n_fraction == 0.052:
            aug_probas =  [1.0]
            seed_datas = np.arange(1005001, 1005061, 1)
        if n_fraction == 0.1:
            aug_probas =  [1.0]
            seed_datas = np.arange(1010001, 1010031, 1)
            seed_datas = np.arange(1010001, 1010002, 1)
        if n_fraction == 0.2:
            aug_probas =  [0.6, 0.8, 1.0]
            aug_probas =  [0.8]
            seed_datas = np.arange(1020001, 1020016, 1)
        if n_fraction == 0.3:
            aug_probas =  [0.2, 0.4, 0.6, 0.8, 1.0]
            seed_datas = np.arange(1030001, 1030011, 1)
        if n_fraction == 0.4:
            aug_probas =  [0.2, 0.4, 0.6, 0.8, 1.0]
            seed_datas = np.arange(1040001, 1040009, 1)
        if n_fraction == 0.6:
            aug_probas =  [0.2, 0.4, 0.6, 0.8, 1.0]
            seed_datas = np.arange(1060001, 1060006, 1)
        if n_fraction == 0.8:
            aug_probas =  [0.2, 0.4, 0.6]
            seed_datas = np.arange(1080001, 1080005, 1)
        if n_fraction == 1.0:
            aug_probas =  [0.2, 0.4, 0.6]
            seed_datas = [1100001]
        if n_fraction == 1.0:
            seeds_test = [1, 2, 3]
        else:
            seeds_test = [1]
        for seed_data in seed_datas:
            args.seed_data = seed_data
            for seed in seeds_test:
                args.seed = seed
                args.method = 'base'
                if utils.experiment_already_done(args):
                    print(f'Already done: {args.seed_data=}, {args.seed=}, {args.valid=}, {args.method=}')
                    continue
                train_model.train_model(args, dataset, device)
            for cp in aug_probas:
                args.method = f'{cm}+{cp}'
                for seed in seeds_test:
                    args.seed = seed
                    if utils.experiment_already_done(args):
                        print(f'Already done: {args.seed_data=}, {args.seed=}, {args.valid=}, {args.method=}')
                        continue
                    train_model.train_model(args, dataset, device)

## UMC dataset (proprietary chronic heart failure dataset)

### Load dataset

In [None]:
%%time
args.dataset = 'UMC(spec64)'
DATASET = os.path.join(DATA, 'UMC', f'zbytes_UMC_dataset_spectrograms64.dat')
dataset = utils.file2dict(DATASET)
print(f'Dataset {DATASET} has been loaded')

### Run base

In [None]:
importlib.reload(latent_space)
importlib.reload(saliency)
importlib.reload(augmentations2d)
importlib.reload(train_model)
importlib.reload(models2d)
importlib.reload(plotters)
importlib.reload(dataloader_physionet2d)
importlib.reload(dataloader_umc2d)
importlib.reload(utils)

args.model='resnet9'

lr_arr = [0.01, 0.005, 0.001, 0.0005, 0.0001, 0.00005, 0.00001]
epochs_arr = [10]
bs_arr = [128, 64, 32]

seed_datas = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
args.seed = 1

args.method = 'base'
for batch_size in bs_arr:
    args.batch_size = batch_size
    for num_epochs in epochs_arr:
        args.num_epochs = num_epochs
        for lr_max in lr_arr:
            args.lr_max = lr_max  
            for seed_data in seed_datas:
                args.seed_data = seed_data
                if utils.experiment_already_done(args):
                    continue
                train_model.train_model(args, dataset, device)

### Train with augmentations

In [None]:
importlib.reload(latent_space)
importlib.reload(saliency)
importlib.reload(augmentations2d)
importlib.reload(train_model)
importlib.reload(models2d)
importlib.reload(plotters)
importlib.reload(dataloader_physionet2d)
importlib.reload(dataloader_umc2d)
importlib.reload(utils)

args.latent_space=False
args.classical_space=False
args.valid = False
args.num_epochs = 50
args.batch_size = 64
args.lr_max = 0.01

args.seed_data = 0
args.model='resnet9'
args.aug = False

args.n_fraction = 1.0


seed_datas = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
args.seed = 1

aug_methods = [
      'durratiomixup+1.0',
      'durmixmagwarp(0.2,4)+1.0',
        ]

for am in aug_methods:
    for seed_data in seed_datas:
        args.seed_data = seed_data
        args.method = 'base'
        if utils.experiment_already_done(args):
            continue
        train_model.train_model(args, dataset, device)
        args.method = am
        if utils.experiment_already_done(args):
            print(f'Already done: {args.seed_data=}, {args.seed=}, {args.valid=}, {args.method=}')
            continue
        train_model.train_model(args, dataset, device)