In [1]:
import drjit as dr
import mitsuba as mi
import matplotlib.pyplot as plt
import numpy as np

import os
import sys
ROOT_DIR = os.path.dirname(os.getcwd())
sys.path.append(ROOT_DIR)

from utils.utils import mse, image_to_bm

from utils.problem import MitsubaProblem
from utils.problems.bunny import BunniesProblem

mi.set_variant('cuda_ad_rgb')

In [2]:
single_bunny_pb = BunniesProblem(nb_bunnies=1, colored=False)

In [3]:
img_ref, bm_ref = single_bunny_pb.render(seed=0, spp=1024)
bm_ref

In [4]:
def get_forward_gradient(pb: MitsubaProblem, vector: np.ndarray, seed: int, spp: int, grad_flags = None):
    opt = mi.ad.Adam(lr=0.01)
    pb.set_params_from_vector(opt, vector)
    scene, params = pb.initialize_scene()
    pb.apply_transformations(params, opt)
    img = mi.render(scene, params, seed=seed, spp=spp)
    if grad_flags is None:
        dr.forward(opt["angle0"])
    else:
        dr.forward(opt["angle0"], grad_flags)
    grad_image = dr.grad(img)
    return grad_image

def get_backward_gradient(pb: MitsubaProblem, vector: np.ndarray, loss_fn, seed: int, spp: int, grad_flags = None):
    opt = mi.ad.Adam(lr=0.01)
    pb.set_params_from_vector(opt, vector)
    scene, params = pb.initialize_scene()
    pb.apply_transformations(params, opt)
    img = mi.render(scene, params, seed=seed, spp=spp)
    loss = loss_fn(img)
    if grad_flags is None:
        dr.backward(loss)
    else:
        dr.backward(loss, flags=grad_flags)
    return opt


In [5]:
forward_grad = get_forward_gradient(single_bunny_pb, np.array([0.0, 0.0, 0.1]), seed=0, spp=1024, grad_flags = dr.ADFlag.BackPropVar | dr.ADFlag.ClearVertices)
image_to_bm(forward_grad / (forward_grad.max_() - forward_grad.min_()))

In [6]:
vector = np.array([0.1, 0.1, 0.1])
loss_fn = lambda img: mse(img, bm_ref)
spp = 1024
key = "angle0"

grad = get_backward_gradient(single_bunny_pb, vector, loss_fn, seed=0, spp=spp, grad_flags = dr.ADFlag.BackPropGrad | dr.ADFlag.ClearVertices)
print(f"grad: {dr.grad(grad[key])}")

sq_grad_sum = get_backward_gradient(single_bunny_pb, vector, loss_fn, seed=0, spp=spp, grad_flags = dr.ADFlag.BackPropVar | dr.ADFlag.ClearVertices)
print(f"sq_grad_sum: {dr.grad(sq_grad_sum[key])}")

ones = get_backward_gradient(single_bunny_pb, vector, loss_fn, seed=0, spp=spp, grad_flags = dr.ADFlag.BackPropOnes | dr.ADFlag.ClearVertices)
print(f"ones: {dr.grad(ones[key])}")

grad: [-6.992592811584473]
sq_grad_sum: [481.3227233886719]


RuntimeError: drjit-autodiff: ad_traverse(): gradient propagation encountered variable a24977 ("fmadd") with an invalid gradient size (expected size 1, actual size 208349)!