# Python code starts here

In [None]:
import numpy as np 
import matplotlib.pyplot as plt 
import os
import torch
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
import imageio
import scipy.signal

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Running on device: {}'.format(device))

In [None]:
class PhysNet_padding_Encoder_Decoder_MAX(torch.nn.Module):
    def __init__(self, frames=64, channels=3):  
        super(PhysNet_padding_Encoder_Decoder_MAX, self).__init__()
        
        self.ConvBlock1 = torch.nn.Sequential(
            torch.nn.Conv3d(channels, 16, [1,5,5],stride=1, padding=[0,2,2]),
            torch.nn.BatchNorm3d(16),
            torch.nn.ReLU(inplace=True),
        )

        self.ConvBlock2 = torch.nn.Sequential(
            torch.nn.Conv3d(16, 32, [3, 3, 3], stride=1, padding=1),
            torch.nn.BatchNorm3d(32),
            torch.nn.ReLU(inplace=True),
        )
        self.ConvBlock3 = torch.nn.Sequential(
            torch.nn.Conv3d(32, 64, [3, 3, 3], stride=1, padding=1),
            torch.nn.BatchNorm3d(64),
            torch.nn.ReLU(inplace=True),
        )
        
        self.ConvBlock4 = torch.nn.Sequential(
            torch.nn.Conv3d(64, 64, [3, 3, 3], stride=1, padding=1),
            torch.nn.BatchNorm3d(64),
            torch.nn.ReLU(inplace=True),
        )
        self.ConvBlock5 = torch.nn.Sequential(
            torch.nn.Conv3d(64, 64, [3, 3, 3], stride=1, padding=1),
            torch.nn.BatchNorm3d(64),
            torch.nn.ReLU(inplace=True),
        )
        self.ConvBlock6 = torch.nn.Sequential(
            torch.nn.Conv3d(64, 64, [3, 3, 3], stride=1, padding=1),
            torch.nn.BatchNorm3d(64),
            torch.nn.ReLU(inplace=True),
        )
        self.ConvBlock7 = torch.nn.Sequential(
            torch.nn.Conv3d(64, 64, [3, 3, 3], stride=1, padding=1),
            torch.nn.BatchNorm3d(64),
            torch.nn.ReLU(inplace=True),
        )
        self.ConvBlock8 = torch.nn.Sequential(
            torch.nn.Conv3d(64, 64, [3, 3, 3], stride=1, padding=1),
            torch.nn.BatchNorm3d(64),
            torch.nn.ReLU(inplace=True),
        )
        self.ConvBlock9 = torch.nn.Sequential(
            torch.nn.Conv3d(64, 64, [3, 3, 3], stride=1, padding=1),
            torch.nn.BatchNorm3d(64),
            torch.nn.ReLU(inplace=True),
        )
        
        self.upsample = torch.nn.Sequential(
            torch.nn.ConvTranspose3d(in_channels=64, out_channels=64, kernel_size=[4,1,1], stride=[2,1,1], padding=[1,0,0]),   #[1, 128, 32]
            torch.nn.BatchNorm3d(64),
            torch.nn.ELU(),
        )
        self.upsample2 = torch.nn.Sequential(
            torch.nn.ConvTranspose3d(in_channels=64, out_channels=64, kernel_size=[4,1,1], stride=[2,1,1], padding=[1,0,0]),   #[1, 128, 32]
            torch.nn.BatchNorm3d(64),
            torch.nn.ELU(),
        )
 
        self.ConvBlock10 = torch.nn.Conv3d(64, 1, [1,1,1],stride=1, padding=0)
        
        self.MaxpoolSpa = torch.nn.MaxPool3d((1, 2, 2), stride=(1, 2, 2))
        self.MaxpoolSpaTem = torch.nn.MaxPool3d((2, 2, 2), stride=2)
        
        
        #self.poolspa = torch.nn.AdaptiveMaxPool3d((frames,1,1))    # pool only spatial space 
        self.poolspa = torch.nn.AdaptiveAvgPool3d((frames,1,1))

        
    def forward(self, x):	    	# x [3, T, 128,128]
        x_visual = x
        [batch,channel,length,width,height] = x.shape
        # print(length)
        # print(torch.mean(x, dim=(1,2,3,4)))
          
        x = self.ConvBlock1(x)		     # x [3, T, 128,128]
        x = self.MaxpoolSpa(x)       # x [16, T, 64,64]

        # print(torch.mean(x, dim=(1,2,3,4)))

        x = self.ConvBlock2(x)		    # x [32, T, 64,64]
        x_visual6464 = self.ConvBlock3(x)	    	# x [32, T, 64,64]
        x = self.MaxpoolSpaTem(x_visual6464)      # x [32, T/2, 32,32]    Temporal halve

        x = self.ConvBlock4(x)		    # x [64, T/2, 32,32]
        x_visual3232 = self.ConvBlock5(x)	    	# x [64, T/2, 32,32]
        x = self.MaxpoolSpaTem(x_visual3232)      # x [64, T/4, 16,16]

        # print(x.shape)

        x = self.ConvBlock6(x)		    # x [64, T/4, 16,16]
        x_visual1616 = self.ConvBlock7(x)	    	# x [64, T/4, 16,16]
        x = self.MaxpoolSpa(x_visual1616)      # x [64, T/4, 8,8]

        # print(x.shape)

        x = self.ConvBlock8(x)		    # x [64, T/4, 8, 8]
        x = self.ConvBlock9(x)		    # x [64, T/4, 8, 8]
        x = self.upsample(x)		    # x [64, T/2, 8, 8]
        x = self.upsample2(x)		    # x [64, T, 8, 8]

        # print(x.shape)
        
        x = self.poolspa(x)     # x [64, T, 1,1]    -->  groundtruth left and right - 7 
        x = self.ConvBlock10(x)    # x [1, T, 1,1]

        # print(x.shape)
        # print(torch.mean(x, dim=(1,2,3,4)))

        
        rPPG = x.view(-1,length)            

        return rPPG, x_visual, x_visual3232, x_visual1616

In [None]:
def eval_model(model, video_filename="output/ref.avi", n_frames=300, sequence_length = 64):
    model.eval()#!!!
    print(f"Reading: {video_filename}")
    frames = np.array(imageio.mimread(video_filename))[:n_frames]
    cur_est_ppgs = None
    with torch.no_grad():
        for cur_frame_num in range(frames.shape[0]):
            cur_frame = frames[cur_frame_num, :, :, :]
            # Preprocess
            cur_frame_cropped = torch.from_numpy(cur_frame.astype(np.uint8)).permute(2, 0, 1).float()
            cur_frame_cropped = cur_frame_cropped / 255
            cur_frame_cropped = cur_frame_cropped.unsqueeze(0).cuda() # Add the T dim
            # Concat
            if cur_frame_num % sequence_length == 0:
                cur_cat_frames = cur_frame_cropped
            else:
                cur_cat_frames = torch.cat((cur_cat_frames, cur_frame_cropped), 0)

            # Test the performance
            if cur_cat_frames.shape[0] == sequence_length:
                # DL
                cur_cat_frames = cur_cat_frames.unsqueeze(0) # Add the B dim
                cur_cat_frames = torch.transpose(cur_cat_frames, 1, 2)
                cur_est_ppg, _, _, _ = model(cur_cat_frames)
                cur_est_ppg = cur_est_ppg.squeeze().cpu().numpy()
            # First seq
                if cur_est_ppgs is None: 
                    cur_est_ppgs = cur_est_ppg
                else:
                    cur_est_ppgs = np.concatenate((cur_est_ppgs, cur_est_ppg), -1)
    est_hr = prpsd2(cur_est_ppgs-np.mean(cur_est_ppgs), 30, 45, 150, BUTTER_ORDER=6, DETREND=False)
    print(f"Estimated HR: {est_hr}")
    return cur_est_ppgs

def prpsd2(BVP, FS, LL_PR, UL_PR, BUTTER_ORDER=6, DETREND=False, PlotTF=False, FResBPM = 0.1, RECT=True):
    '''
    Estimates pulse rate from the power spectral density a BVP signal
    
    Inputs
        BVP              : A BVP timeseries. (1d numpy array)
        fs               : The sample rate of the BVP time series (Hz/fps). (int)
        lower_cutoff_bpm : The lower limit for pulse rate (bpm). (int)
        upper_cutoff_bpm : The upper limit for pulse rate (bpm). (int)
        butter_order     : Order of the Butterworth Filter. (int)
        detrend          : Detrend the input signal. (bool)
        FResBPM          : Resolution (bpm) of bins in power spectrum used to determine pulse rate and SNR. (float)
    
    Outputs
        pulse_rate       : The estimated pulse rate in BPM. (float)
    
    Daniel McDuff, Ethan Blackford, January 2019
    Copyright (c)
    Licensed under the MIT License and the RAIL AI License.
    '''

    N = (60*FS)/FResBPM

    # Detrending + nth order butterworth + periodogram
    if DETREND:
        BVP = detrend(np.cumsum(BVP), 100)
    if BUTTER_ORDER:
        [b, a] = scipy.signal.butter(BUTTER_ORDER, [LL_PR/60, UL_PR/60], btype='bandpass', fs = FS)
    
    BVP = scipy.signal.filtfilt(b, a, np.double(BVP))
    
    # Calculate the PSD and the mask for the desired range
    if RECT:
        F, Pxx = scipy.signal.periodogram(x=BVP,  nfft=N, fs=FS, detrend=False);  
    else:
        F, Pxx = scipy.signal.periodogram(x=BVP, window=np.hanning(len(BVP)), nfft=N, fs=FS)
    FMask = (F >= (LL_PR/60)) & (F <= (UL_PR/60))
    
    # Calculate predicted pulse rate:
    FRange = F * FMask
    PRange = Pxx * FMask
    MaxInd = np.argmax(PRange)
    pulse_rate_freq = FRange[MaxInd]
    pulse_rate = pulse_rate_freq*60

    # Optionally Plot the PSD and peak frequency
    if PlotTF:
        # Plot PSD (in dB) and peak frequency
        plt.figure()
        plt.plot(F, 10 * np.log10(Pxx))
        plt.plot(pulse_rate_freq, 10 * np.log10(PRange[MaxInd]),'ro')
        plt.xlabel('Frequency (Hz)')
        plt.ylabel('Power (dB)')
        plt.xlim([0, 4.5])
        plt.title('Power Spectrum and Peak Frequency')
            
    return pulse_rate

In [None]:
model = PhysNet_padding_Encoder_Decoder_MAX(frames=64, channels=3).to(device)
model.load_state_dict(torch.load(f"assets/PhysNet_state_dict_best.pth"))

In [None]:
ROOT = '/home/pradyumnachari/Documents/ImplicitPPG/SIGGRAPH_Data/rgb_files/'
_trial = 'v_93_6'
plt.imshow(imageio.v2.imread(os.path.join(ROOT,_trial,"rgbd_rgb_0.png")))
plt.show()
print(os.path.join(ROOT,_trial,"rgbd_ppg.npy"))
gt_ppg = np.load(os.path.join(ROOT,_trial,"rgbd_ppg.npy"))[:300]
gt_hr = prpsd2(gt_ppg-np.mean(gt_ppg), 30, 45, 150, BUTTER_ORDER=6, DETREND=False)
print(f"GT: {gt_hr}")
plt.figure()
plt.plot(gt_ppg)
plt.title('GT PPG')
plt.show()

In [None]:
temp = eval_model(model, f"temp.avi", 300)
plt.plot(temp)
plt.show()
print('-'*100)
pleth = eval_model(model, f"pleth.avi", 300)
plt.plot(pleth)
plt.show()
print('-'*100)
pleth = np.array(imageio.v2.mimread("pleth.avi")).mean(1).mean(1)[:,1]
pleth_hr = prpsd2(pleth-np.mean(pleth), 30, 45, 150, BUTTER_ORDER=6, DETREND=False)
print(pleth_hr)
plt.plot(pleth)
plt.show()

In [None]:
orig_est=[0]
# orig_est = eval_model(model, f"assets/v_101_2.avi", 300)
# orig_est = orig_est - orig_est.min()
# orig_est = orig_est / orig_est.max()
print('-'*100)
gt_ppg = np.load(os.path.join(ROOT,_trial,"rgbd_ppg.npy"))[:300]
gt_ppg = gt_ppg - gt_ppg.min()
gt_ppg = gt_ppg / gt_ppg.max()
start_epoch = 1
end_epoch = 5
for i in range(start_epoch,end_epoch+1):
    new_est = eval_model(model, f"residual/epoch_{i:03d}.avi", 300)
    new_est = new_est - new_est.min()
    new_est = new_est / new_est.max()
    motion_est = eval_model(model, f"residual/motion_epoch_{i:03d}.avi", 300)
    motion_est = motion_est - motion_est.min()
    motion_est = motion_est / motion_est.max()
    rescaled_residual = eval_model(model, f"residual/rescaled_residual_epoch_{i:03d}.avi", 300)
    rescaled_residual = rescaled_residual - rescaled_residual.min()
    rescaled_residual = rescaled_residual / rescaled_residual.max()

    pleth = np.array(imageio.v2.mimread(f'residual/rescaled_residual_epoch_{i:03d}.avi'))
    pleth = pleth.mean(1).mean(1)[:,1]
    pleth_hr = prpsd2(pleth-np.mean(pleth), 30, 45, 150, BUTTER_ORDER=6, DETREND=False)
    print(pleth_hr)
    pleth = pleth - pleth.min()
    pleth = pleth / pleth.max()
    pleth_hr = prpsd2(pleth-np.mean(pleth), 30, 45, 150, BUTTER_ORDER=6, DETREND=False)
    print(pleth_hr)

    plt.figure(figsize=(30,5))
    plt.plot(gt_ppg, label="GT")
    plt.plot(orig_est, label="orig")
    plt.plot(new_est, label="new")
    plt.plot(motion_est, label="motion")
    plt.plot(pleth, label="green")
    plt.title('Estimated PPG')
    plt.legend()
    plt.show()
    
    print('-'*100)

In [None]:
# print('-'*100)
# gt_ppg = np.load(os.path.join(ROOT,_trial,"rgbd_ppg.npy"))[:300]
# gt_ppg = gt_ppg - gt_ppg.min()
# gt_ppg = gt_ppg / gt_ppg.max()
# start_epoch = 1
# end_epoch = 10
# for i in range(start_epoch,end_epoch+1):

#     new_est = eval_model(model, f"delta_motion/epoch_{i:03d}.avi", 300)
#     new_est = new_est - new_est.min()
#     new_est = new_est / new_est.max()
#     plt.figure(figsize=(30,5))
#     plt.plot(gt_ppg, label="GT")
#     plt.plot(new_est, label="new")
#     plt.title('Estimated PPG')
#     plt.legend()
#     plt.show()
    
#     print('-'*100)