In [13]:
import pandas as pd
import tensorflow as tf
from pathlib import Path
import numpy as np
from json import dump
from pathlib import Path

In [14]:
DATA_PATH = Path("/scratch/ajb5d/ecg/tfrecords/")
ALL_RECS = list(DATA_PATH.glob("*.tfrecords"))

In [15]:
BATCH_SIZE = 512

record_format = {
    'ecg/data': tf.io.FixedLenSequenceFeature([], tf.float32, allow_missing=True),
    'file_name': tf.io.FixedLenFeature([], tf.int64),
    'hospital_expire_flag': tf.io.FixedLenFeature([], tf.float32),
}

def _parse_record(record):
    example = tf.io.parse_single_example(record, record_format)
    ecg_data = tf.reshape(example['ecg/data'], [5000,12])
    return ecg_data, example['file_name']

def load_dataset(filenames):
    ignore_order = tf.data.Options()
    ignore_order.experimental_deterministic = False
    dataset = tf.data.TFRecordDataset(filenames)
    dataset = dataset.with_options(ignore_order)
    dataset = dataset.map(_parse_record, num_parallel_calls=tf.data.AUTOTUNE)
    return dataset

def get_dataset(filenames, labeled=True):
    dataset = load_dataset(filenames)
    dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
    dataset = dataset.batch(BATCH_SIZE)
    return dataset

In [16]:
all_recs = get_dataset(ALL_RECS)

In [19]:
MODELS = ['resnet-age', 'cnn-age', 'resnet-potassium', 'cnn-potassium', 'cnn-gender', 'resnet-gender', 'cnn-sodium', 'resnet-sodium']
for model_name in MODELS:
    input_path = Path(f"data/models/{model_name}.keras")
    output_path = Path(f"data/{model_name}-scores.json")
    
    print(model_name)
    if output_path.exists() and input_path.stat().st_mtime < output_path.stat().st_mtime:
        print("skipping -- model is older than results")
        continue
    
    model = tf.keras.models.load_model(str(input_path))
    
    results = {}
    for batch in all_recs:
        preds = model.predict(batch[0], batch_size=BATCH_SIZE, verbose = 0)
        for age, study in zip(preds, batch[1]):
            results[study.numpy()] = age[0]

    results2 = {}
    for x in results.keys():
        results2[int(x)] =  float(results[x])
    
    
    with open(output_path, "w") as fh:
        dump(results2, fh)

resnet-age
skipping -- model is older than results
cnn-age
skipping -- model is older than results
resnet-potassium
skipping -- model is older than results
cnn-potassium


2023-11-29 15:49:11.541725: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x7f71cf81a710 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2023-11-29 15:49:11.541771: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): NVIDIA A100-SXM4-80GB, Compute Capability 8.0
2023-11-29 15:49:13.239398: I ./tensorflow/compiler/jit/device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.

KeyboardInterrupt

