In [1]:
import torch
from tqdm import tqdm
from data_loader import DataLoader
from display_helper import display_image, create_video, save_image
from models.small_NeRF_model import SmallNerfModel
from nerf_forward_pass import NeRFManager
from positional_encoding import positional_encoding
from query_points import QueryPointSamplerFromRays
from ray_bundle import RaysFromCameraBuilder
from setup_utils import set_random_seeds, load_training_config_yaml, get_tensor_device
import os
import json
from PIL import Image
import numpy as np

In [2]:
file_path = 'nerf_example_data/nerf_synthetic/lego/'

In [3]:
f = open(file_path + 'transforms_train.json')
train_data = json.load(f)
f.close()
train_data

{'camera_angle_x': 0.6911112070083618,
 'frames': [{'file_path': './train/r_0',
   'rotation': 0.012566370614359171,
   'transform_matrix': [[-0.9999021887779236,
     0.004192245192825794,
     -0.013345719315111637,
     -0.05379832163453102],
    [-0.013988681137561798,
     -0.2996590733528137,
     0.95394366979599,
     3.845470428466797],
    [-4.656612873077393e-10,
     0.9540371894836426,
     0.29968830943107605,
     1.2080823183059692],
    [0.0, 0.0, 0.0, 1.0]]},
  {'file_path': './train/r_1',
   'rotation': 0.012566370614359171,
   'transform_matrix': [[-0.9305422306060791,
     0.11707554012537003,
     -0.34696459770202637,
     -1.398659110069275],
    [-0.3661845624446869,
     -0.29751041531562805,
     0.8817007541656494,
     3.5542497634887695],
    [7.450580596923828e-09,
     0.9475130438804626,
     0.3197172284126282,
     1.2888214588165283],
    [0.0, 0.0, 0.0, 1.0]]},
  {'file_path': './train/r_2',
   'rotation': 0.012566370614359171,
   'transform_matrix'

In [None]:
def get_focal_length():
    012566370614359171

In [114]:
## Get list of images as numpy arrays

In [115]:
main_data = train_data['frames']
numpy_transform_matrix_list = []
for mini_dict in main_data:
    transform_matrix = mini_dict['transform_matrix']
    transform_matrix_array = np.asarray(transform_matrix)
    numpy_transform_matrix_list.append(transform_matrix_array)

In [None]:
main_data = train_data['frames']
numpy_image_list = []
for mini_dict in main_data:
    filename = mini_dict['file_path']
    image = Image.open(file_path + filename + '.png')
    image_array = np.asarray(image)
    numpy_image_list.append(image_array)

In [100]:
for key in train_data:
    print(key)

camera_angle_x
frames


In [None]:
set_random_seeds()
training_config = load_training_config_yaml()
device = get_tensor_device()
data_manager = DataLoader('tiny_nerf_data.npz', device=device)

# training parameters
lr = training_config['training_variables']['learning_rate']
num_iters = training_config['training_variables']['num_iters']
num_encoding_functions = training_config['positional_encoding']['num_encoding_functions']

# Misc parameters
display_every = training_config['display_variables']['display_every']

# Specify encoding function.
encode = lambda x: positional_encoding(x, num_encoding_functions)

# Initialize model and optimizer
model = SmallNerfModel(num_encoding_functions).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

# Setup classes
query_sampler = QueryPointSamplerFromRays(training_config)
rays_from_camera_builder = RaysFromCameraBuilder(data_manager, device)
NeRF_manager = NeRFManager(encode, rays_from_camera_builder, query_sampler, model)

psnrs = []
test_img, test_pose = data_manager.get_random_image_and_pose_example()
for i in tqdm(range(num_iters)):

    target_img, target_tform_cam2world = data_manager.get_image_and_pose(i)

    rgb_predicted = NeRF_manager.forward(target_tform_cam2world)

    loss = torch.nn.functional.mse_loss(rgb_predicted, target_img)
    loss.backward()

    optimizer.step()
    optimizer.zero_grad()

    if i % display_every == 0:
        psnr = -10. * torch.log10(loss)
        psnrs.append(psnr.item())

        print("Loss:", loss.item())
        display_image(i, display_every, psnrs, rgb_predicted)

    if i == num_iters-1:
        save_image(display_every, psnrs, rgb_predicted)
        create_video(NeRF_manager, device)

print('Done!')

