In [1]:
import tensorflow as tf
import numpy as np
from pathlib import Path

2023-10-16 15:07:44.595771: I tensorflow/core/platform/cpu_feature_guard.cc:183] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE3 SSE4.1 SSE4.2 AVX, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
DATA_PATH = Path("/scratch/ajb5d/ecg/tfrecords/")
TRAIN_RECS = list(DATA_PATH.glob("train*.tfrecords"))
VAL_RECS = list(DATA_PATH.glob("train*.tfrecords"))

In [7]:
BATCH_SIZE = 64

record_format = {
    'ecg/data': tf.io.FixedLenSequenceFeature([], tf.float32, allow_missing=True),
    'age': tf.io.FixedLenFeature([], tf.float32),
    'gender': tf.io.FixedLenFeature([], tf.int64),
}

def _parse_record(record):
    example = tf.io.parse_single_example(record, record_format)
    ecg_data = tf.reshape(example['ecg/data'], [5000,12])
    # label = (example['age'] - 50) / 100
    label = example['age']
    return ecg_data, label

def drop_na_ages(x,y):
    return not tf.math.reduce_any(tf.math.is_nan(y))

def load_dataset(filenames):
    ignore_order = tf.data.Options()
    ignore_order.experimental_deterministic = False
    dataset = tf.data.TFRecordDataset(filenames)
    dataset = dataset.with_options(ignore_order)
    dataset = dataset.map(_parse_record, num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.filter(drop_na_ages)
    return dataset

def get_dataset(filenames, labeled=True):
    dataset = load_dataset(filenames)
    dataset = dataset.shuffle(2048)
    dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
    dataset = dataset.batch(BATCH_SIZE)
    return dataset

In [8]:
train_dataset = get_dataset(TRAIN_RECS)
val_dataset = get_dataset(VAL_RECS)

In [14]:
def conv_block(x, n, k, mp, groups = 12):
    x = tf.keras.layers.Conv1D(n * groups, k, padding='same', groups = groups)(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Activation('relu')(x)
    if mp:
        x = tf.keras.layers.MaxPooling1D(mp)(x)

    return x

kernel_initializer = 'he_normal'
input_layer = tf.keras.layers.Input(shape=(5000, 12))
x = input_layer

x = conv_block(x, 16, 7, 2)
x = conv_block(x, 16, 5, 4)
x = conv_block(x, 32, 5, 2)
x = conv_block(x, 32, 5, 4)
x = conv_block(x, 64, 5, 2)
x = conv_block(x, 64, 3, 2)
x = conv_block(x, 64, 3, 2)
x = conv_block(x, 64, 3, 2)

x = tf.keras.layers.Conv1D(128, 12, padding = 'same', data_format = 'channels_first')(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.Activation('relu')(x)
x = tf.keras.layers.MaxPooling1D()(x)

x = tf.keras.layers.Flatten()(x)
x = tf.keras.layers.Dense(128)(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.Activation('relu')(x)
x = tf.keras.layers.Dropout(0.5)(x)
x = tf.keras.layers.Dense(64)(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.Activation('relu')(x)
x = tf.keras.layers.Dropout(0.5)(x)
x = tf.keras.layers.Dense(1)(x)

model = tf.keras.models.Model(input_layer, x)

model.compile(
    optimizer='adam',
    loss=tf.keras.losses.MeanSquaredError(),
    metrics=['mse', 'mae']
)

In [None]:
callbacks = [
    tf.keras.callbacks.TerminateOnNaN(),
]

model.fit(train_dataset, epochs=3, validation_data=val_dataset, callbacks=callbacks)

Epoch 1/3
   2922/Unknown - 333s 114ms/step - loss: 201.5732 - mse: 201.5732 - mae: 11.2941

In [15]:
model.summary()

Model: "model_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_5 (InputLayer)        [(None, 5000, 12)]        0         
                                                                 
 conv1d_35 (Conv1D)          (None, 5000, 192)         1536      
                                                                 
 batch_normalization_36 (Ba  (None, 5000, 192)         768       
 tchNormalization)                                               
                                                                 
 activation_35 (Activation)  (None, 5000, 192)         0         
                                                                 
 max_pooling1d_32 (MaxPooli  (None, 2500, 192)         0         
 ng1D)                                                           
                                                                 
 conv1d_36 (Conv1D)          (None, 2500, 192)         1555