In [1]:
import torch
import plotly.express as px
import numpy as np
import random

##
# Sampling
def generate_random_spheres(n):
    loc = torch.rand((n, 3)) * 2 - 1.0  # [-1, 1]
    radius = torch.rand((n, 1)) * 0.3 + 0.25
    base_color = torch.rand((n, 3))
    return torch.concat((loc, radius, base_color), dim=1)

def uniform_sample_directions(n):
    theta = 2 * torch.pi * torch.rand(n)
    phi = torch.acos(1 - 2 * torch.rand(n))
    
    x = torch.sin(phi) * torch.cos(theta)
    y = torch.sin(phi) * torch.sin(theta)
    z = torch.cos(phi)

    sample = torch.stack([x, y, z], dim=1)
    return sample

def uniform_sample_on_spheres(spheres, n):
    centers = spheres[:, 0:3]  # (N, 3)
    radii = spheres[:, 3:4]  # (N, 1)
    
    areas = 4 * torch.pi * radii ** 2
    total_area = areas.sum()
    
    probabilities = areas / total_area
    sphere_indices = torch.multinomial(probabilities.squeeze(1), n, replacement=True)
    
    sampled_radii = radii[sphere_indices]
    sampled_centers = centers[sphere_indices]

    samples = sampled_radii * uniform_sample_directions(n) + sampled_centers
    return samples, spheres[sphere_indices]

##
# Orthonormal basis
def dot(vec1, vec2):
    return (vec1 * vec2).sum(dim=-1, keepdim=True)

def norm(vec):
    return torch.norm(vec, p=2, dim=-1, keepdim=True)

def to_world(w, normals):
    # https://graphics.pixar.com/library/OrthonormalB/paper.pdf
    x, y, z = normals[:, 0:1], normals[:, 1:2], normals[:, 2:3]
    sign = torch.where(z >= 0.0, 1.0, -1.0)
    a = -1.0 / (sign + z)
    b = x * y * a

    b1 = torch.concat((1.0 + sign * x * x * a, sign * b, -sign * x), dim=1)
    b2 = torch.concat((b, sign + y * y * a, -y), dim=1)

    w_world = w[:, 0:1] * b1 + w[:, 1:2] * b2 + w[:, 2:3] * normals
    return w_world

def to_local(d, normals):
    x, y, z = normals[:, 0:1], normals[:, 1:2], normals[:, 2:3]
    sign = torch.where(z >= 0.0, 1.0, -1.0)
    a = -1.0 / (sign + z)
    b = x * y * a

    b1 = torch.concat((1.0 + sign * x * x * a, sign * b, -sign * x), dim=1)
    b2 = torch.concat((b, sign + y * y * a, -y), dim=1)

    w_local = torch.concat((dot(d, b1), dot(d, b2), dot(d, normals)), dim=1)
    return w_local

##
# Ray tracing
def w_to_rays(x, w, spheres_at_x):
    centers = spheres_at_x[:, :3]  # (N, 3)
    normals = torch.nn.functional.normalize(x - centers, dim=1)
    d = to_world(w, normals)

    rays = torch.concat((x, d), dim=1)
    return rays

def rays_spheres_intersection(rays, spheres):
    eps_t = 0.001
    n_rays = len(rays)
    n_spheres = len(spheres)
    x = rays[:, 0:3].unsqueeze(1).expand(n_rays, n_spheres, 3)
    d = rays[:, 3:6].unsqueeze(1).expand(n_rays, n_spheres, 3)
    centers = spheres[:, 0:3].unsqueeze(0).expand(n_rays, n_spheres, 3)
    radii = spheres[:, 3:4].unsqueeze(0).expand(n_rays, n_spheres, 1)

    c_x = centers - x
    c_x_dot_d = dot(c_x, d)
    D = norm(c_x_dot_d * d - c_x)

    intersect = radii > D
    r2_d2 = radii * radii - D * D

    t1 = torch.where(intersect, c_x_dot_d + torch.sqrt(r2_d2), torch.inf)
    t2 = torch.where(intersect, c_x_dot_d - torch.sqrt(r2_d2), torch.inf)

    t = torch.where(t2 > eps_t, t2, torch.where(t1 > eps_t, t1, torch.inf))
    min_t, min_indices = torch.min(t, dim=1)

    x_prime = rays[:, 0:3] + rays[:, 3:6] * min_t
    spheres_at_x_prime = spheres[min_indices.squeeze(1)]
    no_hit = (min_t == torch.inf).squeeze(1)
    return x_prime, spheres_at_x_prime, no_hit

# Diffuse BSDF
def eval_cosined_diffuse_bsdf(w_out, w_in, base_color):
    cos_theta_o = w_out[:, 2:3]
    spec = base_color / torch.pi * cos_theta_o
    return spec

# Monte Carlo integration of reflected radiance
def sample_and_eval_L_r(x, w_out, spheres, spheres_at_x):
    batch_size = len(x)

    w_in = uniform_sample_directions(batch_size)
    w_in[:, 2] = w_in[:, 2].abs()
    
    rays = w_to_rays(x, w_in, spheres_at_x)

    x_p, spheres_at_x_p, no_hit_p = rays_spheres_intersection(rays, spheres)
    x_p[no_hit_p, :] = 0.0

    normals_p = torch.nn.functional.normalize(x_p - spheres_at_x_p[:, :3], dim=1)
    albedo_p = spheres_at_x_p[:, 4:7]
    w_out_p = to_local(-rays[:,3:6], normals_p)
    query = torch.concat((x_p, w_out_p, normals_p, albedo_p), dim=1)

    L_in = scene["caches"](query)
    L_in[no_hit_p, :] = scene["env_light"]
    
    f_s = eval_cosined_diffuse_bsdf(w_out, w_in, spheres_at_x[:, 4:7])

    rhs = L_in * f_s * (2 * torch.pi)

    return rhs

##
# Training neural caches
def relMSE(lhs, rhs):
    numer = ((lhs - rhs)**2).sum(dim=-1)
    denom = (lhs.detach()**2).sum(dim=-1) + 0.01
    error = (numer / denom).mean()
    return error

def train(scene, n_iters=100, lr=0.001, batch_size=16):
    caches = scene["caches"]
    spheres = scene["spheres"]
    optimizer = torch.optim.Adam(caches.parameters(), lr=lr)

    for iter in range(n_iters):
        optimizer.zero_grad()

        x, spheres_at_x = uniform_sample_on_spheres(spheres, batch_size)
        w_out = uniform_sample_directions(batch_size)
        w_out[:, 2] = w_out[:, 2].abs()  # uniform hemispherical sampling
        normals = torch.nn.functional.normalize(x - spheres_at_x[:, :3], dim=1)
        albedo = spheres_at_x[:, 4:7]

        query = torch.concat((x, w_out, normals, albedo), dim=1)
        lhs = caches(query)

        with torch.no_grad():
            rhs = sample_and_eval_L_r(x, w_out, spheres, spheres_at_x)
        
        loss = relMSE(lhs, rhs)
        loss.backward()
        optimizer.step()

        if iter % 100 == 0:
            print(f"Iteration {iter}/{n_iters}, Loss: {loss.item()}")
    pass

##
# Rendering
def tone_map(color, limit=1.5, gamma=2.2):
    color = torch.clip(color, 0, None)
    luminance = 0.2126 * color[:,:,0] + 0.7152 * color[:,:,1] + 0.0722 * color[:,:,2]
    result = color / (1.0 + luminance.unsqueeze(2) / limit)
    result = torch.clip(result ** (1 / gamma), 0.0, 1.0)
    return result

def generate_camera_rays(scene, width=1024, spp=4):
    camera = scene["camera"]

    loc = camera["loc"]
    lookat = camera["lookat"]
    up = camera["up"]
    fov = camera["fov"]
    focal = camera["focal"]

    # camera basis in world coordinate
    w = (loc - lookat)
    w = w / torch.norm(w)

    u = torch.cross(up, w, dim=0)
    u = u / torch.norm(u)

    v = torch.cross(w, u, dim=0)

    # pixel location (i.e., ray's head) in pixel coordinate
    theta = torch.deg2rad(fov / 2.0)
    half_width = torch.tan(theta) * focal

    interval_width = (2 * half_width) / width
    lin = torch.linspace(-half_width + interval_width / 2, half_width - interval_width / 2, width)
    
    x, y = torch.meshgrid(lin, lin, indexing="xy")  # plotly convention
    y = -y

    jitter_x = (torch.rand((width, width, spp)) - 0.5) * interval_width
    jitter_y = (torch.rand((width, width, spp)) - 0.5) * interval_width
    
    pixel_x = x.unsqueeze(-1) + jitter_x
    pixel_y = y.unsqueeze(-1) + jitter_y

    pixel_x = pixel_x.reshape(width * width * spp, 1)
    pixel_y = pixel_y.reshape(width * width * spp, 1)

    # pixel location to world coordinate
    w = w.unsqueeze(0)
    u = u.unsqueeze(0)
    v = v.unsqueeze(0)

    d = pixel_x * u + pixel_y * v - focal * w

    # ray construction
    d = d / torch.norm(d, dim=-1, keepdim=True)
    d = d.reshape(-1, 3)

    rays = torch.concat((loc.view(1, 3).expand(len(d), 3), d), dim=1)

    return rays

def show_image(color, width, spp, title):
    color = color.reshape(width, width, spp, 3)
    color = color.mean(2)

    image = tone_map(color).cpu().detach().numpy()
    image = (255.0 * image).astype(np.uint8)
    fig = px.imshow(image, title=title)
    fig.show()

def render(scene, width=1024, spp=4):
    # camera ray tracing
    rays = generate_camera_rays(scene, width, spp)

    x, spheres_at_x, no_hit = rays_spheres_intersection(rays, scene["spheres"])
    x[no_hit, :] = 0.0

    # LHS plotting
    normals = torch.nn.functional.normalize(x - spheres_at_x[:, :3], dim=1)
    albedo = spheres_at_x[:, 4:7]
    w_out = to_local(-rays[:,3:6], normals)
    query = torch.concat((x, w_out, normals, albedo), dim=1)

    with torch.no_grad():
        color = scene["caches"](query)
    color[no_hit] = scene["env_light"]

    show_image(color, width, spp, "Neural Radiosity: Left-hand side")

    # RHS plotting
    with torch.no_grad():
        color_rhs = sample_and_eval_L_r(x, to_local(-rays[:,3:6], normals), scene["spheres"], spheres_at_x)
    color_rhs[no_hit] = scene["env_light"]

    show_image(color_rhs, width, spp, "Neural Radiosity: Right-hand side")
    torch.cuda.empty_cache()

def render_pt(scene, width=1024, spp=4, max_depth=6):
    batch_size = width * width * spp
    L_o = torch.zeros((batch_size, 3))
    atten = torch.ones((batch_size, 3))
    active = torch.ones((batch_size)) == 1

    rays = generate_camera_rays(scene, width, spp)

    for _ in range(max_depth):
        x, spheres_at_x, no_hit = rays_spheres_intersection(rays, scene["spheres"])
        x[active & no_hit, :] = 0
        L_o[active & no_hit, :] += scene["env_light"] * atten[active & no_hit, :]
        active &= ~no_hit

        if not active.any():
            break

        w_in = uniform_sample_directions(batch_size)
        w_in[:, 2] = w_in[:, 2].abs()

        normals = torch.nn.functional.normalize(x - spheres_at_x[:, :3], dim=1)
        w_out = to_local(-rays[:, 3:6], normals)
        f_s = eval_cosined_diffuse_bsdf(w_out, w_in, spheres_at_x[:, 4:7])
        atten[active, :] *= f_s[active, :] * (2 * torch.pi)

        rays = w_to_rays(x, w_in, spheres_at_x)
    
    show_image(L_o, width, spp, "Path tracing")
    torch.cuda.empty_cache()

##
# Neural caches
torch.set_default_device("cuda")

seed = 404
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Networks
import sys
sys.path.append("/home/in-young/pytorch-path-tracing/")
from src.utils.embeddings import Cube, PositionalEmbedding, Triplane

class MLP(torch.nn.Module):
    def __init__(self, inp_d=3, hid_d=128, out_d=3, hid_n=1):
        super(MLP, self).__init__()
        # self.embed = Cube(res=30, feature_dim=10)
        self.embed = Triplane(res=64, feature_dim=64)
        layers = [
            torch.nn.Linear(inp_d + self.embed.feature_dim, hid_d, False), 
            torch.nn.ReLU(),
            *([torch.nn.Linear(hid_d, hid_d, False), 
               torch.nn.ReLU()] * hid_n),
            torch.nn.Linear(hid_d, out_d, False), 
            torch.nn.Softplus()
        ]
        self.model = torch.nn.Sequential(*layers)
        self._initialize_weights()

    def _initialize_weights(self):
        for module in self.model:
            if isinstance(module, torch.nn.Linear):
                torch.nn.init.xavier_uniform_(module.weight)
    
    def forward(self, x):
        x = torch.concat((x, self.embed(x[:, :3])), dim=1)
        return self.model(x)

##
# Scene definition
scene = {
    "env_light": 2.0,
    "spheres": generate_random_spheres(15),
    "camera": {
        "loc": torch.tensor([3.0, 0.0, 0.0]),
        "lookat": torch.tensor([0.0, 0.0, 0.0]),
        "up": torch.tensor([0.0, 0.0, 1.0]),
        "fov": torch.tensor(60.0), # degree
        "focal":  torch.tensor(0.1)
    },
    "caches": MLP(inp_d=12, out_d=3, hid_d=64, hid_n=0),
}

##
# Test runs
render_pt(scene, width=512, spp=32, max_depth=6)
train(scene, lr=5e-4, batch_size=2**14, n_iters=15000)
render(scene, width=512, spp=4)


Default grid_sample and affine_grid behavior has changed to align_corners=False since 1.3.0. Please specify align_corners=True if the old behavior is desired. See the documentation of grid_sample for details.



Iteration 0/15000, Loss: 2.408158779144287
Iteration 100/15000, Loss: 0.8399250507354736
Iteration 200/15000, Loss: 0.643166720867157
Iteration 300/15000, Loss: 0.5938358306884766
Iteration 400/15000, Loss: 0.49890708923339844
Iteration 500/15000, Loss: 0.471171498298645
Iteration 600/15000, Loss: 0.42294344305992126
Iteration 700/15000, Loss: 0.41852205991744995
Iteration 800/15000, Loss: 0.4484485983848572
Iteration 900/15000, Loss: 0.391868531703949
Iteration 1000/15000, Loss: 0.408453106880188
Iteration 1100/15000, Loss: 0.38064080476760864
Iteration 1200/15000, Loss: 0.41990455985069275
Iteration 1300/15000, Loss: 0.38957107067108154
Iteration 1400/15000, Loss: 0.3384588360786438
Iteration 1500/15000, Loss: 0.40063905715942383
Iteration 1600/15000, Loss: 0.330056756734848
Iteration 1700/15000, Loss: 0.3324451446533203
Iteration 1800/15000, Loss: 0.32813185453414917
Iteration 1900/15000, Loss: 0.3097460865974426
Iteration 2000/15000, Loss: 0.28731459379196167
Iteration 2100/15000, 