In [None]:
# Credits to https://towardsdatascience.com/multi-task-learning-for-computer-vision-classification-with-keras-36c52e6243d2
#            https://github.com/GATECH-EIC/TinyML-Contest-Solution

In [None]:
# Imports

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import sklearn
import time
import os
import csv
from tensorflow import keras

In [None]:
""" Functions for loading data """
def txt_to_numpy(filename, row=1250):
    file = open(filename)
    lines = file.readlines()
    datamat = np.arange(row, dtype=np.float)
    row_count = 0
    for line in lines:
        line = line.strip().split(' ')
        datamat[row_count] = line[0]
        row_count += 1

    return datamat

"""
AFb,Atrial Fibrillation
AFt,Atrial Flutter
SR,Sinus Rhythm
SVT,Supraventricular Tachycardia
VFb,Ventricular Fibrillation
VFt,Ventricular Flutter
VPD,Ventricular Premature Depolarizations
VT,Ventricular Tachycardia
"""
def txt_to_disease_type(txt):
    txt = txt.upper()
    if txt == "AFb":
      return [1, 0, 0, 0, 0, 0, 0, 0]
    if txt == "AFT":
      return [0, 1, 0, 0, 0, 0, 0, 0]
    if txt == "SR":
      return [0, 0, 1, 0, 0, 0, 0, 0]
    if txt == "SVT":
      return [0, 0, 0, 1, 0, 0, 0, 0]
    if txt == "VFB":
      return [0, 0, 0, 0, 1, 0, 0, 0]
    if txt == "VFT":
      return [0, 0, 0, 0, 0, 1, 0, 0]
    if txt == "VPD":
      return [0, 0, 0, 0, 0, 0, 1, 0]
    if txt == "VT":
      return [0, 0, 0, 0, 0, 0, 0, 1]
    raise "A string was not recognized as a valid input on txt_to_disease_type func!"


def read_data(csv_path, imgs_folder="./tinyml_contest_data_training", augmentation=False, flip_peak=False, flip_time=False, add_noise=False):
    x, y1, y2 = [], [], []
    with open(csv_path, "r") as csv_file:
      reader = csv.reader(csv_file)
      next(reader)
      for item in reader: # item[0] = label; item[1] = filename
        x.append(txt_to_numpy(os.path.join(imgs_folder, item[1]), 1250))
        y1.append([1, 0] if int(item[0]) == 0 else [0, 1]) # Label for life threat
        y2.append(txt_to_disease_type(item[1].split("-")[1]))
    x, y1, y2 = sklearn.utils.shuffle(x, y1, y2)
    if augmentation:
      for i in range(len(x)):
        flip_p = random.random()
        flip_t = random.random()
        if flip_p < 0.5 and flip_peak:
          x[i] = -x[i]
        if flip_t < 0.5 and flip_time:
          x[i] = np.flip(x[i])
        if add_noise:
          max_peak = x[i].max() * 0.05
          factor = random.random()
          # factor = 1
          noise = np.random.normal(0, factor * max_peak, (len(x[i]), 1))
          x[i] = x[i] + noise
    return x, y1, y2


In [None]:
""" Loading data """
x_train, y_train_1, y_train_2 = read_data("./train_indice.csv")
x_test, y_test_1, y_test_2 = read_data("./test_indice.csv")



In [None]:
""" Functions for creating, compiling and training model"""

# Architecture:
# Main branch
# Branch1 => Decides wheter life threatening (VF or VT style) or not (SR and others)
# Branch2 => Classify the input


# labels,Rhythm
# AFb,Atrial Fibrillation
# AFt,Atrial Flutter
# SR,Sinus Rhythm
# SVT,Supraventricular Tachycardia
# VFb,Ventricular Fibrillation
# VFt,Ventricular Flutter
# VPD,Ventricular Premature Depolarizations
# VT,Ventricular Tachycardia


def gen_model():
    inputs = keras.layers.Input(shape=(1250, 1), name='input')

    main_branch = keras.layers.Conv1D(filters=128, kernel_size=20, strides=32, activation="relu") (inputs)
    main_branch = keras.layers.BatchNormalization() (main_branch)
    main_branch = keras.layers.Conv1D(filters=64, kernel_size=20, strides=32, activation="relu") (main_branch)
    main_branch = keras.layers.BatchNormalization() (main_branch)
    main_branch = keras.layers.Conv1D(filters=32, kernel_size=20, strides=32, activation="relu") (main_branch)
    main_branch = keras.layers.BatchNormalization() (main_branch)
    main_branch = keras.layers.Flatten() (main_branch)
    main_branch = keras.layers.Dropout(0.3) (main_branch)
    main_branch = keras.layers.Dense(512, activation="relu") (main_branch)
    main_branch = keras.layers.Dropout(0.1) (main_branch)
    main_branch = keras.layers.Dense(512, activation="relu") (main_branch)
    main_branch = keras.layers.Dropout(0.1) (main_branch)

    # Decides wheter life threatening (VF or VT style) or not (SR and others)
    branch1 = keras.layers.Dense(128, activation="relu") (main_branch)
    branch1 = keras.layers.BatchNormalization() (branch1)
    branch1 = keras.layers.Dense(128, activation="relu") (branch1)
    branch1 = keras.layers.BatchNormalization() (branch1)
    branch1 = keras.layers.Dense(128, activation="relu") (branch1)
    branch1 = keras.layers.BatchNormalization() (branch1)
    branch1 = keras.layers.Dropout(0.2) (branch1)
    branch1 = keras.layers.Dense(2, activation="softmax", name='task_1_output') (branch1)

    # Classify the input
    branch2 = keras.layers.Dense(128, activation="relu") (main_branch)
    branch2 = keras.layers.BatchNormalization() (branch2)
    branch2 = keras.layers.Dense(128, activation="relu") (branch2)
    branch2 = keras.layers.BatchNormalization() (branch2)
    branch2 = keras.layers.Dense(128, activation="relu") (branch2)
    branch2 = keras.layers.BatchNormalization() (branch2)
    branch2 = keras.layers.Dropout(0.2) (branch2)
    branch2 = keras.layers.Dense(8, activation="softmax", name='task_2_output') (branch2)

    model = tf.keras.Model(inputs = inputs, outputs = [branch1, branch2])

    return model

# loss_weight -> [0, 1]
# the weight for the second task is calculated by '1 - loss_weight'
def compile_model(model, loss_weight):
    model.compile(optimizer='adam',
                  loss={'task_1_output': 'binary_crossentropy',
                        'task_2_output': 'categorical_crossentropy'},
                  loss_weights={'task_1_output': loss_weight,
                                'task_2_output': 1 - loss_weight},
                  metrics=[tf.keras.metrics.Accuracy(), tf.keras.metrics.Precision(), tf.keras.metrics.Recall()])

    return model

def fit_batch(gamma_values, epochs=30, batch_size=32, save_models=False, models_dir="./trained_models", verbose=0):

    history = list()
    trained_models = list()
    test_scores = list()

    if save_models:
      os.mkdir(models_dir)

    print('Starting training on batch of models for gamma values ', gamma_values, '\n\n')

    for i, gamma in enumerate(gamma_values):

        print('Training model for gamma equal to ', gamma)
        model = gen_model()
        model = compile_model(model, gamma)
        start = time.time()
        model_history = model.fit({'input': x_train},
                            {'task_1_output': y_train_1, 'task_2_output': y_train_2},
                            epochs=epochs, batch_size=batch_size, verbose=verbose)
        print(f'Training time: {time.time() - start}\n')
        history.append(model_history)
        trained_models.append(model)
        if save_models:
          model.save(os.path.join(models_dir, f"{i}th-model"))

        test_score = model.evaluate({'input': x_test}, {'task_1_output': y_test_1, 'task_2_output': y_test_2})
        print("Results:", test_score)
        print("Metrics:", model.metrics_names)

        test_scores.append(test_score)

    return history, trained_models, test_scores

In [None]:
def plot_multitask_accuracies(gammas, training_history):

    counter = 0

    for history in training_history:

        print(f'\nPlotting Accuracy/Precision/Recall vs Epochs for value of gamma number {gammas[counter]}\n')
        plt.plot(range(len(history.history['task_1_output_accuracy'])), history.history['task_1_output_accuracy'], c='r', label='Task 1')
        plt.plot(range(len(history.history['task_2_output_accuracy'])), history.history['task_2_output_accuracy'], c='b', label='Task 2')
        plt.plot(range(len(history.history['task_1_output_precision'])), history.history['task_1_output_precision'], c='r', label='Task 1')
        plt.plot(range(len(history.history['task_2_output_precision'])), history.history['task_2_output_precision'], c='b', label='Task 2')
        plt.plot(range(len(history.history['task_1_output_recall'])), history.history['task_1_output_recall'], c='r', label='Task 1')
        plt.plot(range(len(history.history['task_2_output_recall'])), history.history['task_2_output_recall'], c='b', label='Task 2')
        plt.xlabel('Epochs')
        plt.legend()
        plt.show()
        counter += 1

In [None]:
"""Actually training the model"""
gamma_values = [0.5, 0.6, 0.7, 0.8] # Higher value => Higher importance on class identification, less importance on classifing life-threat

history, trained_models, _ = fit_batch(gamma_values, save_models=True, verbose=1)