# ECG Rhythm Classification
## 1. Train Model
### Sebastian D. Goodfellow, Ph.D.

# Setup Noteboook

In [None]:
# 3rd party libraries
import os
import sys
import numpy as np
import tensorflow as tf

# Local Libraries
sys.path.insert(0, r'.\..')
from haifanet.utils.plotting.training_data_validation import interval_plot_interact
from haifanet.utils.devices.device_check import print_device_counts
from haifanet.train.train import train
from haifanet.model.model import Model
from haifanet import DATA_DIR, LOG_DIR

# Configure Notebook
import warnings
warnings.filterwarnings('ignore')
%matplotlib inline
%load_ext autoreload
%autoreload 2

# 1. View Training Data

In [None]:
# Data path
data_path = os.path.join(DATA_DIR, 'training')

# Launch plot
interval_plot_interact(path=data_path, dataset='val')

# 2. Initialize Model

In [None]:
# Print devices
print_device_counts()

# Set save path for graphs, summaries, and checkpoints
os.makedirs(LOG_DIR, exist_ok=True)

# Set model name
model_name = '1'

# Maximum number of checkpoints to keep
max_to_keep = 1

# Set randome states
seed = 0                                                         

# Get training dataset dimensions
length, channels = (18000, 1)     

# Number of classes
classes = 4

# Choose network
network_name = 'HaifaNetV1'

# Set network inputs
network_parameters = {'length': length, 'channels': channels, 'classes': classes, 'seed': seed}

# Create model
model = Model(model_name=model_name, 
              network_name=network_name, 
              network_parameters=network_parameters, 
              save_path=LOG_DIR,
              data_path=data_path,
              max_to_keep=max_to_keep)

# 4. Train Model

In [None]:
# Set hyper-parameters
epochs = 200
batch_size = 16
learning_rate_start = 0.001            

# Train model
train(model=model, epochs=epochs, batch_size=batch_size)