In [None]:
import mitsuba as mi
mi.set_variant("cuda_ad_rgb")
import numpy as np

import wandb

import nbimporter
nbimporter.options['only_defs'] = False

import importlib

import testing_scene as ts
importlib.reload(ts)

import vapl_utils as utils
importlib.reload(utils)

import torch

In [None]:
run_name = input("Enter run name: ").strip()
if not run_name:
    run_name = "custom-loss-function-run"
    
num_epochs_input = input("Enter number of epoch (default 100): ").strip()
num_epochs = int(num_epochs_input) if num_epochs_input.isdigit() else 100

wandb.init(
    project="vapls-training", 
    name=run_name,
    config={ 
        "learning_rate": 0.001,
        "max_depth": 4,
        "spp": 1,
        "num_epochs": num_epochs,
        "scene": "Cornell Box",  
        "optimizer": "Adam",
    },
)

In [None]:
def weighted_loss(real, predicted, weight):
    eps = 0.01
    mse = (real - predicted) ** 2
    norm_factor = (weight * ((predicted ** 2).detach() + eps))
    return (mse / (norm_factor +  eps)).mean()

loss_function = utils.Loss(weighted_loss)
field = utils.vapl_grid(ts.scene.bbox().min, ts.scene.bbox().max, 1, 4, 8).cuda()
rhs_integrator = utils.RHSIntegrator(field, loss_function)

def should_render(epoch):
    if epoch < 50:
        return epoch % 5 == 0
    elif epoch < 500:
        return epoch % 20 == 0
    elif epoch < 2000:
        return epoch % 100 == 0
    else:
        return epoch % 250 == 0

for epoch in range(num_epochs):
    rhs_image = mi.render(ts.scene, spp=1, integrator=rhs_integrator)

    wandb.log({"loss": rhs_integrator.losses[-1].item(), "epoch": epoch})
    if (should_render(epoch)):    
        wandb.log({"vapl training": wandb.Image(np.clip(rhs_image ** (1.0 / 2.2), 0, 1))})

wandb.finish()