# PhysioNet/Computing in Cardiology Challenge 2020
## Classification of 12-lead ECGs
### 3. Train Model

# Setup Notebook

In [3]:
# Import 3rd party libraries
import os
import sys
import ast
import time
import json
import numpy as np
import pandas as pd

# Import local Libraries
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(os.getcwd()))))))
from kardioml import DATA_PATH, ECG_LEADS, FS, LABELS_LOOKUP, LABELS_COUNT
from kardioml.models.physionet2017.training.xgboost_model import Model
from kardioml.data.data_loader import load_challenge_data

# Configure Notebook
import warnings
warnings.filterwarnings('ignore')
%matplotlib inline
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Import Data
### Meta Data

In [4]:
# Import to DataFrame
meta_data = pd.read_csv(os.path.join(DATA_PATH, 'physionet_2017', 'training', 'meta_data.csv'))

# View DataFrame
meta_data.head()

Unnamed: 0,age,channel_order,filename,label_train,labels,labels_full,labels_int,sex,shape,label_count,length,labels_concat
0,74.0,"['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', ...",A0001,"[0, 0, 0, 0, 0, 0, 1, 0, 0]",['RBBB'],['Right bundle branch block'],[6],Male,"[12, 7500]",1,15.0,Right bundle branch block
1,49.0,"['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', ...",A0002,"[0, 0, 0, 1, 0, 0, 0, 0, 0]",['Normal'],['Normal sinus rhythm'],[3],Female,"[12, 5000]",1,10.0,Normal sinus rhythm
2,81.0,"['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', ...",A0003,"[1, 0, 0, 0, 0, 0, 0, 0, 0]",['AF'],['Atrial fibrillation'],[0],Female,"[12, 5000]",1,10.0,Atrial fibrillation
3,45.0,"['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', ...",A0004,"[1, 0, 0, 0, 0, 0, 0, 0, 0]",['AF'],['Atrial fibrillation'],[0],Male,"[12, 5974]",1,11.948,Atrial fibrillation
4,53.0,"['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', ...",A0005,"[0, 0, 0, 0, 0, 1, 0, 0, 0]",['PVC'],['Premature ventricular complex'],[5],Male,"[12, 12500]",1,25.0,Premature ventricular complex


### Features

In [None]:
# Import to DataFrame
features = pd.read_csv(os.path.join(DATA_PATH, 'physionet_2017', 'training', 'features.csv'))

# View DataFrame
features.head()

### Labels

In [None]:
# Import to DataFrame
labels = pd.read_csv(os.path.join(DATA_PATH, 'physionet_2017', 'training', 'labels.csv'))

# View DataFrame
labels.head()

# Hyper-Parameter Tuning

In [None]:
# Set parameter bounds
param_bounds = {'learning_rate': (0.01, 1.0),
                'n_estimators': (500, 1500),
                'max_depth': (2, 8),
                'subsample': (0.5, 1.0),  
                'colsample': (0.5, 1.0),
                'gamma': (0.001, 2.0),
                'min_child_weight': (0, 10),
                'max_delta_step': (0, 10)}

# Set number of iterations
n_iter = 40

# Set number CV folds
cv_folds = 4

# Get 1-D labels for stratifying
stratifier = meta_data['labels'].map(lambda val: ast.literal_eval(val)[0])

# Initialize model
model = Model(features=features, labels=labels, cv_folds=cv_folds, stratifier=stratifier)

# Run hyper-paramter search
model.tune_hyper_parameters(param_bounds=param_bounds, n_iter=n_iter)

# Save model
model.save()

# Test Inference

In [None]:
# Load test data
data, header_data = load_challenge_data(filename=os.path.join(DATA_PATH, 'raw', 'Training_WFDB', 'A0100.mat'))

# Run inference
model.challenge_prediction(data=data, header_data=header_data)