## NoobHeart's JAX implementation  

PhysBench does not rely on any deep learning framework, it handles all preprocessing, postprocessing, visualization, and then exposes the core model development completely to users who can flexibly choose their own development tools.  

### Preparation  

First, you need to prepare the dataset and training data. Please refer to `Noob Heart.ipynb`, this part is the same.

In [1]:
import sys
sys.path.append("..")
from utils import *

import jax
import jax.numpy as jnp
from jax import jit, grad, lax
from jax.example_libraries import optimizers

Initialize the convolutional kernel weights of NoobHeart and define the model structure.

In [2]:
def NoobHeart(seeds=jax.random.PRNGKey(range(6))):
    weights = [
        jax.random.normal(seeds[0], (4, 3, 2, 2, 2,)),  # kernel 1
        jax.random.normal(seeds[1], (4, 1, 1, 1)),      # basis  1
        jax.random.normal(seeds[2], (2, 4, 2, 2, 2,)),  # kernel 2
        jax.random.normal(seeds[3], (2, 1, 1, 1)),      # basis  2
        jax.random.normal(seeds[4], (1, 1, 1, 1, 1,)),  # kernel 3
        jax.random.normal(seeds[5], (1, 1, 1, 1)),      # basis  3
    ]
    N = lambda x, ax=2:(x-jnp.expand_dims(x.mean(axis=ax), ax))/jnp.expand_dims(x.std(axis=ax), ax)
    @jit
    def model(x, weights=weights):
        x = N(jnp.transpose(x, (0, 4, 1, 2, 3)))                                # BCDHW
        x = N(lax.tanh(lax.conv(x, weights[0], (1, 2, 2), 'SAME')+weights[1]))  # Conv Tanh Norm
        x = N(lax.tanh(lax.conv(x, weights[2], (1, 2, 2), 'SAME')+weights[3]))
        x = lax.conv(x, jnp.ones((1, 2, 1, 2, 2)), (1, 2, 2), 'VALID')          # Pool
        x = lax.conv(x, weights[4], (1, 1, 1), 'SAME')+weights[5]
        return x.reshape(x.shape[0], -1)                                        # Flatten
    model.weights = weights
    return model

model = NoobHeart()

Write the training loop

In [3]:
batch = 32
train = load_datatape("train_tape.h5", batch=batch) # Training set
valid = load_datatape("valid_tape.h5", batch=batch) # Validation set

opt_init, opt_update, get_params = optimizers.adam(0.01) # Adam optimizer
opt_state = opt_init(model.weights)

best = None
for epoch in range(10): # 10 epochs
    for step, (data, label) in enumerate(train): # train
        loss = lambda weights: jnp.mean(abs((model(data, weights)-label)))
        grads = grad(loss)(get_params(opt_state))
        opt_state = opt_update(step, grads, opt_state)
    vloss = []
    for step, (data, label) in enumerate(valid): # val
        vloss.append(jnp.mean(abs((model(data, get_params(opt_state))-label))))
    vloss = jnp.mean(jnp.stack(vloss))
    print(f'{epoch=},\tval_loss={vloss}')
    if best is None or best[1]>vloss:
        best = opt_state, vloss
        print('Best model checked')
        
for i, j in enumerate(get_params(best[0])): # load the best
    model.weights[i] = j


epoch=0,	val_loss=0.6488630175590515
Best model checked
epoch=1,	val_loss=0.6216443777084351
Best model checked
epoch=2,	val_loss=0.602000892162323
Best model checked
epoch=3,	val_loss=0.5860676169395447
Best model checked
epoch=4,	val_loss=0.5778675079345703
Best model checked
epoch=5,	val_loss=0.5712727308273315
Best model checked
epoch=6,	val_loss=0.5722319483757019
epoch=7,	val_loss=0.5705991387367249
Best model checked
epoch=8,	val_loss=0.5706881880760193
epoch=9,	val_loss=0.5693888664245605
Best model checked


Use `eval_on_dataset` to test and obtain metrics.

In [4]:
eval_on_dataset('ubfc_dataset.h5', model, 32, (8, 8), step=1, batch=32, ipt_dtype=np.float32, save='../results/NoobHeart_PURE_UBFC.h5')
r = get_metrics('../results/NoobHeart_PURE_UBFC.h5')['Whole video']
print(f'HR metrics: MAE:{r["MAE"]}, RMSE:{r["RMSE"]}, R:{r["R"]}')
r = get_metrics_HRV('../results/NoobHeart_PURE_UBFC.h5')['SDNN']
print(f'HRV metrics: MAE:{r["MAE"]}, RMSE:{r["RMSE"]}, R:{r["R"]}')

  0%|          | 0/42 [00:00<?, ?it/s]

HR metrics: MAE:1.009, RMSE:1.433, R:0.9969
HRV metrics: MAE:38.703, RMSE:43.339, R:0.62763
