In [None]:
import jittor as jt
from jittor import nn
import numpy as np
import pandas as pd
from PIL import Image
import cv2 as cv
import os
import matplotlib.pyplot as plt

from dataloaders import NeRFDataset
from raymarching import NeRF_RayGen
from encoders import PositionalEncoder
from networks import NeRF_Net
from renders import NeRF_Render
from utils import Losser

In [None]:
lr = 5e-4
N_iters = 100
batchsize = 2
show_iter = N_iters/10


In [None]:
dataloader = NeRFDataset(
    data_type='npz',
    root_dir= './data/tiny_nerf_data.npz',
    batch_size= batchsize
)

H,W,focal = dataloader.get_para()

rays_gen = NeRF_RayGen(H,W,focal)

encoder = PositionalEncoder()

model = NeRF_Net()

render = NeRF_Render(
    model,
    (2,6),
    dataloader.batch_size,
    100,
    encoder
)

optimizer = nn.Adam(model.parameters(),lr)

losser = Losser()

In [None]:
def train_one_step():
    _,imgs,poses = next(dataloader)

    rays_o,rays_d = rays_gen.get_rays(poses)

    rays_o,rays_d = jt.split(rays_o,1),jt.split(rays_d,1)
    rgbs = jt.condat([render.rendering(rays_o[i],rays_d[i]) for i in range(batchsize)])
    loss = jt.mean(jt.sqr(rgbs-imgs))
    optimizer.step(loss)
    return loss,rgbs,imgs

In [None]:
for i in range(N_iters):
    loss,rgbs,imgs = train_one_step()
    if i % show_iter == 0:
        plt.subplot(221)
        plt.imshow(rgbs[0].numpy())
        plt.subplot(222)
        plt.imshow(rgbs[1].numpy())
        plt.subplot(223)
        plt.imshow(imgs[0].numpy())
        plt.subplot(224)
        plt.imshow(imgs[1].numpy())
        
        print(f"\nLoss = {round(float(loss.numpy()),4)}, PSNR = {round(float(losser.mse2psnr(loss)),2)}")
        