In [None]:
# importing the necessary libraries
from utils import config
from utils.nerf_trainer import NeRF
from utils.nerf import get_nerf_model, render_rgb_depth
from utils.data import *
from utils.train_monitor import get_train_monitor
from rendering import render_videos



import tensorflow as tf
import numpy as np

# setting seed for reproducibility
tf.random.set_seed(42)

In [None]:
json_train_data = read_json(config.TRAIN_JSON)
json_val_data = read_json(config.VAL_JSON)
json_test_data = read_json(config.TEST_JSON)

In [None]:
train_image_paths, train_camera_to_world = get_image_c2w(jsonData=json_train_data,
                                                         datasetPath=config.DATASET_PATH)
train_images = GetImages(train_image_paths)

val_image_paths, val_camera_to_world = get_image_c2w(jsonData=json_val_data,
                                                     datasetPath=config.DATASET_PATH)
val_images = GetImages(val_image_paths)
test_image_paths, test_camera_to_world = get_image_c2w(jsonData=json_test_data,
                                                       datasetPath=config.DATASET_PATH)
test_images = GetImages(test_image_paths)

# instantiate a object of our class used to load images from disk
val_camera_to_world = np.array(val_camera_to_world)
val_camera_to_world = tf.cast(val_camera_to_world, tf.float32)


In [None]:
train_image_datasets = tf.data.Dataset.from_tensor_slices(train_images)
val_image_datasets = tf.data.Dataset.from_tensor_slices(val_images)
test_image_datasets = tf.data.Dataset.from_tensor_slices(test_images)
train_pose_dataset = tf.data.Dataset.from_tensor_slices(train_camera_to_world)
val_pose_datasets = tf.data.Dataset.from_tensor_slices(val_camera_to_world)
test_pose_datasets = tf.data.Dataset.from_tensor_slices(test_camera_to_world)

In [None]:
train_rays_dataset = train_pose_dataset.map(map_fn, num_parallel_calls=config.AUTO)
val_rays_dataset = val_pose_datasets.map(map_fn, num_parallel_calls=config.AUTO)
test_rays_dataset = test_pose_datasets.map(map_fn, num_parallel_calls=config.AUTO)

# zip the images and rays dataset together
train_dataset = tf.data.Dataset.zip((train_image_datasets, train_rays_dataset))
val_dataset = tf.data.Dataset.zip((val_image_datasets, val_rays_dataset,))
test_dataset = tf.data.Dataset.zip((test_image_datasets, test_rays_dataset,))
# build data input pipeline for train, val, and test datasets
train_dataset = (
    train_dataset
    .shuffle(config.BATCH_SIZE,)
    .batch(config.BATCH_SIZE, drop_remainder=True, num_parallel_calls=config.AUTO).
    repeat(2)
    .prefetch(config.AUTO)
)
val_dataset = (
    val_dataset
    .shuffle(config.BATCH_SIZE)
    .batch(config.BATCH_SIZE, drop_remainder=True, num_parallel_calls=config.AUTO)
    .repeat(2)
    .prefetch(config.AUTO)
)
test_dataset = (
    test_dataset
    .batch(config.BATCH_SIZE)
    .prefetch(config.AUTO)
)

In [None]:
train_monitor_callback = get_train_monitor(
    test_dataset, render_rgb_depth=render_rgb_depth, OUTPUT_IMAGE_PATH=config.OUTPUT_IMAGE_PATH)

In [None]:
num_pos = config.IMAGE_HEIGHT * config.IMAGE_WIDTH * config.NUM_SAMPLES
nerf_model = get_nerf_model(num_layers=8, num_pos=num_pos)
model = NeRF(nerf_model)

model.compile(
    optimizer=tf.keras.optimizers.Adam(), loss_fn=tf.keras.losses.MeanSquaredError()
)

In [None]:
for i in range(10):
    model.fit(
        train_dataset,
        validation_data=val_dataset,
        epochs=config.EPOCHS,
        # callbacks=[train_monitor_callback],

    )
    model.nerf_model.save(config.MODEL_PATH ,save_format='tf')
    

In [None]:
model = tf.keras.models.load_model(config.MODEL_PATH ,compile = False)

In [None]:
render_videos(nerf_model=model)