In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
from tqdm import tqdm
import numpy as np
import commentjson as json
import imageio.v2 as iio2
import matplotlib.pyplot as plt
from copy import deepcopy
import random

import torch
import torch.utils.data
import tinycudann as tcnn
import argparse

from implicitpleth.models.siren import Siren
from implicitpleth.models.combinations import MotionNet
from implicitpleth.data.datasets import VideoGridDataset
from implicitpleth.utils.utils import trace_video, Dict2Class

In [None]:
# # Fix random seed
# sd = 0
# np.random.seed(sd)
# torch.backends.cudnn.deterministic = True
# torch.manual_seed(sd)
# random.seed(sd)
# if torch.cuda.is_available():
#     torch.cuda.manual_seed_all(sd)

In [None]:
with open("./configs/residual.json") as f:
    json_config = json.load(f)

_trial = 'v_99_2'
json_config["video_path"] = f'/home/pradyumnachari/Documents/ImplicitPPG/SIGGRAPH_Data/rgb_files/{_trial}'
json_config["verbose"] = True
json_config["append_save_path"] = None
json_config["append_load_path"] = None

json_config["motion_model"]["load_path"] = f'delta_motion/motion_model.pth'
json_config["motion_model"]["load_path"] = f'/home/pradyumnachari/Documents/FastImplicitPleth/dataset_motion/motion_0_{_trial}/epoch_10.pth'
args = Dict2Class(json_config)
print(args.spatiotemporal_device, args.deltaspatial_device, args.pleth_device)
args.spatiotemporal_device = torch.device(args.spatiotemporal_device)
args.deltaspatial_device = torch.device(args.deltaspatial_device)
args.pleth_device = torch.device(args.pleth_device)

In [None]:
print(args.motion_model)
with open(args.motion_model["config"]) as mmf:
    config = json.load(mmf)
motion_model = MotionNet(config["spatiotemporal_encoding"], 
                         config["spatiotemporal_network"],
                         config["deltaspatial_encoding"], 
                         config["deltaspatial_network"])
motion_model.load_state_dict(torch.load(args.motion_model["load_path"])["model_state_dict"])
# Freeze the model
motion_model.eval()
# for params in motion_model.parameters():
#     params.requires_grad = False
# Set the model device
motion_spatiotemporal_device = args.spatiotemporal_device
motion_deltaspatial_device = args.deltaspatial_device
motion_model.set_device(motion_spatiotemporal_device, motion_deltaspatial_device)

In [None]:
# Folders for the traced video data and checkpoints.
if args.append_save_path is not None:
    args.trace["folder"] =  args.trace["folder"] + args.append_save_path
    args.checkpoints["dir"] =  args.checkpoints["dir"] + args.append_save_path
if args.trace["folder"] is not None:
    os.makedirs(args.trace["folder"], exist_ok=True)
    if args.verbose: print(f'Saving trace to {args.trace["folder"]}')
if args.checkpoints["save"]:
    os.makedirs(args.checkpoints["dir"], exist_ok=True)
    if args.verbose: print(f'Saving checkpoints to {args.checkpoints["dir"]}')

In [None]:
# Get info before iterating.
epochs = args.train["epochs"]
ndigits_epoch = int(np.log10(epochs)+1)
latest_ckpt_path = os.path.join(args.checkpoints["dir"], args.checkpoints["latest"])
if os.path.exists(latest_ckpt_path):
    if args.verbose: print('Loading latest checkpoint...')
    # saved_dict = torch.load(latest_ckpt_path)
    # pleth_model.load_state_dict(saved_dict["model_state_dict"])
    # if "optimizer_spatial_state_dict" in saved_dict.keys():
    #     opt_spatial.load_state_dict(saved_dict["optimizer_spatial_state_dict"])
    # if "optimizer_temporal_state_dict" in saved_dict.keys():
    #     opt_temporal.load_state_dict(saved_dict["optimizer_temporal_state_dict"])
    # start_epoch = saved_dict["epoch"] + 1
    # if args.verbose: print(f'Continuing from epoch {start_epoch}.')
else:
    if args.verbose: print('Starting from scratch.')
    start_epoch = 1

In [None]:
# with torch.no_grad():
# motion_tensor, _ = motion_model(dset.loc)
# motion_tensor = motion_tensor.reshape(dset.shape).to(args.pleth_device)
# motion_orig = deepcopy(motion_tensor.detach())

In [None]:
dset = VideoGridDataset(args.video_path, verbose=args.verbose, num_frames=args.data["num_frames"], 
                        start_frame=args.data["start_frame"], pixel_norm=args.data["norm_value"])

In [None]:
dloader = torch.utils.data.DataLoader(range(len(dset)), batch_size=args.data["batch_size"], shuffle=True, num_workers=1)

In [None]:
pleth_encoding_config = {
    "otype": "HashGrid",
    "input_dims": 3,
    "n_levels": 8,
    "n_features_per_level": 2,
    "log2_hashmap_size": 24,
    "base_resolution": 16,
    "per_level_scale": 1.5
}
pleth_network_config = {
    "otype": "CutlassMLP",
    "activation": "Sine",
    "output_activation": "none",
    "n_neurons": 64,
    "n_hidden_layers": 2,
    "output_dims": 3
}
pleth_enc = tcnn.Encoding(pleth_encoding_config["input_dims"], pleth_encoding_config)
pleth_net = tcnn.Network(pleth_enc.n_output_dims, pleth_network_config["output_dims"], pleth_network_config)
pleth_model = torch.nn.Sequential(pleth_enc, pleth_net)
pleth_model.to(args.pleth_device)
lr = 1e-4
opt_enc = torch.optim.Adam(pleth_enc.parameters(), lr=lr,
                       betas=(args.opt["beta1"], args.opt["beta2"]), eps=args.opt["eps"])
opt_net = torch.optim.Adam(pleth_net.parameters(), lr=lr, weight_decay=1e-6,
                       betas=(args.opt["beta1"], args.opt["beta2"]), eps=args.opt["eps"])
epochs = 10
for epoch in range(start_epoch,epochs+1):
    train_loss = 0
    pleth_model.train()
    motion_model.train()
    for count, item in tqdm(enumerate(dloader),total=len(dloader)):
        loc = dset.loc[item].half().to(args.pleth_device)
        pixel = dset.vid[item].half().to(args.pleth_device)
        motion_output, _ = motion_model(loc)
        pleth_output = pleth_model(loc)
        output = motion_output + pleth_output
        # Since the model takes care of moving the data to different devices, move GT correspondingly.
        pixel = pixel.to(output.dtype).to(output.device)
        # Backpropagation.
        opt_enc.zero_grad()
        opt_net.zero_grad()
        l2_error = (output - pixel)**2
        loss = l2_error.mean()
        loss.backward()
        opt_enc.step()
        opt_net.step()
        train_loss += loss.item()
    print(f'Epoch: {epoch}, Loss: {train_loss/len(dloader)}', flush=True)
    with torch.no_grad():
        motion_model.eval()
        pleth_model.eval()
        trace_loc = dset.loc.half().to(args.pleth_device)
        motion_output, _ = motion_model(trace_loc)
        pleth_output = pleth_model(trace_loc)
        
        trace = motion_output + pleth_output
        trace = trace.detach().cpu().float().reshape(dset.shape).permute(2,0,1,3).numpy()
        trace = (np.clip(trace, 0, 1)*255).astype(np.uint8)
        save_path = os.path.join(args.trace["folder"], f'{args.trace["file_tag"]}{str(epoch).zfill(ndigits_epoch)}.avi')
        iio2.mimwrite(save_path, trace, fps=30)
        
        trace = pleth_output.detach().cpu().float().reshape(dset.shape).permute(2,0,1,3).numpy()
        trace = (trace - np.amin(trace, axis=(0,1,2), keepdims=True)) / (np.amax(trace, axis=(0,1,2), keepdims=True) - np.amin(trace, axis=(0,1,2), keepdims=True))
        trace = (np.clip(trace, 0, 1)*255).astype(np.uint8)
        save_path = os.path.join(args.trace["folder"], f'rescaled_residual_{args.trace["file_tag"]}{str(epoch).zfill(ndigits_epoch)}.avi')
        iio2.mimwrite(save_path, trace, fps=30)
        
        trace = motion_output.detach().cpu().float().reshape(dset.shape).permute(2,0,1,3).numpy()
        trace = (np.clip(trace, 0, 1)*255).astype(np.uint8)
        save_path = os.path.join(args.trace["folder"], f'motion_{args.trace["file_tag"]}{str(epoch).zfill(ndigits_epoch)}.avi')
        iio2.mimwrite(save_path, trace, fps=30)

In [None]:
motion_model.eval()
pleth_model.eval()
trace_loc = dset.loc.half().to(args.pleth_device)
motion_output, _ = motion_model(trace_loc)
pleth_output = pleth_model(trace_loc)

In [None]:
motion_trace = motion_output.detach().cpu().float().reshape(dset.shape).permute(2,0,1,3).numpy()
pleth_trace  =  pleth_output.detach().cpu().float().reshape(dset.shape).permute(2,0,1,3).numpy()

In [None]:
motion_frame = [motion_trace[0]]
rep_motion_trace = np.stack(motion_frame*300, axis=0)

In [None]:
static_with_xyt = rep_motion_trace + pleth_trace

In [None]:
trace = (np.clip(static_with_xyt, 0, 1)*255).astype(np.uint8)
iio2.mimwrite('temp.avi', trace, fps=30)
trace.shape

In [None]:
trace = pleth_trace
trace = (trace - np.amin(trace, axis=(0,1,2), keepdims=True)) / (np.amax(trace, axis=(0,1,2), keepdims=True) - np.amin(trace, axis=(0,1,2), keepdims=True))
trace = (np.clip(trace, 0, 1)*255).astype(np.uint8)
iio2.mimwrite('pleth.avi', trace, fps=30)
trace.shape