In [None]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2"

In [None]:
from nerf.data_helper import load_data
import torch

images_train, poses_train, int_mat = load_data("bottles","train", 1)
images_val, poses_val, _ = load_data("bottles","val", 1)

images_train = torch.cat([images_train, images_val[1:23],images_val[24:39],images_val[40:44],images_val[45:]], axis=0)
poses_train = torch.cat([poses_train, poses_val[1:23], poses_val[24:39], poses_val[40:44],poses_val[45:]], axis=0)
images_train.shape, poses_train.shape
# _, poses_test, _ = load_data("./bottles", "test")
images_train.shape, poses_train.shape

In [None]:
from nerf.nerf_helper import get_rays
import torch
from tqdm import tqdm

rays_o_list, rays_d_list = [],[]

for pose in tqdm(poses_train):
    rays_o, rays_d = get_rays((800,800),int_mat, pose)
    rays_o_list.append(rays_o)
    rays_d_list.append(rays_d)

rays_o_list, rays_d_list = torch.stack(rays_o_list), torch.stack(rays_d_list)

In [None]:
rays_o_list = torch.flatten(rays_o_list, start_dim=0, end_dim=2)
rays_d_list = torch.flatten(rays_d_list, start_dim=0, end_dim=2)
rays_d_list.shape

In [None]:
import torch
images_train = torch.flatten(images_train, start_dim=0, end_dim=2)
images_train.shape

In [None]:
mask = torch.eq(images_train,1.)
mask = mask[:,0]*mask[:,1]*mask[:,2]
mask = torch.nonzero(mask==False).squeeze()
len(mask)

In [None]:
'''
Stage 1: training SingleNeRF with Data Argumentation
'''

from nerf.model import NeRF
import torch.backends.cudnn as cudnn

# define all the paremeters here
N_pos = 10
N_dir = 4
N_sample = 256
N_importance = 128
batch_size = 1024*48
fc_width = 258
fc_depth = 8
skips = [4]
lr = 5e-4
num_it = 10001
val_idx = 1
val_gap = 1000
threshold = (0,5)
checkpoint_path_coarse = "NERF_STAGE1.pt"
checkpoint_path_fine = "whatever"
psnrs = []
val_iters = []
losses = []
mini_batch = 1
cudnn.benchmark = True
cudnn.enabled = True
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

nerf_coarse = torch.nn.DataParallel(NeRF(6*N_pos, 6*N_dir,skips, fc_depth, fc_width)).to(device)
nerf_fine = None

if nerf_fine:
    optimizer = torch.optim.Adam(
        list(nerf_coarse.parameters())+list(nerf_fine.parameters()),lr=lr)
else:
    optimizer = torch.optim.Adam(nerf_coarse.parameters(),lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, gamma=0.9, step_size=1000)
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=1000, eta_min=5e-5)

In [None]:
from nerf.train import train_sampled

train_sampled(
    nerf_coarse=nerf_coarse, # coarse model
    nerf_fine=nerf_fine, # fine model
    optimizer=optimizer, # set optimizer
    scheduler=scheduler, # set scheduler
    imgs_train=torch.cat([images_train,images_train[mask]],dim=0), # flattened training set with argumentation
    rays_o_list=torch.cat([rays_o_list, rays_o_list[mask]],dim=0), # flattened rays_o with argumentation
    rays_d_list=torch.cat([rays_d_list, rays_d_list[mask]],dim=0), # flattened rays_d with argumentation
    imgs_val=images_val,
    poses_val=poses_val,
    val_idx=val_idx,
    int_mat=int_mat,
    threshold=threshold,
    N_pos=N_pos,
    N_dir=N_dir,
    N_sample=N_sample,
    N_importance=N_importance,
    checkpoint_path_coarse=checkpoint_path_coarse,
    checkpoint_path_fine=checkpoint_path_fine,
    batch_size=batch_size,
    psnrs = psnrs,
    val_iters = val_iters,
    losses = losses,
    epochs = num_it,
    val_gap = val_gap,
    mini_batch=mini_batch,
    device=device
    )

In [None]:
'''
Stage 2: training SingleNeRF without Data Argumentation
'''

from nerf.model import NeRF
import torch.backends.cudnn as cudnn

# define all the paremeters here
N_pos = 10
N_dir = 4
N_sample = 256
N_importance = 128
batch_size = 1024*48
fc_width = 258
fc_depth = 8
skips = [4]
lr = 5e-4
num_it = 10001
val_idx = 1
val_gap = 1000
threshold = (0,5)
checkpoint_path_coarse = "NERF_STAGE2.pt"
checkpoint_path_fine = "whatever"
psnrs = []
val_iters = []
losses = []
mini_batch = 1
cudnn.benchmark = True
cudnn.enabled = True
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

if nerf_fine:
    optimizer = torch.optim.Adam(
        list(nerf_coarse.parameters())+list(nerf_fine.parameters()),lr=lr)
else:
    optimizer = torch.optim.Adam(nerf_coarse.parameters(),lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, gamma=0.9, step_size=1000)
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=1000, eta_min=5e-5)

In [None]:
train_sampled(
    nerf_coarse=nerf_coarse, # coarse model
    nerf_fine=nerf_fine, # fine model
    optimizer=optimizer, # set optimizer
    scheduler=scheduler, # set scheduler
    imgs_train=images_train, # flattened training set
    rays_o_list=rays_o_list, # flattened rays_o
    rays_d_list=rays_d_list, # flattened rays_d
    imgs_val=images_val,
    poses_val=poses_val,
    val_idx=val_idx,
    int_mat=int_mat,
    threshold=threshold,
    N_pos=N_pos,
    N_dir=N_dir,
    N_sample=N_sample,
    N_importance=N_importance,
    checkpoint_path_coarse=checkpoint_path_coarse,
    checkpoint_path_fine=checkpoint_path_fine,
    batch_size=batch_size,
    psnrs = psnrs,
    val_iters = val_iters,
    losses = losses,
    epochs = num_it,
    val_gap = val_gap,
    mini_batch=mini_batch,
    device=device
    )

In [None]:
'''
Stage 3: hard copy single NeRF and construct DouleNeRF to train
'''
from nerf.model import NeRF
import torch.backends.cudnn as cudnn

# define all the paremeters here
N_pos = 10
N_dir = 4
N_sample = 96
N_importance = 128
batch_size = 1024*32
fc_width = 258
fc_depth = 8
skips = [4]
lr = 1e-4
num_it = 20001
val_idx = 1
val_gap = 1000
threshold = (0,5)
checkpoint_path_coarse = "NERF_STAGE3_COARSE.pt"
checkpoint_path_fine = "NERF_STAGE3_FINE.pt"
psnrs= []
val_iters = []
losses= []
mini_batch = 1
cudnn.benchmark = True
cudnn.enabled = True
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

nerf_coarse = torch.nn.DataParallel(NeRF(6*N_pos, 6*N_dir,skips, fc_depth, fc_width)).to(device)
nerf_fine = torch.nn.DataParallel(NeRF(6*N_pos, 6*N_dir,skips, fc_depth, fc_width)).to(device)

nerf_coarse.load_state_dict(torch.load('NERF_STAGE2.pt')['model_state_dict'])

# hard copy the saved singleNeRF
nerf_fine.load_state_dict(torch.load('NERF_STAGE2.pt')['model_state_dict'])

if nerf_fine:
    optimizer = torch.optim.Adam(
        list(nerf_coarse.parameters())+list(nerf_fine.parameters()),lr=lr)
else:
    optimizer = torch.optim.Adam(nerf_coarse.parameters(),lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, gamma=0.9, step_size=1000)
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=1000, eta_min=1e-5)

In [None]:
train_sampled(
    nerf_coarse=nerf_coarse, # coarse model
    nerf_fine=nerf_fine, # fine model
    optimizer=optimizer, # set optimizer
    scheduler=scheduler, # set scheduler
    imgs_train=images_train, # flattened training set with ehencement
    rays_o_list=rays_o_list, # flattened rays_o
    rays_d_list=rays_d_list, # flattened rays_d
    imgs_val=images_val,
    poses_val=poses_val,
    val_idx=val_idx,
    int_mat=int_mat,
    threshold=threshold,
    N_pos=N_pos,
    N_dir=N_dir,
    N_sample=N_sample,
    N_importance=N_importance,
    checkpoint_path_coarse=checkpoint_path_coarse,
    checkpoint_path_fine=checkpoint_path_fine,
    batch_size=batch_size,
    psnrs = psnrs,
    val_iters = val_iters,
    losses = losses,
    epochs = num_it,
    val_gap = val_gap,
    mini_batch=mini_batch,
    device=device
    )

In [None]:
'''
Stage 4: finetune
'''
from nerf.model import NeRF
import torch.backends.cudnn as cudnn

# define all the paremeters here
N_pos = 10
N_dir = 4
N_sample = 256
N_importance = 96
batch_size = 1024*32
fc_width = 258
fc_depth = 8
skips = [4]
lr = 1e-5
num_it = 30001
val_idx = 1
val_gap = 2000
threshold = (0,5)
checkpoint_path_coarse = "NERF_STAGE4_COARSE.pt"
checkpoint_path_fine = "NERF_STAGE4_FINE.pt"
psnrs = []
val_iters = []
losses = []
mini_batch = 1
cudnn.benchmark = True
cudnn.enabled = True
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

if nerf_fine:
    optimizer = torch.optim.Adam(
        list(nerf_coarse.parameters())+list(nerf_fine.parameters()),lr=lr)
else:
    optimizer = torch.optim.Adam(nerf_coarse.parameters(),lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.99)

In [None]:
train_sampled(
    nerf_coarse=nerf_coarse, # coarse model
    nerf_fine=nerf_fine, # fine model
    optimizer=optimizer, # set optimizer
    scheduler=scheduler, # set scheduler
    imgs_train=images_train, # flattened training set
    rays_o_list=rays_o_list, # flattened rays_o
    rays_d_list=rays_d_list, # flattened rays_d
    imgs_val=images_val,
    poses_val=poses_val,
    val_idx=val_idx,
    int_mat=int_mat,
    threshold=threshold,
    N_pos=N_pos,
    N_dir=N_dir,
    N_sample=N_sample,
    N_importance=N_importance,
    checkpoint_path_coarse=checkpoint_path_coarse,
    checkpoint_path_fine=checkpoint_path_fine,
    batch_size=batch_size,
    psnrs = psnrs,
    val_iters = val_iters,
    losses = losses,
    epochs = num_it,
    val_gap = val_gap,
    mini_batch=mini_batch,
    device=device
    )

In [None]:
from nerf.nerf_helper import nerf_step_sampled
import matplotlib.pyplot as plt
import torch.nn.functional as F
import numpy as np
from PIL import Image

# images_val, poses_val, _ = load_data("bottles","val", 1)
_, poses_test, _ = load_data("./bottles", "test")
idxs = [23,39,44] 
for idx in idxs:
    with torch.no_grad():
        rays_o, rays_d = get_rays((800,800), int_mat.cpu(), poses_val[idx])

        pred_rgb = torch.zeros((800,800,3)).to(device)
        # pred_depth = torch.zeros((800,800)).to(device)
        for i in range(4):
            for j in range(4):
                sub_rays_o = rays_o[i*200:(i*200)+200,j*200:(j*200)+200,:].to(device)
                sub_rays_d = rays_d[i*200:(i*200)+200,j*200:(j*200)+200,:].to(device)

                h,w = 200,200

                _, fine_out = nerf_step_sampled(
                    nerf_coarse,
                    nerf_fine,
                    (h,w),
                    sub_rays_o,
                    sub_rays_d,
                    threshold,
                    device,
                    True,
                    True,
                    N_pos,
                    N_dir,
                    N_sample,
                    N_importance,
                    batch_size
                )

                pred_rgb[i*200:(i*200)+200,j*200:(j*200)+200,:] = fine_out[0]
                # pred_depth[i*200:(i*200)+200,j*200:(j*200)+200] = fine_out[1]
        l = F.mse_loss(pred_rgb, images_val[idx].to(device))
        psnr = -10.*torch.log10(l)
        print(f'psnr of {idx}:', psnr)
        im = Image.fromarray((np.round(pred_rgb.cpu().numpy()*255)).astype(np.uint8)).convert('RGB')
        im.save(f"val_pred/final_stage/{idx}.png")