In [1]:
import sys
sys.path.append("..")

#import jax
#jax.config.update("jax_enable_x64", True)

from utils import *
from model_tn import *
from keras_utils import *
from jax_utils import *

keras.mixed_precision.set_global_policy("mixed_float16")
#keras.mixed_precision.set_global_policy("float32")
#keras.mixed_precision.set_global_policy("float64")

In [2]:
@jax.jit
@partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P('x'))
def u8_to_fp16(x, y):
    x = jnp.astype(x/255., 'float16')
    return x, y

def random_horizontal_flip(x, y, p=.5, root_key=[jax.random.PRNGKey(0)]):
    root_key[0], key = jax.random.split(root_key[0], 2)
    if jax.random.uniform(key)<p:
        return x[..., ::-1 ,:], y
    return x, y

batch_size = 16

tape = "/root/ssd_cache/rppg_training_data/mmpd_160x128x128_all"

train_tape = load_datatape(tape, fold='train', batch=batch_size, dtype='uint8')
valid_tape = load_datatape(tape, fold='val', extended_hr='False', batch=batch_size, dtype='uint8')

#train_rlap = DatatapeMonitor(train_tape)
training_set, validation_set = KerasDataset(train_tape), KerasDataset(valid_tape)

training_set = training_set.apply_fn(random_horizontal_flip)
training_set = training_set.apply_fn(u8_to_fp16)
validation_set = validation_set.apply_fn(u8_to_fp16)
training_set = training_set.apply_fn(compress_aug)



mmpd_160x128x128_all           datatape has been loaded.     7664 items total. fold=train selected
mmpd_160x128x128_all           datatape has been loaded.     1680 items total. fold=val&extended_hr=False selected


In [3]:
model = PhysFormer(TN=True)
#lr_schedule = keras.optimizers.schedules.ExponentialDecay(initial_learning_rate=1e-4, decay_steps=50, decay_rate=0.5)
#opti = keras.optimizers.Adam(learning_rate=lr_schedule, weight_decay=1e-5) # convergence is too slow

a_start, b_start, exp_b = 1., .0, 1.
a, b = a_start, b_start
def combined_loss(y, pred):
    return a*np_loss(y, pred) + b*kl_ce_loss(y, pred)

model.compile(optimizer=keras.optimizers.Adam(learning_rate=1e-4), loss=combined_loss, metrics=[np_loss, kl_ce_loss])
y = model(np.random.random((4, 160, 128, 128, 3)));
model.summary()

In [5]:
stat = None
for _ in range(10):
    stat = train(model, training_set, validation_set, epochs=1, check_point_path='../weights/physformer_mmpd.weights.h5', training_stat=stat)
    b = b_start + exp_b*(2**(stat['epoch']/10)-1) # update combined loss weight
stat['best_loss'] = 1e20
train(model, training_set, validation_set, epochs=10, check_point_path='../weights/physformertn_mmpd.weights.h5', training_stat=stat)
model.load_weights('../weights/physformertn_mmpd.weights.h5') 

Training:   0%|          | 0/479 [00:00<?, ?it/s]

Validating:   0%|          | 0/105 [00:00<?, ?it/s]

Epoch  11:  Training kl_ce_loss: 5.693, loss: 6.079, np_loss:0.3837	Validation kl_ce_loss:  6.29, loss: 6.846, np_loss: 0.592	


Training:   0%|          | 0/479 [00:00<?, ?it/s]

Validating:   0%|          | 0/105 [00:00<?, ?it/s]

Epoch  12:  Training kl_ce_loss: 5.638, loss: 6.023, np_loss:0.3739	Validation kl_ce_loss: 6.343, loss: 6.827, np_loss:0.5675	


Training:   0%|          | 0/479 [00:00<?, ?it/s]

Validating:   0%|          | 0/105 [00:00<?, ?it/s]

Epoch  13:  Training kl_ce_loss: 5.585, loss: 5.908, np_loss:0.3476	Validation kl_ce_loss: 6.202, loss: 6.825, np_loss:0.5647	


Training:   0%|          | 0/479 [00:00<?, ?it/s]

Validating:   0%|          | 0/105 [00:00<?, ?it/s]

Epoch  14:  Training kl_ce_loss: 5.527, loss: 5.845, np_loss:0.3169	Validation kl_ce_loss: 6.318, loss: 6.848, np_loss:0.5844	


Training:   0%|          | 0/479 [00:00<?, ?it/s]

Validating:   0%|          | 0/105 [00:00<?, ?it/s]

Epoch  15:  Training kl_ce_loss: 5.505, loss: 5.813, np_loss: 0.301	Validation kl_ce_loss: 6.375, loss: 6.806, np_loss:0.5951	


Training:   0%|          | 0/479 [00:00<?, ?it/s]

Validating:   0%|          | 0/105 [00:00<?, ?it/s]

Epoch  16:  Training kl_ce_loss: 5.493, loss: 5.788, np_loss:0.2989	Validation kl_ce_loss:  6.25, loss: 6.844, np_loss:0.5724	


Training:   0%|          | 0/479 [00:00<?, ?it/s]

Validating:   0%|          | 0/105 [00:00<?, ?it/s]

Epoch  17:  Training kl_ce_loss: 5.464, loss: 5.745, np_loss:0.2824	Validation kl_ce_loss:  6.28, loss: 6.883, np_loss:0.5605	


Training:   0%|          | 0/479 [00:00<?, ?it/s]

Validating:   0%|          | 0/105 [00:00<?, ?it/s]

Epoch  18:  Training kl_ce_loss: 5.451, loss: 5.713, np_loss:0.2738	Validation kl_ce_loss: 6.303, loss: 6.842, np_loss:0.5829	


Training:   0%|          | 0/479 [00:00<?, ?it/s]

Validating:   0%|          | 0/105 [00:00<?, ?it/s]

Epoch  19:  Training kl_ce_loss: 5.437, loss: 5.696, np_loss:0.2513	Validation kl_ce_loss: 6.232, loss: 6.843, np_loss:0.5804	


Training:   0%|          | 0/479 [00:00<?, ?it/s]

Validating:   0%|          | 0/105 [00:00<?, ?it/s]

Epoch  20:  Training kl_ce_loss: 5.431, loss:  5.68, np_loss:0.2371	Validation kl_ce_loss: 6.337, loss: 6.886, np_loss:0.5912	


In [6]:
eval_on_dataset(dataset_H5_mmpd, pmodel(lambda x:model(x/255.)), 160, (128, 128), step=4, batch=8, save='../results/PhysFormertn_MMPD_RLAP.h5', ipt_dtype='uint8', scenes=['R1', 'R2', 'R3', 'R4'])
get_metrics('../results/PhysFormertn_MMPD_RLAP.h5')

  0%|          | 0/660 [00:00<?, ?it/s]

Unnamed: 0,MAE,RMSE,MAPE,R
Window,3.784±0.096,9.004±2.114,4.669±0.132,0.85546
Whole,2.325±0.183,5.239±2.158,2.813±0.237,0.94162


In [7]:
eval_on_dataset(dataset_H5_cohface, pmodel(lambda x:model(x/255.)), 160, (128, 128), step=4, batch=8, save='../results/PhysFormertn_MMPD_COHFACE.h5', ipt_dtype='uint8', fps=30)
get_metrics('../results/PhysFormertn_MMPD_COHFACE.h5')

  0%|          | 0/164 [00:00<?, ?it/s]

Unnamed: 0,MAE,RMSE,MAPE,R
Window,5.349±0.211,10.496±2.72,7.284±0.275,0.59742
Whole,3.873±0.455,6.995±3.245,5.104±0.552,0.81486


In [8]:
eval_on_dataset(dataset_H5_pure, pmodel(lambda x:model(x/255.)), 160, (128, 128), step=4, batch=8, save='../results/PhysFormertn_MMPD_PURE.h5', ipt_dtype='uint8')
get_metrics('../results/PhysFormertn_MMPD_PURE.h5')

  0%|          | 0/59 [00:00<?, ?it/s]

Unnamed: 0,MAE,RMSE,MAPE,R
Window,4.716±0.528,15.331±6.488,8.451±0.984,0.81113
Whole,3.976±1.087,9.247±5.695,7.171±2.162,0.92852


In [9]:
eval_on_dataset(dataset_H5_ubfc_rppg2, pmodel(lambda x:model(x/255.)), 160, (128, 128), step=4, batch=8, save='../results/PhysFormertn_MMPD_UBFC.h5', ipt_dtype='uint8')
get_metrics('../results/PhysFormertn_MMPD_UBFC.h5')

  0%|          | 0/42 [00:00<?, ?it/s]

Unnamed: 0,MAE,RMSE,MAPE,R
Window,1.396±0.108,2.83±1.494,1.44±0.113,0.98596
Whole,0.545±0.105,0.872±0.529,0.535±0.097,0.99884
