In [None]:
# Allow importing from src
import sys
sys.path.insert(0, '../src/')

In [None]:
from nerf import LNeRF
from utils.rays import create_rays, render_rays
import matplotlib.pyplot as plt
import torch
from torch.nn import functional as F
from torch import Tensor
from torch.utils.data import DataLoader, TensorDataset

COMPUTE_DEVICE = torch.device('cpu')
if torch.cuda.is_available():
    COMPUTE_DEVICE = torch.device('cuda:0')
elif torch.mps.is_available():
    COMPUTE_DEVICE = torch.device('mps')
print(f"{COMPUTE_DEVICE=}")

float_to_rad = lambda f: torch.deg2rad(torch.tensor(f, dtype=torch.float32))

In [None]:
model = LNeRF.load_from_checkpoint(
    "../lightning_logs/shurtape200x200/checkpoints/best_val_psnr_epoch=13.ckpt",
    map_location=COMPUTE_DEVICE,
    hparams_file="../lightning_logs/shurtape200x200/hparams.yaml"
)
focal = torch.tensor(model.hparams.focal, dtype=torch.float32)

In [None]:
# In accordance with mitsuba's conventions
def look_at(radius: float, theta: Tensor, phi: Tensor, target: Tensor = torch.tensor([0,0,0], dtype=torch.float32)):
    origin = torch.tensor([
        radius * torch.sin(theta) * torch.cos(phi),
        radius * torch.sin(theta) * torch.sin(phi),
        radius * torch.cos(theta),
    ])
    
    forward = F.normalize(origin - target, p="fro", dim=0)
    up = torch.tensor([0, 0, 1], dtype=torch.float32)
    right = F.normalize(torch.cross(up, forward, dim=0), p="fro", dim=0)
    up = F.normalize(torch.cross(forward, right, dim=0), p="fro", dim=0)

    return torch.tensor([
        [right[0], up[0], forward[0], origin[0]],
        [right[1], up[1], forward[1], origin[1]],
        [right[2], up[2], forward[2], origin[2]],
        [0, 0, 0, 1],
    ])

look_at(4, float_to_rad(45), float_to_rad(10))

In [None]:
def intrinsic(focal: Tensor, size: float):
    return torch.tensor([
        [focal.item(), 0, size // 2],
        [0, focal.item(), size // 2],
        [0, 0, 1],
    ], dtype=torch.float32)

In [None]:
c2w = look_at(4, float_to_rad(1e-4), float_to_rad(0)).to(COMPUTE_DEVICE)
img = model.render_image(200, 200, c2w, focal.to(COMPUTE_DEVICE))
plt.imshow(img.cpu().clamp(0,1))
plt.axis('off')

In [None]:
img_shape = (200, 200)
thetas = [1e-4, 90, 90, 90, 90, 180 - 1e-4]
phis = [0, 0, 90, 180, 270, 0]

origins, directions = [], []
for theta, phi in zip(thetas, phis):
    o, d = create_rays(img_shape[0], img_shape[1], intrinsic(focal, 200), look_at(4, float_to_rad(theta), float_to_rad(phi)))
    o, d = o.flatten(0, 1), d.flatten(0, 1)
    origins.append(o)
    directions.append(d)

origins = torch.cat(origins, dim=0)
directions = torch.cat(directions, dim=0)

od_loader = DataLoader(TensorDataset(origins, directions), batch_size=2**11)

coarse_rgbs, coarse_depths, fine_rgbs, fine_depths = [], [], [], []

for o, d in od_loader:
    with torch.no_grad():
        cc, cd, fc, fd = model.compute_along_rays(o.to(COMPUTE_DEVICE), d.to(COMPUTE_DEVICE), coarse_samples=64, fine_samples=128)
    coarse_rgbs.append(cc)
    coarse_depths.append(cd)
    fine_rgbs.append(fc)
    fine_depths.append(fd)

coarse_rgbs = torch.cat(coarse_rgbs, dim=0)
coarse_depths = torch.cat(coarse_depths, dim=0)
fine_rgbs = torch.cat(fine_rgbs, dim=0)
fine_depths = torch.cat(fine_depths, dim=0)

In [None]:
rgb, depth, acc, alpha, weights = render_rays(fine_rgbs, fine_depths, far=model.hparams.far)

rescaled_acc = acc - acc.min()
rescaled_acc /= rescaled_acc.max()

rgba = torch.cat([rgb.cpu(), rescaled_acc.cpu()], dim=-1).clamp(0,1)
images = rgba.reshape((-1,) + img_shape + (4,))
for img in images:
    plt.imshow(img)
    plt.axis('off')
    plt.show()

# torch.min(rescaled_acc), torch.max(rescaled_acc), torch.quantile(rescaled_acc, 0.4)

In [None]:
mask = (depth > model.hparams.near).cpu() & (rgba[..., -1] > 0.98)

colors = rgba[mask]
points = origins.cpu()[mask] + depth.unsqueeze(-1).cpu()[mask] * directions.cpu()[mask]
og = look_at(2, float_to_rad(70), float_to_rad(110))[:3, -1]
sorting = torch.argsort(torch.norm(og - points, p="fro", dim=-1), descending=True)

fig, ax = plt.subplots(1, 1, figsize=(8,8), subplot_kw={"projection": "3d"})
for (x, y, z), color in zip(points[sorting], colors[sorting]):
    ax.plot(x, y, z, linewidth=0, markersize=1, marker='o', c=color.tolist())
ax.view_init(20, 110, 0)
ax.set_xlim(-1, 1)
ax.set_ylim(-1, 1)
ax.set_zlim(-1, 1)

points.shape

In [None]:
fig, axs = plt.subplots(2, 2, figsize=(10, 10))
idx = 40 * img_shape[0] + 40

axs[0,0].set_title("Coarses")
axs[0,0].plot(coarse_depths.cpu()[idx, ...], coarse_rgbs.cpu()[idx, ..., 3])

axs[0,1].set_title("Fines")
axs[0,1].plot(fine_depths.cpu()[idx, ...], fine_rgbs.cpu()[idx, ..., 3])

axs[1,0].set_title("Alphas")
axs[1,0].plot(fine_depths.cpu()[idx, ...], alpha.cpu()[idx, ...])

axs[1,1].set_title("Weights")
axs[1,1].plot(fine_depths.cpu()[idx, ...], weights.cpu()[idx, ...])

for ax in axs.flatten():
    ax.axvline(depth[idx].item(), color="red")
    ax.set_xlim(3, 3.5)

print(depth[idx])