In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
from tqdm import tqdm
import numpy as np
import commentjson as json
import imageio.v2 as iio2
import matplotlib.pyplot as plt

import torch
import torch.utils.data
import tinycudann as tcnn
import argparse

from implicitpleth.models.siren import Siren
from implicitpleth.models.combinations import MotionNet
from implicitpleth.data.datasets import VideoGridDataset
from implicitpleth.utils.utils import Dict2Class

In [None]:
with open("./configs/delta_motion.json") as f:
    json_config = json.load(f)

json_config["video_path"] = "./assets/v_84_2.avi"
json_config["verbose"] = True
json_config["append_save_path"] = None
json_config["append_load_path"] = None
args = Dict2Class(json_config)
args.spatiotemporal_device = torch.device(args.spatiotemporal_device)
args.deltaspatial_device = torch.device(args.deltaspatial_device)

In [None]:
model = MotionNet(args.spatiotemporal_encoding, args.spatiotemporal_network,
                  args.deltaspatial_encoding, args.deltaspatial_network)
model.set_device(args.spatiotemporal_device, args.deltaspatial_device)

In [None]:
opt = torch.optim.Adam(model.parameters(), lr=args.opt["lr"],
                       betas=(args.opt["beta1"], args.opt["beta2"]), eps=args.opt["eps"])

In [None]:
if args.append_save_path is not None:
    args.trace["folder"] =  args.trace["folder"] + args.append_save_path
    args.checkpoints["dir"] =  args.checkpoints["dir"] + args.append_save_path
if args.trace["folder"] is not None:
    os.makedirs(args.trace["folder"], exist_ok=True)
    if args.verbose: print(f'Saving trace to {args.trace["folder"]}')
if args.checkpoints["save"]:
    os.makedirs(args.checkpoints["dir"], exist_ok=True)
    if args.verbose: print(f'Saving checkpoints to {args.checkpoints["dir"]}')

In [None]:
epochs = args.train["epochs"]
ndigits_epoch = int(np.log10(epochs)+1)
latest_ckpt_path = os.path.join(args.checkpoints["dir"], args.checkpoints["latest"])
if os.path.exists(latest_ckpt_path):
    if args.verbose: print('Loading latest checkpoint...')
    saved_dict = torch.load(latest_ckpt_path)
    model.load_state_dict(saved_dict["model_state_dict"])
    if "optimizer_state_dict" in saved_dict.keys():
        opt.load_state_dict(saved_dict["optimizer_state_dict"])
    start_epoch = saved_dict["epoch"] + 1
    if args.verbose: print(f'Continuing from epoch {start_epoch}.')
else:
    if args.verbose: print('Start from scratch.')
    start_epoch = 1

In [None]:
X, Y, T = np.meshgrid(np.arange(128),np.arange(128),np.arange(300))

X = (X.ravel() / 128) - 0.5
Y = (Y.ravel() / 128) - 0.5
T = (T.ravel() / 300) - 0.5

trace_loc = torch.tensor(np.stack((X,Y,T), axis=-1))

In [None]:
dset = VideoGridDataset(args.video_path, verbose=args.verbose, num_frames=args.data["num_frames"], 
                        start_frame=args.data["start_frame"], pixel_norm=args.data["norm_value"])
dloader = torch.utils.data.DataLoader(range(len(dset)), batch_size=args.data["batch_size"], shuffle=True)

In [None]:
for epoch in range(start_epoch,epochs+1):
    train_loss = 0
    model.train()
    for count, item in tqdm(enumerate(dloader),total=len(dloader)):
        loc = dset.loc[item].half()
        pixel = dset.vid[item].half()
        output, _ = model(loc)
        # Since the model takes care of moving the data to different devices, move GT correspondingly.
        pixel = pixel.to(output.dtype).to(output.device)
        # Backpropagation.
        opt.zero_grad()
        l2_error = (output - pixel)**2
        loss = l2_error.mean()
        loss.backward()
        opt.step()
        train_loss += loss.item()
    print(f'Epoch: {epoch}, Loss: {train_loss/len(dloader)}', flush=True)
    with torch.no_grad():
        trace, _ = model(trace_loc)
        trace = trace.detach().cpu().float().reshape(dset.shape).permute(2,0,1,3).numpy()
        trace = (np.clip(trace, 0, 1)*255).astype(np.uint8)
        save_path = os.path.join(args.trace["folder"], f'{args.trace["file_tag"]}{str(epoch).zfill(ndigits_epoch)}.avi')
        iio2.mimwrite(save_path, trace, fps=30)