# Physionet 2017 | ECG Rhythm Classification
## 4. Train Model
### Sebastian D. Goodfellow, Ph.D.

# Setup Noteboook

In [1]:
# Import 3rd party libraries
import os
import sys
import numpy as np
import pickle

# Deep learning libraries
import tensorflow as tf

# Import local Libraries
sys.path.insert(0, r'C:\Users\sebastian goodfellow\Documents\code\deep_ecg')
from deepecg.training.utils.plotting.training_data_validation import interval_plot_interact
from deepecg.training.utils.devices.device_check import print_device_counts
from deepecg.training.train.disc.data_generator import DataGenerator
from deepecg.training.train.disc.train import train
from deepecg.training.model.disc.model import Model
from deepecg.config.config import DATA_DIR

# Configure Notebook
import warnings
warnings.filterwarnings('ignore')
%matplotlib inline
%load_ext autoreload
%autoreload 2

# Resources

In [2]:
# Objective Function
# https://stackoverflow.com/questions/44560549/unbalanced-data-and-weighted-cross-entropy

# Global Average Pooling
# https://alexisbcook.github.io/2017/global-average-pooling-layers-for-object-localization/
# https://github.com/philipperemy/tensorflow-class-activation-mapping/blob/master/class_activation_map.py
# https://github.com/AndersonJo/global-average-pooling

# 1. View Training Data

In [3]:
# Data path
data_path = os.path.join(DATA_DIR, 'training', 'disc')

# Launch plot
interval_plot_interact(path=data_path, dataset='val')

interactive(children=(IntSlider(value=852, description='label_id', max=1705), Output()), _dom_classes=('widget…

# 2. Test Data Generator

In [4]:
# Initialize generator
generator = DataGenerator(path=data_path, mode='train', shape=[18000, 1], batch_size=32, 
                          prefetch_buffer=1000, seed=0, num_parallel_calls=24)

# View dataset
generator.dataset

<PrefetchDataset shapes: ((?, 18000, 1), (?,)), types: (tf.float32, tf.int32)>

# 3. Initialize Model

In [5]:
# Print devices
print_device_counts()

# Set save path for graphs, summaries, and checkpoints
save_path = r'C:\Users\sebastian goodfellow\Desktop\tensorboard\deep_ecg\tests'

# Set model name
model_name = 'test_3'

# Maximum number of checkpoints to keep
max_to_keep = 20

# Set randome states
seed = 0                                                         

# Get training dataset dimensions
length, channels = (18000, 1)     

# Number of classes
classes = 4

# Choose network
network_name = 'DeepECGV1'

# Set network inputs
network_parameters = dict(
    length=length,
    channels=channels, 
    classes=classes, 
    seed=seed,   
)

# Create model
model = Model(
    model_name=model_name, 
    network_name=network_name, 
    network_parameters=network_parameters, 
    save_path=save_path,
    data_path=data_path,
    max_to_keep=max_to_keep
)

Workstation has 1 CPUs.
Workstation has 2 GPUs.


# 4. Train Model

In [None]:
# Set hyper-parameters
epochs = 100
batch_size = 10
learning_rate_start = 0.001            

# Train model
train(model=model, learning_rate_start=learning_rate_start, epochs=epochs, batch_size=batch_size)