In [None]:
import numpy as np 
import matplotlib.pyplot as plt 
import os
import torch
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as T
import h5py
from tqdm import tqdm
import imageio
import scipy
import random
from utils import getErrors, CNN3D

import math
import torch.nn as nn

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Running on device: {}'.format(device))

#seeds
sd = 42
# torch.manual_seed(seed)
# np.random.seed(seed)

np.random.seed(sd)
torch.backends.cudnn.deterministic = True
torch.manual_seed(sd)
random.seed(sd)
if torch.cuda.is_available():
  torch.cuda.manual_seed_all(sd)

In [None]:
'''
Code of 'Remote Photoplethysmograph Signal Measurement from Facial Videos Using Spatio-Temporal Networks' 
By Zitong Yu, 2019/05/05
If you use the code, please cite:
@inproceedings{yu2019remote,
    title={Remote Photoplethysmograph Signal Measurement from Facial Videos Using Spatio-Temporal Networks},
    author={Yu, Zitong and Li, Xiaobai and Zhao, Guoying},
    booktitle= {British Machine Vision Conference (BMVC)},
    year = {2019}
}
Only for research purpose, and commercial use is not allowed.
MIT License
Copyright (c) 2019 
      How to use it
    #1. Inference the model
    rPPG, x_visual, x_visual3232, x_visual1616 = model(inputs)
    
    #2. Normalized the Predicted rPPG signal and GroundTruth BVP signal
    rPPG = (rPPG-torch.mean(rPPG)) /torch.std(rPPG)	 	# normalize
    BVP_label = (BVP_label-torch.mean(BVP_label)) /torch.std(BVP_label)	 	# normalize
    
    #3. Calculate the loss
    loss_ecg = Neg_Pearson(rPPG, BVP_label)
'''
########################################


class Neg_Pearson2(torch.nn.Module):    # Pearson range [-1, 1] so if < 0, abs|loss| ; if >0, 1- loss
    def __init__(self):
        super(Neg_Pearson2,self).__init__()

        self.epsilon = 1e-2

        return
    def forward(self, preds, labels):       # tensor [Batch, Temporal]
        loss = 0
        for i in range(preds.shape[0]):
            # print(labels[i])
            # print(preds[i])
            x = normalize_signal2(preds[i])
            y = normalize_signal2(labels[i])

            sum_x = torch.sum(x)                # x
            sum_y = torch.sum(y)               # y
            sum_xy = torch.sum(x*y)         # xy
            sum_x2 = torch.sum(torch.pow(x,2))  # x^2
            sum_y2 = torch.sum(torch.pow(y,2)) # y^2
            N = preds.shape[1]
            pearson = (N*sum_xy - sum_x*sum_y)/(torch.sqrt((N*sum_x2 - torch.pow(sum_x,2)+self.epsilon)*(N*sum_y2 - torch.pow(sum_y,2)+self.epsilon)))

            # print(torch.sqrt((N*sum_x2 - torch.pow(sum_x,2))), (N*sum_y2 - torch.pow(sum_y,2)))
            # print(pearson, end=" ")
            #if (pearson>=0).data.cpu().numpy():    # torch.cuda.ByteTensor -->  numpy
            #    loss += 1 - pearson
            #else:
            #    loss += 1 - torch.abs(pearson)
            
            loss += (1 - pearson)**2
            
        # print(loss) 
        loss = loss/preds.shape[0]
        return loss


def normalize_signal2(sig):
    return (sig-torch.mean(sig)) / (torch.std(sig)+1.00e-6)

In [None]:
##UNIMODAL DATALOADER
class RppgData2(Dataset):
    def __init__(self, datapath, datapath2, datapaths, recording_str, video_length = 900, num_segments = 3, frame_length = 300, fs=30, l_freq_bpm=45, u_freq_bpm=180, fft_resolution = 48) -> None:
        self.ppg_offset = 25
        self.num_samps = 30

        self.no_att_frms = 64
        #Data structure for videos
        self.video_length = video_length
        self.num_segments = num_segments
        self.part_length = int(video_length/num_segments)
        self.datapath = datapath
        self.datapath2 = datapath2

        self.fs = fs
        self.l_freq_bpm = l_freq_bpm
        self.u_freq_bpm = u_freq_bpm
        self.fft_resolution = fft_resolution
        
        #load videos and signals
        self.video_list = datapaths
        self.signal_list = []
        self.item_list = []
        #load signals
        remove_folders = []
        for folder in self.video_list:
            file_path1 = os.path.join(self.datapath, folder)
            file_path1_1 = os.path.join(self.datapath2, 'residual_0_'+folder+'.npy')
            file_path2_1 = os.path.join(self.datapath2, 'residual_300_'+folder+'.npy')
            file_path3_1 = os.path.join(self.datapath2, 'residual_600_'+folder+'.npy')

            if (os.path.exists(file_path1) and os.path.exists(file_path1_1) and os.path.exists(file_path2_1) and os.path.exists(file_path3_1)):
                if(os.path.exists(os.path.join(self.datapath,folder, f"rgbd_ppg.npy"))):
                    signal = np.load(os.path.join(self.datapath,folder, f"rgbd_ppg.npy"))
                    mean_temp = np.mean(signal)
                    std_temp = np.std(signal)
                    signal = (signal - mean_temp)/std_temp
                    self.signal_list.append(signal[self.ppg_offset:])
                else:
                    print(folder, "ppg doesn't exist.")
                    remove_folders.append(folder)
            else:
                print(folder, " doesn't exist.")
                remove_folders.append(folder)

        for i in remove_folders:
            self.video_list.remove(i)    
            print("removed", i)
            
        self.signal_list = np.array(self.signal_list)

        # Create a list of video number and valid frame nuber to extract the datad from.
        self.frame_length = frame_length
        self.video_nums = np.arange(0, len(self.video_list))

        #create all possible sampling combinations and put in self.all_idxs
        self.all_idxs = []
        for num in self.video_nums:
            cur_frame_nums = np.random.choice(np.arange(self.video_length - self.frame_length - self.ppg_offset), size=self.num_samps, replace=False)

            
            for cur_frame_num in cur_frame_nums:
                self.all_idxs.append((num,cur_frame_num))


    
            
            
    def __len__(self):
        return int(len(self.all_idxs))
    def __getitem__(self, idx):
        video_number, frame_start = self.all_idxs[idx]
     
        #Get signal
        item_sig = self.signal_list[int(video_number)][int(frame_start):int(frame_start+self.frame_length)]

        ##Load the 64 necessary frames for attention
        folder = self.video_list[video_number]
        temp_stor = np.zeros((self.frame_length,128,128,3))
        
        temp_stor2 = np.zeros((self.video_length,128,128,3))
        
        ixx = 0
        for frm in range(frame_start,frame_start+self.frame_length):
            im = imageio.imread(os.path.join(self.datapath, folder,'rgbd_rgb_'+str(frm)+'.png'))

            temp_stor[ixx] = im
            
            ixx+=1
            
        frmOuts = temp_stor
            
        for fst in range(self.num_segments):
            vid = np.load(os.path.join(self.datapath2, 'residual_'+str(int(fst*self.part_length))+'_'+folder+'.npy'))

            temp_stor2[fst*self.part_length:(fst+1)*self.part_length] = vid

        frmOuts2 = temp_stor2[frame_start:frame_start+self.frame_length]
        
        frmOuts = np.concatenate((frmOuts, frmOuts2), axis=-1)

        return frmOuts, item_sig
        

In [None]:
def extract_video(path, path2, cur_session):
    length_seg = 300
    vid = np.zeros((900,128,128,3))
    
    for j in range(900): #number of segments
        im = imageio.imread(os.path.join(path, cur_session,'rgbd_rgb_'+str(j)+'.png'))
        vid[j] = im
    
    vid2 = np.zeros((900,128,128,3))
    for j in range(3): #number of segments
        video_rd = np.load(os.path.join(path2, 'residual_'+str(j*length_seg)+'_'+cur_session+'.npy'))

        vid2[j*length_seg:(j+1)*length_seg] = video_rd
        
    #concatenate
    vid = np.concatenate((vid, vid2), axis=-1)

    return vid #should be of shape 3x900

##Eval model
def eval_model(root_dir, root_dir2, session_names, model, model2, in_frames=64, 
               hr_window_size = 300, stride = 128, video_fps = 30, ppg_offset = 25, fft_resolution=48):
    model.eval()#!!!
    model2.eval()
    video_samples = []
    for cur_session in session_names:
        video_sample = {"video_path" : root_dir, "video_path2" : root_dir2, "cur_session" : cur_session}
        video_samples.append(video_sample)


    #Get indices for FFT
    #TODO: band limit dependencies. clean this up
    l_freq_bpm = 45
    u_freq_bpm = 180

    for cur_video_sample in tqdm(video_samples):
        cur_video_path = cur_video_sample["video_path"]
        cur_video_path2 = cur_video_sample["video_path2"]
        cur_session = cur_video_sample["cur_session"]

        frames = extract_video(path=cur_video_path, path2=cur_video_path2, cur_session=cur_session) # (900, 128, 128, 3)
        target = np.load(os.path.join(cur_video_path, cur_session,'rgbd_ppg.npy'))
        ##Apply offset to target
        target = target[ppg_offset:]



        #Normalize target
        target = (target-np.mean(target,axis=0,keepdims=True))/np.std(target,axis=0,keepdims=True)

        #get the start indices
        start_indices = np.arange(0,frames.shape[0]+1-ppg_offset-hr_window_size,stride)
        

        batched_ip = []
        batched_tgt = []

        for ix in start_indices:
            temp_ip = frames[ix:ix+hr_window_size]

            batched_ip.append(temp_ip)


            temp_tgt = target[ix:ix+hr_window_size]
            
            batched_tgt.append(temp_tgt)

        batched_ip = torch.Tensor(np.array(batched_ip)).to(device)
        batched_tgt = np.array(batched_tgt) 
        
        


        # Potential High GPU usage
        # A batch_size of approx 14
        

        with torch.no_grad():
            msk = model2(batched_ip[:,0:64])
            est_wvfrm = model(batched_ip,msk)
            
            # (14, 64)
            est_wvfrm = est_wvfrm.squeeze().cpu().numpy() #Size: B,fft_size

       # Save
        cur_video_sample['est_wvfrm'] = est_wvfrm
        cur_video_sample['gt_wvs'] = batched_tgt
    print('All finished!')

    #Estimate using waveforms
    mae_list = []
    all_hr_est = []
    all_hr_gt = []
    for index, cur_video_sample in enumerate(video_samples):
        cur_video_path = cur_video_sample['video_path']
        est_wvfrm = cur_video_sample['est_wvfrm']
        
        gt_wvs = cur_video_sample['gt_wvs']

        #Just need to iterate over batch dimension and

        # Get est HR for each window
        hr_est_temp = []
        hr_gt_temp = []
        for ixx in range(est_wvfrm.shape[0]):
            est_hr = prpsd2(est_wvfrm[ixx], 30, 45, 150, BUTTER_ORDER=6, DETREND=False)
            gt_hr = prpsd2(gt_wvs[ixx], 30, 45, 150, BUTTER_ORDER=6, DETREND=False)
            hr_est_temp.append(est_hr)
            hr_gt_temp.append(gt_hr)

        hr_est_windowed = np.array([hr_est_temp])
        hr_gt_windowed = np.array(hr_gt_temp)
        all_hr_est.append(hr_est_temp)
        all_hr_gt.append(hr_gt_temp)

        # Errors
        RMSE, MAE, MAX, PCC = getErrors(hr_est_windowed, hr_gt_windowed)

        mae_list.append(MAE)
    print('Mean MAE:', np.mean(np.array(mae_list)))
    return np.array(mae_list), (all_hr_est, all_hr_gt)


In [None]:

def prpsd2(BVP, FS, LL_PR, UL_PR, BUTTER_ORDER=6, DETREND=False, PlotTF=False, FResBPM = 0.1):
    '''
    Estimates pulse rate from the power spectral density a BVP signal
    
    Inputs
        BVP              : A BVP timeseries. (1d numpy array)
        fs               : The sample rate of the BVP time series (Hz/fps). (int)
        lower_cutoff_bpm : The lower limit for pulse rate (bpm). (int)
        upper_cutoff_bpm : The upper limit for pulse rate (bpm). (int)
        butter_order     : Order of the Butterworth Filter. (int)
        detrend          : Detrend the input signal. (bool)
        FResBPM          : Resolution (bpm) of bins in power spectrum used to determine pulse rate and SNR. (float)
    
    Outputs
        pulse_rate       : The estimated pulse rate in BPM. (float)
    
    Daniel McDuff, Ethan Blackford, January 2019
    Copyright (c)
    Licensed under the MIT License and the RAIL AI License.
    '''
    from scipy.signal import butter
    from scipy import signal
    import numpy as np

    N = (60*FS)/FResBPM


    [b, a] = signal.butter(BUTTER_ORDER, [LL_PR/60, UL_PR/60], btype='bandpass', fs = FS)
    
    BVP = signal.filtfilt(b, a, np.double(BVP))
    

    F, Pxx = signal.periodogram(x=BVP,  nfft=N, fs=FS, detrend=False);  

    FMask = (F >= (LL_PR/60)) & (F <= (UL_PR/60))
    
    # Calculate predicted pulse rate:
    FRange = F * FMask
    PRange = Pxx * FMask
    MaxInd = np.argmax(PRange)
    pulse_rate_freq = FRange[MaxInd]
    pulse_rate = pulse_rate_freq*60

    # Optionally Plot the PSD and peak frequency
    if PlotTF:
        # Plot PSD (in dB) and peak frequency
        plt.figure()
        plt.plot(F, 10 * np.log10(Pxx))
        plt.plot(pulse_rate_freq, 10 * np.log10(PRange[MaxInd]),'ro')
        plt.xlabel('Frequency (Hz)')
        plt.ylabel('Power (dB)')
        plt.xlim([0, 4.5])
        plt.title('Power Spectrum and Peak Frequency')
            
    return pulse_rate

In [None]:
import pickle

destination_folder = "VideoDataset/rgb_files/"
destination_folder2 = "ResidualData/NpyRes/"


with open("assets/demo_fold.pkl", "rb") as fpf:
        out = pickle.load(fpf)

train = out[0]["train"]
val = out[0]["val"]
test = out[0]["test"]

#Dataset
dataset = RppgData2(datapath=destination_folder, datapath2=destination_folder2, datapaths=train, recording_str="rgbd_rgb")

train_dataloader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers = 2)

In [None]:
#Set Checkpoint Directory
colab_filename = 'Example_Train_Run'
ckpt_parent_path = 'checkpoints_mask'
assert os.path.exists(ckpt_parent_path), "Check folder to save checkpoint"
ckpt_path = os.path.join(ckpt_parent_path, colab_filename)
os.makedirs(ckpt_path, exist_ok=True)
print(f"Checkpoints will be saved in {ckpt_path}")

In [None]:
from torch import nn
class PlethRegressor(nn.Module):
    def __init__(self, inp_len, out_len, latent=512):
        super(PlethRegressor, self).__init__()
        self.inp_len = inp_len
        self.latent = latent
        self.out_len = out_len
        self.Enc1 = nn.Sequential(
            nn.Conv1d(6, 64, kernel_size=9, stride=1,padding=4),
            nn.BatchNorm1d(64),
            nn.ReLU(),
        )
        self.Enc2 = nn.Sequential(
            nn.Conv1d(64, 128, kernel_size=9, stride=1,padding=4),
            nn.BatchNorm1d(128),
            nn.ReLU(),
        )
        self.Enc3 = nn.Sequential(
            nn.Conv1d(128, 128, kernel_size=9, stride=1,padding=4),
            nn.BatchNorm1d(128),
            nn.ReLU(),
        )
        self.Enc4 = nn.Sequential(
            nn.Conv1d(128, 128, kernel_size=9, stride=1,padding=4),
            nn.BatchNorm1d(128),
            nn.ReLU(),
        )
        
        self.Dec1 = nn.Sequential(
            nn.Conv1d(128, 128, kernel_size=9, stride=1,padding=4),
            nn.BatchNorm1d(128),
            nn.ReLU(),
        )
        self.Dec2 = nn.Sequential(
            nn.Conv1d(128, 128, kernel_size=9, stride=1,padding=4),
            nn.BatchNorm1d(128),
            nn.ReLU(),
        )
        self.Dec3 = nn.Sequential(
            nn.Conv1d(128, 64, kernel_size=9, stride=1,padding=4),
            nn.BatchNorm1d(64),
            nn.ReLU(),
        )
        self.Dec4 = nn.Sequential(
            nn.Conv1d(64, 16, kernel_size=9, stride=1,padding=4),
            nn.BatchNorm1d(16),
            nn.ReLU(),
        )
        self.FinalLayer   = nn.Conv1d(16, 1, kernel_size=1, stride=1)
    def forward(self, inp, wts):
        wts = wts.reshape(wts.shape[0],1,wts.shape[1],wts.shape[2],1)
        inp = torch.sum(inp*wts,(2,3))/torch.sum(wts,(2,3))
        inp = inp.permute(0,2,1)
        
        ##Mean and std normalization
        inp = (inp - torch.mean(inp, 2, True))/torch.std(inp, dim=2, keepdim=True)

        inp = self.Enc1(inp)
        inp = self.Enc2(inp)
        inp = self.Enc3(inp)
        inp = self.Enc4(inp)
        inp = self.Dec1(inp)
        inp = self.Dec2(inp)
        inp = self.Dec3(inp)
        inp = self.Dec4(inp)
        output_signal = self.FinalLayer(inp)
        return torch.squeeze(output_signal, 1)

In [None]:
import math
import torch.nn as nn
from torch.nn.modules.utils import _triple
import torch as tr
import pdb

class SNRLoss_dB_Signals(nn.Module):
  def __init__(self):
    super(SNRLoss_dB_Signals, self).__init__()
  def forward(self, outputs: tr.Tensor, targets: tr.Tensor, Fs=30):
    device = outputs.device
    if not outputs.is_cuda:
      torch.backends.mkl.is_available()
    N = 600
    N_samp = outputs.shape[-1]
    pulse_band = tr.tensor([45/60., 180/60.], dtype=tr.float32).to(device)
    wind_sz = int(1*N/64)

    f = tr.linspace(0, Fs/2, int(N/2)+1, dtype=tr.float32).to(device)
    min_idx = tr.argmin(tr.abs(f - pulse_band[0]))
    max_idx = tr.argmin(tr.abs(f - pulse_band[1]))

    outputs = outputs.view(-1, N_samp)
    targets = targets.view(-1, N_samp)

    #Generate GT heart indices from GT signals
    Y = torch.fft.rfft(targets, n=N, dim=1, norm='forward')
    Y2 = tr.abs(Y) ** 2
    HRixs = tr.argmax(Y2[:,min_idx:max_idx],axis=1)+min_idx

    #print(outputs.shape)
    X = torch.fft.rfft(outputs, n=N, dim=1, norm='forward')

    P1 = tr.abs(X) ** 2

    # calc SNR for each batch
    losses = tr.empty((X.shape[0],), dtype=tr.float32)#.to(device)
    for count, ref_idx in enumerate(HRixs):
      pulse_freq_amp = tr.sum(P1[count, ref_idx-wind_sz:ref_idx+wind_sz])+tr.sum(P1[count, 2*ref_idx-wind_sz:2*ref_idx+wind_sz])
      other_avrg = (tr.sum(P1[count, min_idx:ref_idx-wind_sz])+tr.sum(P1[count, ref_idx+wind_sz:2*ref_idx-wind_sz]) + tr.sum(P1[count, 2*ref_idx+wind_sz:max_idx]))
      losses[count] = -10*tr.log10(pulse_freq_amp/(other_avrg+1e-7))
    losses.to(device)
    return tr.mean(losses)

In [None]:
def total_variation_loss(img):
     bs_img, h_img, w_img = img.size()
     tv_h = torch.pow(img[:,1:,:]-img[:,:-1,:], 2).sum()
     tv_w = torch.pow(img[:,:,1:]-img[:,:,:-1], 2).sum()
     return (tv_h+tv_w)/(bs_img*h_img*w_img)

In [None]:
from tqdm import tqdm
#CONTEXT PATH
PATH = os.path.join(os.getcwd(), f"{ckpt_path}/latest_context.pth")
model = PlethRegressor(inp_len=300, out_len=300).to(device)
model2 = CNN3D(frames=64, sidelen = 128, channels=6).to(device)

loss_fn  = Neg_Pearson2()
loss_fn2  = SNRLoss_dB_Signals()

lam = 3
lam2 = 5

learning_rate = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-6)
optimizer2 = torch.optim.Adam(model2.parameters(), lr=learning_rate, weight_decay=1e-6)

# Train configurations
epochs = 2
checkpoint_period = 1
epoch_start = 1
mae_best_loss = 1000

for epoch in range(epoch_start, epochs+1):
    # Training Phase
    loss_train = 0
    no_batches = 0
    for batch, (imgs, signal) in tqdm(enumerate(train_dataloader), total=len(train_dataloader)):
        model.train()
        model2.train()

        imgs = imgs.float().to(device)
        signal = signal.float().to(device)
        
        

        # Get attention mask
        msk = model2(imgs[:,0:64])

        # Predict the PPG signal and find ther loss
        pred_signal = model(imgs,msk)
        
        loss1 = loss_fn(pred_signal, signal)
        loss2 = loss_fn2(pred_signal, signal)
        loss3 = total_variation_loss(msk)
        
        loss = loss2+lam*loss1+lam2*loss3

        # Backprop
        optimizer.zero_grad()
        optimizer2.zero_grad()
        loss.backward()
        optimizer.step()
        optimizer2.step()
        
        # Accumulate the total loss
        loss_train += loss.item()
        no_batches+=1

    # Save the model every few epochs
    if(epoch % checkpoint_period == 0):
        torch.save(model.state_dict(), os.path.join(os.getcwd(), f"{ckpt_path}/PhysNet_state_dict_{epoch}_epochs.pth"))
        torch.save(model2.state_dict(), os.path.join(os.getcwd(), f"{ckpt_path}/PhysNetAtt_state_dict_{epoch}_epochs.pth"))
        #See if best checkpoint
        mae_loss_list, hrs = eval_model(root_dir=destination_folder, root_dir2=destination_folder2, session_names=val, model=model, model2=model2)
        current_loss = np.mean(mae_loss_list) 
        if(current_loss < mae_best_loss):
            mae_best_loss = current_loss
            torch.save(model.state_dict(), os.path.join(os.getcwd(), f"{ckpt_path}/PhysNet_state_dict_best.pth"))
            torch.save(model2.state_dict(), os.path.join(os.getcwd(), f"{ckpt_path}/PhysNetAtt_state_dict_best.pth"))
            print("Best checkpoint saved!")
        print("Saved Checkpoint!")

    print(f"Epoch: {epoch} ; Loss: {loss_train/no_batches:>7f}")
    
    #plot an example mask
    ixx = 0
    og_msk = msk[ixx].cpu().detach().numpy()
    
    plt.figure()
    plt.imshow(og_msk)
    plt.show()
    

        