In [5]:
import tensorflow as tf
import numpy as np
from pathlib import Path

In [6]:
DATA_PATH = Path("/scratch/ajb5d/ecg/tfrecords/")
TRAIN_RECS = list(DATA_PATH.glob("train*.tfrecords"))
VAL_RECS = list(DATA_PATH.glob("val*.tfrecords"))

In [7]:
BATCH_SIZE = 64

record_format = {
    'ecg/data': tf.io.FixedLenSequenceFeature([], tf.float32, allow_missing=True),
    'age': tf.io.FixedLenFeature([], tf.float32),
    'sodium': tf.io.FixedLenFeature([], tf.float32),
}

def _parse_record(record):
    example = tf.io.parse_single_example(record, record_format)
    ecg_data = tf.reshape(example['ecg/data'], [5000,12])
    label = example['sodium']
    return ecg_data, label

def drop_na_ages(x,y):
    return not tf.math.reduce_any(tf.math.is_nan(y))

def load_dataset(filenames):
    ignore_order = tf.data.Options()
    ignore_order.experimental_deterministic = False
    dataset = tf.data.TFRecordDataset(filenames)
    dataset = dataset.with_options(ignore_order)
    dataset = dataset.map(_parse_record, num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.filter(drop_na_ages)
    return dataset

def get_dataset(filenames, labeled=True):
    dataset = load_dataset(filenames)
    dataset = dataset.shuffle(2048)
    dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
    dataset = dataset.batch(BATCH_SIZE)
    return dataset

In [8]:
train_dataset = get_dataset(TRAIN_RECS)
val_dataset = get_dataset(VAL_RECS)

In [9]:
# tf.keras.mixed_precision.set_global_policy('mixed_float16')

In [10]:
def conv_block(x, n, k, mp, groups = 12):
    x = tf.keras.layers.Conv1D(n * groups, k, padding='same', groups = groups)(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Activation('relu')(x)
    if mp:
        x = tf.keras.layers.MaxPooling1D(mp)(x)

    return x

input_layer = tf.keras.layers.Input(shape=(5000, 12))
x = input_layer

x = conv_block(x, 16, 7, 2)
x = conv_block(x, 16, 5, 4)
x = conv_block(x, 32, 5, 2)
x = conv_block(x, 32, 5, 4)
x = conv_block(x, 64, 5, 2)
x = conv_block(x, 64, 3, 2)
x = conv_block(x, 64, 3, 2)
x = conv_block(x, 64, 3, 2)

x = tf.keras.layers.Conv1D(128, 12, padding = 'same', data_format = 'channels_first')(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.Activation('relu')(x)
x = tf.keras.layers.MaxPooling1D()(x)

x = tf.keras.layers.Flatten()(x)
x = tf.keras.layers.Dense(128)(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.Activation('relu')(x)
x = tf.keras.layers.Dropout(0.5)(x)
x = tf.keras.layers.Dense(64)(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.Activation('relu')(x)
x = tf.keras.layers.Dropout(0.5)(x)
x = tf.keras.layers.Dense(1)(x)

model = tf.keras.models.Model(input_layer, x)

model.compile(
    optimizer='adam',
    loss=tf.keras.losses.MeanSquaredError(),
    metrics=['mse', 'mae']
)

In [11]:
from datetime import datetime
import os

def make_checkpoint_dir(data_path, label):
    current_datetime = datetime.now()
    formatted_datetime = current_datetime.strftime("%Y-%m-%d_%H-%M-%S")
    output_dir = f"{label}-{formatted_datetime}"
    output_path = f"{data_path}/{output_dir}"
    
    if not os.path.exists(output_path):
        os.makedirs(output_path)
    
    return output_path

model_name = "cnn-potassium"
output_path = make_checkpoint_dir("data/models", model_name)

print(f"Model: {model_name} Run Path: {output_path}")

In [None]:
callbacks = [
    tf.keras.callbacks.TerminateOnNaN(),
    tf.keras.callbacks.ReduceLROnPlateau(),
    tf.keras.callbacks.ModelCheckpoint(output_path, save_best=True),
    tf.keras.callbacks.CSVLogger(f"data/models/{model_name}-history.csv")
]

history = model.fit(train_dataset, epochs=50, validation_data=val_dataset, callbacks=callbacks)

Epoch 1/10


2023-10-31 09:19:02.179085: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x56180db14dc0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2023-10-31 09:19:02.179144: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): NVIDIA A100-SXM4-40GB, Compute Capability 8.0
2023-10-31 09:19:02.306179: I tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:432] Loaded cuDNN version 8904
2023-10-31 09:19:03.550261: I ./tensorflow/compiler/jit/device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
2023-10-31 09:19:05.184999: I tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:625] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.
2023-10-31 09:19:05.367601: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:255] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to en

   4212/Unknown - 368s 83ms/step - loss: 2837.2769 - mse: 2837.2769 - mae: 35.5184

2023-10-31 09:25:05.414175: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous recv item cancelled. Key hash: 18421638274924574341


INFO:tensorflow:Assets written to: data/models/cnn-potassium-2023-10-31_09-18-57/assets


INFO:tensorflow:Assets written to: data/models/cnn-potassium-2023-10-31_09-18-57/assets


Epoch 2/10


INFO:tensorflow:Assets written to: data/models/cnn-potassium-2023-10-31_09-18-57/assets


Epoch 3/10


INFO:tensorflow:Assets written to: data/models/cnn-potassium-2023-10-31_09-18-57/assets


Epoch 4/10


INFO:tensorflow:Assets written to: data/models/cnn-potassium-2023-10-31_09-18-57/assets


Epoch 5/10


INFO:tensorflow:Assets written to: data/models/cnn-potassium-2023-10-31_09-18-57/assets


Epoch 6/10


INFO:tensorflow:Assets written to: data/models/cnn-potassium-2023-10-31_09-18-57/assets


Epoch 7/10


INFO:tensorflow:Assets written to: data/models/cnn-potassium-2023-10-31_09-18-57/assets


Epoch 8/10


INFO:tensorflow:Assets written to: data/models/cnn-potassium-2023-10-31_09-18-57/assets


Epoch 9/10


INFO:tensorflow:Assets written to: data/models/cnn-potassium-2023-10-31_09-18-57/assets


Epoch 10/10

In [None]:
model.save(f"data/models/{model_name}.keras")