In [1]:
import torch
from rendering import rendering

In [5]:
def compute_fisherrf_uncertainty(nerf_model, grid_points, rendering_fn, tn, tf, loss_fn=torch.nn.MSELoss(), device='cuda'):
    """
    Compute FisherRF-based point-wise uncertainty for a grid of points using a trained NeRF model.

    Args:
        nerf_model (nn.Module): The trained NeRF model instance.
        grid_points (torch.Tensor): Tensor of 3D points to query [N^3, 3].
        rendering_fn (callable): The rendering function to generate predictions.
        tn (float): Near plane distance.
        tf (float): Far plane distance.
        loss_fn (callable): The loss function used during training (default: MSE).
        device (str): Device for computation ('cuda' or 'cpu').

    Returns:
        torch.Tensor: Uncertainty values for each point in the grid [N^3].
    """
    nerf_model.eval()  # Set the model to evaluation mode
    grid_points.requires_grad = True  # Enable gradient computation for the grid points

    # Generate dummy ray origins and directions
    rays_o = torch.zeros_like(grid_points, device=device)  # [N^3, 3], assuming origin at (0, 0, 0)
    rays_d = torch.nn.functional.normalize(grid_points - rays_o, dim=-1)  # Normalized ray directions [N^3, 3]

    # Use the rendering function to compute predicted colors
    predicted_colors = rendering_fn(nerf_model, rays_o, rays_d, tn, tf, device=device)  # [N^3, 3]

    # Dummy ground truth for the loss (assume black background)
    ground_truth_colors = torch.zeros_like(predicted_colors)

    # Compute the loss for each point
    # Compute element-wise squared error
    losses = ((predicted_colors - ground_truth_colors) ** 2).mean(dim=1)  # Per-point loss [N^3]


    uncertainties = []

    for i in range(grid_points.size(0)):
        # Compute the gradient of the loss w.r.t. the model parameters for each point
        loss = losses[i]  # Individual loss for a single point
        grads = torch.autograd.grad(
            outputs=loss,
            inputs=nerf_model.parameters(),
            create_graph=True,
            retain_graph=True,
            allow_unused=True
        )

        # Compute Fisher Information (per point)
        fisher_info = 0
        for grad in grads:
            if grad is not None:
                fisher_info += torch.sum(grad ** 2)  # Sum of squared gradients

        uncertainties.append(fisher_info.item())  # Store uncertainty for this point

    uncertainties = torch.tensor(uncertainties, device=device)  # [N^3]
    return uncertainties


In [6]:
# Load the trained NeRF model
device = 'cuda'
pth_file = 'experiments/suzanne/set100/models/M0.pth'
nerf_model = torch.load(pth_file).to(device)

# Define the grid points
N = 10
scale = 1.5
x = torch.linspace(-scale, scale, N, device=device)
y = torch.linspace(-scale, scale, N, device=device)
z = torch.linspace(-scale, scale, N, device=device)
x, y, z = torch.meshgrid((x, y, z))
grid_points = torch.cat((x.reshape(-1, 1), y.reshape(-1, 1), z.reshape(-1, 1)), dim=1).to(device)  # [N^3, 3]

  nerf_model = torch.load(pth_file).to(device)


In [7]:
def rendering_fn(model, rays_o, rays_d, tn, tf, device='cuda'):
    t = torch.linspace(tn, tf, 100).to(device)
    delta = torch.cat((t[1:] - t[:-1], torch.tensor([1e10], device=device)))
    x = rays_o.unsqueeze(1) + t.unsqueeze(0).unsqueeze(-1) * rays_d.unsqueeze(1)  # [nb_rays, nb_bins, 3]
    colors, density = model.intersect(x.reshape(-1, 3), rays_d.expand(x.shape[1], x.shape[0], 3).transpose(0, 1).reshape(-1, 3))
    colors = colors.reshape((x.shape[0], 100, 3))  # [nb_rays, nb_bins, 3]
    density = density.reshape((x.shape[0], 100))
    alpha = 1 - torch.exp(-density * delta.unsqueeze(0))  # [nb_rays, nb_bins]
    weights = torch.cumprod(1 - alpha + 1e-10, dim=1) * alpha  # [nb_rays, nb_bins]
    return (weights.unsqueeze(-1) * colors).sum(1)  # Rendered color
# Compute uncertainties
uncertainties = compute_fisherrf_uncertainty(
    nerf_model,
    grid_points,
    rendering_fn,
    tn=2.0,  # Near plane
    tf=6.0,  # Far plane
    device=device
)

print("Point-wise uncertainties:", uncertainties)

Point-wise uncertainties: tensor([3.1778e-02, 3.4055e-02, 6.4941e-02, 3.5116e+00, 3.6575e+00, 8.0486e-01,
        3.3760e-01, 1.4387e-01, 4.1386e-02, 2.5016e-02, 1.9511e-02, 4.1798e-02,
        1.2863e-01, 5.9680e+00, 4.4475e+00, 6.2833e-01, 2.5494e-01, 1.1609e-01,
        3.8760e-02, 2.3718e-02, 1.2359e-02, 1.8760e-02, 4.6280e-01, 2.0988e+00,
        5.1935e-01, 4.8900e-01, 2.4326e-01, 9.0937e-02, 4.2602e-02, 2.7437e-02,
        8.9870e-03, 2.5878e-02, 1.2493e-01, 1.9307e-01, 5.0776e-01, 5.0966e-01,
        2.7463e-01, 9.0608e-02, 5.1866e-02, 3.1535e-02, 1.4036e-02, 1.7021e-01,
        1.2156e+00, 2.2033e+00, 5.7687e-01, 4.3933e-01, 2.1984e-01, 9.7455e-02,
        4.7324e-02, 2.1598e-02, 1.7828e-02, 2.5008e-01, 1.6689e+00, 5.2316e+00,
        2.7615e+00, 5.0480e-01, 2.0728e-01, 8.4930e-02, 4.1830e-02, 2.0848e-02,
        1.8024e-02, 2.7630e-02, 6.4993e-01, 1.1846e+00, 3.1882e+00, 5.5400e-01,
        2.3471e-01, 8.8975e-02, 5.8275e-02, 3.2752e-02, 3.3800e-02, 1.0455e-02,
        2.8756

In [8]:
import torch

# Define a dummy neural network
class NeuralNetwork(torch.nn.Module):
    def __init__(self, num_weights):
        super(NeuralNetwork, self).__init__()
        self.weights = torch.nn.Parameter(torch.randn(num_weights))
    
    def forward(self, x, d):
        # Dummy density and color predictions based on weights
        density = torch.dot(self.weights[:len(x)], x)
        color = torch.dot(self.weights[len(x):], d)
        return density, color

# Initialize the neural network
num_weights = 10  # Example: 5 for density, 5 for color
net = NeuralNetwork(num_weights)

# Example inputs
x = torch.randn(5)  # 3D spatial point with additional features
d = torch.randn(5)  # View direction
observed_y = torch.tensor(1.0)  # Observed pixel value
noise_variance = 0.1

# Forward pass
density, color = net(x, d)
predicted_y = density + color

# Compute the log-likelihood
log_likelihood = -0.5 / noise_variance * (observed_y - predicted_y) ** 2

# Compute the gradient of the log-likelihood with respect to weights
grad_log_likelihood = torch.autograd.grad(log_likelihood, net.weights, retain_graph=True)[0]

# Fisher Information Matrix (outer product of gradients)
fim = torch.outer(grad_log_likelihood, grad_log_likelihood)

# Print the Fisher Information Matrix
print("Fisher Information Matrix:\n", fim)


Fisher Information Matrix:
 tensor([[   14.2811,   -57.1937,   141.3810,   -18.5445,   119.4245,   -39.4169,
           -75.3514,    -8.7922,  -108.5234,    76.1082],
        [  -57.1937,   229.0522,  -566.2097,    74.2678,  -478.2774,   157.8590,
           301.7711,    35.2113,   434.6202,  -304.8021],
        [  141.3810,  -566.2097,  1399.6522,  -183.5876,  1182.2865,  -390.2223,
          -745.9684,   -87.0413, -1074.3673,   753.4611],
        [  -18.5445,    74.2678,  -183.5876,    24.0806,  -155.0765,    51.1841,
            97.8461,    11.4169,   140.9211,   -98.8289],
        [  119.4245,  -478.2774,  1182.2865,  -155.0765,   998.6776,  -329.6208,
          -630.1196,   -73.5238,  -907.5182,   636.4487],
        [  -39.4169,   157.8590,  -390.2223,    51.1841,  -329.6208,   108.7938,
           207.9756,    24.2671,   299.5331,  -210.0645],
        [  -75.3514,   301.7711,  -745.9684,    97.8461,  -630.1196,   207.9756,
           397.5765,    46.3901,   572.6022,  -401.5698],

In [9]:
import torch

# Example: Fisher Information Matrix and a neural network output
def compute_uncertainty(fim_inv, jacobian):
    """
    Compute uncertainty at a spatial location given FIM inverse and the Jacobian.
    
    Args:
        fim_inv (torch.Tensor): Inverse of the Fisher Information Matrix (dim: num_weights x num_weights).
        jacobian (torch.Tensor): Jacobian of the output with respect to weights (dim: num_weights).
    
    Returns:
        float: Uncertainty at the spatial location.
    """
    # Uncertainty = sqrt(J^T FIM^-1 J)
    uncertainty = torch.sqrt(jacobian @ fim_inv @ jacobian.T)
    return uncertainty.item()

# Example FIM inverse (for simplicity, diagonal here)
num_weights = 10
fim_inv = torch.diag(torch.ones(num_weights) * 0.1)  # Inverse of Fisher Information Matrix

# Example Jacobian (output sensitivity to weights)
jacobian = torch.randn(num_weights)  # Random Jacobian for illustration

# Compute uncertainty at the location
uncertainty = compute_uncertainty(fim_inv, jacobian)
print("Uncertainty at the spatial location:", uncertainty)


Uncertainty at the spatial location: 1.163894772529602


  uncertainty = torch.sqrt(jacobian @ fim_inv @ jacobian.T)
