In [1]:
import tensorflow as tf
from tensorflow import keras
import numpy as np
from pathlib import Path

2023-10-26 09:46:43.120602: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-10-26 09:46:43.710169: I tensorflow/core/platform/cpu_feature_guard.cc:183] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE3 SSE4.1 SSE4.2 AVX, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
DATA_PATH = Path("/scratch/ajb5d/ecg/tfrecords/")
TRAIN_RECS = list(DATA_PATH.glob("train*.tfrecords"))
VAL_RECS = list(DATA_PATH.glob("val*.tfrecords"))

BATCH_SIZE = 64

record_format = {
    'ecg/data': tf.io.FixedLenSequenceFeature([], tf.float32, allow_missing=True),
    'age': tf.io.FixedLenFeature([], tf.float32),
    'gender': tf.io.FixedLenFeature([], tf.int64),
    'hospital_expire_flag': tf.io.FixedLenFeature([], tf.float32),
}

def _parse_record(record):
    example = tf.io.parse_single_example(record, record_format)
    ecg_data = tf.reshape(example['ecg/data'], [5000,12])
    label = example['hospital_expire_flag']
    return (ecg_data, example['age'], example['gender']), label

@tf.function
def drop_na_labels(x,y):
    return not tf.math.reduce_any(tf.math.is_nan(y))

@tf.function
def drop_na_age(x,y):
    return not tf.math.reduce_any(tf.math.is_nan(x[1]))

def load_dataset(filenames):
    ignore_order = tf.data.Options()
    ignore_order.experimental_deterministic = False
    dataset = tf.data.TFRecordDataset(filenames)
    dataset = dataset.with_options(ignore_order)
    dataset = dataset.map(_parse_record, num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.filter(drop_na_labels)
    dataset = dataset.filter(drop_na_age)
    return dataset

def get_dataset(filenames, labeled=True):
    dataset = load_dataset(filenames)
    dataset = dataset.shuffle(2048)
    dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
    dataset = dataset.batch(BATCH_SIZE)
    return dataset

train_dataset = get_dataset(TRAIN_RECS)
val_dataset = get_dataset(VAL_RECS)

2023-10-26 09:46:48.071387: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1636] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 31042 MB memory:  -> device: 0, name: Tesla V100-SXM2-32GB, pci bus id: 0000:18:00.0, compute capability: 7.0


In [3]:
from datetime import datetime
import os

def make_checkpoint_dir(data_path, label):
    current_datetime = datetime.now()
    formatted_datetime = current_datetime.strftime("%Y-%m-%d_%H-%M-%S")
    output_dir = f"{label}-{formatted_datetime}"
    output_path = f"{data_path}/{output_dir}"
    
    if not os.path.exists(output_path):
        os.makedirs(output_path)
    
    return output_path

In [4]:
model = keras.models.load_model("data/models/resnet-age.keras")

In [5]:
for i in range(len(model.layers)):
    model.layers[i].trainable = False

In [6]:
model.layers[-2]

<keras.src.layers.reshaping.flatten.Flatten at 0x7f6b304e01c0>

In [7]:
x = model.layers[-2].output
age_input = tf.keras.layers.Input(shape=(1,), name="age_input")
gender_input = tf.keras.layers.Input(shape=(1,), name="gender_input")
x = keras.layers.Add(name="merge")([x, age_input, gender_input])
x = tf.keras.layers.Dense(128, name="tl_dense_3")(x)
x = tf.keras.layers.Dense(1, activation='sigmoid', name = "tl_dense_4")(x)

new_model = keras.Model([model.input, age_input, gender_input],outputs=x)

new_model.compile(
    optimizer='adam',
    loss=tf.keras.losses.BinaryCrossentropy(),
    metrics=['accuracy', tf.keras.metrics.AUC()]
)

In [8]:
callbacks = [
    tf.keras.callbacks.TerminateOnNaN(),
    tf.keras.callbacks.ReduceLROnPlateau(),
    tf.keras.callbacks.ModelCheckpoint(make_checkpoint_dir("data/models", "resnet-tl-mort"))
]

new_model.fit(train_dataset, epochs=3, validation_data=val_dataset, callbacks=callbacks)

Epoch 1/3


2023-10-26 09:47:22.935932: I tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:432] Loaded cuDNN version 8904
2023-10-26 09:47:24.894616: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x7f6aadaf7c00 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2023-10-26 09:47:24.894648: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Tesla V100-SXM2-32GB, Compute Capability 7.0
2023-10-26 09:47:24.965224: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:255] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2023-10-26 09:47:25.434890: I ./tensorflow/compiler/jit/device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


    453/Unknown - 92s 184ms/step - loss: 114.4933 - accuracy: 0.7850 - auc: 0.5020

2023-10-26 09:49:11.963659: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous recv item cancelled. Key hash: 14487482424551181521


INFO:tensorflow:Assets written to: data/models/resnet-tl-mort-2023-10-26_09-47-16/assets


INFO:tensorflow:Assets written to: data/models/resnet-tl-mort-2023-10-26_09-47-16/assets


Epoch 2/3


INFO:tensorflow:Assets written to: data/models/resnet-tl-mort-2023-10-26_09-47-16/assets


Epoch 3/3


INFO:tensorflow:Assets written to: data/models/resnet-tl-mort-2023-10-26_09-47-16/assets




<keras.src.callbacks.History at 0x7f6b30321d20>