In [1]:
import tensorflow as tf
from tensorflow import keras
import numpy as np
from pathlib import Path

2023-10-19 08:06:36.961179: I tensorflow/core/platform/cpu_feature_guard.cc:183] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE3 SSE4.1 SSE4.2 AVX, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
DATA_PATH = Path("/scratch/ajb5d/ecg/tfrecords/")
TRAIN_RECS = list(DATA_PATH.glob("train*.tfrecords"))
VAL_RECS = list(DATA_PATH.glob("train*.tfrecords"))

BATCH_SIZE = 64

record_format = {
    'ecg/data': tf.io.FixedLenSequenceFeature([], tf.float32, allow_missing=True),
    'age': tf.io.FixedLenFeature([], tf.float32),
    'gender': tf.io.FixedLenFeature([], tf.int64),
}

def _parse_record(record):
    example = tf.io.parse_single_example(record, record_format)
    ecg_data = tf.reshape(example['ecg/data'], [5000,12])
    label = example['gender']
    return ecg_data, label

def load_dataset(filenames):
    ignore_order = tf.data.Options()
    ignore_order.experimental_deterministic = False
    dataset = tf.data.TFRecordDataset(filenames)
    dataset = dataset.with_options(ignore_order)
    dataset = dataset.map(_parse_record, num_parallel_calls=tf.data.AUTOTUNE)
    return dataset

def get_dataset(filenames, labeled=True):
    dataset = load_dataset(filenames)
    dataset = dataset.shuffle(2048)
    dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
    dataset = dataset.batch(BATCH_SIZE)
    return dataset

train_dataset = get_dataset(TRAIN_RECS)
val_dataset = get_dataset(VAL_RECS)

2023-10-19 08:06:43.141241: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1636] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 46583 MB memory:  -> device: 0, name: NVIDIA RTX A6000, pci bus id: 0000:01:00.0, compute capability: 8.6


In [17]:
model = keras.models.load_model("data/models/cnn-age.keras")

In [18]:
flatten_index = np.argmax([str(type(x)).find("Flatten") > 0 for x in model.layers])

In [19]:
for i in range(len(model.layers)):
    model.layers[i].trainable = False

In [20]:
x = model.layers[flatten_index].output
x = tf.keras.layers.Dense(128, name="tl_dense_1")(x)
x = tf.keras.layers.BatchNormalization(name="tl_bn_1")(x)
x = tf.keras.layers.Activation('relu', name = "tl_act_1")(x)
x = tf.keras.layers.Dropout(0.5, name = "tl_do_1")(x)
x = tf.keras.layers.Dense(64, name = "tl_dense_2")(x)
x = tf.keras.layers.BatchNormalization(name="tl_bn_2")(x)
x = tf.keras.layers.Activation('relu', name = "tl_act_2")(x)
x = tf.keras.layers.Dropout(0.5, name = "tl_do_2")(x)
x = tf.keras.layers.Dense(1, activation='sigmoid', name = "tl_dense_3")(x)

new_model = keras.Model(inputs=model.input,outputs=x)

new_model.compile(
    optimizer='adam',
    loss=tf.keras.losses.BinaryCrossentropy(),
    metrics=['accuracy', tf.keras.metrics.AUC()]
)

In [15]:
from datetime import datetime
import os

def make_checkpoint_dir(data_path, label):
    current_datetime = datetime.now()
    formatted_datetime = current_datetime.strftime("%Y-%m-%d_%H-%M-%S")
    output_dir = f"{label}-{formatted_datetime}"
    output_path = f"{data_path}/{output_dir}"
    
    if not os.path.exists(output_path):
        os.makedirs(output_path)
    
    return output_path

In [21]:
callbacks = [
    tf.keras.callbacks.TerminateOnNaN(),
    tf.keras.callbacks.ReduceLROnPlateau(),
    tf.keras.callbacks.ModelCheckpoint(make_checkpoint_dir("data/models", "cnn-tl-gender"))
]

new_model.fit(train_dataset, epochs=10, validation_data=val_dataset, callbacks=callbacks)

Epoch 1/10
   1492/Unknown - 34s 21ms/step - loss: 0.7431 - accuracy: 0.5189 - auc_2: 0.5198

KeyboardInterrupt: 

In [22]:
new_model.summary()

Model: "model_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 5000, 12)]        0         
                                                                 
 conv1d (Conv1D)             (None, 5000, 192)         1536      
                                                                 
 batch_normalization (Batch  (None, 5000, 192)         768       
 Normalization)                                                  
                                                                 
 activation (Activation)     (None, 5000, 192)         0         
                                                                 
 max_pooling1d (MaxPooling1  (None, 2500, 192)         0         
 D)                                                              
                                                                 
 conv1d_1 (Conv1D)           (None, 2500, 192)         1555