In [1]:
import numpy as np
import imageio
import torch.nn.functional as F
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
import os

from nerf_model import TinyNerfModel
from nerf_dataset import TinyCybertruckDataset
from nerf_functions import get_rays, render_rays, get_device

In [2]:
device = get_device()

Using device: NVIDIA GeForce RTX 4080 SUPER


In [3]:
timestamp = "" # Add timestamp here. Will use the latest trial if empty
output_dir = "output"
if timestamp == "":
    timestamp = sorted(os.listdir(output_dir))[-1] # Fetching the latest trial

output_dir = os.path.join(output_dir, timestamp)
model_file_name = sorted(os.listdir(output_dir))[-1] # Fetching the latest checkpoint
checkpoint_state_dict = torch.load(os.path.join(output_dir, model_file_name))

# load latest model
model = TinyNerfModel().to(device) # Make sure to use same embedding size as the model used for training
model.load_state_dict(checkpoint_state_dict)
model.eval()

TinyNerfModel(
  (block1): Sequential(
    (0): Linear(in_features=39, out_features=256, bias=True)
    (1): ReLU()
    (2): Linear(in_features=256, out_features=256, bias=True)
    (3): ReLU()
    (4): Linear(in_features=256, out_features=256, bias=True)
    (5): ReLU()
    (6): Linear(in_features=256, out_features=256, bias=True)
    (7): ReLU()
  )
  (block2): Sequential(
    (0): Linear(in_features=295, out_features=256, bias=True)
    (1): ReLU()
    (2): Linear(in_features=256, out_features=256, bias=True)
    (3): ReLU()
    (4): Linear(in_features=256, out_features=256, bias=True)
    (5): ReLU()
    (6): Linear(in_features=256, out_features=256, bias=True)
    (7): ReLU()
    (8): Linear(in_features=256, out_features=4, bias=True)
  )
)

In [4]:
def pose_spherical(theta, phi, radius):
    """
    Convert spherical coordinates to camera pose
    Inputs:
        theta: float, azimuthal angle in degrees
        phi: float, polar angle in degrees
        radius: float, distance from the origin
    Outputs:
        c2w: torch.tensor, 4x4 camera-to-world matrix
    """
    trans_t = lambda t : torch.tensor([
        [1,0,0,0],
        [0,1,0,0],
        [0,0,1,t],
        [0,0,0,1],
    ], dtype=torch.float32)

    rot_phi = lambda phi : torch.tensor([
        [1,0,0,0],
        [0,np.cos(phi),-np.sin(phi),0],
        [0,np.sin(phi), np.cos(phi),0],
        [0,0,0,1],
    ], dtype=torch.float32)

    rot_theta = lambda th : torch.tensor([
        [np.cos(th),0,-np.sin(th),0],
        [0,1,0,0],
        [np.sin(th),0, np.cos(th),0],
        [0,0,0,1],
    ], dtype=torch.float32)

    c2w = trans_t(radius)
    c2w = rot_phi(phi / 180. * np.pi) @ c2w
    c2w = rot_theta(theta / 180. * np.pi) @ c2w
    c2w = torch.tensor([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]], dtype=torch.float32) @ c2w
    return c2w

In [5]:
testimg, _, testfocal = TinyCybertruckDataset("test")[0]
H, W = testimg.shape[:2]

In [6]:
frames = []

for th in tqdm(np.linspace(0., 360., 120, endpoint=False)):
    c2w = pose_spherical(theta=th, phi=-15., radius=5.).to(device)

    rays_o, rays_d = get_rays(H, W, testfocal, c2w, device=device)

    rgb = render_rays(model, rays_o, rays_d, near=3., far=7., N_samples=64, device=device)

    img = rgb.detach().cpu().numpy()
    img = np.clip(img, 0, 1)
    frames.append((img*255).astype(np.uint8))

imageio.mimwrite('tiny_nerf.mp4', frames, fps=30, quality=10)

100%|██████████| 120/120 [00:11<00:00, 10.04it/s]
