# PhysioNet/Computing in Cardiology Challenge 2020
## Classification of 12-lead ECGs
### 4. Cross-Validate Model

# Setup Noteboook

In [None]:
# 3rd party libraries
import os
import sys
import numpy as np
import tensorflow as tf

# Local Libraries
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(os.getcwd()))))))
from kardioml.models.deepecg_binary.utils.devices.device_check import print_device_counts
from kardioml.models.deepecg_binary.train.train import train
from kardioml.models.deepecg_binary.model.model import Model
from kardioml import DATA_PATH, OUTPUT_PATH, ECG_LEADS

# Configure Notebook
import warnings
warnings.filterwarnings('ignore')
%matplotlib inline
%load_ext autoreload
%autoreload 2

# 1. Run Cross-Validation

In [None]:
# Set data path
data_path = os.path.join(DATA_PATH, 'formatted')

# Set model name
training_run = '26'

# Set sample length (seconds)
duration = 60

# Loop through CV folds
for cv_fold in np.arange(1, 6):
    
    print('Cross-Validation Fold {}'.format(cv_fold))
    
    # Set data path
    lookup_path = os.path.join(DATA_PATH, 'deepecg', 'cross_validation', str(cv_fold))

    # Print devices
    print_device_counts()

    # Set save path for graphs, summaries, and checkpoints
    output_path = os.path.join(OUTPUT_PATH, training_run)
    os.makedirs(output_path, exist_ok=True)

    # Set model name
    model_name = 'cv_{}'.format(cv_fold)

    # Maximum number of checkpoints to keep
    max_to_keep = 1

    # Set randome states
    seed = 0                                                         

    # Set sample frequency
    fs = 350

    # Get training dataset dimensions
    length, channels = (int(duration * fs), len(ECG_LEADS))          

    # Number of classes
    classes = 2

    # Choose network
    network_name = 'DeepECGV1'

    # Set hyper-parameter
    hyper_params = {'num_res_layers': 9, 'drop_rate': 0.3, 'kernel_size': 3, 
                    'conv_filts': 128, 'res_filts': 128, 'skip_filts': 128, 
                    'dilation': True, 'fs': fs}

    # Set network inputs
    network_parameters = {'length': length, 'channels': channels, 'classes': classes, 'seed': seed,
                          'hyper_params': hyper_params}

    # Create model
    print('Initializing Model')
    model = Model(model_name=model_name, 
                  network_name=network_name, 
                  network_parameters=network_parameters, 
                  save_path=output_path,
                  data_path=data_path,
                  lookup_path=lookup_path,
                  max_to_keep=max_to_keep)

    # Set hyper-parameters
    epochs = 60
    batch_size = 16            

    # Train model
    print('Training Start')
    train(model=model, epochs=epochs, batch_size=batch_size)
    print('Training End\n')