# Library

In [None]:
import os, os.path

import json
import h5py

import torch # Tested with PyTorch version 1.7.1
import torch.nn as nn
from torch.nn import init
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim

import numpy as np
import cv2
import random

import matplotlib.pyplot as plt
from pathlib import Path

# Helper

In [None]:
def tensor2img(tensor, normalization=False):
  img = tensor.cpu().numpy().transpose((1, 2, 0))
  if normalization: # Normalize to [0, 1] for visualization 
    img = (img - img.min()) / (img.max() - img.min())
  else: # Convert [-1, 1] to [0, 1]
    img = (img + 1) / 2
  return img

def multiple_tensors_show(tensors, normalization_list):
  f, axarr = plt.subplots(1, len(tensors), figsize=(5 * len(tensors), 5))
  for i, tensor in enumerate(tensors):
    img = tensor2img(tensor, normalization=normalization_list[i])
    axarr[i].imshow(img)
  plt.show()

# Dataloader

In [None]:
class DatasetUBFC(Dataset):
  """
      Dataset class for training network.
  """

  def __init__(
      self, root_dir, root_dir_transfered, 
      session_names, num_samples, 
      seq_length, device, resize_shape
      ):
      
    self.resize_shape = resize_shape
    self.sessions = session_names
    self.seq_length = seq_length    

    self.all_sessions = []
    self.length = num_samples
    self.ppg_der = {}
    self.frames = {}
    self.masks = {}
    self.frames_transfered = {}
    self.root_dir_transfered = root_dir_transfered

    for session in self.sessions:
      db = h5py.File(os.path.join(root_dir, session + '.h5'), 'r')
      frames = db['dataset_1']
      target = db['ppg']

      if self.root_dir_transfered:
        db_transfered = h5py.File(os.path.join(root_dir_transfered, session + '_af.h5'), 'r')
        frames_transfered = db_transfered['dataset_1']
        self.frames_transfered[session] = frames_transfered 
      
      # Normalize PPG
      target = target - np.mean(target)
      target = target / np.std(target)
      
      self.frames[session] = frames
      self.ppg_der[session] = target

      if self.root_dir_transfered:
        self.frames_transfered[session] = frames_transfered  

  def __len__(self):
    return (self.length)

  def __getitem__(self, idx):

    # Pick a session
    session_num = np.random.randint(low=0, high=len(self.sessions))
        
    subject = self.sessions[session_num]
    frames = self.frames[subject]

    cur_ppg_signal = self.ppg_der[subject]
    
    if self.root_dir_transfered:
      frames_transfered = self.frames_transfered[subject]

    # Pick a random frame
    cur_frame_num = np.random.randint(
        low=0, high=len(frames_transfered) - self.seq_length # Can't pick the last frame.
        )

    temp_next_frames_list = []
    temp_next_ppgs_list = []

    if self.root_dir_transfered:
      temp_transfered_next_frames_list = []

    # Following frames
    for j in range(self.seq_length):
      next_frame = frames[cur_frame_num + j]
      next_frame = cv2.resize(
          next_frame, self.resize_shape, 
          interpolation=cv2.INTER_LINEAR
          )
      next_frame = torch.from_numpy(next_frame).permute(2, 0, 1).float()
      next_frame = next_frame / 127.5 - 1

      if self.root_dir_transfered:
        transfered_next_frame = frames_transfered[cur_frame_num + j]
        transfered_next_frame = cv2.resize(
            transfered_next_frame, self.resize_shape, 
            interpolation=cv2.INTER_LINEAR
            )
        transfered_next_frame = torch.from_numpy(transfered_next_frame).permute(2, 0, 1).float()
        transfered_next_frame = transfered_next_frame / 127.5 - 1

      temp_next_frames_list.append(next_frame)
      if self.root_dir_transfered:
        temp_transfered_next_frames_list.append(transfered_next_frame)
      next_ppg_value = cur_ppg_signal[cur_frame_num + j]
      temp_next_ppgs_list.append(torch.from_numpy(
        np.array(next_ppg_value).astype('float32')
        ))
      
    data = {
      'next_frame': torch.stack(temp_next_frames_list),
      'next_ppg_value': torch.stack(temp_next_ppgs_list),
    }
    if self.root_dir_transfered:
      data['transfered_next_frame'] = torch.stack(temp_transfered_next_frames_list)

    return data

# Generator

In [None]:
"""
The code is modified from https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix.
"""

def init_weights(net, init_type='normal', init_gain=0.02):
    """
    Initialize network weights.
    """
    def init_func(m):  # define the initialization function
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
            if init_type == 'normal':
                init.normal_(m.weight.data, 0.0, init_gain)
            elif init_type == 'xavier':
                init.xavier_normal_(m.weight.data, gain=init_gain)
            elif init_type == 'kaiming':
                init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
            elif init_type == 'orthogonal':
                init.orthogonal_(m.weight.data, gain=init_gain)
            else:
                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
            if hasattr(m, 'bias') and m.bias is not None:
                init.constant_(m.bias.data, 0.0)
        elif classname.find('BatchNorm2d') != -1:  # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
            init.normal_(m.weight.data, 1.0, init_gain)
            init.constant_(m.bias.data, 0.0)

    print('initialize network with %s' % init_type)
    net.apply(init_func)  # apply the initialization function <init_func>


def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
    """
    Initialize a network.
    """
    if len(gpu_ids) > 0:
        assert(torch.cuda.is_available())
        net.to(gpu_ids[0])
        net = torch.nn.DataParallel(net, gpu_ids)  # multi-GPUs
    init_weights(net, init_type, init_gain=init_gain)
    return net

In [None]:
class ResnetGenerator3d(nn.Module):
    """
    Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
    """

    def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm3d, use_dropout=False, n_blocks=6, padding_type='reflect'):
        """
        Construct a Resnet-based generator
        """
        assert(n_blocks >= 0)
        super(ResnetGenerator3d, self).__init__()
        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm3d
        else:
            use_bias = norm_layer == nn.InstanceNorm3d

        model = [nn.ReplicationPad3d(3),
                 nn.Conv3d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
                 norm_layer(ngf),
                 nn.ReLU()]

        n_downsampling = 2
        for i in range(n_downsampling):  # add downsampling layers
            mult = 2 ** i
            model += [nn.Conv3d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
                      norm_layer(ngf * mult * 2),
                      nn.ReLU()]

        mult = 2 ** n_downsampling
        for i in range(n_blocks):       # add ResNet blocks

            model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]

        for i in range(n_downsampling):  # add upsampling layers
            mult = 2 ** (n_downsampling - i)
            model += [nn.ConvTranspose3d(ngf * mult, int(ngf * mult / 2),
                                         kernel_size=3, stride=2,
                                         padding=1, output_padding=1,
                                         bias=use_bias),
                      norm_layer(int(ngf * mult / 2)),
                      nn.ReLU()]
        model += [nn.ReplicationPad3d(3)]
        model += [nn.Conv3d(ngf, output_nc, kernel_size=7, padding=0)]
        model += [nn.Tanh()]

        self.model = nn.Sequential(*model)

    def forward(self, input):
        """Standard forward"""
        return self.model(input)


class ResnetBlock(nn.Module):
    """Define a Resnet block"""

    def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
        """
        Initialize the Resnet block
        """
        super(ResnetBlock, self).__init__()
        self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)

    def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
        """
        Construct a convolutional block.
        """
        conv_block = []
        p = 0
        if padding_type == 'reflect':
            conv_block += [nn.ReplicationPad3d(1)]
        elif padding_type == 'replicate':
            conv_block += [nn.ReplicationPad3d(1)]
        elif padding_type == 'zero':
            p = 1
        else:
            raise NotImplementedError('padding [%s] is not implemented' % padding_type)

        conv_block += [nn.Conv3d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)]
        if use_dropout:
            conv_block += [nn.Dropout(0.5)]

        p = 0
        if padding_type == 'reflect':
            conv_block += [nn.ReplicationPad3d(1)]
        elif padding_type == 'replicate':
            conv_block += [nn.ReplicationPad3d(1)]
        elif padding_type == 'zero':
            p = 1
        else:
            raise NotImplementedError('padding [%s] is not implemented' % padding_type)
        conv_block += [nn.Conv3d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)]

        return nn.Sequential(*conv_block)

    def forward(self, x):
        """Forward function (with skip connections)"""
        out = x + self.conv_block(x)  # add skip connections
        return out

In [None]:
import functools

def get_norm_layer3d(norm_type='instance'):
    """
    Return a normalization layer
    """
    if norm_type == 'batch':
        norm_layer = functools.partial(nn.BatchNorm3d, affine=True, track_running_stats=True)
    elif norm_type == 'instance':
        norm_layer = functools.partial(nn.InstanceNorm3d, affine=False, track_running_stats=False)
    else:
        raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
    return norm_layer

# PhysResNet (PRN)


In [None]:
class RPPGNetResnet(nn.Module):
  def __init__(self, seq_length):
    super().__init__()

    self.learned_shortcut1 = nn.Conv3d(3, 16, kernel_size=1, bias=False)
    self.layers1 = nn.Sequential(
        nn.Conv3d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1, bias=True),
        nn.BatchNorm3d(16),
        nn.ReLU(),
        nn.AvgPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2), padding=0),
        nn.Conv3d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1, bias=True),
        nn.BatchNorm3d(16),
        nn.ReLU(),
        nn.AvgPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2), padding=0)
    )
    self.pool1 = nn.AvgPool3d(kernel_size=(1, 4, 4), stride=(1, 4, 4), padding=0)

    self.learned_shortcut2 = nn.Conv3d(16, 64, kernel_size=1, bias=False)
    self.layers2 = nn.Sequential(
        nn.Conv3d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1, bias=True),
        nn.BatchNorm3d(32),
        nn.ReLU(),
        nn.AvgPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2), padding=0),
        nn.Conv3d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True),
        nn.BatchNorm3d(64),
        nn.ReLU(),
        nn.AvgPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2), padding=0)
    )
    self.pool2 = nn.AvgPool3d(kernel_size=(1, 4, 4), stride=(1, 4, 4), padding=0)

    self.learned_shortcut3 = nn.Conv3d(64, 256, kernel_size=1, bias=False)
    self.layers3 = nn.Sequential(
        nn.Conv3d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1, bias=True),
        nn.BatchNorm3d(128),
        nn.ReLU(),
        nn.AvgPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2), padding=(0, 1, 1)),
        nn.Conv3d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1, bias=True),
        nn.BatchNorm3d(256),
        nn.ReLU()
    )
    
    self.pool3 = nn.AvgPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2), padding=(0, 1, 1))

    self.pooling = nn.AdaptiveAvgPool3d((seq_length, 1, 1))
    self.final_conv = nn.Conv3d(in_channels=256, out_channels=1, kernel_size=1, stride=1, padding=0, bias=True)
    
  def forward(self, A):   

    identity = A
    out = self.layers1(A)   
    out += self.learned_shortcut1(self.pool1(identity))
  
    identity = out 
    out = self.layers2(out)
    out += self.learned_shortcut2(self.pool2(identity))
    
    identity = out
    out = self.layers3(out)
    out += self.learned_shortcut3(self.pool3(identity))

    out = self.pooling(out)
    out = self.final_conv(out)

    return out

# Loss

In [None]:
tr = torch

class NegPeaLoss(nn.Module):
  def __init__(self):
    super(NegPeaLoss, self).__init__()

  def forward(self, x, y):
    if len(x.size()) == 1:
      x = tr.unsqueeze(x, 0)
      y = tr.unsqueeze(y, 0)
    T = x.shape[1]
    p_coeff = tr.sub(T * tr.sum(tr.mul(x, y), 1), tr.mul(tr.sum(x, 1), tr.sum(y, 1)))
    norm = tr.sqrt((T * tr.sum(x ** 2, 1) - tr.sum(x, 1) ** 2) * (T * tr.sum(y ** 2, 1) - tr.sum(y, 1) ** 2))
    p_coeff = tr.div(p_coeff, norm)
    losses = tr.tensor(1.) - p_coeff
    totloss = tr.mean(losses)
    return totloss

class ThresholdLoss(nn.Module):
  def __init__(self, threshold=0.1):
    super(ThresholdLoss, self).__init__()
    self.threshold = threshold

  def forward(self, est, gt):
    l1_map = torch.abs(est - gt)
    return torch.mean(l1_map[l1_map > self.threshold])

# Training 

In [None]:
## Training parameters
params = {
    'batch_size': 2, # Batch size during training.
    'num_samples': 160, # Number of samples in each epoch
    'seq_length': 256, # Number of frames as the next frames
    'img_shape': (80, 80), # Spatial size of training images.
    'num_epochs': 601, # Number of training epochs.
    'lr': 0.0001, # Learning rate for optimizer 0.0003
    'lr_rppgnet': 0.0003,
    'display_tensor': 15, # Show images during training
    'val_epoch': 15, # Perform validation every K epochs
    'save_path': './model_checkpoint_folder', # Name for saving the model weights
    'save_name_generator': 'generator_checkpoint.pt',
    'save_name_rppgnet': 'rppgnet_checkpoint.pt',
    'rgb_loss_weight': 1.0,
    'rppgnet_loss_weight': 1.0, # Loss weight for rPPG loss
    'input_nc': 3,
    'output_nc': 3,
    'ngf': 64,
    'dropout': False,
    'norm': 'instance',
    'weight_decay': 0.0, # Regularization
    'beta1': 0.5, # Adam beta1
    'beta2': 0.999, # Adam beta2
    'init_type': 'kaiming', # xavier, xavier_uniform, kaiming, orthogonal
    'init_gain': 0.02,
    'generator_pretrained_checkpoint': '/path/to/pretrained_generator.pt',
    'rppgnet_pretrained_checkpoint': '/path/to/pretrained_rppgnet.pt',
    'gpu_ids': [0]
}

opt = {
    'crop_size': 80,
    'batch_size': params['batch_size'],
}

In [None]:
# Create the save_path dir
if params['save_path']:
  Path(params['save_path']).mkdir(parents=True, exist_ok=True)

# Dataloaders
root_dir = '/path/to/real_subjects' # Dir for the real subjects
root_dir_transfered = '/path/to/pseudo_gt' # Dir for the pseudo GT

train_session_nums = [
    1, 3, 4, 5, 8, 9, 10, 11, 12, 13, 16, 17, 18, 20, 
    23,  24, 25, 26, 27, 31, 32, 33, 35, 36, 38, 39, 40, 
    41, 42, 45, 47, 49
    ]

val_session_nums = [14, 15, 22, 30, 34, 37, 43, 44, 46, 48]

# Check train and val overlap
for cur_session_num in val_session_nums:
  if cur_session_num in train_session_nums:
    print('Session overlap!', cur_session_num)

train_session_names = []
for cur_session_num in train_session_nums:
  cur_session_name = 'subject' + str(cur_session_num) 
  train_session_names.append(cur_session_name)

val_session_names = []
for cur_session_num in val_session_nums:
  cur_session_name = 'subject' + str(cur_session_num) 
  val_session_names.append(cur_session_name)

# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)

train_set = DatasetUBFC(
    root_dir=root_dir, 
    root_dir_transfered=root_dir_transfered,
    session_names=train_session_names, 
    num_samples=params['num_samples'],
    seq_length=params['seq_length'], 
    device=device, 
    resize_shape=params['img_shape'],
    )
ppg_train_loader = DataLoader(
    train_set,
    batch_size=params['batch_size'],
    shuffle=True,
    num_workers=0,
    pin_memory=True
    )

val_set = DatasetUBFC(
    root_dir=root_dir, 
    root_dir_transfered=root_dir_transfered,
    session_names=val_session_names, 
    num_samples=params['num_samples'],
    seq_length=params['seq_length'], 
    device=device, 
    resize_shape=params['img_shape'],
    )
ppg_val_loader = DataLoader(
    val_set,
    batch_size=params['batch_size'],
    shuffle=True,
    num_workers=0,
    pin_memory=True
    )

In [None]:
# Model
norm_layer = get_norm_layer3d(norm_type=params['norm'])
model = ResnetGenerator3d(params['input_nc'], params['output_nc'], params['ngf'], norm_layer=norm_layer, use_dropout=params['dropout'], n_blocks=6)
model = init_net(model, params['init_type'], params['init_gain'], params['gpu_ids'])

if params['generator_pretrained_checkpoint']:
  print('Loading generator checkpoint:', params['generator_pretrained_checkpoint'])
  checkpoint = torch.load(params['generator_pretrained_checkpoint'])
  model.load_state_dict(checkpoint['model_state_dict'])

model_rppgnet = RPPGNetResnet(params['seq_length'])
if params['rppgnet_pretrained_checkpoint']:
  print('Loading rppg checkpoint:', params['rppgnet_pretrained_checkpoint'])
  model_rppgnet.load_state_dict(torch.load(params['rppgnet_pretrained_checkpoint']))
model_rppgnet.to(device)

# Reload optimizer 
beta1 = params['beta1']
beta2 = params['beta2']
optimizer = optim.Adam(
    model.parameters(), lr=params['lr'], 
    betas=(beta1, beta2), weight_decay=params['weight_decay']
    )
optimizer_rppgnet = optim.Adam(
    model_rppgnet.parameters(), lr=params['lr_rppgnet'], 
    betas=(beta1, beta2), weight_decay=params['weight_decay']
    )

# Scheduler
scheduler = optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=params['num_epochs'], eta_min=0, verbose=True
    )

scheduler_rppgnet = optim.lr_scheduler.CosineAnnealingLR(
    optimizer_rppgnet, T_max=params['num_epochs'], eta_min=0, verbose=True
    )

# Loss
criterion = ThresholdLoss()
criterion_rppg = NegPeaLoss()

In [None]:
val_loss_history = []
train_loss_history = []

for epoch in range(params['num_epochs']):
  print('\nEpoch {}/{}'.format(epoch, params['num_epochs'] - 1))
  print('-' * 10)

  # Training
  model.train()
  model_rppgnet.train()

  running_g_loss = 0.0
  running_rppg_loss = 0.0
  running_rppg_loss_real = 0.0
  running_rppg_loss_fake = 0.0

  # Iterate over data
  for i, data in enumerate(ppg_train_loader):
    next_ppg = data['next_ppg_value'].to(device) 
    next_frame = data['next_frame'].to(device) 
    next_frame_transfered = data['transfered_next_frame'].to(device)
    seq_length = next_ppg.shape[1]
    
    # Generator step
    optimizer.zero_grad()
   
    frame_transfered_hat = model(torch.transpose(next_frame, 1, 2))
    with torch.no_grad():
      est_ppg_hat = model_rppgnet(frame_transfered_hat)
    g_loss = params['rgb_loss_weight'] * criterion(
        torch.transpose(frame_transfered_hat, 1, 2), next_frame_transfered
    ) + params['rppgnet_loss_weight'] * criterion_rppg(est_ppg_hat.squeeze(), next_ppg)
    
    g_loss.backward()
    optimizer.step()

    # rPPG step
    optimizer_rppgnet.zero_grad()
   
    est_ppg_transfered = model_rppgnet(frame_transfered_hat.detach())
    rppg_loss_fake = criterion_rppg(est_ppg_transfered.squeeze(), next_ppg)
    est_ppg_original = model_rppgnet(torch.transpose(next_frame, 1, 2))
    rppg_loss_real = criterion_rppg(est_ppg_original.squeeze(), next_ppg)

    rppg_loss = rppg_loss_real + rppg_loss_fake
    rppg_loss.backward()
    optimizer_rppgnet.step()

    running_g_loss += g_loss.item()
    running_rppg_loss += rppg_loss.item()
    running_rppg_loss_real += rppg_loss_real.item()
    running_rppg_loss_fake += rppg_loss_fake.item()

    # Display some training frames
    if params['display_tensor'] and i == 0 and epoch % params['display_tensor'] == 0:
      
      # Pick a random frame
      idx = np.random.randint(0, next_frame.shape[0])
      multiple_tensors_show(
          [next_frame.detach().cpu()[idx][idx],
          next_frame_transfered.detach().cpu()[idx][idx],
          torch.transpose(frame_transfered_hat, 1, 2).detach().cpu()[idx][idx]],
          normalization_list=[False, False, False]
          )

  epoch_g_loss = running_g_loss / len(ppg_train_loader)
  epoch_rppg_loss = running_rppg_loss / len(ppg_train_loader)
  epoch_rppg_loss_real = running_rppg_loss_real / len(ppg_train_loader)
  epoch_rppg_loss_fake = running_rppg_loss_fake / len(ppg_train_loader)

  print('G Loss: {:.4f} '.format(epoch_g_loss))
  print('RPPGNet Loss: {:.4f} '.format(epoch_rppg_loss))
  print('RPPGNet Loss Real: {:.4f} '.format(epoch_rppg_loss_real))
  print('RPPGNet Loss Fake: {:.4f} '.format(epoch_rppg_loss_fake))

  train_loss_history.append(epoch_g_loss)

  # Adjust the scheduler
  scheduler.step()
  scheduler_rppgnet.step()
  
  # Validation 
  if epoch % params['val_epoch'] == 0:
    model.eval()
    model_rppgnet.eval()

    running_loss = 0.0
    running_loss_rppg = 0.0
    running_rppg_loss_real = 0.0
    running_rppg_loss_fake = 0.0

    with torch.no_grad():
      # Iterate over data.
      for i, data in enumerate(ppg_val_loader):
        next_ppg = data['next_ppg_value'].to(device) 
        next_frame = data['next_frame'].to(device) 
        next_frame_transfered = data['transfered_next_frame'].to(device) 
        seq_length = next_ppg.shape[1]

        frame_transfered_hat = model(torch.transpose(next_frame, 1, 2))

        loss = criterion(torch.transpose(frame_transfered_hat, 1, 2), next_frame_transfered)
        running_loss += loss.item()
       
        est_ppg_transfered = model_rppgnet(frame_transfered_hat)
        rppg_loss_fake = criterion_rppg(est_ppg_transfered.squeeze(), next_ppg)
        est_ppg_original = model_rppgnet(torch.transpose(next_frame, 1, 2))
        rppg_loss_real = criterion_rppg(est_ppg_original.squeeze(), next_ppg)

        running_rppg_loss_real += rppg_loss_real.item()
        running_rppg_loss_fake += rppg_loss_fake.item()
        running_loss_rppg += rppg_loss_real.item() + rppg_loss_fake.item()
  
        # Display some frames
        if params['display_tensor'] and i == 0 and epoch % params['display_tensor'] == 0:
          # Pick a random frame
          idx = np.random.randint(0, next_frame.shape[0])
          multiple_tensors_show(
              [next_frame.detach().cpu()[idx][idx],
              next_frame_transfered.detach().cpu()[idx][idx],
              torch.transpose(frame_transfered_hat, 1, 2).detach().cpu()[idx][idx]],
              normalization_list=[False, False, False]
              )
    epoch_loss_rgb = running_loss / len(ppg_val_loader)
    epoch_loss_rppg = running_loss_rppg / len(ppg_val_loader)
    epoch_rppg_loss_real = running_rppg_loss_real / len(ppg_val_loader)
    epoch_rppg_loss_fake = running_rppg_loss_fake / len(ppg_val_loader)

    print('Val Loss RGB: {:.4f} '.format(epoch_loss_rgb))
    print('Val Loss rPPG: {:.4f} '.format(epoch_loss_rppg))
    print('RPPGNet Loss Real: {:.4f} '.format(epoch_rppg_loss_real))
    print('RPPGNet Loss Fake: {:.4f} '.format(epoch_rppg_loss_fake))
    
    epoch_loss = epoch_rppg_loss_real
    val_loss_history.append(epoch_loss)

    # Save the checkpoint after validation
    if params['save_path'] and epoch_loss <= min(val_loss_history):
      print('Saving in Epoch', epoch)
      torch.save({
          'epoch': epoch,
          'model_state_dict': model.state_dict(),
          'optimizer_state_dict': optimizer.state_dict(),
          'scheduler_state_dict': scheduler.state_dict(),
          'train_loss': train_loss_history,
          'val_loss': val_loss_history,
          }, os.path.join(params['save_path'], params['save_name_generator']))
      torch.save({
        'epoch': epoch,
        'model_state_dict': model_rppgnet.state_dict(),
        'optimizer_state_dict': optimizer_rppgnet.state_dict(),
        'scheduler_state_dict': scheduler_rppgnet.state_dict(),
        }, os.path.join(params['save_path'], params['save_name_rppgnet']))