In [None]:
import torch
import torch.nn as nn
import numpy as np

import os
import imageio
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

from dataset import get_rays
from model import Nerf
from ml_helpers import training
from ml_helpers import testing
# from visualization import get_c2w_poses
# from visualization import visualize_camera_poses
# from visualization import visualize_rays

# Camera / Dataset

In [None]:
# poses = get_c2w_poses(datapath='fox', mode='train')

# visualize_camera_poses(poses)

In [None]:
batch_size = 1024
height = 400
width = 400

o, d, target_px_values = get_rays('fox', mode='train')
dataloader = DataLoader(torch.cat((torch.from_numpy(o).reshape(-1, 3).type(torch.float),
                                   torch.from_numpy(d).reshape(-1, 3).type(torch.float),
                                   torch.from_numpy(target_px_values).reshape(-1, 3).type(torch.float)), dim=1),
                       batch_size=batch_size, shuffle=True)


dataloader_warmup = DataLoader(torch.cat((torch.from_numpy(o).reshape(90, 400, 400, 3)[:, 100:300, 100:300, :].reshape(-1, 3).type(torch.float),
                               torch.from_numpy(d).reshape(90, 400, 400, 3)[:, 100:300, 100:300, :].reshape(-1, 3).type(torch.float),
                               torch.from_numpy(target_px_values).reshape(90, 400, 400, 3)[:, 100:300, 100:300, :].reshape(-1, 3).type(torch.float)), dim=1),
                       batch_size=batch_size, shuffle=True)


test_o, test_d, test_target_px_values = get_rays('fox', mode='test')

In [None]:
#visualize_rays(origins=o, directions=-d, num_rays_to_sample_per_set=1)

# Training - Change model name

In [None]:
pth_file = 'nerf_models/fox_a6.pth'

In [None]:
device = 'cuda'

tn = 8.
tf = 12.
nb_epochs = 10
lr = 1e-3
gamma = .5
nb_bins = 100

model = Nerf(hidden_dim=128).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[5, 10], gamma=gamma)


training_loss = training(model, optimizer, scheduler, tn, tf, nb_bins, 1, dataloader_warmup, model_name=pth_file, device=device)
plt.plot(training_loss)
plt.show()
training_loss = training(model, optimizer, scheduler, tn, tf, nb_bins, nb_epochs, dataloader, model_name=pth_file, device=device)
plt.plot(training_loss)
plt.show()

# Testing

In [None]:
model_pth = torch.load(pth_file)

device = 'cpu'
test_img_idx = 1

img,_,psnr = testing(model=model_pth, o=torch.from_numpy(test_o[test_img_idx]).to(device).float(), d=torch.from_numpy(test_d[test_img_idx]).to(device).float(),
            tn=0, tf=1000., nb_bins=100, chunk_size=60, H=height, W=width,target=test_target_px_values[test_img_idx].reshape(height,width,3))
    
plt.subplot(1, 2, 1)
plt.imshow(test_target_px_values[test_img_idx].reshape(height,width,3))

plt.subplot(1, 2, 2)
plt.imshow(img); print("PSNR of test image:", np.round(psnr,4)); print("Test view: ", test_img_idx)