##Tiny NeRF
This is a simplied version of the method presented in *NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis*

[Project Website](http://www.matthewtancik.com/nerf)

[arXiv Paper](https://arxiv.org/abs/2003.08934)

[Full Code](github.com/bmild/nerf)

Components not included in the notebook
*   5D input including view directions
*   Hierarchical Sampling




________________________________________________________________________________
# Cambios/comentarios que dejamos

## Cambios grandes

- Cambiamos el config_loader para que pueda crearse fácil una configuración como un objeto con parámetros. La clase ahora se llama ConfigurationLoader.

- Cambiamos el tracker de métricas para que guarde más cosas. Fijate que el loss, PSNR ahora se guarda ahí.

- En el ciclo de entrenamiento, dividimos la parte de train y la parte de test en dos funciones separadas.

- La función de renderizado de imágenes ahora solo hace el plot. Al plot agregamos también una gráfico del loss.

- La información sobre PSNR y loss ahora se guardan en todas iteraciones (usando el monitor) pero se muestran como antes (cada 50 iters). Fijate que las curvas son más suaves.


## Comentarios menores y dudas

- Dejamos cosas comentadas. La idea es que veas los cambios. Cuando estés conforme, borrá las porciones de código comentadas.

- El valor de 1024*8 de chunk_size de donde sale? Tiene que ver con el tamaño de las imágenes? Habría que parametrizarlo? Habría que ponerlo en el configurationLoader?

    -) Y por qué multiplica por 8?

- Importante: necesitamos entender qué es lo que retorna la función volume_render (y por tanto render_ray).

- Revisar bien qué cosas deberían estar como variables globales, y qué cosas deberían pasarse como parámetros.

    - lista de imágenes? trainingMonitor? configurationLoader?



In [None]:
import os, sys
import json
import torch
import cv2
import pandas as pd
import json
import imageio
import time
import random

from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt

In [None]:
def generate_random_seed():
    return random.randint(1, 1000)

def set_random_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device == 'cuda':
    torch.cuda.empty_cache()
print(device)

# Loading dataset

## Load default data

In [None]:
if not os.path.exists('tiny_nerf_data.npz'):
    !wget http://cseweb.ucsd.edu/~viscomp/projects/LF/papers/ECCV20/nerf/tiny_nerf_data.npz
exp_name = 'tiny_nerf_data'

## Load custom data (in progress)

Introduce your experiment name (name of the dataset where images and poses are) and the desire factor to downsize the images in the dataset, 1 keeps the size of the original image and the more this number increase the lower the final resolution of the images would be.

*(Keep in mind that more resolution also means more GPU usage so adjust chunk_size in configuration and your Accelerator accordingly)

**(The requiered poses are generated by the script given in https://github.com/NVlabs/instant-ngp/blob/master/docs/nerf_dataset_tips.md)

TODO : find a way to adjust focal and load dataset from link

In [None]:
# exp_name is the name of the dataset
exp_name = 'mate'
dataset_name = 'mate-dataset'

# Default value no resize
resize_ratio = 16

In [None]:
# Load custom data (poses generated by script and stored in transforms.json)
focal = 130 # Hardcoded (TODO)

# Path to your .json file
json_file_path = "/kaggle/input/"+dataset_name+"/transforms.json"

# Open and load the JSON data
with open(json_file_path, 'r') as file:
    data = json.load(file)
    raw_poses = data['frames']

def get_pose_from_image(image_name):
    found = False
    i = 0
    number_poses = len(raw_poses)
    raw_pose = False
    while not found and i < number_poses:
        img_pose_path = raw_poses[i]['file_path']
        if img_pose_path.endswith(image_name):
            raw_pose = raw_poses[i]['transform_matrix']
            found = True
        else:
            i = i + 1
    return raw_pose

exp_image_dir = "/kaggle/input/"+dataset_name+"/images/"
exp_image_files = os.listdir(exp_image_dir)

images_list = []
poses_list = []
index = 0

for image_file in exp_image_files: 
    image_path = os.path.join(exp_image_dir, image_file)
    
    # Search pose for image
    raw_pose = get_pose_from_image(image_file)
    
    # Load the image 
    img = cv2.imread(image_path)
    (height, width, _) = img.shape
    resized_height = height / resize_ratio
    resized_width = width / resize_ratio
    resized_dimensions = (int(resized_width),int(resized_height))
    
    # Prune if image exist but no pose is asocciated
    if img is not None and raw_pose:
        # Convert BGR to RGB since OpenCV loads images in BGR format
        img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        # Resize image if needed (also the original code assumes the images are in float32)
        img_rgb = cv2.resize(img_rgb, resized_dimensions, interpolation = cv2.INTER_CUBIC).astype(np.float32)
        img_rgb /= 255

        # Append the image to the list
        images_list.append(img_rgb)
        poses_list.append(raw_pose)

# Convert the list of images to a NumPy array with shape (num_images, width, height, rgb)
images_array = np.array(images_list)
poses_array = np.array(poses_list).astype(np.float32)

print(images_array.dtype)
print(poses_array.dtype)

# Check the shape of the array to confirm
print(images_array.shape, poses_array.shape)

output_path = '/kaggle/working/'+exp_name+'.npz'

np.savez_compressed(output_path, images=images_array, poses=poses_array, focal=focal)

# Load Input Images and Poses

In [None]:
# Loading dataset (either using custom or given data)
load_path = exp_name + '.npz'
data = np.load(load_path)

# Retriving images, poses, focal length and dimension for training
images = data['images']

print("1.", len(images))

poses = data['poses']
poses = torch.from_numpy(poses).to(device)

focal = data['focal']
focal = torch.from_numpy(focal).to(device)
H, W = images.shape[1:3]

print(images.shape, poses.shape, focal)

amount_images = images.shape[0]-1 #The last one is for holdout view

print(amount_images)

# Getting pose and image to compare in holdout view 
testimg = images[amount_images-4] # por qué es -4?
testimg = torch.from_numpy(testimg).to(device)

testpose =  poses[amount_images-4]

# Trim the amount of images to not contain the holdout image
images = torch.from_numpy(images[:amount_images]).to(device)
poses = poses[:amount_images]

print("2.", len(images))


# Show image used for holdout view
plt.imshow(testimg.detach().cpu().numpy())
plt.show()

# Encoding

In [None]:
# Comentar para qué sirve esta función
# Cambiar el nombre a positionalEncoding
# Comentar qué es el parámetro x y de qué tipo/estructura se trata
# Cambié el for interno que no servía para nada.
# Cambié el nombre de la variable de salida.

# Function that applies the positional encoding talked in the paper (5D coordiantes to a higher dimension)
# x is the 5D coordinates
def positionalEncoding(x):
#     rets = [x]
    output = [x]
    # We follow (sin(2^L-1)*π*x) for each coordinate in x
    for i in range(config_loader.L_embed):
    #    for fn in [tf.sin, tf.cos]:
    #        rets.append(fn(2.**i * x))
            output.append(torch.sin(2.**i * x))
            output.append(torch.cos(2.**i * x))     
   # return tf.concat(rets, -1)
    output = torch.cat(output, -1)
    return output

# Configuration

## Definitions

In [None]:
# Class for setting up configuration of experiments
# TODO see which arguments store or change (and how to do it)
class ConfigurationLoader:
    
    def __init__(self,
                 seed,
                 L_embed = 6, learning_rate = 1e-3, 
                 N_samples = 32, N_iters = 1000, 
                 iterations_to_plot_image = 100, 
                 chunk_size = 512, 
                 image_selection_mode = 'random'):

        self.seed = seed
        self.L_embed = L_embed
        self.learning_rate = learning_rate
        self.N_samples = N_samples

        # Max iteration
        self.N_iters = N_iters
        
        # Plot every i iterations
        self.iterations_to_plot_image = iterations_to_plot_image
        
        # Amount of rays for each ray (decrease if running out of memory)
        self.chunk_size = chunk_size

        # image selection mode
        self.image_selection_mode = image_selection_mode

        # embedding function
        self.embed_fn  = positionalEncoding
        
    def get_configuration(self) :
        configuration_string = {
            "seed" : self.seed,
            "L_embed" : self.L_embed,
            "learning_rate" : self.learning_rate,
            "N_samples" : self.N_samples,
            "chunk_size" : self.chunk_size,
            "image_selection_mode" : self.image_selection_mode
        }
        return configuration_string
   
    def load_configuration(self, path) :
        if os.path.exists(path):
            with open(path, 'r') as file:
                data = json.load(file)
                self.seed = data['seed']
                self.L_embed = data['L_embed']
                self.learning_rate = data['learning_rate']
                self.N_samples = data['N_samples']
                self.chunk_size = data['chunk_size']
                self.image_selection_mode = data['image_selection_mode']
            print('Configuration loaded from ',path)
            return True
        else:
            print('Configuration file not found')
            return False
        

## Default configuration

Setup custom configuration or load configuration from a file (if configuration file not found we setup a default configuration)

In [None]:
path_to_config = input("Path to .json file containing configuration")

random_seed = generate_random_seed()
set_random_seed(random_seed)
config_loader = ConfigurationLoader(random_seed) # default configuration 


if (config_loader.load_configuration(path_to_config)):
    set_random_seed(config_loader.seed)
print(config_loader.seed)


## Auxilar functions

Functions for tracking data

In [None]:
def get_current_time():
    format_time = "%Y-%m-%d-%H-%M-%S"
    return time.strftime(format_time, time.localtime())
    
# Class for storing information of experiments
# class DebugLogger:

class TrainingMonitor:    
    def __init__(self, exp_name):
        self.exp_name = exp_name
        self.exp_start_time = get_current_time()
        
        snapshot_columns = ['Time','Iteration','Loss','PSNR']
        
        # Create an empty dataframe with data to track
        self.exp_data = pd.DataFrame(columns=snapshot_columns)

        self.psnrs = []
        self.iternums = []
        self.losses = []

    def get_psnrs(self):
        return self.psnrs

    def get_losses(self):
        return self.losses

    def get_iternums(self):
        return self.iternums

    def save_psnr(self, psnr):
        self.psnrs.append(psnr)

    def save_loss(self, loss):
        self.losses.append(loss)

    def save_iternum(self, iternum):
        self.iternums.append(iternum)
        
    def store_snapshot(self, iter_num, loss, psnr):
        snapshot = {'Time':get_current_time(),'Iteration':iter_num,'Loss':loss,'PSNR':psnr}
        self.exp_data = pd.concat([self.exp_data, pd.DataFrame([snapshot])], ignore_index=True)
        self.exp_data.to_csv(self.exp_name+"("+str(self.exp_start_time)+")"+".csv",index=False)
        
    #Store the model configuration also
    def save_exp_configuration(self, config_loader):
        with open(self.exp_name+"("+str(self.exp_start_time)+")"+"_configuration.json", mode="w", encoding="utf-8") as config_json_file:
            json.dump(config_loader.get_configuration(), config_json_file)


# NerfModel

In [None]:
class TinyNerfModel(torch.nn.Module):
    def __init__(self, D=8, W=256, L_embed=6):
        
        super(TinyNerfModel, self).__init__()
        input_dim = 3 + 3* 2 * L_embed

        # Track current dimension as we go.
        current_dim = input_dim
        self.layers = torch.nn.ModuleList()

        # Creating hidden layers and skip layers
        # current_dim is tracking the dimension of each layer to allow connection between skip layer and hidden layer
        for i in range(D):
            
            self.layers.append(torch.nn.Linear(current_dim, W))

            # Skip layer 
            if (i % 4 == 0) and (i > 0):
                current_dim = W + input_dim
            else:
                current_dim = W
        
        # Final layer with linear activation (RGBα)
        self.final_layer = torch.nn.Linear(current_dim, 4)
        
        self.D = D
        self.W = W
        self.input_dim = input_dim

    def forward(self, x):
        inputs = x
        outputs = x

        # We pass forward each layer
        for i, layer in enumerate(self.layers):
            outputs = layer(outputs)
            outputs = torch.nn.functional.relu(outputs)

            # Skip layers concats 
            if (i % 4 == 0) and (i > 0):
                outputs = torch.cat([outputs, inputs], dim=-1)
        
        outputs = self.final_layer(outputs)
        return outputs

## Auxiliar function to model

In [None]:
from torchmetrics.image import StructuralSimilarityIndexMeasure

def cumprod_exclusive(tensor: torch.Tensor) -> torch.Tensor:
    r"""Mimick functionality of tf.math.cumprod(..., exclusive=True), as it isn't available in PyTorch.
    Args:
        tensor (torch.Tensor): Tensor whose cumprod (cumulative product, see `torch.cumprod`) along dim=-1 is to be computed.
  
    Returns:
        cumprod (torch.Tensor): cumprod of Tensor along dim=-1, mimiciking the functionality of
      tf.math.cumprod(..., exclusive=True) (see `tf.math.cumprod` for details).
    """
    # TESTED
    # Only works for the last dimension (dim=-1)
    dim = -1
    # Compute regular cumprod first (this is equivalent to `tf.math.cumprod(..., exclusive=False)`).
    cumprod = torch.cumprod(tensor, dim)
    # "Roll" the elements along dimension 'dim' by 1 element.
    cumprod = torch.roll(cumprod, 1, dim)
    # Replace the first element by "1" as this is what tf.cumprod(..., exclusive=True) does.
    cumprod[..., 0] = 1.
  
    return cumprod

def get_image_random_index(images):
    img_index = np.random.randint(images.shape[0])
    return img_index

def get_image_by_index(img_i):
    return images[img_i]

def get_pose_by_index(img_i):
    return poses[img_i]

def has_to_plot(i, iterations_to_plot_image):
    return (i % iterations_to_plot_image) == 0

def calculate_loss(img_1, img_2):
    # alpha = 0.98
    # beta = 1 - alpha
    mse = torch.nn.functional.mse_loss(img_1, img_2)
    # img_1 = img_1.unsqueeze(0)
    # img_2 = img_2.unsqueeze(0)
    # img_1 = torch.nn.functional.pad(img_1, (4, 4, 4, 4), mode='constant', value=0)
    # img_2 = torch.nn.functional.pad(img_2, (4, 4, 4, 4), mode='constant', value=0)
    # ssim_value = ssim(img_1,img_2)
    # loss = mse * alpha + ssim_value * beta
    loss = mse
    return loss

# Create the SSIM metric
ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to(device)

# def render_image_training(H, W, focal, pose, testimage, model):
#     # We obtain the rays to query the model
#     rays_o, rays_d = get_rays(H, W, focal, pose)
#     rgb, depth, acc = render_rays(model, rays_o, rays_d, near=2., far=6., N_samples=config_loader.N_samples, chunk_size=config_loader.chunk_size)

#     # Calculate loss between expected image and obtained
#     loss = torch.nn.functional.mse_loss(rgb, testimage)

#     # Peak signal-to-noise-ratio metric for comparison
#     psnr = -10. * torch.log10(loss)

#     # Store data
#     training_monitor.store_snapshot(i, loss.item(), psnr.item())
    
#     psnrs.append(psnr.item())
#     iternums.append(i)
  
#     plt.figure(figsize=(10,4))
#     plt.subplot(121)
#     plt.imshow(rgb.detach().cpu().numpy())
#     plt.title(f'Iteration: {i}')
#     plt.subplot(122)
#     plt.plot(iternums, psnrs)
#     plt.title('PSNR')
#     plt.show()

# def predict_testing_image_and_save_info(H, W, focal, pose, testimage, model, training_monitor):
#     # We obtain the rays to query the model
#     rays_o, rays_d = get_rays(H, W, focal, pose)
#     rgb, depth, acc = render_rays(model, rays_o, rays_d, near=2., far=6., N_samples=config_loader.N_samples, chunk_size=config_loader.chunk_size)

#     # Calculate loss between expected image and obtained
#     loss = torch.nn.functional.mse_loss(rgb, testimage)

#     # Peak signal-to-noise-ratio metric for comparison
#     psnr = -10. * torch.log10(loss)

#     # Store data
#     training_monitor.store_snapshot(i, loss.item(), psnr.item())

def generate_checkpoint(model, iteration_number, loss):
    path = '/kaggle/working/'+exp_name+'_checkpoint.pt'
    #torch.save(model.state_dict(),path)
    torch.save({
            'epoch': iteration_number,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss
            }, path)

def load_checkpoint(model, path):
    iteration_number = 0
    if os.path.exists(path):
        checkpoint = torch.load(path, weights_only=True)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        iteration_number = checkpoint['epoch']
        loss = checkpoint['loss']
    
        #model.load_state_dict(torch.load(path, weights_only=True))
        #model.eval()
        model.train()
        print('Checkpoint found at '+str(iteration_number)+" iterations")
    else:
        print('No checkpoint to load, starting fresh')

    return iteration_number

def calculate_and_save_training_metrics(rgb, test_image, iteration_number):
    # Calculate loss between expected image and obtained
    # loss = torch.nn.functional.mse_loss(rgb, test_image)
    
    # rgb = rgb.unsqueeze(0)
    # test_image = testimg.unsqueeze(0)
    # rgb = torch.nn.functional.pad(rgb, (4, 4, 4, 4), mode='constant', value=0)
    # test_image = torch.nn.functional.pad(test_image, (4, 4, 4, 4), mode='constant', value=0)
    # loss = ssim(rgb,test_image)

    loss = calculate_loss(rgb,test_image)
    
    # Peak signal-to-noise-ratio metric for comparison
    psnr = -10. * torch.log10(loss)
    
    # Store data
    training_monitor.store_snapshot(iteration_number, loss.item(), psnr.item())
    training_monitor.save_loss(loss.item())
    training_monitor.save_psnr(psnr.item())
    training_monitor.save_iternum(iteration_number)

def plot_image_and_metrics(rgb, iteration_number):

    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(14, 4))

    # Primer gráfico: Imagen predicha
    ax1.imshow(rgb.detach().cpu().numpy())
    ax1.set_title(f'Predicted Image (Iteration: {iteration_number})')
    
    # Segundo gráfico: PSNR
    ax2.plot(training_monitor.get_iternums(), training_monitor.get_psnrs(), label='PSNR', color='blue')
    ax2.set_title('PSNR')
    ax2.set_xlabel('Iterations')
    ax2.set_ylabel('PSNR')
    ax2.legend()
    
    # Tercer gráfico: Loss
    ax3.plot(training_monitor.get_iternums(), training_monitor.get_losses(), label='Loss', color='blue')
    ax3.set_title('Loss')
    ax3.set_xlabel('Iterations')
    ax3.set_ylabel('Loss')
    ax3.legend()
    
    # Ajustar diseño y mostrar
    plt.tight_layout()
    plt.show()


def render_image(H, W, focal, pose, test_image, model, 
                 iteration_number):
    rays_o, rays_d = get_rays(H, W, focal, pose)
    rgb, depth, acc = render_rays(model, rays_o, rays_d, near=2., far=6., N_samples=config_loader.N_samples, chunk_size=config_loader.chunk_size)

    # if in_training:
    #     calculate_and_save_training_metrics(rgb, test_image, iteration_number)

    plot_image_and_metrics(rgb, iteration_number)

    # Estos dos listas las pasamos al TrainingMonitor
    # psnrs.append(psnr.item())
    # iternums.append(i)

    # fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))

    # # First subplot
    # ax1.imshow(rgb.detach().cpu().numpy())
    # ax1.set_title(f'Predicted Image (Iteration: {iteration_number})')
    
    # # Second subplot
    # ax2.plot(training_monitor.get_iternums(), training_monitor.get_psnrs(), label='PSNR')
    # # ax2.plot(training_monitor.get_iternums(), training_monitor.get_losses(), label='Loss')
    # ax2.set_title('PSNR and Loss')
    
    # # Adjust layout and display the figure
    # plt.tight_layout()
    # plt.show()
  
    # plt.figure(figsize=(10,4))
    # plt.subplot(121)
    # plt.imshow(rgb.detach().cpu().numpy())
    # plt.title(f'Predicted Image (Iteration: {iteration_number})')
    # plt.subplot(122)
    # plt.plot(training_monitor.get_iternums(), training_monitor.get_psnrs())
    # plt.plot(training_monitor.get_iternums(), training_monitor.get_losses())
    # plt.title('PSNR')
    # plt.show()

# def render_image(H, W, focal, pose, model):
#     # We obtain the rays to query the model
#     rays_o, rays_d = get_rays(H, W, focal, pose)
#     rgb, depth, acc = render_rays(model, rays_o, rays_d, near=2., far=6., N_samples=config_loader.N_samples, chunk_size=config_loader.chunk_size)

#     # We print the rgb map
#     plt.figure(figsize=(10,4))
#     plt.subplot(121)
#     plt.imshow(rgb)
#     plt.show()

def get_rays(H, W, focal, c2w):
    
    # Create tensors that represent the image for accessing pixels
    i, j = torch.meshgrid(torch.arange(W).to(c2w), torch.arange(H).to(c2w))
    i, j =  i.transpose(-1, -2), j.transpose(-1, -2)
    
    # Normalize the coordinates between [-0.5,0.5] and scale it by the focal
    x_cord = (i-W*.5)/focal
    y_cord = -(j-H*.5)/focal #We invert the axis
    z_cord = -torch.ones_like(i) #Points infront of camera
    
    # Generate director vectors in camera coordinate for each pixel
    dirs = torch.stack([x_cord, y_cord, z_cord], -1)
    
    # We need to add a new dimension to the tensor to be compatible with transform matrix
    dirs = dirs[..., np.newaxis, :]
    
    # Transform matrix to world coordinates (3x3) given by the pose
    tf_matrix = c2w[:3,:3]
    
    # We translate the director vector of the rays from the camera coordinates to world coordinates
    rays_d = torch.sum(dirs * tf_matrix, -1)
    
    # We set the origin of the rays
    
    cam_origin = c2w[:3,-1]
    rays_o = cam_origin.expand(rays_d.shape)
    return rays_o, rays_d

def volume_render(raw, ray_o, depth_values):
    
    # Compute opacities and colors
    sigma_a = torch.nn.functional.relu(raw[...,3])
    rgb = torch.sigmoid(raw[...,:3]) 
     
    one_e_10 = torch.tensor([1e10], dtype=ray_o.dtype, device=ray_o.device)
    dists = torch.cat((depth_values[..., 1:] - depth_values[..., :-1],
                  one_e_10.expand(depth_values[..., :1].shape)), dim=-1)
    alpha = 1. - torch.exp(-sigma_a * dists)
    weights = alpha * cumprod_exclusive(1. - alpha + 1e-10)
    
    rgb_map = (weights[..., None] * rgb).sum(dim=-2)
    depth_map = (weights * depth_values).sum(dim=-1)
    acc_map = weights.sum(-1)
    
    return rgb_map, depth_map, acc_map

def render_rays(network_fn, rays_o, rays_d, near, far, N_samples, chunk_size, rand=False):

    def get_minibatches(inputs: torch.Tensor, chunk_size = 1024 * 8):
        return [inputs[i:i + chunk_size] for i in range(0, inputs.shape[0], chunk_size)]
    
    # Compute 3D query points (TODO maybe refactor to another function)
    z_vals = torch.linspace(near, far, N_samples).to(rays_o)
    if rand:
        noise_shape = list(rays_o.shape[:-1]) + [N_samples]
        # depth_values: (num_samples)
        z_vals = z_vals \
        + torch.rand(noise_shape).to(rays_o) * (far
            - near) / N_samples
    
    pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None]

    # Run network
    pts_flat = pts.reshape((-1,3))
    pts_flat = config_loader.embed_fn(pts_flat)

    # Split the encoded points into "chunks", run the model on all chunks, and
    # concatenate the results (to avoid out-of-memory issues).
    batches = get_minibatches(pts_flat, chunk_size=chunk_size)
    predictions = []
    for batch in batches:
        predictions.append(model(batch))
    raw = torch.cat(predictions, dim=0)

    raw_shape = list(pts.shape[:-1]) + [4]
    raw = torch.reshape(raw, raw_shape)
    
    # Do volume rendering based on opacities and color of sampled points
    rgb_map, depth_map, acc_map = volume_render(raw, rays_o, z_vals)

    return rgb_map, depth_map, acc_map

Here we optimize the model. We plot a rendered holdout view and its PSNR every specified iterations.

In [None]:
def training_stage(H, W, focal, model):

    # esto lo pusimos así por las dudas, para el futuro.
    if config_loader.image_selection_mode == 'random':
        image_index = get_image_random_index(images)
    elif config_loader.image_selection_mode == 'angle':
        image_index =  get_image_random_index(images) # <<placeholder, está igual que el then.

    target_image = get_image_by_index(image_index)
    target_image_pose = get_pose_by_index(image_index)
    target_image.to(device)
    target_image_pose.to(device)
    rays_o, rays_d = get_rays(H, W, focal, target_image_pose)
    
    rgb, depth, acc = render_rays(model, rays_o, rays_d, near=2., far=6., N_samples=config_loader.N_samples, chunk_size=config_loader.chunk_size, rand=True)
    
    loss = calculate_loss(rgb,target_image)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

def testing_stage(H, W, focal, model, testimg, testpose, iteration_number, start_time):
    # testing with holdout image
    rays_o, rays_d = get_rays(H, W, focal, testpose)
    rgb, depth, acc = render_rays(model, rays_o, rays_d, near=2., far=6., N_samples=config_loader.N_samples, chunk_size=config_loader.chunk_size)

    calculate_and_save_training_metrics(rgb, testimg, iteration_number)

    loss = calculate_loss(rgb,testimg)
    if has_to_plot(iteration_number, config_loader.iterations_to_plot_image):
        print(iteration_number, (time.time() - start_time) / config_loader.iterations_to_plot_image, 'secs per iter ')
        print('Loss ', loss.item())
        generate_checkpoint(model, iteration_number, loss.item())
        start_time = time.time()
        # render_image(H, W, focal, testpose, testimg, model, i, in_training = True)
        plot_image_and_metrics(rgb, iteration_number)

If you have any checkpoint for this specific model and you wish to continue training specify the path to obtain the .pt file to continue training (leave empty to use default checkpoint on output folder of the notebook)

In [None]:
checkpoint_path = input('Path to checkpoint.pt')

if (checkpoint_path == ''):
    checkpoint_path = '/kaggle/working/'+exp_name+'_checkpoint.pt'
    print("Default path ",'/kaggle/working/'+exp_name+'_checkpoint.pt')
else:
    print("Checkpoint path setted to ",checkpoint_path)

## Train model

In [None]:
# Create model
model = TinyNerfModel(config_loader.L_embed)
model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=config_loader.learning_rate)

# Load checkpoint if it exists

iteration_number_checkpoint = load_checkpoint(model,checkpoint_path)

# EStas dos ahora están en el monitor.
# psnrs = []
# iternums = []

training_monitor = TrainingMonitor(exp_name) # cambiamos el nombre de la clase y de la variable. 

# seed = 9458
# torch.manual_seed(seed)
# np.random.seed(seed)
training_monitor.save_exp_configuration(config_loader);
start_time = time.time()


for i in range(config_loader.N_iters+1): # acá borré un +1 en el parámetro

    training_stage(H, W, focal, model)

    testing_stage(H, W, focal, model, testimg, testpose, iteration_number_checkpoint + i, start_time)

    # img_i = np.random.randint(images.shape[0])

    # # esto lo pusimos así por las dudas, para el futuro.
    # if config_loader.image_selection_mode == 'random':
    #     image_index = get_image_random_index(images)
    # elif config_loader.image_selection_mode == 'angle':
    #     image_index =  get_image_random_index(images) # <<placeholder, está igual que el then.

    # target_image = get_image_by_index(image_index)
    # target_image_pose = get_pose_by_index(image_index)
    # target_image.to(device)
    # target_image_pose.to(device)
    # rays_o, rays_d = get_rays(H, W, focal, target_image_pose)
    
    # target = images[img_i].to(device)
    # pose = poses[img_i].to(device)
    # rays_o, rays_d = get_rays(H, W, focal, pose)
    
    # rgb, depth, acc = render_rays(model, rays_o, rays_d, near=2., far=6., N_samples=config_loader.N_samples, chunk_size=config_loader.chunk_size, rand=True)
    # loss = torch.nn.functional.mse_loss(rgb, target_image)
    # loss.backward()
    # optimizer.step()
    # optimizer.zero_grad()

    # # testing with holdout image
    # rays_o, rays_d = get_rays(H, W, focal, testpose)
    # rgb, depth, acc = render_rays(model, rays_o, rays_d, near=2., far=6., N_samples=config_loader.N_samples, chunk_size=config_loader.chunk_size)

    # calculate_and_save_training_metrics(rgb, testimg, i)

    # # cambiamos esto para tener una única función render_image
    # if has_to_plot(i, config_loader.iterations_to_plot_image):
    #     print(i, (time.time() - start_time) / config_loader.iterations_to_plot_image, 'secs per iter ')
    #     print('Loss ',loss.item())
    #     start_time = time.time()
    #     render_image(H, W, focal, testpose, testimg, model, i, in_training = True)

    


    # Peak signal-to-noise-ratio metric for comparison
    # psnr = -10. * torch.log10(loss)

    # store information with the monitor.
    # training_monitor.store_snapshot(i, loss.item(), psnr.item())
    
    # if has_to_plot(i, config_loader.iterations_to_plot_image):
        # Render the holdout view for logging
        # render_image_training(H,W, focal, testpose, testimg, model)
        
print('Done')

# Interactive Visualization

In [None]:
%matplotlib inline
from ipywidgets import interactive, widgets


trans_t = lambda t : torch.tensor([
    [1,0,0,0],
    [0,1,0,0],
    [0,0,1,t],
    [0,0,0,1],
], dtype=torch.float32, device=device)

rot_phi = lambda phi : torch.as_tensor([
    [1,0,0,0],
    [0,torch.cos(phi),-torch.sin(phi),0],
    [0,torch.sin(phi), torch.cos(phi),0],
    [0,0,0,1],
], dtype=torch.float32, device=device)

rot_theta = lambda th : torch.as_tensor([
    [torch.cos(th),0,-torch.sin(th),0],
    [0,1,0,0],
    [torch.sin(th),0, torch.cos(th),0],
    [0,0,0,1],
], dtype=torch.float32, device=device)


def pose_spherical(theta, phi, radius):
    phi = torch.tensor(phi)
    theta = torch.tensor(theta)
    c2w = trans_t(radius)
    c2w = rot_phi(phi/180.*np.pi) @ c2w
    c2w = rot_theta(theta/180.*np.pi) @ c2w
    c2w = torch.tensor([[-1, 0, 0, 0],
                        [0, 0, 1, 0],
                        [0, 1, 0, 0],
                        [0, 0, 0, 1]], dtype=torch.float32, device=device) @ c2w
    return c2w


def f(theta, phi, radius):
    c2w = pose_spherical(theta, phi, radius)
    rays_o, rays_d = get_rays(H, W, focal, c2w[:3,:4])
    rgb, depth, acc = render_rays(model, rays_o, rays_d, near=2., far=6., N_samples=config_loader.N_samples, chunk_size=config_loader.chunk_size)
    
    plt.figure(2, figsize=(20,6))
    plt.imshow(rgb.detach().cpu())
    plt.show()
    

sldr = lambda v, mi, ma: widgets.FloatSlider(
    value=v,
    min=mi,
    max=ma,
    step=.01,
)

names = [
    ['theta', [100., 0., 360]],
    ['phi', [-30., -90, 0]],
    ['radius', [4., 3., 5.]],
]

interactive_plot = interactive(f, **{s[0] : sldr(*s[1]) for s in names})
output = interactive_plot.children[-1]
output.layout.height = '350px'
interactive_plot

# Comparing model prediction vs Ground truth

In [None]:
from IPython import display

def update_image(rgb_map, pose_index):
    ax1.imshow(rgb_map.cpu().detach())
    ax2.imshow(get_image_by_index(pose_index).cpu().detach())
    plt.pause(time_between_images)
    display.display(fig)
    display.clear_output(wait=True)

def plot_image_by_pose_index(pose_index):
    c2w = get_pose_by_index(pose_index)
    rays_o, rays_d = get_rays(H, W, focal, c2w[:3,:4])
    rgb, depth, acc = render_rays(model, rays_o, rays_d, near=2., far=6., N_samples=config_loader.N_samples, chunk_size=config_loader.chunk_size)  
    update_image(rgb, pose_index)

time_between_images = 0.01
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 6))
ax1.set_title('Predicted')
ax2.set_title('Ground truth')
for i in range(1,len(poses)):
    plot_image_by_pose_index(i)

# Render 360 Video

In [None]:
frames = []
for th in tqdm(np.linspace(0., 360., 120, endpoint=False)):
    c2w = pose_spherical(th, -30., 4.)
    rays_o, rays_d = get_rays(H, W, focal, c2w[:3,:4])
    rgb, depth, acc = render_rays(model, rays_o, rays_d, near=2., far=6., N_samples=config_loader.N_samples, chunk_size=config_loader.chunk_size)
    rgb = rgb.detach().cpu().numpy()
    parsed_img = (255*np.clip(rgb,0,1)).astype(np.uint8)
    
    # We need to parse the image at the end for the color to appear correctly
    frame = cv2.cvtColor(parsed_img, cv2.COLOR_RGB2BGR)
    
    frames.append(frame)

import imageio as iio
f = 'video.mp4'

# Video encoding
fourcc = cv2.VideoWriter_fourcc(*'vp09') # We can also use mp4v but it doesn't show on the cell below
video = cv2.VideoWriter(f, fourcc, 30, (W,H))
for frame in frames:
    video.write(frame)
video.release()


Display Video

In [None]:
from IPython.display import HTML
from base64 import b64encode
mp4 = open('video.mp4','rb').read()
data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
HTML("""
<video width=400 controls autoplay loop>
      <source src="%s" type="video/mp4">
</video>
""" % data_url)