# PhysioNet/Computing in Cardiology Challenge 2020
## Classification of 12-lead ECGs
### 5. Tune Model

# Setup Noteboook

In [None]:
# 3rd party libraries
import os
import sys
import numpy as np
import tensorflow as tf

# Local Libraries
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(os.getcwd()))))))
from kardioml.models.deepecg.utils.devices.device_check import print_device_counts
from kardioml.models.deepecg.train.train import train
from kardioml.models.deepecg.model.model import Model
from kardioml import DATA_PATH, OUTPUT_PATH, LABELS_COUNT, ECG_LEADS, FS

# Configure Notebook
import warnings
warnings.filterwarnings('ignore')
%matplotlib inline
%load_ext autoreload
%autoreload 2

# 1. Hyper-Parameter Search

In [None]:
# Set data path
data_path = os.path.join(DATA_PATH, 'formatted')

# Set model name
training_run = 'hpt_15'

# Set sample length (seconds)
duration = 60

# Loop through CV folds
for num_res_layers in [6, 7, 8, 9, 10, 11, 12]:
    for num_filts in [128, 256, 512, 1024]:
        for drop_rate in [0.3]:
            for kernel_size in [3]:
                for fs in [300]:

                    experiment_name = 'num_res_layers-{}_num_filts-{}_drop_rate-{}_kernel_size-{}_fs-{}'.format(num_res_layers, 
                                                                                                                num_filts,
                                                                                                                drop_rate, 
                                                                                                                kernel_size,
                                                                                                                fs)
                    print('Hyper-Parameter: {}'.format(experiment_name))

                    # Set data path
                    lookup_path = os.path.join(DATA_PATH, 'deepecg')

                    # Print devices
                    print_device_counts()

                    # Set save path for graphs, summaries, and checkpoints
                    output_path = os.path.join(OUTPUT_PATH, training_run)
                    os.makedirs(output_path, exist_ok=True)

                    # Set model name
                    model_name = 'sess_{}'.format(experiment_name)

                    # Maximum number of checkpoints to keep
                    max_to_keep = 1

                    # Set randome states
                    seed = 0 

                    # Get training dataset dimensions
                    length, channels = (int(duration * fs), 13)     

                    # Number of classes
                    classes = LABELS_COUNT

                    # Choose network
                    network_name = 'DeepECGV1'

                    # Set hyper-parameter
                    hyper_params = {'num_res_layers': num_res_layers, 'drop_rate': drop_rate, 'kernel_size': kernel_size, 
                                    'conv_filts': num_filts, 'res_filts': num_filts, 'skip_filts': num_filts, 
                                    'dilation': True, 'fs': fs}

                    # Set network inputs
                    network_parameters = {'length': length, 'channels': channels, 'classes': classes, 'seed': seed,
                                          'hyper_params': hyper_params}

                    try:
                        # Create model
                        print('Initializing Model')
                        model = Model(model_name=model_name, 
                                      network_name=network_name, 
                                      network_parameters=network_parameters, 
                                      save_path=output_path,
                                      data_path=data_path,
                                      lookup_path=lookup_path,
                                      max_to_keep=max_to_keep)

                        # Set hyper-parameters
                        epochs = 50
                        batch_size = 16            

                        # Train model
                        print('Training Start')
                        train(model=model, epochs=epochs, batch_size=batch_size)
                        print('Training End\n')
                        
                    except:
                        print('Training Failure: {}'.format(experiment_name))