In [None]:
import os
from tqdm import tqdm
import numpy as np
import imageio.v3 as iio3
import imageio.v2 as iio2
import commentjson as json
import matplotlib.pyplot as plt

import torch
import torch.utils.data
import tinycudann as tcnn

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
config = {
    'optimizer': {
        'otype': 'Adam',
        'learning_rate': 0.01,
        'beta1': 0.9,
        'beta2': 0.99,
        'epsilon': 1e-15,
        'l2_reg': 1e-06
        },
    'encoding': {
        'otype': 'HashGrid',
        'n_levels': 32,
        'n_features_per_level': 2,
        'log2_hashmap_size': 20,
        'base_resolution': 16,
        'per_level_scale': 1.5
        },
    'network': {
        'otype': 'FullyFusedMLP',
        'activation': 'ReLU',
        'output_activation': 'None',
        'n_neurons': 64,
        'n_hidden_layers': 4
    }
}

In [None]:
class VidData(object):
	def __init__(self, vid_folder, num_frames=900, img_str="rgbd_rgb_", ext=".png", verbose=True):
		""" Video Dataset

		Args:
			img_name (_type_): _description_
		"""
		self.verbose = verbose
		self.vid_folder = vid_folder
		self.img_str = img_str
		self.ext = ext
		self.num_frames = num_frames

		self.vid = []
		for frame_num in range(self.num_frames):
			self.vid.append(iio3.imread(
			  	os.path.join(self.vid_folder, self.img_str+str(frame_num)+self.ext)))
		self.vid = np.stack(self.vid, axis=2) # R C T Ch
		self.shape = self.vid.shape
		if self.vid.ndim == 3:
			self.verbose: print("Grayscale. Converting to RGB")
			self.vid = self.vid[...,np.newaxis]
			self.vid = np.concatenate([self.vid,self.vid,self.vid], axis=-1)

		if self.verbose: print(f"Shape of the Video: {self.shape}")
		X, Y, T  = np.meshgrid(np.arange(self.shape[1]), np.arange(self.shape[0]), np.arange(self.shape[2]))
		if self.verbose: print(f"Grid Shape -> X: {X.shape}, Y: {Y.shape}, T: {T.shape}")
		self.x = X.ravel()
		self.y = Y.ravel()
		self.t = T.ravel()

		assert len(self.x)*3 == np.prod(self.shape)
		assert len(self.y)*3 == np.prod(self.shape)
		assert len(self.t)*3 == np.prod(self.shape)
		self.num_pixels = len(self.x)

	def __len__(self):
		return self.num_pixels

	def __getitem__(self, idx):
		item_x = (self.x[idx] / self.shape[1]) - 0.5
		item_y = (self.y[idx] / self.shape[0]) - 0.5
		item_t = (self.t[idx] / self.shape[2]) - 0.5
		item_loc = np.array([item_x,item_y,item_t])
		item_rgb = self.vid[self.y[idx],self.x[idx],self.t[idx]]
		return {"pixel": item_rgb, "loc": item_loc}

In [None]:
def trace_video(model, dataset, dloader, device, save_path=None):
    model.eval()
    print("Eval Mode")
    with torch.no_grad():
        temp = []
        for item in tqdm(dloader, leave=False):
            inp = item["loc"].half().to(device)
            temp.append(model(inp)[0].squeeze().cpu().detach().float().numpy())
    temp = np.concatenate(temp,axis=0).reshape(dset.shape[0],dset.shape[1],dset.shape[2],3)
    temp = np.clip(temp, a_min=0, a_max=1)
    f, axarr = plt.subplots(1,5, figsize=(20,20))
    plot_frame_num = np.arange(0, dataset.num_frames, dataset.num_frames//5)
    axarr[0].imshow(temp[:,:,plot_frame_num[0]])
    axarr[0].set_title(f"Frame: {plot_frame_num[0]}")
    axarr[0].axis('off')
    axarr[1].imshow(temp[:,:,plot_frame_num[1]])
    axarr[1].set_title(f"Frame: {plot_frame_num[1]}")
    axarr[1].axis('off')
    axarr[2].imshow(temp[:,:,plot_frame_num[2]])
    axarr[2].set_title(f"Frame: {plot_frame_num[2]}")
    axarr[2].axis('off')
    axarr[3].imshow(temp[:,:,plot_frame_num[3]])
    axarr[3].set_title(f"Frame: {plot_frame_num[3]}")
    axarr[3].axis('off')
    axarr[4].imshow(temp[:,:,plot_frame_num[4]])
    axarr[4].set_title(f"Frame: {plot_frame_num[4]}")
    axarr[4].axis('off')
    plt.show()
    if save_path is not None:
        temp = np.transpose(temp*255, (2,0,1,3)).astype(np.uint8)
        iio2.mimwrite(save_path, temp, fps=30)
    del temp
    del plot_frame_num

In [None]:
class LinearSineAct(torch.nn.Module):
    def __init__(self, in_features, hidden_features, out_features, 
                 outermost_linear=True, first_omega_0=30, hidden_omega_0=30.):
        super().__init__()

        self.hidden1 = torch.nn.Linear(in_features, hidden_features)
        self.hidden2 = torch.nn.Linear(hidden_features, hidden_features)
        self.final = torch.nn.Linear(hidden_features, out_features)

    
    def forward(self, coords):
        l1_o = torch.sin(self.hidden1(coords.float()))
        l2_o = torch.sin(self.hidden2(l1_o))
        out = self.final(l2_o)

        return out, {"sine1out":l1_o, "sine2out":l2_o}

In [None]:
dset = VidData("v_84_2", num_frames=300)
train_dataloader = torch.utils.data.DataLoader(dset, batch_size=8192,shuffle=True)
test_dataloader = torch.utils.data.DataLoader(dset, batch_size=8192,shuffle=False)

In [None]:
encoding = tcnn.Encoding(3, config["encoding"])
network = LinearSineAct(encoding.n_output_dims, 64, 3)
model = torch.nn.Sequential(encoding, network)
model.to(device)
print(model)


opt_enc = torch.optim.Adam(encoding.parameters(), lr=1e-2, betas=(0.9, 0.99), eps=1e-15)
opt_net = torch.optim.Adam(network.parameters(), lr=1e-2, weight_decay=1e-6, betas=(0.9, 0.99), eps=1e-15)
save_dir = "lin_sine"
os.makedirs(save_dir, exist_ok=True)

epochs = 30
for epoch in range(1,epochs+1):
	train_loss = 0
	model.train()
	for count, item in tqdm(enumerate(train_dataloader),total=len(train_dataloader)):
		loc = item["loc"].half().to(device)
		pixel = item["pixel"].half().to(device)/255
		output, _ = model(loc)
		opt_enc.zero_grad()
		opt_net.zero_grad()
		relative_l2_error = (output - pixel.to(output.dtype))**2 #/ (output.detach()**2 + 0.01)
		loss = relative_l2_error.mean()
		loss.backward()
		opt_enc.step()
		opt_net.step()
		train_loss += loss.item()
	print(f"Epoch: {epoch}, Loss: {train_loss/len(train_dataloader)}")	
	trace_video(model, dset, test_dataloader, device, save_path=f"./{save_dir}/epoch_{epoch:02d}.avi")