In [1]:
import sys
sys.path.append("..")

#import jax
#jax.config.update("jax_enable_x64", True)

from utils import *
from model_tn import *
from keras_utils import *
from jax_utils import *

keras.mixed_precision.set_global_policy("mixed_float16")
#keras.mixed_precision.set_global_policy("float32")
#keras.mixed_precision.set_global_policy("float64")

In [2]:
@jax.jit
@partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P('x'))
def u8_to_fp16(x, y):
    x = jnp.astype(x/255., 'float16')
    return x, y

def random_horizontal_flip(x, y, p=.5, root_key=[jax.random.PRNGKey(0)]):
    root_key[0], key = jax.random.split(root_key[0], 2)
    if jax.random.uniform(key)<p:
        return x[..., ::-1 ,:], y
    return x, y

batch_size = 16

rlap_tape = "/root/ssd_cache/rppg_training_data/rlap_160x128x128_all"

train_rlap = load_datatape(rlap_tape, fold='train', batch=batch_size, dtype='uint8')
valid_rlap = load_datatape(rlap_tape, fold='val', extended_hr='False', batch=batch_size, dtype='uint8')

train_rlap = DatatapeMonitor(train_rlap)
training_set, validation_set = KerasDataset(train_rlap), KerasDataset(valid_rlap)

training_set = training_set.apply_fn(random_horizontal_flip)
training_set = training_set.apply_fn(u8_to_fp16)
validation_set = validation_set.apply_fn(u8_to_fp16)
training_set = training_set.apply_fn(compress_aug)


rlap_160x128x128_all           datatape has been loaded.    24704 items total. fold=train selected
rlap_160x128x128_all           datatape has been loaded.     3056 items total. fold=val&extended_hr=False selected


In [3]:
model = PhysFormer()
#lr_schedule = keras.optimizers.schedules.ExponentialDecay(initial_learning_rate=1e-4, decay_steps=50, decay_rate=0.5)
#opti = keras.optimizers.Adam(learning_rate=lr_schedule, weight_decay=1e-5) # convergence is too slow

a_start, b_start, exp_b = 1., .0, 1.
a, b = a_start, b_start
def combined_loss(y, pred):
    return a*np_loss(y, pred) + b*kl_ce_loss(y, pred)

model.compile(optimizer=keras.optimizers.Adam(learning_rate=1e-4), loss=combined_loss, metrics=[np_loss, kl_ce_loss])
y = model(np.random.random((4, 160, 128, 128, 3)));
model.summary()

In [4]:
stat = None
for _ in range(10):
    stat = train(model, training_set, validation_set, epochs=1, check_point_path='../weights/physformer.weights.h5', training_stat=stat)
    b = b_start + exp_b*(2**(stat['epoch']/10)-1) # update combined loss weight
stat['best_loss'] = 1e20
train(model, training_set, validation_set, epochs=10, check_point_path='../weights/physformer.weights.h5', training_stat=stat)
model.load_weights('../weights/physformer.weights.h5')

HTML(value='')

Training:   0%|          | 0/1544 [00:00<?, ?it/s]

Validating:   0%|          | 0/191 [00:00<?, ?it/s]

Epoch   1:  Training kl_ce_loss: 6.191, loss:0.4174, np_loss:0.4178	Validation kl_ce_loss: 5.799, loss:0.2244, np_loss:0.2214	


Training:   0%|          | 0/1544 [00:00<?, ?it/s]

Validating:   0%|          | 0/191 [00:00<?, ?it/s]

Epoch   2:  Training kl_ce_loss: 5.735, loss:0.6025, np_loss:0.1923	Validation kl_ce_loss: 5.721, loss:0.5934, np_loss:0.1761	


Training:   0%|          | 0/1544 [00:00<?, ?it/s]

Validating:   0%|          | 0/191 [00:00<?, ?it/s]

Epoch   3:  Training kl_ce_loss: 5.671, loss: 1.014, np_loss:0.1702	Validation kl_ce_loss: 5.659, loss: 1.022, np_loss:0.1719	


Training:   0%|          | 0/1544 [00:00<?, ?it/s]

Validating:   0%|          | 0/191 [00:00<?, ?it/s]

Epoch   4:  Training kl_ce_loss: 5.631, loss: 1.465, np_loss:0.1654	Validation kl_ce_loss: 5.658, loss:  1.49, np_loss:0.1763	


Training:   0%|          | 0/1544 [00:00<?, ?it/s]

Validating:   0%|          | 0/191 [00:00<?, ?it/s]

Epoch   5:  Training kl_ce_loss: 5.601, loss: 1.948, np_loss:0.1579	Validation kl_ce_loss: 5.629, loss: 1.978, np_loss:0.1733	


Training:   0%|          | 0/1544 [00:00<?, ?it/s]

Validating:   0%|          | 0/191 [00:00<?, ?it/s]

Epoch   6:  Training kl_ce_loss: 5.583, loss: 2.467, np_loss:0.1616	Validation kl_ce_loss: 5.657, loss: 2.521, np_loss:0.1815	


Training:   0%|          | 0/1544 [00:00<?, ?it/s]

Validating:   0%|          | 0/191 [00:00<?, ?it/s]

Epoch   7:  Training kl_ce_loss: 5.556, loss: 3.026, np_loss:0.1571	Validation kl_ce_loss: 5.641, loss: 3.088, np_loss:0.1799	


Training:   0%|          | 0/1544 [00:00<?, ?it/s]

Validating:   0%|          | 0/191 [00:00<?, ?it/s]

Epoch   8:  Training kl_ce_loss: 5.553, loss: 3.625, np_loss:0.1604	Validation kl_ce_loss: 5.612, loss: 3.699, np_loss:0.1768	


Training:   0%|          | 0/1544 [00:00<?, ?it/s]

Validating:   0%|          | 0/191 [00:00<?, ?it/s]

Epoch   9:  Training kl_ce_loss: 5.538, loss: 4.266, np_loss:0.1598	Validation kl_ce_loss: 5.643, loss:  4.36, np_loss:0.1887	


Training:   0%|          | 0/1544 [00:00<?, ?it/s]

Validating:   0%|          | 0/191 [00:00<?, ?it/s]

Epoch  10:  Training kl_ce_loss: 5.536, loss: 4.954, np_loss: 0.161	Validation kl_ce_loss: 5.663, loss:  5.06, np_loss:0.1924	


Training:   0%|          | 0/1544 [00:00<?, ?it/s]

Validating:   0%|          | 0/191 [00:00<?, ?it/s]

Epoch  11:  Training kl_ce_loss: 5.527, loss: 5.693, np_loss:0.1615	Validation kl_ce_loss: 5.648, loss: 5.816, np_loss:0.1979	


Training:   0%|          | 0/1544 [00:00<?, ?it/s]

Validating:   0%|          | 0/191 [00:00<?, ?it/s]

Epoch  12:  Training kl_ce_loss: 5.518, loss: 5.689, np_loss:0.1624	Validation kl_ce_loss: 5.618, loss: 5.813, np_loss:0.1838	


Training:   0%|          | 0/1544 [00:00<?, ?it/s]

Validating:   0%|          | 0/191 [00:00<?, ?it/s]

Epoch  13:  Training kl_ce_loss: 5.532, loss: 5.689, np_loss:0.1653	Validation kl_ce_loss: 5.616, loss: 5.809, np_loss:0.1869	


Training:   0%|          | 0/1544 [00:00<?, ?it/s]

Validating:   0%|          | 0/191 [00:00<?, ?it/s]

Epoch  14:  Training kl_ce_loss: 5.526, loss: 5.687, np_loss:0.1628	Validation kl_ce_loss: 5.637, loss: 5.809, np_loss:0.1931	


Training:   0%|          | 0/1544 [00:00<?, ?it/s]

Validating:   0%|          | 0/191 [00:00<?, ?it/s]

Epoch  15:  Training kl_ce_loss: 5.521, loss: 5.685, np_loss:0.1633	Validation kl_ce_loss: 5.596, loss: 5.809, np_loss:0.1873	


Training:   0%|          | 0/1544 [00:00<?, ?it/s]

Validating:   0%|          | 0/191 [00:00<?, ?it/s]

Epoch  16:  Training kl_ce_loss: 5.519, loss: 5.685, np_loss:0.1635	Validation kl_ce_loss: 5.564, loss: 5.808, np_loss:0.1781	


Training:   0%|          | 0/1544 [00:00<?, ?it/s]

Validating:   0%|          | 0/191 [00:00<?, ?it/s]

Epoch  17:  Training kl_ce_loss: 5.527, loss: 5.685, np_loss:0.1638	Validation kl_ce_loss: 5.615, loss: 5.809, np_loss:0.1873	


Training:   0%|          | 0/1544 [00:00<?, ?it/s]

Validating:   0%|          | 0/191 [00:00<?, ?it/s]

Epoch  18:  Training kl_ce_loss: 5.524, loss: 5.685, np_loss:0.1644	Validation kl_ce_loss: 5.565, loss: 5.808, np_loss:0.1856	


Training:   0%|          | 0/1544 [00:00<?, ?it/s]

Validating:   0%|          | 0/191 [00:00<?, ?it/s]

Epoch  19:  Training kl_ce_loss: 5.527, loss: 5.686, np_loss: 0.162	Validation kl_ce_loss: 5.658, loss: 5.809, np_loss:0.1895	


Training:   0%|          | 0/1544 [00:00<?, ?it/s]

Validating:   0%|          | 0/191 [00:00<?, ?it/s]

Epoch  20:  Training kl_ce_loss: 5.521, loss: 5.684, np_loss:0.1616	Validation kl_ce_loss: 5.595, loss: 5.808, np_loss:0.1854	


In [2]:
model = PhysFormer()
model(np.random.random((4, 160, 128, 128, 3)))
model.load_weights('../weights/physformer.weights.h5')

In [3]:
eval_on_dataset(dataset_H5_mmpd, pmodel(lambda x:model(x/255.)), 160, (128, 128), step=4, batch=8, save='../results/PhysFormer_RLAP_MMPD.h5', ipt_dtype='uint8')
get_metrics('../results/PhysFormer_RLAP_MMPD.h5', dropped='False')

  0%|          | 0/660 [00:00<?, ?it/s]

Unnamed: 0,MAE,RMSE,MAPE,R
Window,15.386±0.212,23.685±3.511,17.112±0.225,0.18374
Whole,12.974±0.544,19.047±5.295,14.246±0.536,0.27479


In [4]:
eval_on_dataset(dataset_H5_cohface, pmodel(lambda x:model(x/255.)), 160, (128, 128), step=4, batch=8, save='../results/PhysFormer_RLAP_COHFACE.h5', ipt_dtype='uint8', fps=30)
get_metrics('../results/PhysFormer_RLAP_COHFACE.h5')

  0%|          | 0/164 [00:00<?, ?it/s]

Unnamed: 0,MAE,RMSE,MAPE,R
Window,10.198±0.298,16.323±3.506,13.775±0.388,0.19363
Whole,7.909±0.709,12.038±4.796,10.298±0.826,0.38177


In [5]:
eval_on_dataset(dataset_H5_pure, pmodel(lambda x:model(x/255.)), 160, (128, 128), step=4, batch=8, save='../results/PhysFormer_RLAP_PURE.h5', ipt_dtype='uint8')
get_metrics('../results/PhysFormer_RLAP_PURE.h5')

  0%|          | 0/59 [00:00<?, ?it/s]

Unnamed: 0,MAE,RMSE,MAPE,R
Window,6.983±0.557,16.878±5.706,11.026±0.928,0.70383
Whole,6.01±1.41,12.385±7.738,9.459±2.351,0.85275


In [6]:
eval_on_dataset(dataset_H5_ubfc_rppg2, pmodel(lambda x:model(x/255.)), 160, (128, 128), step=4, batch=8, save='../results/PhysFormer_RLAP_UBFC.h5', ipt_dtype='uint8')
get_metrics('../results/PhysFormer_RLAP_UBFC.h5')

  0%|          | 0/42 [00:00<?, ?it/s]

Unnamed: 0,MAE,RMSE,MAPE,R
Window,1.215±0.088,2.341±1.042,1.247±0.094,0.99054
Whole,0.472±0.082,0.711±0.426,0.457±0.072,0.99931
