In [1]:
import sys
sys.path.append("..")

#import jax
#jax.config.update("jax_enable_x64", True)

from utils import *
from model_tn import *
from keras_utils import *
from jax_utils import *

keras.mixed_precision.set_global_policy("mixed_float16")
#keras.mixed_precision.set_global_policy("float32")
#keras.mixed_precision.set_global_policy("float64")

In [2]:
@jax.jit
@partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P('x'))
def u8_to_fp16(x, y):
    x = jnp.astype(x/255., 'float16')
    return x, y

def random_horizontal_flip(x, y, p=.5, root_key=[jax.random.PRNGKey(0)]):
    root_key[0], key = jax.random.split(root_key[0], 2)
    if jax.random.uniform(key)<p:
        return x[..., ::-1 ,:], y
    return x, y

batch_size = 16

tape = "/root/ssd_cache/rppg_training_data/mmpd_160x128x128_all"

train_tape = load_datatape(tape, fold='train', batch=batch_size, dtype='uint8')
valid_tape = load_datatape(tape, fold='val', extended_hr='False', batch=batch_size, dtype='uint8')

#train_rlap = DatatapeMonitor(train_tape)
training_set, validation_set = KerasDataset(train_tape), KerasDataset(valid_tape)

training_set = training_set.apply_fn(random_horizontal_flip)
training_set = training_set.apply_fn(u8_to_fp16)
validation_set = validation_set.apply_fn(u8_to_fp16)
training_set = training_set.apply_fn(compress_aug)


mmpd_160x128x128_all           datatape has been loaded.     7664 items total. fold=train selected
mmpd_160x128x128_all           datatape has been loaded.     1680 items total. fold=val&extended_hr=False selected


In [3]:
model = PhysFormer()
#lr_schedule = keras.optimizers.schedules.ExponentialDecay(initial_learning_rate=1e-4, decay_steps=50, decay_rate=0.5)
#opti = keras.optimizers.Adam(learning_rate=lr_schedule, weight_decay=1e-5) # convergence is too slow

a_start, b_start, exp_b = 1., .0, 1.
a, b = a_start, b_start
def combined_loss(y, pred):
    return a*np_loss(y, pred) + b*kl_ce_loss(y, pred)

model.compile(optimizer=keras.optimizers.Adam(learning_rate=1e-4), loss=combined_loss, metrics=[np_loss, kl_ce_loss])
y = model(np.random.random((4, 160, 128, 128, 3)));
model.summary()

In [4]:
stat = None
for _ in range(10):
    stat = train(model, training_set, validation_set, epochs=1, check_point_path='../weights/physformer_mmpd.weights.h5', training_stat=stat)
    b = b_start + exp_b*(2**(stat['epoch']/10)-1) # update combined loss weight
stat['best_loss'] = 1e20
train(model, training_set, validation_set, epochs=10, check_point_path='../weights/physformer_mmpd.weights.h5', training_stat=stat)
model.load_weights('../weights/physformer_mmpd.weights.h5')

Training:   0%|          | 0/479 [00:00<?, ?it/s]

Validating:   0%|          | 0/105 [00:00<?, ?it/s]

Epoch   1:  Training kl_ce_loss: 7.492, loss:0.9736, np_loss:0.9745	Validation kl_ce_loss: 7.054, loss:0.9443, np_loss:0.9432	


Training:   0%|          | 0/479 [00:00<?, ?it/s]

Validating:   0%|          | 0/105 [00:00<?, ?it/s]

Epoch   2:  Training kl_ce_loss: 6.756, loss: 1.229, np_loss:0.7419	Validation kl_ce_loss: 6.407, loss: 1.184, np_loss:0.7234	


Training:   0%|          | 0/479 [00:00<?, ?it/s]

Validating:   0%|          | 0/105 [00:00<?, ?it/s]

Epoch   3:  Training kl_ce_loss: 6.577, loss: 1.638, np_loss:0.6612	Validation kl_ce_loss:  6.24, loss: 1.652, np_loss:0.7088	


Training:   0%|          | 0/479 [00:00<?, ?it/s]

Validating:   0%|          | 0/105 [00:00<?, ?it/s]

Epoch   4:  Training kl_ce_loss: 6.459, loss: 2.118, np_loss:0.6256	Validation kl_ce_loss: 6.207, loss: 2.119, np_loss:0.6658	


Training:   0%|          | 0/479 [00:00<?, ?it/s]

Validating:   0%|          | 0/105 [00:00<?, ?it/s]

Epoch   5:  Training kl_ce_loss:  6.36, loss: 2.615, np_loss:0.5748	Validation kl_ce_loss:  6.24, loss: 2.677, np_loss:0.6678	


Training:   0%|          | 0/479 [00:00<?, ?it/s]

Validating:   0%|          | 0/105 [00:00<?, ?it/s]

Epoch   6:  Training kl_ce_loss: 6.314, loss: 3.176, np_loss:0.5668	Validation kl_ce_loss: 6.308, loss: 3.243, np_loss:0.6704	


Training:   0%|          | 0/479 [00:00<?, ?it/s]

Validating:   0%|          | 0/105 [00:00<?, ?it/s]

Epoch   7:  Training kl_ce_loss: 6.275, loss: 3.782, np_loss:0.5513	Validation kl_ce_loss: 6.227, loss: 3.913, np_loss:0.6502	


Training:   0%|          | 0/479 [00:00<?, ?it/s]

Validating:   0%|          | 0/105 [00:00<?, ?it/s]

Epoch   8:  Training kl_ce_loss: 6.225, loss: 4.429, np_loss:0.5468	Validation kl_ce_loss:  6.24, loss: 4.593, np_loss: 0.693	


Training:   0%|          | 0/479 [00:00<?, ?it/s]

Validating:   0%|          | 0/105 [00:00<?, ?it/s]

Epoch   9:  Training kl_ce_loss: 6.144, loss: 5.087, np_loss: 0.527	Validation kl_ce_loss: 6.228, loss: 5.259, np_loss:0.6342	


Training:   0%|          | 0/479 [00:00<?, ?it/s]

Validating:   0%|          | 0/105 [00:00<?, ?it/s]

Epoch  10:  Training kl_ce_loss: 6.048, loss: 5.772, np_loss:0.5002	Validation kl_ce_loss: 6.257, loss: 6.091, np_loss:0.6361	


Training:   0%|          | 0/479 [00:00<?, ?it/s]

Validating:   0%|          | 0/105 [00:00<?, ?it/s]

Epoch  11:  Training kl_ce_loss: 6.045, loss: 6.565, np_loss:0.4985	Validation kl_ce_loss: 6.258, loss: 6.898, np_loss:0.6361	


Training:   0%|          | 0/479 [00:00<?, ?it/s]

Validating:   0%|          | 0/105 [00:00<?, ?it/s]

Epoch  12:  Training kl_ce_loss: 6.008, loss: 6.506, np_loss:0.4998	Validation kl_ce_loss: 6.293, loss: 6.924, np_loss:0.6978	


Training:   0%|          | 0/479 [00:00<?, ?it/s]

Validating:   0%|          | 0/105 [00:00<?, ?it/s]

Epoch  13:  Training kl_ce_loss: 5.966, loss: 6.454, np_loss:0.4736	Validation kl_ce_loss: 6.223, loss: 6.872, np_loss: 0.647	


Training:   0%|          | 0/479 [00:00<?, ?it/s]

Validating:   0%|          | 0/105 [00:00<?, ?it/s]

Epoch  14:  Training kl_ce_loss: 5.893, loss: 6.345, np_loss:0.4589	Validation kl_ce_loss: 6.344, loss:  6.92, np_loss:0.6652	


Training:   0%|          | 0/479 [00:00<?, ?it/s]

Validating:   0%|          | 0/105 [00:00<?, ?it/s]

Epoch  15:  Training kl_ce_loss: 5.857, loss: 6.321, np_loss: 0.451	Validation kl_ce_loss: 6.303, loss:  6.96, np_loss:0.6595	


Training:   0%|          | 0/479 [00:00<?, ?it/s]

Validating:   0%|          | 0/105 [00:00<?, ?it/s]

Epoch  16:  Training kl_ce_loss: 5.844, loss: 6.282, np_loss: 0.439	Validation kl_ce_loss: 6.276, loss: 6.896, np_loss:0.6496	


Training:   0%|          | 0/479 [00:00<?, ?it/s]

Validating:   0%|          | 0/105 [00:00<?, ?it/s]

Epoch  17:  Training kl_ce_loss: 5.816, loss: 6.248, np_loss: 0.439	Validation kl_ce_loss: 6.234, loss: 6.917, np_loss: 0.632	


Training:   0%|          | 0/479 [00:00<?, ?it/s]

Validating:   0%|          | 0/105 [00:00<?, ?it/s]

Epoch  18:  Training kl_ce_loss: 5.783, loss: 6.193, np_loss:0.4233	Validation kl_ce_loss: 6.269, loss: 6.901, np_loss:0.6559	


Training:   0%|          | 0/479 [00:00<?, ?it/s]

Validating:   0%|          | 0/105 [00:00<?, ?it/s]

Epoch  19:  Training kl_ce_loss: 5.746, loss: 6.158, np_loss:0.4074	Validation kl_ce_loss:   6.3, loss: 6.969, np_loss:0.6616	


Training:   0%|          | 0/479 [00:00<?, ?it/s]

Validating:   0%|          | 0/105 [00:00<?, ?it/s]

Epoch  20:  Training kl_ce_loss: 5.727, loss: 6.139, np_loss:0.4059	Validation kl_ce_loss: 6.318, loss: 6.937, np_loss:0.6368	


In [5]:
eval_on_dataset(dataset_H5_mmpd, pmodel(lambda x:model(x/255.)), 160, (128, 128), step=4, batch=8, save='../results/PhysFormer_MMPD_RLAP.h5', ipt_dtype='uint8', scenes=['R1', 'R2', 'R3', 'R4'])
get_metrics('../results/PhysFormer_MMPD_RLAP.h5')

  0%|          | 0/660 [00:00<?, ?it/s]

Unnamed: 0,MAE,RMSE,MAPE,R
Window,5.253±0.106,10.472±2.119,6.371±0.144,0.80564
Whole,3.536±0.224,6.746±2.464,4.176±0.274,0.90357


In [6]:
eval_on_dataset(dataset_H5_cohface, pmodel(lambda x:model(x/255.)), 160, (128, 128), step=4, batch=8, save='../results/PhysFormer_MMPD_COHFACE.h5', ipt_dtype='uint8', fps=30)
get_metrics('../results/PhysFormer_MMPD_COHFACE.h5')

  0%|          | 0/164 [00:00<?, ?it/s]

Unnamed: 0,MAE,RMSE,MAPE,R
Window,11.419±0.243,15.445±2.915,16.791±0.378,0.22748
Whole,8.445±0.504,10.625±3.566,12.677±0.861,0.43844


In [7]:
eval_on_dataset(dataset_H5_pure, pmodel(lambda x:model(x/255.)), 160, (128, 128), step=4, batch=8, save='../results/PhysFormer_MMPD_PURE.h5', ipt_dtype='uint8')
get_metrics('../results/PhysFormer_MMPD_PURE.h5')

  0%|          | 0/59 [00:00<?, ?it/s]

Unnamed: 0,MAE,RMSE,MAPE,R
Window,17.71±0.663,25.467±6.134,29.929±1.178,0.44946
Whole,16.947±1.749,21.627±8.447,28.861±3.387,0.64177


In [8]:
eval_on_dataset(dataset_H5_ubfc_rppg2, pmodel(lambda x:model(x/255.)), 160, (128, 128), step=4, batch=8, save='../results/PhysFormer_MMPD_UBFC.h5', ipt_dtype='uint8')
get_metrics('../results/PhysFormer_MMPD_UBFC.h5')

  0%|          | 0/42 [00:00<?, ?it/s]

Unnamed: 0,MAE,RMSE,MAPE,R
Window,3.29±0.401,9.736±4.683,3.163±0.355,0.84839
Whole,2.577±0.935,6.587±5.138,2.497±0.859,0.92607
