# Import

In [None]:
import os
from tqdm import tqdm
import numpy as np
import commentjson as json
import imageio.v2 as iio2
import matplotlib.pyplot as plt

import torch
import torch.utils.data
import tinycudann as tcnn
import argparse

from utils import Dict2Class, CNN3D, VideoGridDataset, prpsd2

# Config Handling

In [None]:
# Load the config.
with open("./configs/residual_plethysmograph.json") as f:
    json_config = json.load(f)

# Essential config params
json_config["verbose"] = True

# Convert the dictionary to a class to mimic argparser
args = Dict2Class(json_config)
# Torch Device
args.pleth_device = torch.device(args.pleth_device)

# Pleth Model

In [None]:
pleth_enc = tcnn.Encoding(args.pleth_encoding["input_dims"], args.pleth_encoding)
pleth_net = tcnn.Network(pleth_enc.n_output_dims, args.pleth_network["output_dims"], args.pleth_network)
pleth_model = torch.nn.Sequential(pleth_enc, pleth_net)
pleth_model.to(args.pleth_device)

# Mask Model

In [None]:
# We use the same device as the pleth. This can be if there is a lack of GPU memory.
mask_model = CNN3D(frames=64, sidelen = 128, channels=6).to(args.pleth_device)

# Create the Query Grid

`Set the path and other params here`

In [None]:
video_path = "./assets/vid.avi"
num_frames = 300
start_frame = 0

gt_path = './assets/ppg.npy'

In [None]:
dset = VideoGridDataset(video_path, verbose=True, num_frames=num_frames, 
                        start_frame=start_frame, pixel_norm=255)
trace_loc = dset.loc.to(args.pleth_device)

# Load Models

`Paths to load models`

In [None]:
pleth_model_path = "./residual_plethysmograph/epoch_5.pth"
mask_model_path = "./assets/mask_model.pth"

In [None]:
pleth_model.load_state_dict(torch.load(pleth_model_path)['model_state_dict'])
mask_model.load_state_dict(torch.load(mask_model_path))

# Query and Generate the Residual Signal

In [None]:
with torch.no_grad():
    pleth_tensor = pleth_model(trace_loc)

In [None]:
pleth_tensor = pleth_tensor.reshape(*dset.shape).permute(2,0,1,3).unsqueeze(0)
print(pleth_tensor.shape)

# Generate the Mask

In [None]:
vid_tensor = dset.vid.to(args.pleth_device)
vid_tensor = vid_tensor.reshape(*dset.shape).permute(2,0,1,3).unsqueeze(0)
print(vid_tensor.shape)

In [None]:
inp_to_model = torch.cat((vid_tensor, pleth_tensor), dim=-1)
# Due to a compute limit, the model was only trained on the first 64 frames.
mask = mask_model(inp_to_model[:,0:64])

In [None]:
pleth_full_vid_npy = pleth_tensor.detach().cpu().numpy()[0]
mask_npy = mask.detach().cpu().numpy()[0]

In [None]:
plt.imshow(mask_npy)
plt.axis("off")
plt.title("Generated Mask")
plt.show()

# Define a detrend Function

Based on the rPPG_Toolbox detrend function

In [None]:
from scipy import sparse

def utils_detrend(input_signal, lambda_value):
    signal_length = input_signal.shape[0]
    # observation matrix
    H = np.identity(signal_length)
    ones = np.ones(signal_length)
    minus_twos = -2 * np.ones(signal_length)
    diags_data = np.array([ones, minus_twos, ones])
    diags_index = np.array([0, 1, 2])
    D = sparse.spdiags(diags_data, diags_index,
                (signal_length - 2), signal_length).toarray()
    filtered_signal = np.dot(
        (H - np.linalg.inv(H + (lambda_value ** 2) * np.dot(D.T, D))), input_signal)
    return filtered_signal

def detrend_signal(BVP, fs=30):
    BVP = np.reshape(BVP,(1,-1))
    BVP = utils_detrend(np.mat(BVP).H, 100)
    BVP = np.asarray(np.transpose(BVP))[0]
    return BVP

# Green Estimate

In [None]:
greeen_est = pleth_full_vid_npy[...,1].mean(1).mean(1)
greeen_est = detrend_signal(greeen_est)
plt.figure(figsize=(10,3))
plt.plot(greeen_est)
plt.title("Green Estimate")
plt.show()
print(f"Predicted Heart Rate: {prpsd2(greeen_est-np.mean(greeen_est), FS=30, LL_PR=45, UL_PR=180)}")

# Masked Green Estimate

In [None]:
greeen_masked_est = (pleth_full_vid_npy[...,1] * mask_npy).sum(1).sum(1) / mask_npy.sum(0).sum(0)
greeen_est = detrend_signal(greeen_masked_est)
plt.figure(figsize=(10,3))
plt.plot(greeen_masked_est)
plt.title("Green Masked Estimate")
plt.show()
print(f"Predicted Heart Rate: {prpsd2(greeen_masked_est-np.mean(greeen_masked_est), FS=30, LL_PR=45, UL_PR=180)}")

# Ground Truth

In [None]:
gt = np.load(gt_path)
# We process in chunks of 300
# If the default values were used, then only the first 300 samples were processed
gt = gt[start_frame : start_frame+num_frames]
plt.figure(figsize=(10,3))
plt.plot(gt)
plt.title("Green Estimate")
plt.show()
print(f"Predicted Heart Rate: {prpsd2(gt-np.mean(gt), FS=30, LL_PR=45, UL_PR=180)}")