# Library

In [None]:
import os, os.path

import torch
tr = torch
import torch.nn as nn
from torch.nn import init
import torchvision as tv
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import torch.nn.functional as F
import torch.nn.utils.spectral_norm as spectral_norm

import re
import numpy as np
import cv2
import random
from PIL import Image 

from scipy.io import loadmat
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
from numpy import linalg as LA

from torch.utils.data import DataLoader
import json
import h5py

from torchvision.transforms import RandomRotation, ToPILImage, ToTensor, ColorJitter
import torchvision.transforms.functional as TF
from sklearn.model_selection import train_test_split


import sys
import argparse
from tqdm import tqdm
from types import SimpleNamespace
import glob

# Fix random seed for reproducability
rseed = 43
np.random.seed(rseed)
torch.backends.cudnn.deterministic = True
torch.manual_seed(rseed)
if torch.cuda.is_available():
  torch.cuda.manual_seed_all(rseed)

# Dataloader

In [None]:
import pandas as pd
class DatasetUBFC(Dataset):
  """
      Dataset class for training network.
  """
  def __init__(
      self, root_dir, session_names, num_samples, 
      seq_length, resize_shape, val=False, transform=False):
    """
    :param path: Path to hdf5 file
    :param labels: tuple of label names to use (e.g.: ('pulseNumerical', 'resp_signal') or ('pulse_signal', ) )
        Note that the first label must be the pulse rate if it is present!
    """

    # self.hr_path = hr_path
    self.resize_shape = resize_shape
    self.transform = transform
    self.sessions = session_names
    self.seq_length = seq_length    

    self.all_sessions = []
    self.length = num_samples
    self.ppg_der = {}
    self.frames = {}
    self.avg_img = {}
    self.hr = {}
    
    print(len(self.sessions),self.sessions) 
    self.files = os.listdir(root_dir[1])
    print(len(self.files))
    i = -1
   
    len_session = 1
    for num_dataset in range(len_session):
      print(num_dataset)
      print(len(self.sessions[num_dataset]))
      for session_num, session in enumerate(self.sessions[num_dataset]):
        i += 1
        if num_dataset == 0:
       
          db = h5py.File(os.path.join(root_dir[num_dataset], session + '.h5'), 'r')
        else:
          # print(self.files[session])
          db = h5py.File(os.path.join(root_dir[num_dataset], self.files[session]), 'r')

        frames = db['dataset_1']
        target = db['ppg']

        # Normalize PPG
        target = target - np.mean(target)
        target = target / np.std(target) #target = target / np.max(np.abs(target))

        self.frames[i] = frames
        self.ppg_der[i] = target
        # self.avg_img[session] = avg_img
    print('frames.shape', len(self.frames))
    
  def __len__(self):
    return (self.length)

  def __getitem__(self, idx):

    # Pick a session
    session_num = np.random.randint(low=0, high=len(self.frames))

    frames = self.frames[session_num]
    cur_ppg_signal = self.ppg_der[session_num]
    # avg_img = self.avg_img[subject].clone()
    
    # Pick a random frame
    cur_frame_num = np.random.randint(
        low=0, high=len(frames) - self.seq_length # Can't pick the last frame.
        )

    # Flip
    if self.transform:
      k = random.randint(0, 1)
      k_h = random.randint(0, 1)

    # Following frames
    temp_next_frames_list = []
    temp_next_ppgs_list = []
    
    for j in range(self.seq_length):
      next_frame = frames[cur_frame_num + j]
      next_frame = cv2.resize(
          next_frame, self.resize_shape, 
          interpolation=cv2.INTER_LINEAR
          )
      next_frame = torch.from_numpy(next_frame).permute(2, 0, 1).float()
      next_frame = next_frame / 127.5 - 1
      
      # Augmentation
      if (self.transform == True) and (k == 1):
        next_frame = torch.flip(next_frame, [2])
      if (self.transform == True) and (k_h == 1):
        next_frame = torch.flip(next_frame, [1])

      # PPG
      next_ppg_value = cur_ppg_signal[cur_frame_num + j]

      temp_next_frames_list.append(next_frame)
      temp_next_ppgs_list.append(
          torch.tensor(next_ppg_value).float()
          )
    
    # Return
    data = {
      'next_frame': torch.stack(temp_next_frames_list),
      # 'hr': np.mean(hr),
      'next_ppg_value': torch.stack(temp_next_ppgs_list),
    }

    return data

# Loss

In [None]:
class Neg_Pearson(nn.Module):    # Pearson range [-1, 1] so if < 0, abs|loss| ; if >0, 1- loss
  def __init__(self):
    super(Neg_Pearson,self).__init__()
    return
      
  def forward(self, preds, labels):       # all variable operation
    loss = torch.tensor(0).float().to(preds.device)
    # preds = (preds - torch.mean(preds)) / torch.std(preds)
    # labels = (labels - torch.mean(labels)) / torch.std(labels)
    for i in range(preds.shape[0]):
      preds_nor = (preds[i] - torch.mean(preds[i])) / (torch.std(preds[i]))
      labels_nor = (labels[i] - torch.mean(labels[i])) / (torch.std(labels[i]))

      if torch.sum(torch.isnan(preds_nor)) or torch.sum(torch.isnan(labels_nor)):
          continue

      #sum_x = torch.sum((preds[i] - torch.mean(preds[i])) / torch.std(preds[i]))               # x
      #sum_y = torch.sum((labels[i] - torch.mean(labels[i])) / torch.std(labels[i]))               # y
      sum_x = torch.sum(preds_nor)                # x
      sum_y = torch.sum(labels_nor)               # y
      sum_xy = torch.sum(preds_nor * labels_nor)        # xy
      sum_x2 = torch.sum(torch.pow(preds_nor, 2))  # x^2
      sum_y2 = torch.sum(torch.pow(labels_nor, 2)) # y^2
      N = preds.shape[1]
      pearson = (N*sum_xy - sum_x*sum_y)/(torch.sqrt((N*sum_x2 - torch.pow(sum_x,2))*(N*sum_y2 - torch.pow(sum_y,2)))+1e-7)
      # print(pearson, sum_xy, sum_x2, sum_y2)
      # print(labels[i])


      # if (pearson>=0).data.cpu().numpy():    # torch.cuda.ByteTensor -->  numpy
      #    loss += 1 - pearson
      # else:
      #    loss += 1 - torch.abs(pearson)
      
      loss += 1 - pearson
        
    loss = loss/preds.shape[0]
    return loss

# Network

In [None]:
class RPPGNetResnet(nn.Module):
  def __init__(self, seq_length):
    super().__init__()

    self.learned_shortcut1 = nn.Conv3d(3, 16, kernel_size=1, bias=False)
    self.layers1 = nn.Sequential(
        nn.Conv3d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1, bias=True),
        nn.BatchNorm3d(16),
        nn.ReLU(),
        nn.AvgPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2), padding=0),
        nn.Conv3d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1, bias=True),
        nn.BatchNorm3d(16),
        nn.ReLU(),
        nn.AvgPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2), padding=0)
    )
    self.pool1 = nn.AvgPool3d(kernel_size=(1, 4, 4), stride=(1, 4, 4), padding=0)

    self.learned_shortcut2 = nn.Conv3d(16, 64, kernel_size=1, bias=False)
    self.layers2 = nn.Sequential(
        nn.Conv3d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1, bias=True),
        nn.BatchNorm3d(32),
        nn.ReLU(),
        nn.AvgPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2), padding=0),
        nn.Conv3d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True),
        nn.BatchNorm3d(64),
        nn.ReLU(),
        nn.AvgPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2), padding=0)
    )
    self.pool2 = nn.AvgPool3d(kernel_size=(1, 4, 4), stride=(1, 4, 4), padding=0)

    self.learned_shortcut3 = nn.Conv3d(64, 256, kernel_size=1, bias=False)
    self.layers3 = nn.Sequential(
        nn.Conv3d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1, bias=True),
        nn.BatchNorm3d(128),
        nn.ReLU(),
        nn.AvgPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2), padding=(0, 1, 1)),
        nn.Conv3d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1, bias=True),
        nn.BatchNorm3d(256),
        nn.ReLU()
    )
    
    self.pool3 = nn.AvgPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2), padding=(0, 1, 1))

    self.pooling = nn.AdaptiveAvgPool3d((seq_length, 1, 1))
    self.final_conv = nn.Conv3d(in_channels=256, out_channels=1, kernel_size=1, stride=1, padding=0, bias=True)
    
  def forward(self, A):   

    identity = A
    out = self.layers1(A)   
    out += self.learned_shortcut1(self.pool1(identity))
  
    identity = out 
    out = self.layers2(out)
    out += self.learned_shortcut2(self.pool2(identity))
    
    identity = out
    out = self.layers3(out)
    out += self.learned_shortcut3(self.pool3(identity))

    out = self.pooling(out)
    out = self.final_conv(out)

    return out

# Train function

In [None]:
def train_model(model, dataloaders, criterion, optimizer, scheduler, params):
  val_loss_history = []
  train_loss_history = []

  criterion_SNR = SNRLoss()

  min_train_loss = 9999
  min_val_loss = 9999
  min_val_loss_snr = 9999
  max_train_loss_snr = -9999

  for epoch in range(params['num_epochs']):
    print('Epoch {}/{}'.format(epoch, params['num_epochs']))
    print('-' * 10)

    # Each epoch has a training and validation phase
    phases = ['train', 'val']
    for phase in phases:
      running_loss = 0.0
      running_loss_snr = 0.0
      if phase == 'train':
        model.train()
      else:
        model.eval()  # Set model to evaluate mode

      # Iterate over data.
      for inputs in dataloaders[phase]:
        next_frames = inputs['next_frame'].to(device)
        targets = inputs['next_ppg_value'].to(device) 
        # targets_hr = inputs['hr'].to(device)  
        next_frames = torch.transpose(next_frames, 1, 2)
        # print(next_frames.shape)
       
        # zero the parameter gradients
        optimizer.zero_grad()

        # forward
        outputs = model(next_frames).squeeze()
        loss = criterion(outputs, targets)
        # print(loss.shape)
        if phase == 'train':
          loss.backward()
          # nn.utils.clip_grad_norm_(model.parameters(),1e3)
          optimizer.step()

        # statistics
        running_loss += loss.item()

      epoch_loss = running_loss / len(dataloaders[phase])
      print('{} Loss: {:.4f} '.format(phase, epoch_loss))

      if phase == 'val':
        val_loss_history.append(epoch_loss)

        if epoch_loss < min_val_loss:
          print('saving in epoch', epoch)
          min_val_loss = epoch_loss
          print(f'new min loss: {min_val_loss}')
          torch.save(model.state_dict(), f"checkpoints/{params['check_point_dir']}/model_best.pt")
    
    # Scheduler
    scheduler.step()
    print('\n')      

    # Save
    torch.save(model.state_dict(), f"checkpoints/{params['check_point_dir']}/model_ep{epoch}.pt")

**Tain test split**

In [None]:
fp_df = pd.read_excel($Path of label$)
print(len(fp_df))
fp_df.info()

In [None]:
fp_df = fp_df.dropna(subset = ['Study ID','Fitzpatrick Skin Type'])
print(len(fp_df))
exception_subjects  = [31, 99] #this will drop the entire subject altogether.  
fp_df = fp_df[~fp_df['Study ID'].isin(exception_subjects)]
print(len(fp_df))
fp_df = fp_df.drop(columns=list(set(fp_df) - set(['Study ID', 'Fitzpatrick Skin Type'])))
fp_df.info()

In [None]:
train_subject, test_subject = train_test_split(fp_df, test_size = 0.5, random_state = rseed, stratify = fp_df['Fitzpatrick Skin Type'] )# 50% subjects in train, 50% in val + test
validation_subject, test_subject = train_test_split(test_subject, test_size = 0.8, random_state = rseed, stratify = test_subject['Fitzpatrick Skin Type'] )# 10 out of 50 in val, remaining in test
# we need to repeat above split 5 times and retrain. 
print(train_subject['Fitzpatrick Skin Type'].value_counts())
print(validation_subject['Fitzpatrick Skin Type'].value_counts())
print(test_subject['Fitzpatrick Skin Type'].value_counts())

In [None]:
train_subject['Study ID']

In [None]:
train_videos = []
validation_videos = []
train_videos = []
test_videos = []
for subject in train_subject['Study ID']:
  for j in range(1,6):
    train_videos.append(str(int(subject)) + '_' + str(j) )
for subject in test_subject['Study ID']:
  for j in range(1,6):
    test_videos.append(str(int(subject)) + '_' +str(j) )
for subject in validation_subject['Study ID']:
  for j in range(1,6):
    validation_videos.append(str(int(subject)) + '_' +str(j) )

excepetion_videos = ['4_1','8_3', '10_5','16_5','49_5','34_1','58_5',] 
train_videos = [ tv for tv in train_videos if tv not in excepetion_videos]
validation_videos = [ tv for tv in validation_videos if tv not in excepetion_videos]
test_videos = [ tv for tv in test_videos if tv not in excepetion_videos]
print(len(train_videos),train_videos)
print(len(validation_videos), validation_videos)
print(len(test_videos),test_videos)


# Train

In [None]:
params = {
  'lr': 0.0003,
  'min_lr_ratio': 0.01,
  'weight_decay': 0.01,
  'num_samples': 160,
  'seq_length': 256,
  'batch_size': 4,
  'img_shape': (80, 80),
  'num_epochs': 60,
  'snr_epoch': 10,
  'check_point_dir': '1440finetune_real_prn_v2',
  'pretrained_weights': None
}

# create output dir
if params['check_point_dir']:
  try:
    os.makedirs(f"checkpoints/{params['check_point_dir']}")
    print("Output directory is created")
  except FileExistsError:
    reply = input('Override existing weights? [y/n]')
    if reply == 'n':
      print('Add another outout path then!')
      sys.exit()

In [None]:
root_dir = [$Path of real data$, $Path of syn data]

In [None]:
train_session_nums = train_videos
val_session_nums = validation_videos

train_all = [] 
val_all = []
train_session_names = []
val_session_names = []

for cur_session_num in train_session_nums:
  sn = cur_session_num.split('_')[0]
  cur_session_name = 'subject'+ sn + '/' + cur_session_num
  train_session_names.append(cur_session_name)
print(train_session_names)
train_all.append(train_session_names)

for cur_session_num in val_session_nums:
  sn = cur_session_num.split('_')[0]
  cur_session_name = 'subject'+ sn + '/' + cur_session_num
  val_session_names.append(cur_session_name)
print(val_session_names)
val_all.append(val_session_names)

len_synthetic = 480
print(len_synthetic)
session_nums = np.random.permutation(np.arange(len_synthetic))

train_session_syn_nums = []
val_session_syn_nums = []
train_session_syn_nums.append(session_nums[:np.int(len_synthetic*0.8)])
val_session_syn_nums.append(session_nums[np.int(len_synthetic*0.8):])

train_all.append(list(train_session_syn_nums[0]))
val_all.append(list(val_session_syn_nums[0]))

In [None]:
train_set = DatasetUBFC(
    root_dir=root_dir, 
    session_names=train_all, 
    num_samples=params['num_samples'],
    seq_length=params['seq_length'], 
    resize_shape=params['img_shape']
    )

In [None]:
val_set = DatasetUBFC(
    root_dir=root_dir, 
    session_names=val_all, 
    num_samples=params['num_samples'],
    seq_length=params['seq_length'], 
    resize_shape=params['img_shape'],
    val = True
    )

In [None]:

ppg_train_loader = DataLoader(
    train_set,
    batch_size=params['batch_size'],
    shuffle=True,
    num_workers=0,
    pin_memory=True
    )


ppg_val_loader = DataLoader(
    val_set,
    batch_size=params['batch_size'],
    shuffle=False,
    num_workers=0,
    pin_memory=True
    )

dataloaders = {'train': ppg_train_loader, 'val': ppg_val_loader}
print('\nDataLoaders succesfully constructed!')


In [None]:
model = RPPGNetResnet(params['seq_length'])


# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)

# If there are pretrained weights, initialize model
if params['pretrained_weights']:
  model.load_state_dict(tr.load(params['pretrained_weights']))
  print('\nPre-trained weights are loaded!')

# Copy model to working device
model = model.to(device)

# criterion
criterion = Neg_Pearson()   # rPPG singal 

print('====loss====', criterion)
print('====lr====',  params['lr'])
print('====weight_decay====',  params['weight_decay'])

optimizer = optim.AdamW(
    model.parameters(), 
    lr=params['lr'], 
    betas=(0.5, 0.999), 
    weight_decay=params['weight_decay']
    )

scheduler = optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=params['num_epochs'], 
    eta_min=params['min_lr_ratio'] * params['lr'], verbose=True
    )

train_model(
    model, 
    dataloaders, 
    criterion, 
    optimizer, 
    scheduler, 
    params
    )

print('\nTraining is finished without flaw!')