In [4]:
import drjit as dr
import mitsuba as mi

mi.set_variant("cuda_ad_rgb")

import os

import matplotlib.pyplot as plt
import numpy as np
import torch
from mitsuba.python.ad.integrators.common import ADIntegrator, mis_weight
from tqdm import tqdm

import tinycudann as tcnn 

In [5]:
# Harder scene
# To change the order so that the declaration of glass BSDF occurs before the reference of it, we need to use the `pop` function.
scene_dict = mi.cornell_box()
scene_dict.pop("small-box")
scene_dict.pop("large-box")
scene_dict["glass"] = {"type": "dielectric"}
scene_dict["ball"] = {
    "type": "sphere",
    "to_world": mi.ScalarTransform4f.scale([0.4, 0.4, 0.4]).translate([0.5, 0, 0.5]),
    "bsdf": {"type": "ref", "id": "glass"}
}

scene = mi.load_dict(scene_dict)

In [None]:
"""
In relation to Hierachical Light Sampling paper by AMD, light is going to be represented as pair of Gaussian + vMF
    * Isotropic Gaussian, approximates the light positions distirbution
        μ - mean (centre of Gaussian)
        σ2 - variance (spread of distribution)
    * vMF, approximates directional distirbution of radiant intensity (Normalized Gaussian)
        κ - sharpness
        ν - axis
        α - amplitude

So final struct that we have:
[
    vec3  mean
    float variance
    float sharpness
    vec3  axis
    vec3  amplitude
]

"""

class GaussianGrid(torch.nn.Module):
    def __init__(self, bb_min, bb_max, num_gaussian_in_mixture, num_param_per_gaussian, num_param_per_vmf):
        super().__init__()

        self.bb_min = bb_min
        self.bb_max = bb_max

        # tiny-cuda-nn config for hash grid
        config = {
            "encoding": {
                "otype": "HashGrid",
                "base_resolution": 16,
                "n_levels": num_gaussian_in_mixture,
                "n_features_per_level": num_param_per_gaussian,
                "log2_hashmap_size": 22,
            },
        }
        n_input_dims = 3
        self.gaussian_grid = tcnn.Encoding(n_input_dims, config["encoding"])

        config["encoding"]["n_features_per_level"] = num_param_per_vmf
        self.vmf_grid = tcnn.Encoding(n_input_dims, config["encoding"])
        

    def forward(self, si):
        # get the position of ray-scene intersection in scene bb
        X = ((si.p - self.bb_min) / (self.bb_max - self.bb_min)).torch()

        gaussians = self.gaussian_grid(X)
        # mean
        gaussians[:,:3] = gaussians[:,:3] * 0.5 + 0.5
        # variance
        gaussians[:, 3] = torch.relu(gaussians[:, 3]) 

        vmf = self.vmf_grid(X)
        # sharpness
        vmf[:, 0] = torch.exp(vmf[:, 0])
        # axis
        norm = torch.norm(vmf[:, 1:4])
        vmf[:, 1:4] = vmf[:, 1:4] / norm
        # amplitude
        vmf[:, 4:7] = torch.relu(vmf[:, 4:7])

        return gaussians, vmf    

        

def get_camera_first_bounce(scene):
    cam_origin = mi.Point3f(0, 1, 3)
    cam_dir = dr.normalize(mi.Vector3f(0, -0.5, -1))
    cam_width = 2.0
    cam_height = 2.0
    image_res = [256, 256]

    x, y = dr.meshgrid(
        dr.linspace(mi.Float, -cam_width / 2, cam_width / 2, image_res[0]),
        dr.linspace(mi.Float, -cam_height / 2, cam_height / 2, image_res[1]),
    )
    ray_origin_local = mi.Vector3f(x, y, 0)
    ray_origin = mi.Frame3f(cam_dir).to_world(ray_origin_local) + cam_origin
    ray = mi.Ray3f(o=ray_origin, d=cam_dir)
    si = scene.ray_intersect(ray)

    return si, image_res

field = GaussianGrid(scene.bbox().min, scene.bbox().max, 1, 4, 8).cuda()
si, image_res = get_camera_first_bounce(scene)
field(si)
    

In [3]:
# It is possible to just use render_rhs and RHSIntegrator from 
# https://github.com/krafton-ai/neural-radiosity-tutorial-mitsuba3/blob/main/neural_radiosity.ipynb

def first_non_specular_or_null_si(scene, si, sampler):
    """Find the first non-specular or null surface interaction.

    Args:
        scene (mi.Scene): Scene object.
        si (mi.SurfaceInteraction3f): Surface interaction.
        sampler (mi.Sampler): Sampler object.

    Returns:
        tuple: A tuple containing four values:
            - si (mi.SurfaceInteraction3f): First non-specular or null surface interaction.
            - β (mi.Spectrum): The product of the weights of all previous BSDFs.
            - null_face (bool): A boolean mask indicating whether the surface is a null face or not.
    """
    # Instead of `bsdf.flags()`, based on `bsdf_sample.sampled_type`.
    with dr.suspend_grad():
        bsdf_ctx = mi.BSDFContext()

        depth = mi.UInt32(0)
        β = mi.Spectrum(1)
        bsdf = si.bsdf()

        null_face = ~mi.has_flag(si.bsdf().flags(), mi.BSDFFlags.BackSide) & (
            si.wi.z < 0
        )
        active = si.is_valid() & ~null_face  # non-null surface
        active &= ~mi.has_flag(si.bsdf().flags(), mi.BSDFFlags.Smooth)  # Delta surface

        loop = mi.Loop(
            name="first_non_specular_or_null_si",
            state=lambda: (sampler, depth, β, active, null_face, si, bsdf),
        )
        max_depth = 6
        loop.set_max_iterations(max_depth)

        while loop(active):
            # loop invariant: si is located at non-null and Delta surface
            # if si is located at null or Smooth surface, end loop
            bsdf_sample, bsdf_weight = bsdf.sample(
                bsdf_ctx, si, sampler.next_1d(), sampler.next_2d(), active
            )
            ray = si.spawn_ray(si.to_world(bsdf_sample.wo))
            si = scene.ray_intersect(
                ray, ray_flags=mi.RayFlags.All, coherent=dr.eq(depth, 0)
            )
            bsdf = si.bsdf(ray)

            β *= bsdf_weight
            depth[si.is_valid()] += 1

            null_face &= ~mi.has_flag(bsdf.flags(), mi.BSDFFlags.BackSide) & (
                si.wi.z < 0
            )
            active &= si.is_valid() & ~null_face & (depth < max_depth)
            active &= ~mi.has_flag(bsdf.flags(), mi.BSDFFlags.Smooth)

    # return si at the first non-specular bounce or null face
    return si, β, null_face

def render_rhs(scene, model, si, sampler):
    with dr.suspend_grad():
        bsdf_ctx = mi.BSDFContext()

        depth = mi.UInt32(0)
        L = mi.Spectrum(0)
        β = mi.Spectrum(1)
        η = mi.Float(1)
        prev_si = dr.zeros(mi.SurfaceInteraction3f)
        prev_bsdf_pdf = mi.Float(1.0)
        prev_bsdf_delta = mi.Bool(True)

        bsdf = si.bsdf()
        Le = β * si.emitter(scene).eval(si)

        # emitter sampling
        active_next = si.is_valid()
        active_em = active_next & mi.has_flag(bsdf.flags(), mi.BSDFFlags.Smooth)

        ds, em_weight = scene.sample_emitter_direction(
            si, sampler.next_2d(), True, active_em
        )
        active_em &= dr.neq(ds.pdf, 0.0)

        wo = si.to_local(ds.d)
        bsdf_value_em, bsdf_pdf_em = bsdf.eval_pdf(bsdf_ctx, si, wo, active_em)
        mis_em = dr.select(ds.delta, 1, mis_weight(ds.pdf, bsdf_pdf_em))
        Lr_dir = β * mis_em * bsdf_value_em * em_weight

        # bsdf sampling
        bsdf_sample, bsdf_weight = bsdf.sample(
            bsdf_ctx, si, sampler.next_1d(), sampler.next_2d(), active_next
        )

        # update
        L = L + Le + Lr_dir
        ray = si.spawn_ray(si.to_world(bsdf_sample.wo))
        η *= bsdf_sample.eta
        β *= bsdf_weight

        prev_si = dr.detach(si, True)
        prev_bsdf_pdf = bsdf_sample.pdf
        prev_bsdf_delta = mi.has_flag(bsdf_sample.sampled_type, mi.BSDFFlags.Delta)

        si = scene.ray_intersect(ray, ray_flags=mi.RayFlags.All, coherent=True)

        ds = mi.DirectionSample3f(scene, si=si, ref=prev_si)

        mis = mis_weight(
            prev_bsdf_pdf,
            scene.pdf_emitter_direction(prev_si, ds, ~prev_bsdf_delta),
        )
        
        si, β2, null_face = first_non_specular_or_null_si(scene, si, sampler)
        β *= β2

        L += β * mis * si.emitter(scene).eval(si)

        out = model(si)
        active_nr = (
            si.is_valid()
            & ~null_face
            & dr.eq(si.emitter(scene).eval(si), mi.Spectrum(0))
        )

        Le = L
        w_nr = β * mis
        L = Le + dr.select(active_nr, w_nr * mi.Spectrum(out), 0)

    return L, Le, out, w_nr, active_nr

In [None]:
class RHSIntegrator(ADIntegrator):
    def __init__(self, model):
        super().__init__()
        self.model = model
    
    def sample(self, mode, scene, sampler, ray,depth, reparam, active, **kwargs,):
        self.model.eval() 
        with torch.no_grad():
            w, h = list(scene.sensors()[0].film().size())
            L = mi.Spectrum(0)

            ray = mi.Ray3f(dr.detach(ray))
            si = scene.ray_intersect(
                ray, ray_flags=mi.RayFlags.All, coherent=dr.eq(depth, 0)
            )

            # update si and bsdf with the first non-specular ones
            si, β, _ = first_non_specular_or_null_si(scene, si, sampler)
            L, _, _, _, _ = render_rhs(scene, self.model, si, sampler)

        self.model.train()
        torch.cuda.empty_cache()
        return β * L, si.is_valid(), None

rhs_integrator = RHSIntegrator(field)
rhs_image = mi.render(scene, spp=M, integrator=rhs_integrator)

fig, ax = plt.subplots()
fig.patch.set_visible(False)  # Hide the figure's background
ax.axis('off')  # Remove the axes from the image
fig.tight_layout()  # Remove any extra white spaces around the image
ax.imshow(np.clip(rhs_image ** (1.0 / 2.2), 0, 1))