In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
from tqdm import tqdm
import numpy as np
import commentjson as json
import imageio.v2 as iio2
import matplotlib.pyplot as plt
from copy import deepcopy
import random

import torch
import torch.utils.data
import tinycudann as tcnn
import argparse

from implicitpleth.models.siren import Siren
from implicitpleth.models.combinations import MotionNet
from implicitpleth.data.datasets import VideoGridDataset
from implicitpleth.utils.utils import trace_video, Dict2Class

In [None]:
# # Fix random seed
# sd = 0
# np.random.seed(sd)
# torch.backends.cudnn.deterministic = True
# torch.manual_seed(sd)
# random.seed(sd)
# if torch.cuda.is_available():
#     torch.cuda.manual_seed_all(sd)

In [None]:
with open("./configs/spatiotemporal_residual.json") as f:
    json_config = json.load(f)

json_config["video_path"] = "./assets/v_101_2.avi"
# json_config["video_path"] = "./assets/v_100_2.avi"
# json_config["video_path"] = "./assets/v_84_2.avi"
json_config["verbose"] = True
json_config["append_save_path"] = None
json_config["append_load_path"] = None
args = Dict2Class(json_config)
args.spatiotemporal_device = torch.device(args.spatiotemporal_device)
args.deltaspatial_device = torch.device(args.deltaspatial_device)
args.pleth_spatial_device = torch.device(args.pleth_spatial_device)
args.pleth_temporal_device = torch.device(args.pleth_temporal_device)
args.io_device = torch.device(args.io_device)

In [None]:
class PlethSpatioTemporalModel(torch.nn.Module):
    def __init__(self, pleth_spatial_encoding, pleth_spatial_network, 
                 pleth_temporal_encoding, pleth_temporal_network):
        super().__init__()

        self.spatial_model = tcnn.NetworkWithInputEncoding(pleth_spatial_encoding["input_dims"],
                                                           pleth_spatial_network["output_dims"],
                                                           pleth_spatial_encoding,
                                                           pleth_spatial_network)
        self.temporal_model = tcnn.NetworkWithInputEncoding(pleth_temporal_encoding["input_dims"],
                                                            pleth_temporal_network["output_dims"],
                                                            pleth_temporal_encoding,
                                                            pleth_temporal_network)
        self.spatial_device = torch.device("cpu")
        self.temporal_device = torch.device("cpu")

    def forward(self, inp, flag = False):
        # All custom models in the repo return 2 values
        if flag:
            inp = inp.to(self.spatial_device)
            spatial_out = self.spatial_model(inp)
            return spatial_out, None
        else:
            inp = inp.to(self.temporal_device)
            temporal_out = self.temporal_model(inp)
            return temporal_out, None

    def set_device(self, spatial_device, temporal_device):
        self.spatial_device = spatial_device
        self.temporal_device = temporal_device
        self.spatial_model.to(self.spatial_device)
        self.temporal_model.to(self.temporal_device)

    def load_spatial_checkpoint(self, load_path, key='model_state_dict'):
        self.spatial_model.load_state_dict(torch.load(load_path)[key])


In [None]:
print(args.motion_model)
with open(args.motion_model["config"]) as mmf:
    config = json.load(mmf)
motion_model = MotionNet(config["spatiotemporal_encoding"], 
                         config["spatiotemporal_network"],
                         config["deltaspatial_encoding"], 
                         config["deltaspatial_network"])
motion_model.load_state_dict(torch.load(args.motion_model["load_path"])["model_state_dict"])
# Freeze the model
motion_model.eval()
# for params in motion_model.parameters():
#     params.requires_grad = False
# Set the model device
motion_spatiotemporal_device = torch.device(torch.device("cuda:0"))
motion_deltaspatial_device = torch.device(torch.device("cuda:0"))
motion_model.set_device(motion_spatiotemporal_device, motion_deltaspatial_device)

# motion_tensor, _ = motion_model(dset.loc)
# motion_tensor = motion_tensor.detach().cpu().float().reshape(dset.shape).permute(2,0,1,3).numpy()

In [None]:
# Folders for the traced video data and checkpoints.
if args.append_save_path is not None:
    args.trace["folder"] =  args.trace["folder"] + args.append_save_path
    args.checkpoints["dir"] =  args.checkpoints["dir"] + args.append_save_path
if args.trace["folder"] is not None:
    os.makedirs(args.trace["folder"], exist_ok=True)
    if args.verbose: print(f'Saving trace to {args.trace["folder"]}')
if args.checkpoints["save"]:
    os.makedirs(args.checkpoints["dir"], exist_ok=True)
    if args.verbose: print(f'Saving checkpoints to {args.checkpoints["dir"]}')

In [None]:
# Get info before iterating.
epochs = args.train["epochs"]
ndigits_epoch = int(np.log10(epochs)+1)
latest_ckpt_path = os.path.join(args.checkpoints["dir"], args.checkpoints["latest"])
if os.path.exists(latest_ckpt_path):
    if args.verbose: print('Loading latest checkpoint...')
    # saved_dict = torch.load(latest_ckpt_path)
    # pleth_model.load_state_dict(saved_dict["model_state_dict"])
    # if "optimizer_spatial_state_dict" in saved_dict.keys():
    #     opt_spatial.load_state_dict(saved_dict["optimizer_spatial_state_dict"])
    # if "optimizer_temporal_state_dict" in saved_dict.keys():
    #     opt_temporal.load_state_dict(saved_dict["optimizer_temporal_state_dict"])
    # start_epoch = saved_dict["epoch"] + 1
    # if args.verbose: print(f'Continuing from epoch {start_epoch}.')
else:
    if args.verbose: print('Starting from scratch.')
    start_epoch = 1

In [None]:
import scipy.signal
def pulse_rate_from_power_spectral_density(pleth_sig: np.array, FS: float,
                                           LL_PR: float, UL_PR: float,
                                           BUTTER_ORDER: int = 6,
                                           DETREND: bool = False,
                                           FResBPM: float = 0.1) -> float:
    """ Function to estimate the pulse rate from the power spectral density of the plethysmography signal.
    Args:
        pleth_sig (np.array): Plethysmography signal.
        FS (float): Sampling frequency.
        LL_PR (float): Lower cutoff frequency for the butterworth filtering.
        UL_PR (float): Upper cutoff frequency for the butterworth filtering.
        BUTTER_ORDER (int, optional): Order of the butterworth filter. Give None to skip filtering. Defaults to 6.
        DETREND (bool, optional): Boolena Flag for executing cutsom_detrend. Defaults to False.
        FResBPM (float, optional): Frequency resolution. Defaults to 0.1.
    Returns:
        pulse_rate (float): _description_
    
    Daniel McDuff, Ethan Blackford, January 2019
    Copyright (c)
    Licensed under the MIT License and the RAIL AI License.
    """

    N = (60*FS)/FResBPM

    # Detrending + nth order butterworth + periodogram
    if DETREND:
        pleth_sig = custom_detrend(np.cumsum(pleth_sig), 100)
    if BUTTER_ORDER:
        [b, a] = scipy.signal.butter(BUTTER_ORDER, [LL_PR/60, UL_PR/60], btype='bandpass', fs = FS)
    
    pleth_sig = scipy.signal.filtfilt(b, a, np.double(pleth_sig))
    
    # Calculate the PSD and the mask for the desired range
    F, Pxx = scipy.signal.periodogram(x=pleth_sig,  nfft=N, fs=FS);  
    FMask = (F >= (LL_PR/60)) & (F <= (UL_PR/60))
    
    # Calculate predicted pulse rate:
    FRange = F * FMask
    PRange = Pxx * FMask
    MaxInd = np.argmax(PRange)
    pulse_rate_freq = FRange[MaxInd]
    pulse_rate = pulse_rate_freq*60
            
    return pulse_rate

In [None]:
dset = VideoGridDataset(args.video_path, verbose=args.verbose, num_frames=args.data["num_frames"], 
                        start_frame=args.data["start_frame"], pixel_norm=args.data["norm_value"])

In [None]:
# with torch.no_grad():
# motion_tensor, _ = motion_model(dset.loc)
# motion_tensor = motion_tensor.reshape(dset.shape).to(args.io_device)
# motion_orig = deepcopy(motion_tensor.detach())

In [None]:
shape = (128,128,300,3)
X,Y = np.meshgrid(np.arange(shape[1]),np.arange(shape[0]))
T = np.arange(shape[2])
X = (torch.tensor(X.ravel()) / shape[1]) - 0.5
Y = (torch.tensor(Y.ravel()) / shape[0]) - 0.5
T = (torch.tensor(T.ravel()) / shape[2]) - 0.5
XY = torch.stack((X,Y), dim=-1)
T0 = torch.stack((T,torch.zeros_like(T)), dim=-1)
# pixel = dset.vid.reshape(shape)
# pixel = dset.vid

In [None]:
dloader = torch.utils.data.DataLoader(range(len(dset)), batch_size=args.data["batch_size"], shuffle=True)

In [None]:
# pte = tcnn.Encoding(args.pleth_temporal_encoding["input_dims"], args.pleth_temporal_encoding)
# ptn = tcnn.Network(pte.n_output_dims, args.pleth_temporal_network["output_dims"], args.pleth_temporal_network)
# pleth_temporal_model = torch.nn.Sequential(pte,ptn)
# pleth_temporal_model.to(args.io_device)
# pleth_spatial_model = tcnn.NetworkWithInputEncoding(args.pleth_spatial_encoding["input_dims"],
#                                                     args.pleth_spatial_network["output_dims"],
#                                                     args.pleth_spatial_encoding,
#                                                     args.pleth_spatial_network)
# # ckpt = torch.load(os.path.join(args.pre_train_checkpoints["dir"], args.pre_train_checkpoints["latest"]))
# # pleth_spatial_model.load_state_dict(ckpt['model_state_dict'])
# pleth_spatial_model.to(args.io_device)
# opt_temporal_enc = torch.optim.Adam(pte.parameters(), lr=args.opt["lr"],
#                                 betas=(args.opt["beta1"], args.opt["beta2"]), eps=args.opt["eps"])
# opt_temporal_net = torch.optim.Adam(ptn.parameters(), lr=args.opt["lr"],
#                                 betas=(args.opt["beta1"], args.opt["beta2"]), eps=args.opt["eps"])
# opt_spatial = torch.optim.Adam(pleth_spatial_model.parameters(), lr=args.opt["lr"],
#                                 betas=(args.opt["beta1"], args.opt["beta2"]), eps=args.opt["eps"])

# flag=False
# switch = 3
# for epoch in range(start_epoch,epochs+1):
#     train_loss = 0
#     pleth_temporal_model.train()
#     pleth_spatial_model.train()
#     motion_model.train()
#     for count, item in tqdm(enumerate(dloader),total=len(dloader)):
#         if epoch > switch:
#             # print("Yes")
#             loc = dset.loc[item].half().to(args.io_device)
#             pixel = dset.vid[item].half().to(args.io_device)
#             XY_loc = loc[...,0:2]
#             T_loc = torch.cat((loc[...,2:3], torch.zeros_like(loc[...,2:3])), dim=-1)
#             motion_output, _ = motion_model(loc)
#             pleth_spatial_output = pleth_spatial_model(XY_loc)
#             pleth_temporal_output = pleth_temporal_model(T_loc)
#             output = motion_output + (pleth_spatial_output*pleth_temporal_output)
#             # Since the model takes care of moving the data to different devices, move GT correspondingly.
#             pixel = pixel.to(output.dtype).to(output.device)
#             # Backpropagation.
#             if flag:
#                 opt_temporal_enc.zero_grad()
#                 opt_temporal_net.zero_grad()
#                 l2_error = (output - pixel)**2
#                 loss = l2_error.mean()
#                 loss.backward()
#                 opt_temporal_enc.step()
#                 opt_temporal_net.step()
#             else:
#                 opt_spatial.zero_grad()
#                 l2_error = (output - pixel)**2
#                 loss = l2_error.mean()
#                 loss.backward()
#                 opt_spatial.step()
#             train_loss += loss.item()
#             if epoch % 1:
#                 flag = not flag
#             pass
#         else:
#             # print("No")
#             loc = dset.loc[item].half().to(args.io_device)
#             pixel = dset.vid[item].half().to(args.io_device)
#             T_loc = torch.cat((loc[...,2:3], torch.zeros_like(loc[...,2:3])), dim=-1)
#             motion_output, _ = motion_model(loc)
#             # pleth_output = pleth_model(loc)
#             pleth_output = pleth_temporal_model(T_loc)
#             output = motion_output + pleth_output
#             # Since the model takes care of moving the data to different devices, move GT correspondingly.
#             pixel = pixel.to(output.dtype).to(output.device)
#             # Backpropagation.
#             opt_temporal_enc.zero_grad()
#             opt_temporal_net.zero_grad()
#             l2_error = (output - pixel)**2
#             loss = l2_error.mean()
#             loss.backward()
#             opt_temporal_enc.step()
#             opt_temporal_net.step()
#             train_loss += loss.item()


#     print(f'Epoch: {epoch}, Loss: {train_loss/len(dloader)}', flush=True)
#     with torch.no_grad():
#         if epoch > switch:
#             motion_model.eval()
#             pleth_spatial_model.eval()
#             pleth_temporal_model.eval()

#             trace_loc = dset.loc.half().to(args.io_device)
#             trace_motion, _ = motion_model(trace_loc)
#             trace_XY_loc = trace_loc[...,0:2]
#             trace_pleth_spatial = pleth_spatial_model(trace_XY_loc)
#             trace_T_loc = torch.cat((trace_loc[...,2:3], torch.zeros_like(trace_loc[...,2:3])), dim=-1)
#             trace_pleth_temporal = pleth_temporal_model(trace_T_loc)
            
#             trace = trace_motion + (trace_pleth_spatial*trace_pleth_temporal)
#             trace = trace.detach().cpu().float().reshape(dset.shape).permute(2,0,1,3).numpy()
#             trace = (np.clip(trace, 0, 1)*255).astype(np.uint8)
#             save_path = os.path.join(args.trace["folder"], f'{args.trace["file_tag"]}{str(epoch).zfill(ndigits_epoch)}.avi')
#             iio2.mimwrite(save_path, trace, fps=30)
            
#             trace = trace_motion.detach().cpu().float().reshape(dset.shape).permute(2,0,1,3).numpy()
#             trace = (np.clip(trace, 0, 1)*255).astype(np.uint8)
#             save_path = os.path.join(args.trace["folder"], f'motion_{args.trace["file_tag"]}{str(epoch).zfill(ndigits_epoch)}.avi')
#             iio2.mimwrite(save_path, trace, fps=30)
            
#             trace = trace_pleth_spatial.detach().cpu().float().reshape(dset.shape[:3]).permute(2,0,1).numpy()
#             trace = (np.clip(trace, 0, 1)*255).astype(np.uint8)
#             save_path = os.path.join(args.trace["folder"], f'mask_{args.trace["file_tag"]}{str(epoch).zfill(ndigits_epoch)}.avi')
#             iio2.mimwrite(save_path, trace, fps=30)
#             pass
#         else:
#             motion_model.eval()
#             pleth_temporal_model.eval()

#             trace_loc = dset.loc.half().to(args.io_device)
#             motion_output, _ = motion_model(trace_loc)

#             trace_T_loc = torch.cat((trace_loc[...,2:3], torch.zeros_like(trace_loc[...,2:3])), dim=-1)
#             pleth_output = pleth_temporal_model(trace_T_loc)

#             trace = motion_output + pleth_output

#             trace = trace.detach().cpu().float().reshape(dset.shape).permute(2,0,1,3).numpy()
#             trace = (np.clip(trace, 0, 1)*255).astype(np.uint8)
#             save_path = os.path.join(args.trace["folder"], f'{args.trace["file_tag"]}{str(epoch).zfill(ndigits_epoch)}.avi')
#             iio2.mimwrite(save_path, trace, fps=30)
            
#             trace = motion_output.detach().cpu().float().reshape(dset.shape).permute(2,0,1,3).numpy()
#             trace = (np.clip(trace, 0, 1)*255).astype(np.uint8)
#             save_path = os.path.join(args.trace["folder"], f'motion_{args.trace["file_tag"]}{str(epoch).zfill(ndigits_epoch)}.avi')
#             iio2.mimwrite(save_path, trace, fps=30)
            
#             # trace = pleth_output.detach().cpu().float().reshape(dset.shape).permute(2,0,1,3).numpy()
#             # trace = (np.clip(trace, 0, 1)*255).astype(np.uint8)
#             # save_path = os.path.join(args.trace["folder"], f'pleth_{args.trace["file_tag"]}{str(epoch).zfill(ndigits_epoch)}.avi')
#             # iio2.mimwrite(save_path, trace, fps=30)
            
#             # trace = pleth_output.detach().cpu().float().reshape(dset.shape).permute(2,0,1,3).numpy()
#             # trace = (trace - np.amin(trace, axis=(0,1,2), keepdims=True)) / (np.amax(trace, axis=(0,1,2), keepdims=True) - np.amin(trace, axis=(0,1,2), keepdims=True))
#             # trace = trace/20 + 0.5
#             # trace = (np.clip(trace, 0, 1)*255).astype(np.uint8)
#             # save_path = os.path.join(args.trace["folder"], f'pleth_scaled_{args.trace["file_tag"]}{str(epoch).zfill(ndigits_epoch)}.avi')
#             # iio2.mimwrite(save_path, trace, fps=30)

In [None]:
# 100%|██████████| 600/600 [00:12<00:00, 47.74it/s]Epoch: 1, Loss: 0.00011248618364334106

# 100%|██████████| 600/600 [00:12<00:00, 49.90it/s]Epoch: 2, Loss: 5.8252314726511634e-05

# 100%|██████████| 600/600 [00:12<00:00, 47.62it/s]Epoch: 3, Loss: 6.430208683013916e-05

# 100%|██████████| 600/600 [00:12<00:00, 46.75it/s]Epoch: 4, Loss: 4.203597704569499e-05

# 100%|██████████| 600/600 [00:12<00:00, 48.58it/s]Epoch: 5, Loss: 4.323959350585938e-05

# 100%|██████████| 600/600 [00:12<00:00, 47.33it/s]Epoch: 6, Loss: 4.204720258712768e-05

# 100%|██████████| 600/600 [00:12<00:00, 47.14it/s]Epoch: 7, Loss: 4.2064587275187176e-05

In [None]:
# pleth_temporal_model = tcnn.NetworkWithInputEncoding(args.pleth_temporal_encoding["input_dims"],
#                                                      args.pleth_temporal_network["output_dims"],
#                                                      args.pleth_temporal_encoding,
#                                                      args.pleth_temporal_network)
# pleth_temporal_model.to(args.io_device)
# pleth_spatial_model = tcnn.NetworkWithInputEncoding(args.pleth_spatial_encoding["input_dims"],
#                                                     args.pleth_spatial_network["output_dims"],
#                                                     args.pleth_spatial_encoding,
#                                                     args.pleth_spatial_network)
# ckpt = torch.load(os.path.join(args.pre_train_checkpoints["dir"], args.pre_train_checkpoints["latest"]))
# pleth_spatial_model.load_state_dict(ckpt['model_state_dict'])
# pleth_spatial_model.to(args.io_device)
# opt_temporal = torch.optim.Adam(pleth_temporal_model.parameters(), lr=args.opt["lr"],
#                                 betas=(args.opt["beta1"], args.opt["beta2"]), eps=args.opt["eps"])
# opt_spatial = torch.optim.Adam(pleth_spatial_model.parameters(), lr=args.opt["lr"],
#                                 betas=(args.opt["beta1"], args.opt["beta2"]), eps=args.opt["eps"])

# flag=True
# for epoch in range(start_epoch,epochs+1):
#     train_loss = 0
#     pleth_temporal_model.train()
#     pleth_spatial_model.train()
#     motion_model.train()
#     for count, item in tqdm(enumerate(dloader),total=len(dloader)):
#         loc = dset.loc[item].half().to(args.io_device)
#         pixel = dset.vid[item].half().to(args.io_device)
#         XY_loc = loc[...,0:2]
#         T_loc = torch.cat((loc[...,2:3], torch.zeros_like(loc[...,2:3])), dim=-1)
#         motion_output, _ = motion_model(loc)
#         pleth_spatial_output = pleth_spatial_model(XY_loc)
#         pleth_temporal_output = pleth_temporal_model(T_loc)
#         output = motion_output + (pleth_spatial_output*pleth_temporal_output)
#         # Since the model takes care of moving the data to different devices, move GT correspondingly.
#         pixel = pixel.to(output.dtype).to(output.device)
#         # Backpropagation.
#         if flag:
#             opt_temporal.zero_grad()
#             l2_error = (output - pixel)**2
#             loss = l2_error.mean()
#             loss.backward()
#             opt_temporal.step()
#         else:
#             opt_spatial.zero_grad()
#             l2_error = (output - pixel)**2
#             loss = l2_error.mean()
#             loss.backward()
#             opt_spatial.step()
#         train_loss += loss.item()
#         if epoch % 1:
#             flag = not flag
#     print(f'Epoch: {epoch}, Loss: {train_loss/len(dloader)}', flush=True)
#     with torch.no_grad():
#         motion_model.eval()
#         pleth_spatial_model.eval()
#         pleth_temporal_model.eval()

#         trace_loc = dset.loc.half().to(args.io_device)
#         trace_motion, _ = motion_model(trace_loc)
#         trace_XY_loc = trace_loc[...,0:2]
#         trace_pleth_spatial = pleth_spatial_model(trace_XY_loc)
#         trace_T_loc = torch.cat((trace_loc[...,2:3], torch.zeros_like(trace_loc[...,2:3])), dim=-1)
#         trace_pleth_temporal = pleth_temporal_model(trace_T_loc)
        
#         trace = trace_motion + (trace_pleth_spatial*trace_pleth_temporal)
#         trace = trace.detach().cpu().float().reshape(dset.shape).permute(2,0,1,3).numpy()
#         trace = (np.clip(trace, 0, 1)*255).astype(np.uint8)
#         save_path = os.path.join(args.trace["folder"], f'{args.trace["file_tag"]}{str(epoch).zfill(ndigits_epoch)}.avi')
#         iio2.mimwrite(save_path, trace, fps=30)
        
#         trace = trace_motion.detach().cpu().float().reshape(dset.shape).permute(2,0,1,3).numpy()
#         trace = (np.clip(trace, 0, 1)*255).astype(np.uint8)
#         save_path = os.path.join(args.trace["folder"], f'motion_{args.trace["file_tag"]}{str(epoch).zfill(ndigits_epoch)}.avi')
#         iio2.mimwrite(save_path, trace, fps=30)
        
#         trace = trace_pleth_spatial.detach().cpu().float().reshape(dset.shape).permute(2,0,1,3).numpy()
#         trace = (np.clip(trace, 0, 1)*255).astype(np.uint8)
#         save_path = os.path.join(args.trace["folder"], f'mask_{args.trace["file_tag"]}{str(epoch).zfill(ndigits_epoch)}.avi')
#         iio2.mimwrite(save_path, trace, fps=30)

In [None]:
# pleth_model = tcnn.NetworkWithInputEncoding(args.pleth_temporal_encoding["input_dims"],
#                                             args.pleth_temporal_network["output_dims"],
#                                             args.pleth_temporal_encoding,
#                                             args.pleth_temporal_network, seed=sd)
# pleth_model.to(args.io_device)
# opt = torch.optim.Adam(pleth_model.parameters(), lr=args.opt["lr"],
#                        betas=(args.opt["beta1"], args.opt["beta2"]), eps=args.opt["eps"])

# for epoch in range(start_epoch,epochs+1):
#     train_loss = 0
#     pleth_model.train()
#     motion_model.train()
#     for count, item in tqdm(enumerate(dloader),total=len(dloader)):
#         loc = dset.loc[item].half().to(args.io_device)
#         pixel = dset.vid[item].half().to(args.io_device)
#         T_loc = torch.cat((loc[...,2:3], torch.zeros_like(loc[...,2:3])), dim=-1)
#         motion_output, _ = motion_model(loc)
#         # pleth_output = pleth_model(loc)
#         pleth_output = pleth_model(T_loc)
#         output = motion_output + pleth_output
#         # Since the model takes care of moving the data to different devices, move GT correspondingly.
#         pixel = pixel.to(output.dtype).to(output.device)
#         # Backpropagation.
#         opt.zero_grad()
#         l2_error = (output - pixel)**2
#         loss = l2_error.mean()
#         loss.backward()
#         opt.step()
#         train_loss += loss.item()
#     print(f'Epoch: {epoch}, Loss: {train_loss/len(dloader)}', flush=True)
#     # with torch.no_grad():
#     #     motion_model.eval()
#     #     pleth_model.eval()
#     #     trace_loc = dset.loc.half().to(args.io_device)
#     #     trace_T_loc = torch.cat((trace_loc[...,2:3], torch.zeros_like(trace_loc[...,2:3])), dim=-1)
#     #     motion_output, _ = motion_model(trace_loc)
#     #     # pleth_output = pleth_model(trace_loc)
#     #     pleth_output = pleth_model(trace_T_loc)
#     #     trace = motion_output + pleth_output
#     #     trace = trace.detach().cpu().float().reshape(dset.shape).permute(2,0,1,3).numpy()
#     #     trace = (np.clip(trace, 0, 1)*255).astype(np.uint8)
#     #     save_path = os.path.join(args.trace["folder"], f'{args.trace["file_tag"]}{str(epoch).zfill(ndigits_epoch)}.avi')
#     #     iio2.mimwrite(save_path, trace, fps=30)
#     #     trace = motion_output.detach().cpu().float().reshape(dset.shape).permute(2,0,1,3).numpy()
#     #     trace = (np.clip(trace, 0, 1)*255).astype(np.uint8)
#     #     save_path = os.path.join(args.trace["folder"], f'motion_{args.trace["file_tag"]}{str(epoch).zfill(ndigits_epoch)}.avi')
#     #     iio2.mimwrite(save_path, trace, fps=30)
#     #     trace = pleth_output.detach().cpu().float().reshape(dset.shape).permute(2,0,1,3).numpy()
#     #     trace = (np.clip(trace, 0, 1)*255).astype(np.uint8)
#     #     save_path = os.path.join(args.trace["folder"], f'pleth_{args.trace["file_tag"]}{str(epoch).zfill(ndigits_epoch)}.avi')
#     #     iio2.mimwrite(save_path, trace, fps=30)
#     #     trace = pleth_output.detach().cpu().float().reshape(dset.shape).permute(2,0,1,3).numpy()
#     #     if not np.array_equal(np.amax(trace, axis=(1,2), keepdims=True), np.amin(trace, axis=(1,2), keepdims=True)):
#     #         trace = (trace - np.amin(trace, axis=(1,2), keepdims=True)) / (np.amax(trace, axis=(1,2), keepdims=True) - np.amin(trace, axis=(1,2), keepdims=True))
#     #     else:
#     #         trace = (trace - np.amin(trace, axis=(0,1,2), keepdims=True)) / (np.amax(trace, axis=(0,1,2), keepdims=True) - np.amin(trace, axis=(0,1,2), keepdims=True))
#     #         trace = trace/20 + 0.5
#     #     trace = (np.clip(trace, 0, 1)*255).astype(np.uint8)
#     #     save_path = os.path.join(args.trace["folder"], f'pleth_scaled_{args.trace["file_tag"]}{str(epoch).zfill(ndigits_epoch)}.avi')
#     #     iio2.mimwrite(save_path, trace, fps=30)

In [None]:
pleth_encoding_config = {
    "otype": "HashGrid",
    "input_dims": 2,
    "n_levels": 8,
    "n_features_per_level": 2,
    "log2_hashmap_size": 24,
    "base_resolution": 16,
    "per_level_scale": 1.5
}
pleth_network_config = {
    "otype": "CutlassMLP",
    "activation": "Sine",
    "output_activation": "none",
    "n_neurons": 64,
    "n_hidden_layers": 2,
    "output_dims": 3
}
pleth_enc = tcnn.Encoding(3, pleth_encoding_config)
pleth_net = tcnn.Network(pleth_enc.n_output_dims, 3, pleth_network_config)
pleth_model = torch.nn.Sequential(pleth_enc, pleth_net)
pleth_model.to(args.io_device)
opt_enc = torch.optim.Adam(pleth_enc.parameters(), lr=1e-3,
                       betas=(args.opt["beta1"], args.opt["beta2"]), eps=args.opt["eps"])
opt_net = torch.optim.Adam(pleth_net.parameters(), lr=1e-3, weight_decay=1e-6,
                       betas=(args.opt["beta1"], args.opt["beta2"]), eps=args.opt["eps"])
epochs = 10
for epoch in range(start_epoch,epochs+1):
    train_loss = 0
    pleth_model.train()
    motion_model.train()
    for count, item in tqdm(enumerate(dloader),total=len(dloader)):
        loc = dset.loc[item].half().to(args.io_device)
        pixel = dset.vid[item].half().to(args.io_device)
        motion_output, _ = motion_model(loc)
        pleth_output = pleth_model(loc)
        output = motion_output + pleth_output
        # Since the model takes care of moving the data to different devices, move GT correspondingly.
        pixel = pixel.to(output.dtype).to(output.device)
        # Backpropagation.
        opt_enc.zero_grad()
        opt_net.zero_grad()
        l2_error = (output - pixel)**2
        loss = l2_error.mean()
        loss.backward()
        opt_enc.step()
        opt_net.step()
        train_loss += loss.item()
    print(f'Epoch: {epoch}, Loss: {train_loss/len(dloader)}', flush=True)
    with torch.no_grad():
        motion_model.eval()
        pleth_model.eval()
        trace_loc = dset.loc.half().to(args.io_device)
        motion_output, _ = motion_model(trace_loc)
        pleth_output = pleth_model(trace_loc)
        
        trace = motion_output + pleth_output
        trace = trace.detach().cpu().float().reshape(dset.shape).permute(2,0,1,3).numpy()
        trace = (np.clip(trace, 0, 1)*255).astype(np.uint8)
        save_path = os.path.join(args.trace["folder"], f'{args.trace["file_tag"]}{str(epoch).zfill(ndigits_epoch)}.avi')
        iio2.mimwrite(save_path, trace, fps=30)
        
        trace = pleth_output.detach().cpu().float().reshape(dset.shape).permute(2,0,1,3).numpy()
        trace = (trace - np.amin(trace, axis=(0,1,2), keepdims=True)) / (np.amax(trace, axis=(0,1,2), keepdims=True) - np.amin(trace, axis=(0,1,2), keepdims=True))
        trace = (np.clip(trace, 0, 1)*255).astype(np.uint8)
        save_path = os.path.join(args.trace["folder"], f'rescaled_residual_{args.trace["file_tag"]}{str(epoch).zfill(ndigits_epoch)}.avi')
        iio2.mimwrite(save_path, trace, fps=30)
        
        trace = motion_output.detach().cpu().float().reshape(dset.shape).permute(2,0,1,3).numpy()
        trace = (np.clip(trace, 0, 1)*255).astype(np.uint8)
        save_path = os.path.join(args.trace["folder"], f'motion_{args.trace["file_tag"]}{str(epoch).zfill(ndigits_epoch)}.avi')
        iio2.mimwrite(save_path, trace, fps=30)

In [None]:
motion_model.eval()
pleth_model.eval()
trace_loc = dset.loc.half().to(args.io_device)
motion_output, _ = motion_model(trace_loc)
pleth_output = pleth_model(trace_loc)

In [None]:
motion_trace = motion_output.detach().cpu().float().reshape(dset.shape).permute(2,0,1,3).numpy()
pleth_trace  =  pleth_output.detach().cpu().float().reshape(dset.shape).permute(2,0,1,3).numpy()

In [None]:
motion_frame = [motion_trace[0]]
rep_motion_trace = np.stack(motion_frame*300, axis=0)

In [None]:
static_with_xyt = rep_motion_trace + pleth_trace

In [None]:
trace = (np.clip(static_with_xyt, 0, 1)*255).astype(np.uint8)
iio2.mimwrite('temp.avi', trace, fps=30)
trace.shape