In [6]:
import drjit as dr
import mitsuba as mi

mi.set_variant("cuda_ad_rgb")

import sys

import matplotlib.pyplot as plt
import numpy as np
import torch
from mitsuba.python.ad.integrators.common import ADIntegrator, mis_weight
from tqdm import tqdm
import math

import tinycudann as tcnn 

In [None]:
# Harder scene
# To change the order so that the declaration of glass BSDF occurs before the reference of it, we need to use the `pop` function.
scene_dict = mi.cornell_box()
scene_dict.pop("small-box")
scene_dict.pop("large-box")
scene_dict["glass"] = {"type": "dielectric"}
scene_dict["ball"] = {
    "type": "sphere",
    "to_world": mi.ScalarTransform4f.scale([0.4, 0.4, 0.4]).translate([0.5, 0, 0.5]),
    "bsdf": {"type": "ref", "id": "glass"}
}

scene = mi.load_dict(scene_dict)

film = scene.sensors()[0].film()
print(f"Image resolution: {film.size()}")

In [8]:
"""
In relation to Hierachical Light Sampling paper by AMD, light is going to be represented as pair of Gaussian + vMF
    * Isotropic Gaussian, approximates the light positions distirbution
        μ - mean (centre of Gaussian)
        σ2 - variance (spread of distribution)
    * vMF, approximates directional distirbution of radiant intensity (Normalized Gaussian)
        κ - sharpness
        ν - axis
        α - amplitude

So final struct that we have:
[
    vec3  mean
    float variance
    float sharpness
    vec3  axis
    vec3  amplitude
]

Going to call Gaussian vMF pair - Virtual Anisotropic Point Light - VAPL

"""


class vapl_grid(torch.nn.Module):
    def __init__(self, bb_min, bb_max, num_gaussian_in_mixture, num_param_per_gaussian, num_param_per_vmf):
        super().__init__()

        self.bb_min = bb_min
        self.bb_max = bb_max

        # tiny-cuda-nn config for hash grid
        config = {
            "encoding": {
                "otype": "HashGrid",
                "base_resolution": 16,
                "n_levels": num_gaussian_in_mixture,
                "n_features_per_level": num_param_per_gaussian,
                "log2_hashmap_size": 22,
            },
        }
        n_input_dims = 3
        self.gaussian_grid = tcnn.Encoding(n_input_dims, config["encoding"])

        config["encoding"]["n_features_per_level"] = num_param_per_vmf
        self.vmf_grid = tcnn.Encoding(n_input_dims, config["encoding"])
        

    def forward(self, si):
        # get the position of ray-scene intersection in scene bb
        print(dr.shape(si.p))
        print(f"x: {si.p.x}, y: {si.p.y}, z: {si.p.z}")
        X = ((si.p - self.bb_min) / (self.bb_max - self.bb_min)).torch()
        
        gaussians = self.gaussian_grid(X)
        # mean
        eps = 1e-2
        gaussians[:,:3] = gaussians[:,:3] / eps * 0.5 + 0.5
        # variance
        gaussians[:, 3] = torch.relu(gaussians[:, 3]) 

        vmf = self.vmf_grid(X)
        # sharpness
        vmf[:, 0] = torch.exp(vmf[:, 0])
        # axis
        norm = torch.norm(vmf[:, 1:4])
        vmf[:, 1:4] = vmf[:, 1:4] / norm
        # amplitude
        vmf[:, 4:7] = torch.relu(vmf[:, 4:7])

        return gaussians, vmf    

def get_camera_first_bounce(scene):
    cam_origin = mi.Point3f(0, 1, 3)
    cam_dir = dr.normalize(mi.Vector3f(0, -0.5, -1))
    cam_width = 2.0
    cam_height = 2.0
    image_res = [4, 4]

    x, y = dr.meshgrid(
        dr.linspace(mi.Float, -cam_width / 2, cam_width / 2, image_res[0]),
        dr.linspace(mi.Float, -cam_height / 2, cam_height / 2, image_res[1]),
    )
    ray_origin_local = mi.Vector3f(x, y, 0)
    ray_origin = mi.Frame3f(cam_dir).to_world(ray_origin_local) + cam_origin
    ray = mi.Ray3f(o=ray_origin, d=cam_dir)
    si = scene.ray_intersect(ray)

    return si, image_res
    

In [9]:
class sg_lobe:
    def __init__(self, axis, sharpness, amplitude):
        self.axis = axis
        self.sharpness = sharpness
        self.log_amplitude = amplitude

    
# product of 2 spherical gaussians
def sg_product(axis1, sharpness1, axis2, sharpness2):
    axis = axis1 * sharpness1 + axis2 * sharpness2
    sharpess = torch.norm(axis)

    d = axis1 - axis2
    len2 = torch.dot(d, d)
    log_amplitude = -sharpness1 * sharpness2 * len2 / max(sharpess + sharpness1 + sharpness2, sys.float_info.max)

    return sg_lobe(axis / max(sharpess, sys.float_info.min), sharpess, log_amplitude)

# [Tokuyoshi et al. 2024 "Hierarchical Light Sampling with Accurate Spherical Gaussian Lighting (Supplementary Document)" Listing. 5]
def upper_sg_clamp_cosine_integral_over_two_pi(sharpness):
	if (sharpness <= 0.5):
		# Taylor-series approximation for the numerical stability.
		return (((((((-1.0 / 362880.0) * sharpness + 1.0 / 40320.0) * sharpness - 1.0 / 5040.0) * sharpness + 1.0 / 720.0) * sharpness - 1.0 / 120.0) * sharpness + 1.0 / 24.0) * sharpness - 1.0 / 6.0) * sharpness + 0.5

	return (torch.expm1(-sharpness) + sharpness) / (sharpness * sharpness)

# [Tokuyoshi et al. 2024 "Hierarchical Light Sampling with Accurate Spherical Gaussian Lighting (Supplementary Document)" Listing. 6]
def lower_sg_clamp_cosine_integral_over_two_pi(sharpness):
    e = torch.exp(-sharpness)
    if (sharpness <= 0.5):
        # Taylor-series approximation for the numerical stability.
        return e * (((((((((1.0 / 403200.0) * sharpness - 1.0 / 45360.0) * sharpness + 1.0 / 5760.0) * sharpness - 1.0 / 840.0) * sharpness + 1.0 / 144.0) * sharpness - 1.0 / 30.0) * sharpness + 1.0 / 8.0) * sharpness - 1.0 / 3.0) * sharpness + 0.5);

    return e * (-torch.expm1(-sharpness) - sharpness * e) / (sharpness * sharpness)

# Approximate product integral of an SG and clamped cosine / pi.
# [Tokuyoshi et al. 2024 "Hierarchical Light Sampling with Accurate Spherical Gaussian Lighting (Supplementary Document)" Listing. 7]
def sg_clamp_cosine_product_integral_over_pi(cosine, sharpness):
    A = 2.7360831611272558028247203765204
    B = 17.02129778174187535455530451145
    C = 4.0100826728510421403939290030394
    D = 15.219156263147210594866010069381
    E = 76.087896272360737270901154261082

    t = sharpness * np.sqrt(0.5 * ((sharpness * A) * sharpness + B) / (((sharpness + C) * sharpness + D) * sharpness + E))
    tz = t * cosine

    INV_SQRTPI = 0.56418958354775628694807945156077  
    CLAMPING_THRESHOLD = 0.5 * np.finfo(np.float32).eps
    lerp_factor = torch.clamp(
        torch.max(
            0.5 * (cosine * torch.erfc(-tz) + torch.erfc(t)) - 0.5 * INV_SQRTPI * torch.exp(-tz * tz) * torch.expm1(t * t * (cosine * cosine - 1.0)) / t, CLAMPING_THRESHOLD),
            0.0, 1.0
    )

    lower_integral = lower_sg_clamp_cosine_integral_over_two_pi(sharpness)
    upper_integral = upper_sg_clamp_cosine_integral_over_two_pi(sharpness)

    return 2.0 * torch.lerp(lower_integral, upper_integral, lerp_factor)

def sggx(m, roughness_mat):
    det = torch.det(roughness_mat).clamp(min=1e-7)  # FLT_MIN аналогично 1e-7
    roughnessMatAdj = torch.tensor([
        [roughness_mat[1, 1], -roughness_mat[0, 1]],
        [-roughness_mat[1, 0], roughness_mat[0, 0]]
    ], device=m.device, dtype=m.dtype)

    length2 = (m[:2] @ roughnessMatAdj @ m[:2]) / det + m[2] ** 2

    return 1.0 / (math.pi * torch.sqrt(det) * (length2 ** 2))


# Approximate the reflection lobe with an SG lobe for microfacet BRDFs.
# [Wang et al. 2009 "All-Frequency Rendering with Dynamic, Spatially-Varying Reflectance"]
def sgg_reflection_pdf(wi, m, roughness_mat):
    return sggx(m, roughness_mat) / (4.0 * torch.sqrt(torch.dot()))

# Approximate hemispherical integral for a vMF distribution (i.e. normalized SG).
# The parameter "cosine" is the cosine of the angle between the SG axis and the pole axis of the hemisphere.
# [Tokuyoshi et al. 2024 "Hierarchical Light Sampling with Accurate Spherical Gaussian Lighting (Supplementary Document)" Listing. 4]
def vmf_hemispherical_integral(cosine, sharpness):
    # interpolation factor [Tokuyoshi 2022].
    A = 0.6517328826907056171791055021459
    B = 1.3418280033141287699294252888649
    C = 7.2216687798956709087860872386955
    steepness = sharpness * torch.sqrt((0.5 * sharpness + A) / ((sharpness + B) * sharpness + C))
    lerp_factor = torch.clamp(0.5 + 0.5 * (torch.erf(steepness * torch.clamp(cosine, -1.0, 1.0)) / torch.erf(steepness)), 0, 1)

    # Interpolation between upper and lower hemispherical integrals
    e = torch.exp(-sharpness)
    return torch.lerp(e, 1.0, lerp_factor) / (e + 1.0)

def luminance(color):
    r, g, b = color.x, color.y, color.z 
    return 0.2126 * r + 0.7152 * g + 0.0722 * b

def compute_jacobian(wi):
    vlen = torch.linalg.norm(wi[:2])
    v = wi[:2] / vlen if vlen != 0.0 else torch.tensor([1.0, 0.0], device=wi.device)

    rot_mat = torch.tensor([[v[0], -v[1]], [v[1], v[0]]], device=wi.device)
    scale_mat = torch.tensor([[0.5, 0.0], [0.0, 0.5 / wi[2]]], device=wi.device)

    jacobian_mat = rot_mat @ scale_mat
    jj_mat = jacobian_mat @ jacobian_mat.T

    return jj_mat

def isotropic_ndf_filtering(si):
    SIGMA2 = 0.15915494 # Variance of pixel filter kernel (1/(2pi))
    KAPPA = 0.18 # User-specified clamping threshold
    dndu = si.dn_du
    dndv = si.dn_dv
    roughness = si.bsdf().roughness

    kernel_roughness2 = SIGMA2 * (torch.dot(dndu, dndu) + torch.dot(dndv, dndv)) # Eq. 14 in the paper
    clamped_kernel_roughness2 = torch.clamp(kernel_roughness2, max=KAPPA)
    filtered_roughness2 = torch.clamp(roughness**2 + clamped_kernel_roughness2, min=0.0, max=1.0)

    return torch.sqrt(filtered_roughness2)

def compute_filtered_roughness_mat(filtered_proj_roughness_mat, tr, det):
    FLT_MAX = torch.finfo(torch.float32).max 
    
    denom = 1.0 + tr + det
    is_finite = torch.isfinite(denom) 
    
    mat1 = (filtered_proj_roughness_mat + torch.tensor([[det, 0.0], [0.0, det]], device=filtered_proj_roughness_mat.device))
    mat1 = torch.clamp(mat1, max=FLT_MAX) / denom

    mat2 = torch.tensor([
        [torch.clamp(filtered_proj_roughness_mat[0, 0], max=FLT_MAX) / torch.clamp(filtered_proj_roughness_mat[0, 0] + 1.0, max=FLT_MAX), 0.0],
        [0.0, torch.clamp(filtered_proj_roughness_mat[1, 1], max=FLT_MAX) / torch.clamp(filtered_proj_roughness_mat[1, 1] + 1.0, max=FLT_MAX)]
    ], device=filtered_proj_roughness_mat.device)

    return torch.where(is_finite, mat1, mat2)

# (exp(x) - 1)/x with cancellation of rounding errors.
# [Nicholas J. Higham "Accuracy and Stability of Numerical Algorithms", Section 1.14.1, p. 19]
def expm1_over_x(x):
    u = math.exp(x)
    if (u == 1.0):
        return 1.0
    y = u - 1.0
    if (math.abs(x) < 1.0):
        return y / math.log(u)
    
    return y / x
          
def sg_integral(sharpness):
    return 4.0 * torch.pi * expm1_over_x(-2.0 * sharpness)

In [10]:
class vapl:
    def __init__(self, gaussian, vmf):
        self.mean      = gaussian[:3]
        self.variance  = gaussian[3]
        self.sharpness = vmf[0]
        self.axis      = vmf[1:4]
        self.amplitude = vmf[4:7]

    def __repr__(self):
        return (f"vapl(\n"
                f"  mean={self.mean},\n"
                f"  variance={self.variance},\n"
                f"  sharpness={self.sharpness},\n"
                f"  axis={self.axis},\n"
                f"  amplitude={self.amplitude}\n"
                f")")
    
    def convolve_with_bsdf(self, si):
        SGLIGHT_SHARPNESS_MAX = float.fromhex("0x1.0p41")

        position  = si.p
        normal    = si.n
        tangent   = si.dp_du
        bitangent = si.dp_dv

        tangent_frame = mi.Frame3f(tangent, bitangent, normal)

        light_vec = self.mean - position
        squared_distance = torch.dot(self.mean, position)
        light_dir = light_vec * torch.rsqrt(squared_distance)

        # clamp variance for the numerical stability
        variance = torch.max(self.variance, squared_distance / SGLIGHT_SHARPNESS_MAX)

        # compute the maximum emissive radiance of the vapl.
        emissive = self.amplitude / variance

        # compute vapl sharpness for a light distribution viewed from the shading point.
        light_sharpness = squared_distance / variance

        # light lobe given by the product of the light distribution viewed 
        # from the shading point and the directional distribution of the vapl.
        light_lobe = sg_product(self.axis, self.sharpness, light_dir, light_sharpness)

        # bsdf of current intersection
        bsdf = si.bsdf()

        # NOT SURE IN THIS ONE

        ctx_diffuse = mi.BSDFContext(flags=mi.BSDFFlags.DiffuseReflection)
        ctx_specular = mi.BSDFContext(flags=mi.BSDFFlags.GlossyReflection)
    
        ray = mi.core.Ray3f(si.p, si.to_world(si.wi))
        wo = si.to_local(-ray.d)

        diffuse = bsdf.eval(ctx_diffuse, si, wo)
        specular = bsdf.eval(ctx_specular, si, wo)

        # Diffuse SG lighting.
		# [Tokuyoshi et al. 2024 "Hierarchical Light Sampling with Accurate Spherical Gaussian Lighting", Section 4]
        amplitude = torch.exp(light_lobe.log_amplitude)
        cosine = torch.clamp(torch.dot(light_lobe.axis, si.n), -1.0, 1.0)
        diffuse_illumination = amplitude * sg_clamp_cosine_product_integral_over_pi(cosine, light_lobe.sharpness)

        # Compute JJ^T for NDF filtering.
        wi = tangent_frame.to_local(si.wi)
        jj_mat = compute_jacobian(wi)

        # Compute determinant of JJ^T 
        det_jj4 = 1.0 / (4.0 * wi.z * wi.z)

        roughness = isotropic_ndf_filtering(si)
        roughness2 = roughness**2
        proj_roughness2 = roughness2 / max(1.0 - roughness2, sys.float_info.min)
        roughness_max2 = max(roughness2[0], roughness2[1])
        reflect_sharpness = (1.0 - roughness_max2) / max(2.0 * roughness_max2, sys.float_info.min)
        reflect_vec = mi.reflect(si.wi, normal) * reflect_sharpness

        # Glossy SG lighting.
		# [Tokuyoshi et al. 2024 "Hierarchical Light Sampling with Accurate Spherical Gaussian Lighting", Section 5]
        prod_vec = reflect_vec + light_lobe.axis * light_lobe.sharpness
        prod_sharpness = torch.linalg.norm(prod_vec)
        prod_dir = prod_vec / prod_sharpness
        light_lobe_variance = 1.0 / light_lobe.sharpness
        filtered_proj_roughness_mat = torch.tensor([
            [proj_roughness2[0], 0.0],
            [0.0, proj_roughness2[1]]], 
            device=proj_roughness2.device) + 2.0 * light_lobe_variance * jj_mat

        # Compute the determinant of filteredProjRoughnessMat in a numerically stable manner.
		# See the supplementary document (Section 5.2) of the paper for the derivation.
        det = proj_roughness2[0] * proj_roughness2[1] + 2.0 * light_lobe_variance * (proj_roughness2[0] * jj_mat[0, 0] + proj_roughness2[1] * jj_mat[1, 1]) + light_lobe_variance * light_lobe_variance * det_jj4

        # NDF filtering in a numerically stable manner
        # See the supplementary document (Section 5.2) of the paper for the derivation
        tr = filtered_proj_roughness_mat[0, 0] + filtered_proj_roughness_mat[1, 1]
        filtered_roughness_mat = compute_filtered_roughness_mat(filtered_proj_roughness_mat, tr, det)
        lobe = sgg_reflection_pdf(wi, half_vec, filtered_roughness_mat)

        # visibility of the SG light in the upper hemisphere.
        visibility = vmf_hemispherical_integral(torch.dot(prod_dir, normal), prod_sharpness)

        # evaluate the filtered reflection lobe
        half_vec_unnormalize = wi + tangent_frame.to_local(light_lobe.axis)
        half_vec = half_vec_unnormalize / torch.maximum(torch.norm(half_vec_unnormalize), torch.tensor(torch.finfo(torch.float32).eps))

        specular_illumination = amplitude * visibility * lobe * sg_integral(light_lobe.sharpness)

        # Is that correct convolution?
        result = emissive * (diffuse * diffuse_illumination + specular * specular_illumination)
        return luminance(result)


class vapl_mixture:
    def __init__(self, gaussians, vmfs):
        self.mixture = []
        for gaussian, vmf in zip(gaussians, vmfs):
            self.mixture.append(vapl(gaussian, vmf))
        
        self.normalized_vapl_weights = [len(self.mixture)]

    def __repr__(self):
        mixture_str = "\n".join([repr(vapl) for vapl in self.mixture])
        return (f"vapl_mixture(\n"
                f"  Number of vapls: {len(self.mixture)},\n"
                f"  vapls=\n{mixture_str}\n"
                f")")
    
    def get_vapl(self, index):
        return self.mixture[index]
        
    def get_normalized_vapl_weights(self, si):
        total_weight = 0.0
        size = len(self.normalized_vapl_weights)
        for i in range(size):
            weight = self.mixture[i].convolve_with_bsdf(si)
            self.normalized_vapl_weights[i] = weight
            total_weight += weight

        for i in range(size):
            self.normalized_vapl_weights[i] /= total_weight

        return self.normalized_vapl_weights

In [None]:
# test that vapl grid and mixture works
field = vapl_grid(scene.bbox().min, scene.bbox().max, 1, 4, 8).cuda()
si, image_res = get_camera_first_bounce(scene)

gaussians, vmfs = field(si)
mixture = vapl_mixture(gaussians, vmfs)
print(mixture)

In [12]:
# It is possible to just use render_rhs and RHSIntegrator from 
# https://github.com/krafton-ai/neural-radiosity-tutorial-mitsuba3/blob/main/neural_radiosity.ipynb

def first_non_specular_or_null_si(scene, si, sampler):
    """Find the first non-specular or null surface interaction.

    Args:
        scene (mi.Scene): Scene object.
        si (mi.SurfaceInteraction3f): Surface interaction.
        sampler (mi.Sampler): Sampler object.

    Returns:
        tuple: A tuple containing four values:
            - si (mi.SurfaceInteraction3f): First non-specular or null surface interaction.
            - β (mi.Spectrum): The product of the weights of all previous BSDFs.
            - null_face (bool): A boolean mask indicating whether the surface is a null face or not.
    """
    # Instead of `bsdf.flags()`, based on `bsdf_sample.sampled_type`.
    with dr.suspend_grad():
        bsdf_ctx = mi.BSDFContext()

        depth = mi.UInt32(0)
        β = mi.Spectrum(1)
        bsdf = si.bsdf()

        null_face = ~mi.has_flag(si.bsdf().flags(), mi.BSDFFlags.BackSide) & (
            si.wi.z < 0
        )
        active = si.is_valid() & ~null_face  # non-null surface
        active &= ~mi.has_flag(si.bsdf().flags(), mi.BSDFFlags.Smooth)  # Delta surface

        loop = mi.Loop(
            name="first_non_specular_or_null_si",
            state=lambda: (sampler, depth, β, active, null_face, si, bsdf),
        )
        max_depth = 6
        loop.set_max_iterations(max_depth)

        while loop(active):
            # loop invariant: si is located at non-null and Delta surface
            # if si is located at null or Smooth surface, end loop
            bsdf_sample, bsdf_weight = bsdf.sample(
                bsdf_ctx, si, sampler.next_1d(), sampler.next_2d(), active
            )
            ray = si.spawn_ray(si.to_world(bsdf_sample.wo))
            si = scene.ray_intersect(
                ray, ray_flags=mi.RayFlags.All, coherent=dr.eq(depth, 0)
            )
            bsdf = si.bsdf(ray)

            β *= bsdf_weight
            depth[si.is_valid()] += 1

            null_face &= ~mi.has_flag(bsdf.flags(), mi.BSDFFlags.BackSide) & (
                si.wi.z < 0
            )
            active &= si.is_valid() & ~null_face & (depth < max_depth)
            active &= ~mi.has_flag(bsdf.flags(), mi.BSDFFlags.Smooth)

    # return si at the first non-specular bounce or null face
    return si, β, null_face


def render_rhs(scene, model, si, sampler):
    with dr.suspend_grad():
        bsdf_ctx = mi.BSDFContext()

        depth = mi.UInt32(0)
        L = mi.Spectrum(0) # base light init as 0
        β = mi.Spectrum(1) # base light init as 1
        η = mi.Float(1)

        prev_si = dr.zeros(mi.SurfaceInteraction3f)
        prev_bsdf_pdf = mi.Float(1.0)
        prev_bsdf_delta = mi.Bool(True)

    
        # get the vapl mixture for this intersection
        gaussians, vmfs = model(si)
        print(gaussians.shape)
        print(vmfs.shape)
        mixture = vapl_mixture(gaussians, vmfs)

        

        weights = mixture.get_normalized_vapl_weights(si)
        
        # TODO: calculate new sampled dir


        # All the stuff from original render_rhs function
        Le = β * si.emitter(scene).eval(si)

        # emitter sampling
        active_next = si.is_valid()
        active_em = active_next & mi.has_flag(bsdf.flags(), mi.BSDFFlags.Smooth)

        ds, em_weight = scene.sample_emitter_direction(
            si, sampler.next_2d(), True, active_em
        )
        active_em &= dr.neq(ds.pdf, 0.0)

        wo = si.to_local(ds.d)
        bsdf_value_em, bsdf_pdf_em = bsdf.eval_pdf(bsdf_ctx, si, wo, active_em)
        mis_em = dr.select(ds.delta, 1, mis_weight(ds.pdf, bsdf_pdf_em))
        Lr_dir = β * mis_em * bsdf_value_em * em_weight

        # bsdf sampling
        bsdf_sample, bsdf_weight = bsdf.sample(
            bsdf_ctx, si, sampler.next_1d(), sampler.next_2d(), active_next
        )

        # update
        L = L + Le + Lr_dir
        ray = si.spawn_ray(si.to_world(bsdf_sample.wo))
        η *= bsdf_sample.eta
        β *= bsdf_weight

        prev_si = dr.detach(si, True)
        prev_bsdf_pdf = bsdf_sample.pdf
        prev_bsdf_delta = mi.has_flag(bsdf_sample.sampled_type, mi.BSDFFlags.Delta)

        si = scene.ray_intersect(ray, ray_flags=mi.RayFlags.All, coherent=True)

        ds = mi.DirectionSample3f(scene, si=si, ref=prev_si)

        mis = mis_weight(
            prev_bsdf_pdf,
            scene.pdf_emitter_direction(prev_si, ds, ~prev_bsdf_delta),
        )
        
        si, β2, null_face = first_non_specular_or_null_si(scene, si, sampler)
        β *= β2

        L += β * mis * si.emitter(scene).eval(si)

        out = model(si)
        active_nr = (
            si.is_valid()
            & ~null_face
            & dr.eq(si.emitter(scene).eval(si), mi.Spectrum(0))
        )

        Le = L
        w_nr = β * mis
        L = Le + dr.select(active_nr, w_nr * mi.Spectrum(out), 0)

    return L, Le, out, w_nr, active_nr

In [None]:
class RHSIntegrator(ADIntegrator):
    def __init__(self, model):
        super().__init__()
        self.model = model
    
    def sample(self, mode, scene, sampler, ray,depth, reparam, active, **kwargs,):
        self.model.eval() 
        with torch.no_grad():
            w, h = list(scene.sensors()[0].film().size())
            L = mi.Spectrum(0)

            ray = mi.Ray3f(dr.detach(ray))
            si = scene.ray_intersect(
                ray, ray_flags=mi.RayFlags.All, coherent=dr.eq(depth, 0)
            )

            # update si and bsdf with the first non-specular ones
            si, β, _ = first_non_specular_or_null_si(scene, si, sampler)
            L, _, _, _, _ = render_rhs(scene, self.model, si, sampler)

        self.model.train()
        torch.cuda.empty_cache()
        return β * L, si.is_valid(), None

rhs_integrator = RHSIntegrator(field)
rhs_image = mi.render(scene, spp=1, integrator=rhs_integrator)

fig, ax = plt.subplots()
fig.patch.set_visible(False)  # Hide the figure's background
ax.axis('off')  # Remove the axes from the image
fig.tight_layout()  # Remove any extra white spaces around the image
ax.imshow(np.clip(rhs_image ** (1.0 / 2.2), 0, 1))