In [None]:
import mitsuba as mi
mi.set_variant("cuda_ad_rgb")

import numpy as np
import matplotlib.pyplot as plt

import import_ipynb
import importlib

import testing_scene as ts
importlib.reload(ts)

import vapl_utils as utils
importlib.reload(utils)


In [None]:
def plot_vapl_means(scene, pos, variance, amplitude, h, w, ax):
    """Render the 'mean' positions of the Virtual Anisotropic Point Lights (VAPL) on the image.
    
    Args:
        scene (mi.Scene): Mitsuba scene.
        vapl_model (vapl_grid): The trained VAPL grid model.
        image (np.ndarray): The image to plot the points on.
        h (int): Height of the image in pixels.
        w (int): Width of the image in pixels.
        ax (matplotlib.axes.Axes): Matplotlib axis object for plotting.
    """
    p = pos.cpu().detach().numpy()
    variance = variance.cpu().detach().numpy().flatten()
    amplitude = amplitude.cpu().detach().numpy()
    # Convert the mean positions from world coordinates to NDC
    means_ndc = utils.world_to_ndc(scene, p)  # Convert from Torch tensor to numpy array
    
    # Convert NDC to pixel coordinates
    means_pix = utils.ndc_to_pixel(means_ndc, h, w)
    
    amplitude_norm = amplitude / amplitude.max() if amplitude.max() > 0 else amplitude
    colors = amplitude_norm

    point_sizes = 10* variance

    # Plot the means as points on the image
    ax.scatter(means_pix.x, means_pix.y, c=colors, cmap="coolwarm", marker='o', s=10.0)

In [None]:
def weighted_loss(real, predicted, weight):
    eps = 0.01
    mse = (real - predicted) ** 2
    norm_factor = (weight * ((predicted ** 2).detach() + eps))
    return (mse / (norm_factor +  eps)).mean()

import torch
loss_function = utils.Loss(weighted_loss)
field = utils.vapl_grid(ts.scene.bbox().min, ts.scene.bbox().max, 1, 4, 8).cuda()
rhs_integrator = utils.RHSIntegrator(field, loss_function)

def should_render(epoch):
    if epoch < 50:
        return epoch % 5 == 0
    elif epoch < 500:
        return epoch % 20 == 0
    elif epoch < 2000:
        return epoch % 100 == 0
    else:
        return epoch % 250 == 0

for epoch in range(3001):
    rhs_integrator.epoch = epoch

    rhs_image = mi.render(ts.scene, spp=1, integrator=rhs_integrator)

    if (should_render(epoch)):
        gaussians, vmfs = utils.get_all_gaussians(field)
        mean = gaussians[:, :3]
        variance = gaussians[:, 3]
        amplitude = vmfs[:, 4:7]
        
        h, w = rhs_image.shape[0], rhs_image.shape[1] 
        
        fig, ax = plt.subplots(1, 2, figsize=(12, 6))
        
        ax[0].imshow(np.clip(rhs_image ** (1.0 / 2.2), 0, 1))  # Корректируем изображение
        ax[0].axis("off")
        ax[0].set_title(f"RHS Image - epoch:{epoch}")
        
        plot_vapl_means(ts.scene, mean, variance, amplitude, h, w, ax[1])
        ax[1].imshow(np.clip(rhs_image ** (1.0 / 2.2), 0, 1))  # Наложение изображения
        ax[1].axis("off")
        ax[1].set_title(f"RHS + VAPL - epoch:{epoch}")
        
        plt.show()
