# NeRF: Collision Handling in Instant Neural Graphics Primitives
Federico Montagna (fedemonti00@gmail.com)

## Install requirements

In [None]:
# %pip install -r requirements.txt
# %load_ext autoreload
# %autoreload 2

## Imports

In [None]:
from functions import *
from utils import *
from models import *
from datasets import *

In [None]:
# input = torch.randn(5, 3).requires_grad_(True)
# indices = torch.tensor([0., 0., 0., 0., 0.]).requires_grad_(True)
# print("Input:", input, input.shape)
# print("Indices:", indices, indices.shape)

# output = DifferentiableIndexing().apply(input, indices, Custom_Indexing())  # Call the function
# print("Output:", output, output.shape)

# # Assuming output is used in a computation...
# loss = output.sum()
# loss.backward()

# # Gradients will be available in input.grad and custom_index.indices.grad
# print(indices.grad) 

##  Load configuration

In [None]:
with open("config.yaml", "r") as f:
    config = yaml.safe_load(f)
    config = config_apply_rules(config)


In [None]:
set_random_seed(seed=config["train_params"]["random_seed"])

## Load data

In [None]:
images_dataset = ImageDataset(
    should_log=config["log_flags"]["dataset"],
    should_use_all_levels=config["train_params"]["learnable_hash_model"]["should_use_all_levels"],
    device=device,
    **config["train_params"]["dataset"]
)


## Load models

In [None]:
learnableHashFunctionModel = LearnableHashFunctionModel(
    num_levels=config["train_params"]["dataset"]["num_levels"],
    input_size=images_dataset.get_num_dimensions(),
    hash_table_size=eval(config["train_params"]["dataset"]["hash_table_size"]),
    should_static_hash=config["flags"]["should_fast_hash"],
    # should_use_all_levels=config["train_params"]["dataset"]["should_use_all_levels"],
    should_log=config["log_flags"]["learnable_hash_model"],
    device=device,
    **config["train_params"]["learnable_hash_model"],
)

gngfModel = GNGFModel(
    learnable_hash_function_model=learnableHashFunctionModel,
    batch_size=len(images_dataset),
    num_levels=config["train_params"]["dataset"]["num_levels"],
    hash_table_size=eval(config["train_params"]["dataset"]["hash_table_size"]),
    should_learn_images=config["flags"]["should_learn_images"],
    should_log=config["log_flags"]["gngf_model"],
    device=device,
    **config["train_params"]["gngf_model"],
)
log((gngfModel), True)
log(("Number of parameters:", sum(p.numel() for p in gngfModel.parameters() if p.requires_grad)), True, color=bcolors.WARNING)

In [None]:
loss_fn = Loss(
    hash_table_size=eval(config["train_params"]["dataset"]["hash_table_size"]),
    should_use_all_levels=config["train_params"]["learnable_hash_model"]["should_use_all_levels"],
    should_log=config["log_flags"]["loss"],
    device=device,
    **config["train_params"]["loss"],
)   


models_parameters = {}
opts = []
if config["flags"]["should_learn_images"]:
    models_parameters["NeRF"] = {
        "each": [ # TODO change name
            {
                "param": gngfModel.hash_tables.parameters(),
                "lr": eval(config["train_params"]["gngf_model"]["features"]["lr"]),
                "weight_decay": eval(config["train_params"]["gngf_model"]["features"]["weight_decay"])
            },
            {
                "param": gngfModel.mlp.parameters(),
                "lr": eval(config["train_params"]["gngf_model"]["mlp"]["lr"]),
                "weight_decay": eval(config["train_params"]["gngf_model"]["mlp"]["weight_decay"])
            }
        ],
        "betas": config["train_params"]["gngf_model"]["betas"],
        "eps": eval(config["train_params"]["gngf_model"]["eps"])
    }
    opts.append(eval(config["train_params"]["gngf_model"]["optimizer"], {"torch": torch}))

if not config["flags"]["should_fast_hash"]:
    models_parameters["learnable_hash"] = {
        "each": [ # TODO change name
            {
                "param": gngfModel.learnable_hash_function_model.parameters(),
                "lr": eval(config["train_params"]["learnable_hash_model"]["lr"]),
                "weight_decay": eval(config["train_params"]["learnable_hash_model"]["weight_decay"])
            }
        ],
        "betas": config["train_params"]["learnable_hash_model"]["betas"],
        "eps": eval(config["train_params"]["learnable_hash_model"]["eps"])
    }
    opts.append(eval(config["train_params"]["learnable_hash_model"]["optimizer"], {"torch": torch}))

optimizers = get_optimizer(
    models_parameters=models_parameters,
    optimizers=opts
)

scheduler = None
if (
    "scheduler" in config["train_params"]["learnable_hash_model"].keys() 
    and config["train_params"]["learnable_hash_model"]["scheduler"]["name"] is not None
):
    scheduler = get_scheduler(
        scheduler=config["train_params"]["learnable_hash_model"]["scheduler"]["name"],
        optimizer=optimizers["learnable_hash"],
        step_size=config["train_params"]["learnable_hash_model"]["scheduler"]["step_size"],
        gamma_eta_min=eval(config["train_params"]["learnable_hash_model"]["scheduler"]["gamma_eta_min"]),
    )


early_stopper = EarlyStopper(
    tolerance=config["train_params"]["early_stopper"]["tolerance"],
    min_delta=eval(config["train_params"]["early_stopper"]["min_delta"]),
)


## Load checkpoints

In [None]:
config, gngfModel, optimizers, scheduler = load_checkpoint(
    config=config,
    model=gngfModel,
    optimizers=optimizers,
    scheduler=scheduler,
    weights_path=config["train_params"]["load_weights_path"],
    # new_epochs=config["train_params"]["epochs"],
    should_continue_training=config["flags"]["should_continue_training"],
    should_use_pretrained=config["flags"]["should_use_pretrained"],
    should_log=config["log_flags"]["load_checkpoint"],
)

## Wandb init

In [None]:
wandb_init(
    config
)

## Run

In [None]:
plt.ioff()
should_draw = False

try:
    start_epoch = config["train_params"]["start_epoch"]
except KeyError:
    start_epoch = 0

pbar = tqdm(range(start_epoch, config["train_params"]["epochs"]))
best_loss = np.inf
best_kl_div_loss = np.inf
best_psnr = 0.0

print_allocated_memory(config["log_flags"]["allocated_memory"])

for e in pbar:
    should_draw = (
        (e == config["train_params"]["epochs"] - 1) 
        or 
        (e % config["train_params"]["drawing_rate"] == 0) 
        or 
        early_stopper.early_stop
    ) if config["train_params"]["drawing_rate"] is not None else False
    
    if config["flags"]["should_fast_hash"] and (e > 0):
        should_draw = False
    
    results = run(
        data=images_dataset[-1],
        model=gngfModel,
        loss_fn=loss_fn,
        optimizers=optimizers,
        scheduler=scheduler,
        gradient_clip=(
            eval(config["train_params"]["learnable_hash_model"]["gradient_clip"])
            if not config["flags"]["should_fast_hash"]
            else None
        ),
        should_use_all_levels=config["train_params"]["learnable_hash_model"]["should_use_all_levels"],
        should_draw=should_draw,
        should_log=config["log_flags"],
        device=device,
    )

    loss, kl_div_loss, psnr = save_checkpoint(
        epoch=e,
        run_name=config["wandb"]["name"],
        model=gngfModel,
        config=config,
        best_loss=best_loss,
        best_kl_div_loss=best_kl_div_loss,
        best_psnr=best_psnr,
        loss=results[list(results.keys())[-1]]["loss"],
        kl_div_loss=np.mean(results[list(results.keys())[-1]]["kl_div_losses"]),
        psnr=np.mean(results[list(results.keys())[-1]]["images_psnr"]) if config["flags"]["should_learn_images"] else None,
        NeRF_optimizer=optimizers["NeRF"] if config["flags"]["should_learn_images"] else None,
        learnable_hash_optimizer=optimizers["learnable_hash"] if not config["flags"]["should_fast_hash"] else None,
        save_weights_rate=config["train_params"]["save_weights_rate"],
        save_weights_path=config["train_params"]["save_weights_path"],
        should_early_stop=early_stopper.early_stop,
        should_log=config["log_flags"]["save_checkpoint"],
    )
    best_loss = loss
    best_kl_div_loss = kl_div_loss
    best_psnr = psnr
    del loss, kl_div_loss, psnr

    wandb_log(
        e=e,
        batch_size=len(config["train_params"]["dataset"]["images_paths"]),
        results=results,
        lr=(
            scheduler.get_last_lr()[0] 
            if scheduler is not None
            else None
        ),
        should_log=config["log_flags"]["wandb"],
        should_wandb=config["flags"]["should_wandb"],
    )

    if (
        scheduler is not None
        and 
        (
            config["train_params"]["learnable_hash_model"]["scheduler"]["stop_epoch"] > e
        )
    ):
        scheduler.step()

    if early_stopper.early_stop:
        print("!!! Stopping at epoch:", e, "!!!")

        del results
        break

    early_stopper(results[list(results.keys())[-1]]["loss"])

    description = f"""
        Epoch {e} - {list(results.keys())[-1]} loss: {results[list(results.keys())[-1]]['loss']:5f},
        Collisions: {results[list(results.keys())[-1]]['collisions']},
        PSNR: {results[list(results.keys())[-1]]['images_psnr']}
    """.replace("\n", "").replace("    ", " ")
    pbar.set_description(description)

    print_allocated_memory(config["log_flags"]["allocated_memory"])
    del results

In [None]:
if config["flags"]["should_wandb"]:
    wandb.finish()

print_allocated_memory(config["log_flags"]["allocated_memory"])
torch.cuda.empty_cache()
print_allocated_memory(config["log_flags"]["allocated_memory"])

In [None]:
a = images_dataset[-1]["unique_grids_per_level"]
[a[l].shape for l in range(len(a))]