In [None]:
# Allow importing from src
import sys
sys.path.insert(0, '../src/')

In [None]:
from nerf import LNeRF
from utils.data import load_npz
from utils.rays import create_rays
import matplotlib.pyplot as plt
import torch
from torch.nn import functional as F

COMPUTE_DEVICE = torch.device('cpu')
if torch.cuda.is_available():
    COMPUTE_DEVICE = torch.device('cuda:0')
elif torch.mps.is_available():
    COMPUTE_DEVICE = torch.device('mps')
print(f"{COMPUTE_DEVICE=}")

In [None]:
model = LNeRF.load_from_checkpoint(
    "../lightning_logs/bober200x200/checkpoints/best_val_psnr_epoch=16.ckpt",
    map_location=COMPUTE_DEVICE,
    hparams_file="../lightning_logs/bober200x200/hparams.yaml"
)

_, c2ws, focal = load_npz("../data/BOBER.npz")

In [None]:
model.hparams.batch_size = 2**6

In [None]:
img = model.render_image(400, 400, c2ws[3].to(COMPUTE_DEVICE), focal=focal.to(COMPUTE_DEVICE) * 2)
plt.imshow(img.cpu().clamp(0,1))
plt.axis('off')

In [None]:
torch.where(torch.isclose(img[..., 3], torch.tensor(0.0), atol=1e-1))

In [None]:
intrinsic = torch.tensor([
    [focal.item() / 2, 0, 100 // 2],
    [0, focal.item() / 2, 100 // 2],
    [0, 0, 1],
], dtype=torch.float32)

origins, directions = create_rays(100, 100, intrinsic, c2ws[32])

In [None]:
with torch.no_grad():
    cc, cd, fc, fd = model.compute_along_rays(origins[20:50, 20:50].flatten(0,1).to(COMPUTE_DEVICE), directions[20:50, 20:50].flatten(0,1).to(COMPUTE_DEVICE),
                                              coarse_samples=200, fine_samples=200)

In [None]:
def render_rays(rgbs, depths):
    device = rgbs.device

    distances = depths[..., 1:] - depths[..., :-1]
    # 1e10 ensures the last color is rendered no matter what
    distances = torch.cat([distances, F.relu(model.hparams.far - depths[..., -1:])], -1)
    # directions already normalized at ray calculation, so distances correspond to world already

    alpha = 1.0 - torch.exp(-F.relu(rgbs[..., 3]) * distances)
    # 1e10 ensures the last color is rendered no matter what
    weights = alpha * torch.cumprod(
        torch.cat([torch.ones((alpha.shape[0], 1), device=device), 1. - alpha + torch.finfo(rgbs.dtype).eps], -1), -1
    )[:, :-1]

    rgb = torch.sum(weights[..., None] * rgbs[..., :3], dim=-2)
    depth = torch.sum(weights * depths, dim=-1)
    acc = torch.sum(weights, dim=-1).unsqueeze(-1)

    return rgb, depth, acc, alpha, weights

rgb, depth, acc, alpha, weights = render_rays(fc, fd)

rescaled_acc = acc - acc.min()
rescaled_acc /= rescaled_acc.max()

plt.imshow(torch.cat([rgb.reshape(30, 30, -1).cpu(), rescaled_acc.reshape(30, 30, -1).cpu()], dim=-1).clamp(0,1))

torch.min(rescaled_acc), torch.max(rescaled_acc), torch.quantile(rescaled_acc, 0.4)

In [None]:
fig, axs = plt.subplots(2, 2, figsize=(10, 10))
idx = 20**2

axs[0,0].set_title("Coarses")
axs[0,0].plot(cd.cpu()[idx, ...], cc.cpu()[idx, ..., 3])

axs[0,1].set_title("Fines")
axs[0,1].plot(fd.cpu()[idx, ...], fc.cpu()[idx, ..., 3])

axs[1,0].set_title("Alphas")
axs[1,0].plot(fd.cpu()[idx, ...], alpha.cpu()[idx, ...])

axs[1,1].set_title("Weights")
axs[1,1].plot(fd.cpu()[idx, ...], weights.cpu()[idx, ...])

for ax in axs.flatten():
    ax.axvline(depth[idx].item(), color="red")
    ax.set_xlim(3.1, 3.4)

print(depth[idx])

In [None]:
mask = (depth < model.hparams.near).cpu()

points = origins[20:50, 20:50].flatten(0,1)[mask] + depth.unsqueeze(-1).cpu()[mask] * directions[20:50, 20:50].flatten(0,1)[mask]

fig, ax = plt.subplots(1, 1, figsize=(8,8), subplot_kw={"projection": "3d"})
ax.plot(points[:, 0], points[:, 1], points[:, 2], linewidth=0, markersize=2, marker='o', color="#000B")
ax.view_init(00, 180, 0)

points.shape