# Setup Code

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# imports
import torch
import torch.nn.functional as F
import numpy as np
import sys
import os
from matplotlib import pyplot as plt

In [None]:
# select devices
if torch.cuda.is_available():
    print("Good to go!")
    DEVICE = torch.device("cuda")
else:
    print("Bad to go!")
    DEVICE = torch.device("cpu")

Load config from config files.

In [None]:
# load config
sys.path.append(os.getcwd())
# choose between ship, lego
import configs.ship, configs.lego
sample_t: tuple = (2,6)
scale_factor = 3
# change config file here
config = configs.lego

Load the dataset and show the first one.

In [None]:
from nerf.data import load_blender
imgs, poses, int_mat = load_blender(config.datadir, device=DEVICE, scale_factor=scale_factor)
img_n, img_h, img_w = imgs.shape[:3]
# visualize
plt.imshow(np.array(imgs[0].to(device="cpu")))
plt.axis("off")
plt.title("demo image")
plt.show()
print("and its pose: ")
print(np.array(poses[0].to(device="cpu")))

# Test Functions

In [None]:
# compute rays
from nerf.graphics import compute_rays

rays_o, rays_d = compute_rays((img_h, img_w), int_mat, poses[0])
print("origin: ", rays_o[0,0])
print("normalized origin: ", F.normalize(rays_o[0,0], dim=0))
print("center of ray: ", rays_d[img_h//2,img_w//2])

In [None]:
# query from rays
from nerf.graphics import queries_from_rays
samples = None
samples, depths = queries_from_rays(rays_o, rays_d, sample_t, 8)
print("samples[0, 0]: ", samples[0,0])
print("depths: ", depths)

In [None]:
# test pos encode

from nerf.nerf_helper import *

L = 6
x = torch.tensor([[ 1.8013, -0.6242,  0.7009]])
# x = torch.tensor([ 1.8013, -0.6242,  0.7009])
enc_x = PosEncode(x, L, True)
# enc_xx = PosEncode1(x, L, True)
print(enc_x)
# print(enc_xx)




In [None]:
# test render from nerf
from nerf.graphics import render_from_nerf
fake_depth = torch.Tensor([1])
fake_nerf_output = imgs[0].cpu().reshape(img_h, img_w, 1, 4)
rgb, depth = render_from_nerf(fake_nerf_output, fake_depth)
plt.imshow(rgb)
plt.show()

# Train

In [None]:
# One iteration of TinyNeRF (forward pass).
# TODO train
# raise Exception("nothing wrong")


In [None]:
from nerf.model import NeRF, TinyNeRF
from nerf.nerf_helper import nerf_iter_once, tinynerf_iter_once
import os.path

###### parameters
L_pos = 10
L_dir = 4

num_samples = 32
batch_size = 512   # increase batchsize if u have large GPU MEM
fc_width = 64
fc_depth = 4
skips = [2]

lr = 5e-3
# betas=(0.9, 0.999)
num_it = 10000
display_every = 200

###### models
model = NeRF(ch_in_pos=6*L_pos+3, ch_in_dir=6*L_dir+3, fc_width=fc_width, fc_depth=fc_depth, skips=skips)
# model = TinyNeRF(6*L_pos+3, fc_width=128)
model.to(DEVICE)
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=lr,
    # betas=betas,
)
seed = 9458
torch.manual_seed(seed)
np.random.seed(seed)
ckpt_path = 'nerf.pt'


###### load validation data
# imgs_val, poses_val, int_mat_val = load_blender(config.datadir, data_type="val",scale_factor=2, device=DEVICE)
# num_val = imgs_val.shape[0]
val_idx = 1
val_img = imgs[val_idx].to(DEVICE)
val_c2w = poses[val_idx]

###### train
psnrs = []
its = []
i = 0

###### check saved checkpoints
if os.path.exists(ckpt_path):
    print("checkpoint found! Loading...")
    checkpoint = torch.load(ckpt_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    i = checkpoint['epoch']
    loss = checkpoint['loss']
    psnrs = checkpoint['psnrs']
    its = checkpoint['its']
    print("checkpoint loaded, i =",i)
else:
    print("No checkpoint found")

for i in range(i, num_it):
    gt_img_idx = np.random.randint(100)
    gt_img = imgs[gt_img_idx].to(DEVICE)
    gt_c2w = poses[gt_img_idx]

    pred_rgb,_ = nerf_iter_once(
                model,
                (img_h, img_w),
                int_mat,
                gt_c2w,
                sample_t,
                L_pos,
                L_dir,
                num_samples=num_samples,
                batch_size=batch_size
                )
    optimizer.zero_grad()
    loss = torch.nn.functional.mse_loss(pred_rgb, gt_img[...,:3])

    # print("train_it:", i, "img_idx: ", gt_img_idx, "loss:",float(loss))
    loss.backward()
    optimizer.step()

    if i % display_every == 0:
        # val_idx = np.random.randint(num_val)
        # val_img = imgs_val[val_idx].to(DEVICE)
        # val_c2w = poses_val[val_idx]

        pred_rgb,_ = nerf_iter_once(
                model,
                (img_h, img_w),
                int_mat,
                val_c2w,
                sample_t,
                L_pos,
                L_dir,
                num_samples=num_samples,
                batch_size=batch_size
                )

        loss = torch.nn.functional.mse_loss(pred_rgb, val_img[...,:3])
        print("Iteration ", i)
        print("Val loss: ", loss)

        psnr = -10. * torch.log10(loss)
        psnrs.append(psnr.item())
        its.append(i)

        plt.figure(figsize=(10, 4))
        plt.subplot(121)
        img_np = pred_rgb.detach().cpu().numpy()
        plt.imshow(img_np)
        plt.title(f"Iteration {i}")
        plt.subplot(122)
        plt.plot(its, psnrs)
        plt.title("PSNR")
        plt.show()

        torch.save({
            'epoch': i,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            'psnrs': psnrs,
            'its': its
            }, ckpt_path)