In [1]:
import sys
sys.path.append("..")

#import jax
#jax.config.update("jax_enable_x64", True)

from utils import *
from model_tn import *
from keras_utils import *
from jax_utils import *

keras.mixed_precision.set_global_policy("mixed_float16")
#keras.mixed_precision.set_global_policy("float32")
#keras.mixed_precision.set_global_policy("float64")

In [2]:
@jax.jit
@partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P('x'))
def u8_to_fp16(x, y):
    x = jnp.astype(x/255., 'float16')
    return x, y

def random_horizontal_flip(x, y, p=.5, root_key=[jax.random.PRNGKey(0)]):
    root_key[0], key = jax.random.split(root_key[0], 2)
    if jax.random.uniform(key)<p:
        return x[..., ::-1 ,:], y
    return x, y

batch_size = 16

rlap_tape = "/root/ssd_cache/rppg_training_data/rlap_160x128x128_all"

train_rlap = load_datatape(rlap_tape, fold='train', batch=batch_size, dtype='uint8')
valid_rlap = load_datatape(rlap_tape, fold='val', extended_hr='False', batch=batch_size, dtype='uint8')

train_rlap = DatatapeMonitor(train_rlap)
training_set, validation_set = KerasDataset(train_rlap), KerasDataset(valid_rlap)

training_set = training_set.apply_fn(random_horizontal_flip)
training_set = training_set.apply_fn(u8_to_fp16)
validation_set = validation_set.apply_fn(u8_to_fp16)
training_set = training_set.apply_fn(compress_aug)


rlap_160x128x128_all           datatape has been loaded.    24704 items total. fold=train selected
rlap_160x128x128_all           datatape has been loaded.     3056 items total. fold=val&extended_hr=False selected


In [3]:
model = PhysFormer(TN=True)
#lr_schedule = keras.optimizers.schedules.ExponentialDecay(initial_learning_rate=1e-4, decay_steps=50, decay_rate=0.5)
#opti = keras.optimizers.Adam(learning_rate=lr_schedule, weight_decay=1e-5) # convergence is too slow

a_start, b_start, exp_b = 1., .0, 1.
a, b = a_start, b_start
def combined_loss(y, pred):
    return a*np_loss(y, pred) + b*kl_ce_loss(y, pred)

model.compile(optimizer=keras.optimizers.Adam(learning_rate=1e-4), loss=combined_loss, metrics=[np_loss, kl_ce_loss])
y = model(np.random.random((4, 160, 128, 128, 3)));
model.summary()

In [4]:
stat = None
for _ in range(10):
    stat = train(model, training_set, validation_set, epochs=1, check_point_path='../weights/physformer_mmpd.weights.h5', training_stat=stat)
    b = b_start + exp_b*(2**(stat['epoch']/10)-1) # update combined loss weight
stat['best_loss'] = 1e20
train(model, training_set, validation_set, epochs=10, check_point_path='../weights/physformertn.weights.h5', training_stat=stat)
model.load_weights('../weights/physformertn.weights.h5')

HTML(value='')

Training:   0%|          | 0/1544 [00:00<?, ?it/s]

Validating:   0%|          | 0/191 [00:00<?, ?it/s]

Epoch   1:  Training kl_ce_loss: 5.798, loss:0.2233, np_loss:0.2238	Validation kl_ce_loss: 5.766, loss:0.1658, np_loss: 0.167	


Training:   0%|          | 0/1544 [00:00<?, ?it/s]

Validating:   0%|          | 0/191 [00:00<?, ?it/s]

Epoch   2:  Training kl_ce_loss: 5.662, loss:0.5572, np_loss: 0.151	Validation kl_ce_loss: 5.646, loss:0.5557, np_loss:0.1349	


Training:   0%|          | 0/1544 [00:00<?, ?it/s]

Validating:   0%|          | 0/191 [00:00<?, ?it/s]

Epoch   3:  Training kl_ce_loss:   5.6, loss:  0.97, np_loss:0.1361	Validation kl_ce_loss: 5.655, loss:0.9829, np_loss:0.1423	


Training:   0%|          | 0/1544 [00:00<?, ?it/s]

Validating:   0%|          | 0/191 [00:00<?, ?it/s]

Epoch   4:  Training kl_ce_loss: 5.571, loss: 1.417, np_loss:0.1307	Validation kl_ce_loss: 5.645, loss: 1.442, np_loss:0.1404	


Training:   0%|          | 0/1544 [00:00<?, ?it/s]

Validating:   0%|          | 0/191 [00:00<?, ?it/s]

Epoch   5:  Training kl_ce_loss: 5.541, loss: 1.898, np_loss:0.1306	Validation kl_ce_loss: 5.683, loss: 1.939, np_loss:0.1452	


Training:   0%|          | 0/1544 [00:00<?, ?it/s]

Validating:   0%|          | 0/191 [00:00<?, ?it/s]

Epoch   6:  Training kl_ce_loss: 5.519, loss: 2.416, np_loss: 0.131	Validation kl_ce_loss: 5.661, loss:  2.47, np_loss:0.1449	


Training:   0%|          | 0/1544 [00:00<?, ?it/s]

Validating:   0%|          | 0/191 [00:00<?, ?it/s]

Epoch   7:  Training kl_ce_loss: 5.505, loss:  2.97, np_loss:0.1318	Validation kl_ce_loss:  5.53, loss: 3.041, np_loss:0.1383	


Training:   0%|          | 0/1544 [00:00<?, ?it/s]

Validating:   0%|          | 0/191 [00:00<?, ?it/s]

Epoch   8:  Training kl_ce_loss: 5.492, loss: 3.566, np_loss:0.1331	Validation kl_ce_loss:  5.61, loss: 3.648, np_loss:0.1551	


Training:   0%|          | 0/1544 [00:00<?, ?it/s]

Validating:   0%|          | 0/191 [00:00<?, ?it/s]

Epoch   9:  Training kl_ce_loss:  5.48, loss: 4.202, np_loss:0.1342	Validation kl_ce_loss: 5.587, loss: 4.303, np_loss:0.1532	


Training:   0%|          | 0/1544 [00:00<?, ?it/s]

Validating:   0%|          | 0/191 [00:00<?, ?it/s]

Epoch  10:  Training kl_ce_loss: 5.484, loss: 4.885, np_loss: 0.139	Validation kl_ce_loss: 5.577, loss: 5.004, np_loss:0.1593	


Training:   0%|          | 0/1544 [00:00<?, ?it/s]

Validating:   0%|          | 0/191 [00:00<?, ?it/s]

Epoch  11:  Training kl_ce_loss: 5.478, loss: 5.618, np_loss:0.1409	Validation kl_ce_loss: 5.618, loss: 5.753, np_loss:0.1638	


Training:   0%|          | 0/1544 [00:00<?, ?it/s]

Validating:   0%|          | 0/191 [00:00<?, ?it/s]

Epoch  12:  Training kl_ce_loss: 5.476, loss: 5.616, np_loss:0.1405	Validation kl_ce_loss: 5.565, loss: 5.752, np_loss:0.1574	


Training:   0%|          | 0/1544 [00:00<?, ?it/s]

Validating:   0%|          | 0/191 [00:00<?, ?it/s]

Epoch  13:  Training kl_ce_loss:  5.47, loss: 5.616, np_loss:0.1388	Validation kl_ce_loss: 5.554, loss: 5.752, np_loss:0.1567	


Training:   0%|          | 0/1544 [00:00<?, ?it/s]

Validating:   0%|          | 0/191 [00:00<?, ?it/s]

Epoch  14:  Training kl_ce_loss: 5.476, loss: 5.614, np_loss:0.1428	Validation kl_ce_loss: 5.614, loss:  5.75, np_loss:0.1563	


Training:   0%|          | 0/1544 [00:00<?, ?it/s]

Validating:   0%|          | 0/191 [00:00<?, ?it/s]

Epoch  15:  Training kl_ce_loss: 5.469, loss: 5.615, np_loss: 0.139	Validation kl_ce_loss: 5.572, loss: 5.751, np_loss:0.1557	


Training:   0%|          | 0/1544 [00:00<?, ?it/s]

Validating:   0%|          | 0/191 [00:00<?, ?it/s]

Epoch  16:  Training kl_ce_loss: 5.477, loss: 5.615, np_loss:0.1424	Validation kl_ce_loss: 5.572, loss: 5.751, np_loss:0.1666	


Training:   0%|          | 0/1544 [00:00<?, ?it/s]

Validating:   0%|          | 0/191 [00:00<?, ?it/s]

Epoch  17:  Training kl_ce_loss: 5.474, loss: 5.614, np_loss:0.1425	Validation kl_ce_loss: 5.611, loss: 5.751, np_loss:0.1605	


Training:   0%|          | 0/1544 [00:00<?, ?it/s]

Validating:   0%|          | 0/191 [00:00<?, ?it/s]

Epoch  18:  Training kl_ce_loss: 5.475, loss: 5.614, np_loss:0.1413	Validation kl_ce_loss: 5.587, loss: 5.751, np_loss:0.1573	


Training:   0%|          | 0/1544 [00:00<?, ?it/s]

Validating:   0%|          | 0/191 [00:00<?, ?it/s]

Epoch  19:  Training kl_ce_loss: 5.476, loss: 5.615, np_loss:0.1414	Validation kl_ce_loss: 5.608, loss: 5.751, np_loss:0.1657	


Training:   0%|          | 0/1544 [00:00<?, ?it/s]

Validating:   0%|          | 0/191 [00:00<?, ?it/s]

Epoch  20:  Training kl_ce_loss: 5.479, loss: 5.614, np_loss:0.1413	Validation kl_ce_loss: 5.578, loss: 5.751, np_loss:0.1576	


In [5]:
eval_on_dataset(dataset_H5_mmpd, pmodel(lambda x:model(x/255.)), 160, (128, 128), step=4, batch=8, save='../results/PhysFormertn_RLAP_MMPD.h5', ipt_dtype='uint8')
get_metrics('../results/PhysFormertn_RLAP_MMPD.h5', dropped='False')

  0%|          | 0/660 [00:00<?, ?it/s]

Unnamed: 0,MAE,RMSE,MAPE,R
Window,11.201±0.196,20.061±3.455,12.796±0.225,0.37795
Whole,8.136±0.407,13.229±4.132,9.179±0.437,0.5856


In [6]:
eval_on_dataset(dataset_H5_cohface, pmodel(lambda x:model(x/255.)), 160, (128, 128), step=4, batch=8, save='../results/PhysFormertn_RLAP_COHFACE.h5', ipt_dtype='uint8', fps=30)
get_metrics('../results/PhysFormertn_RLAP_COHFACE.h5')

  0%|          | 0/164 [00:00<?, ?it/s]

Unnamed: 0,MAE,RMSE,MAPE,R
Window,3.523±0.2,9.227±3.077,4.777±0.261,0.70438
Whole,2.597±0.423,6.01±3.32,3.415±0.52,0.85647


In [7]:
eval_on_dataset(dataset_H5_pure, pmodel(lambda x:model(x/255.)), 160, (128, 128), step=4, batch=8, save='../results/PhysFormertn_RLAP_PURE.h5', ipt_dtype='uint8')
get_metrics('../results/PhysFormertn_RLAP_PURE.h5')

  0%|          | 0/59 [00:00<?, ?it/s]

Unnamed: 0,MAE,RMSE,MAPE,R
Window,0.43±0.048,1.398±0.756,0.659±0.08,0.99815
Whole,0.238±0.07,0.59±0.478,0.363±0.125,0.9997


In [8]:
eval_on_dataset(dataset_H5_ubfc_rppg2, pmodel(lambda x:model(x/255.)), 160, (128, 128), step=4, batch=8, save='../results/PhysFormertn_RLAP_UBFC.h5', ipt_dtype='uint8')
get_metrics('../results/PhysFormertn_RLAP_UBFC.h5')

  0%|          | 0/42 [00:00<?, ?it/s]

Unnamed: 0,MAE,RMSE,MAPE,R
Window,1.171±0.084,2.252±1.049,1.204±0.091,0.99121
Whole,0.446±0.078,0.673±0.42,0.436±0.069,0.99933
