In [2]:
import torch
import plotly.express as px
import numpy as np

##
# Sampling
def generate_random_spheres(n):
    loc = torch.randn((n, 3))
    radius = torch.rand((n, 1))
    base_color = torch.rand((n, 3)) * 0.5 + 0.25
    return torch.concat((loc, radius, base_color), dim=1)

def uniform_sample_directions(n):
    theta = 2 * torch.pi * torch.rand(n)
    phi = torch.acos(1 - 2 * torch.rand(n))
    
    x = torch.sin(phi) * torch.cos(theta)
    y = torch.sin(phi) * torch.sin(theta)
    z = torch.cos(phi)

    sample = torch.stack([x, y, z], dim=1)
    return sample

def uniform_sample_on_spheres(spheres, n):
    centers = spheres[:, 0:3]  # (N, 3)
    radii = spheres[:, 3:4]  # (N, 1)
    
    areas = 4 * torch.pi * radii ** 2
    total_area = areas.sum()
    
    probabilities = areas / total_area
    sphere_indices = torch.multinomial(probabilities.squeeze(1), n, replacement=True)
    
    sampled_radii = radii[sphere_indices]
    sampled_centers = centers[sphere_indices]

    samples = sampled_radii * uniform_sample_directions(n) + sampled_centers
    return samples, spheres[sphere_indices]

##
# Orthonormal basis
def dot(vec1, vec2):
    return (vec1 * vec2).sum(dim=-1, keepdim=True)

def norm(vec):
    return torch.norm(vec, p=2, dim=-1, keepdim=True)

def to_world(w, normals):
    # https://graphics.pixar.com/library/OrthonormalB/paper.pdf
    x, y, z = normals[:, 0:1], normals[:, 1:2], normals[:, 2:3]
    sign = torch.where(z >= 0.0, 1.0, -1.0)
    a = -1.0 / (sign + z)
    b = x * y * a

    b1 = torch.concat((1.0 + sign * x * x * a, sign * b, -sign * x), dim=1)
    b2 = torch.concat((b, sign + y * y * a, -y), dim=1)

    w_world = w[:, 0:1] * b1 + w[:, 1:2] * b2 + w[:, 2:3] * normals
    return w_world

def to_local(d, normals):
    x, y, z = normals[:, 0:1], normals[:, 1:2], normals[:, 2:3]
    sign = torch.where(z >= 0.0, 1.0, -1.0)
    a = -1.0 / (sign + z)
    b = x * y * a

    b1 = torch.concat((1.0 + sign * x * x * a, sign * b, -sign * x), dim=1)
    b2 = torch.concat((b, sign + y * y * a, -y), dim=1)

    w_local = torch.concat((dot(d, b1), dot(d, b2), dot(d, normals)), dim=1)
    return w_local

##
# Ray tracing
def w_to_rays(x, w, spheres_at_x):
    """
    w의 3번째 coordinates가 normal 축이라고 가정
    """
    centers = spheres_at_x[:, :3]  # (N, 3)
    normals = torch.nn.functional.normalize(x - centers, dim=1)
    d = to_world(w, normals)  # (TODO)

    rays = torch.concat((x, d), dim=1)
    return rays

def rays_spheres_intersection(rays, spheres):
    n_rays = len(rays)
    n_spheres = len(spheres)
    x = rays[:, 0:3].unsqueeze(1).expand(n_rays, n_spheres, 3)
    d = rays[:, 3:6].unsqueeze(1).expand(n_rays, n_spheres, 3)
    centers = spheres[:, 0:3].unsqueeze(0).expand(n_rays, n_spheres, 3)
    radii = spheres[:, 3:4].unsqueeze(0).expand(n_rays, n_spheres, 1)

    c_x = centers - x
    c_x_dot_d = dot(c_x, d)
    D = norm(c_x_dot_d * d - c_x)

    intersect = radii > D
    r2_d2 = radii * radii - D * D

    t1 = torch.where(intersect, c_x_dot_d + torch.sqrt(r2_d2), torch.inf)
    t2 = torch.where(intersect, c_x_dot_d - torch.sqrt(r2_d2), torch.inf)

    t = torch.where(t2 > 0, t2, torch.where(t1 > 0, t1, torch.inf))
    min_t, min_indices = torch.min(t, dim=1)

    x_prime = rays[:, 0:3] + rays[:, 3:6] * min_t
    spheres_at_x_prime = spheres[min_indices.squeeze(1)]
    no_hit = (min_t == torch.inf).squeeze(1)
    return x_prime, spheres_at_x_prime, no_hit

# Diffuse BSDF
def eval_cosined_diffuse_bsdf(w_out, w_in, base_color):
    """
    w의 3번째 coordinates가 normal 축이라고 가정
    """
    cos_theta_o = w_out[:, 2:3]
    spec = base_color / torch.pi * cos_theta_o
    return spec

# Monte Carlo integration of reflected radiance
def sample_and_eval_L_r(x, w_out, spheres, spheres_at_x):
    batch_size = len(x)

    w_in = uniform_sample_directions(batch_size)
    w_in[:, 2] = w_in[:, 2].abs()
    rays = w_to_rays(x, w_in, spheres_at_x)

    x_prime, _, no_hit = rays_spheres_intersection(rays, spheres)
    x_prime[no_hit, :] = 0.0

    L_in = scene["caches"](torch.concat((x_prime, -rays[:,3:6]), dim=1))
    L_in[no_hit, :] = scene["env_light"]
    
    f_s = eval_cosined_diffuse_bsdf(w_out, w_in, spheres_at_x[:, 4:7])

    rhs = L_in * f_s * (2 * torch.pi)

    return rhs

##
# Training neural caches
def train(scene, n_iters=100, lr=0.001, batch_size=16):
    caches = scene["caches"]
    spheres = scene["spheres"]
    optimizer = torch.optim.Adam(caches.parameters(), lr=lr)

    for iter in range(n_iters):
        optimizer.zero_grad()

        x, spheres_at_x = uniform_sample_on_spheres(spheres, batch_size)
        w_out = uniform_sample_directions(batch_size)
        w_out[:, 2] = w_out[:, 2].abs()  # uniform hemispherical sampling
        normals = torch.nn.functional.normalize(x - spheres_at_x[:, :3], dim=1)
        lhs = caches(torch.concat((x, to_world(w_out, normals)), dim=1))

        with torch.no_grad():
            rhs = sample_and_eval_L_r(x, w_out, spheres, spheres_at_x)
        
        loss = (((lhs - rhs)**2).sum(dim=-1) / ((lhs.detach()**2).sum(dim=-1) + 0.01)).mean()
        loss.backward()
        optimizer.step()

        if iter % 100 == 0:
            print(f"Iteration {iter}/{n_iters}, Loss: {loss.item()}")
    pass

##
# Rendering
def tone_map(color, limit=1.5, gamma=2.2):
    color = torch.clip(color, 0, None)
    luminance = 0.2126 * color[:,:,0] + 0.7152 * color[:,:,1] + 0.0722 * color[:,:,2]
    result = color / (1.0 + luminance.unsqueeze(2) / limit)
    result = torch.clip(result ** (1 / gamma), 0.0, 1.0)
    return result

def render(scene, width=1024, spp=4):
    camera = scene["camera"]

    loc = camera["loc"]
    lookat = camera["lookat"]
    up = camera["up"]
    fov = camera["fov"]
    focal = camera["focal"]

    # camera basis in world coordinate
    w = (loc - lookat)
    w = w / torch.norm(w)

    u = torch.cross(up, w)
    u = u / torch.norm(u)

    v = torch.cross(w, u)

    # pixel location (i.e., ray's head) in pixel coordinate
    theta = torch.deg2rad(fov / 2.0)
    half_width = torch.tan(theta) * focal

    interval_width = (2 * half_width) / width
    lin = torch.linspace(-half_width + interval_width / 2, half_width - interval_width / 2, width)
    
    y, x = torch.meshgrid(lin, lin)  # plotly convention
    y = -y

    jitter_x = (torch.rand((width, width, spp)) - 0.5) * interval_width
    jitter_y = (torch.rand((width, width, spp)) - 0.5) * interval_width
    
    pixel_x = x.unsqueeze(-1) + jitter_x
    pixel_y = y.unsqueeze(-1) + jitter_y

    pixel_x = pixel_x.reshape(width * width * spp, 1)
    pixel_y = pixel_y.reshape(width * width * spp, 1)

    # pixel location to world coordinate
    w = w.unsqueeze(0)
    u = u.unsqueeze(0)
    v = v.unsqueeze(0)

    d = pixel_x * u + pixel_y * v - focal * w

    # ray construction
    d = d / torch.norm(d, dim=-1, keepdim=True)
    d = d.reshape(-1, 3)

    rays = torch.concat((loc.view(1, 3).expand(len(d), 3), d), dim=1)

    # camera ray tracing
    x_prime, spheres_at_x_prime, no_hit = rays_spheres_intersection(rays, scene["spheres"])
    x_prime[no_hit, :] = 0.0

    # LHS plotting
    color = scene["caches"](torch.concat((x_prime, -rays[:,3:6]), dim=1))
    color[no_hit] = scene["env_light"] 
    color = color.reshape(width, width, spp, 3)
    color = color.mean(2)

    image = tone_map(color).cpu().detach().numpy()
    image = (255.0 * image).astype(np.uint8)
    fig = px.imshow(image, title="Outgoing radiance")
    fig.show()

    # RHS plotting
    normals = torch.nn.functional.normalize(x_prime - spheres_at_x_prime[:, :3], dim=1)
    color_rhs = sample_and_eval_L_r(x_prime, to_local(-rays[:,3:6], normals), scene["spheres"], spheres_at_x_prime)
    color_rhs[no_hit] = scene["env_light"] 
    color_rhs = color_rhs.reshape(width, width, spp, 3)
    color_rhs = color_rhs.mean(2)

    image_rhs = tone_map(color_rhs).cpu().detach().numpy()
    image_rhs = (255.0 * image_rhs).astype(np.uint8)
    fig = px.imshow(image_rhs, title="Emitted radiance + reflected radiance")
    fig.show()

    return color

##
# Neural caches
torch.set_default_device('cuda')
class MLP(torch.nn.Module):
    def __init__(self, input_dim=6, hidden_dim=128, output_dim=3):
        super(MLP, self).__init__()
        self.fc1 = torch.nn.Linear(input_dim, hidden_dim)
        self.relu1 = torch.nn.ReLU()
        self.fc2 = torch.nn.Linear(hidden_dim, hidden_dim)
        self.relu2 = torch.nn.ReLU()
        self.fc3 = torch.nn.Linear(hidden_dim, output_dim)
        self.relu3 = torch.nn.Softplus()
    
    def forward(self, x):
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        x = self.relu2(x)
        x = self.fc3(x)
        x = self.relu3(x)
        return x


##
# Scene definition
scene = {
    "env_light": 3.0,
    "spheres": generate_random_spheres(10),
    "caches": MLP(input_dim=6, output_dim=3),
    "camera": {
        "loc": torch.tensor([5.0, 0.0, 0.0]),
        "lookat": torch.tensor([0.0, 0.0, 0.0]),
        "up": torch.tensor([0.0, 0.0, 1.0]),
        "fov": torch.tensor(60.0), # degree
        "focal":  torch.tensor(0.1)
    }
}

##
# Test runs
_ = render(scene, width=512, spp=16)
train(scene, lr=1e-4, batch_size=2**14, n_iters=3000)
_ = render(scene)

Iteration 0/3000, Loss: 1.292201280593872
Iteration 100/3000, Loss: 0.9665079116821289
Iteration 200/3000, Loss: 0.9099463224411011
Iteration 300/3000, Loss: 0.9453305006027222
Iteration 400/3000, Loss: 1.0878119468688965
Iteration 500/3000, Loss: 1.1914409399032593
Iteration 600/3000, Loss: 1.2285919189453125
Iteration 700/3000, Loss: 1.2435500621795654
Iteration 800/3000, Loss: 1.2930560111999512
Iteration 900/3000, Loss: 1.3128411769866943
Iteration 1000/3000, Loss: 1.2749030590057373
Iteration 1100/3000, Loss: 1.3460972309112549
Iteration 1200/3000, Loss: 1.4152617454528809
Iteration 1300/3000, Loss: 1.3463869094848633
Iteration 1400/3000, Loss: 1.2506160736083984
Iteration 1500/3000, Loss: 1.25923490524292
Iteration 1600/3000, Loss: 1.2645939588546753
Iteration 1700/3000, Loss: 1.3405787944793701
Iteration 1800/3000, Loss: 1.2420933246612549
Iteration 1900/3000, Loss: 1.1930444240570068
Iteration 2000/3000, Loss: 1.327399492263794
Iteration 2100/3000, Loss: 1.2790014743804932
Iter