In [None]:
import numpy as np
import os
import pickle
import h5py
import pandas as pd

import tensorflow as tf
from tensorflow.keras.models import load_model

from hfnet import HFNet

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"] = str(0)
gpus = tf.config.list_physical_devices('GPU')
tf.config.set_visible_devices(gpus[0], 'GPU')
tf.config.experimental.set_memory_growth(gpus[0], True)

In [None]:
# This characterises what the tabular input to the model looks like.
# In this case, we just have two demographic features: age and sex
TAB_SIZE_SPEC = {
    'demo': [0,2],
}

In [None]:
## Load the model from file

ECG_ENCODER_PATH = './PCLR.h5'
MODEL_WEIGHTS = './best_model_weights.h5'

# Load the ECG encoder architecture from the checkpoint
ecg_encoder = load_model(ECG_ENCODER_PATH)

# Get the latent representation (320-dimensional)
latent = tf.keras.Model(ecg_encoder.inputs, ecg_encoder.get_layer('embed').output)

enc = HFNet(TAB_SIZE_SPEC, latent)
enc.build([(None,2500,12), (None, 2)])

# Load the pre-trained model weights
enc.load_weights(MODEL_WEIGHTS)

In [None]:
# Test out 
op = enc((np.random.randn(2, 2500, 12), np.random.randn(2,2)))

In [None]:
op.shape

In [None]:
op