In [1]:
import os, json, math, time
import numpy as np
import imageio.v2 as imageio
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import pandas as pd


def load_blender_data_raw(basedir):
    splits = ['train', 'val', 'test']
    metas = {s: json.load(open(os.path.join(basedir, f"transforms_{s}.json"), "r"))
             for s in splits}

    all_imgs, all_poses, counts = [], [], [0]
    for s in splits:
        meta = metas[s]
        imgs, poses = [], []
        for f in meta["frames"]:
            fname = os.path.join(basedir, f["file_path"] + ".png")
            imgs.append(imageio.imread(fname).astype(np.float32) / 255.0)
            poses.append(np.array(f["transform_matrix"], dtype=np.float32))
        imgs  = np.stack(imgs,  axis=0)
        poses = np.stack(poses, axis=0)
        counts.append(counts[-1] + imgs.shape[0])
        all_imgs.append(imgs)
        all_poses.append(poses)

    imgs  = np.concatenate(all_imgs,  axis=0)
    poses = np.concatenate(all_poses, axis=0)

    H, W = imgs[0].shape[:2]
    camera_angle_x = float(metas["train"]["camera_angle_x"])
    focal = 0.5 * W / np.tan(0.5 * camera_angle_x)
    i_split = [np.arange(counts[i], counts[i+1]) for i in range(3)]
    return imgs, poses, [H, W, focal], i_split


def get_rays(H, W, focal, c2w):
    device = c2w.device
    i, j = torch.meshgrid(
        torch.arange(W, dtype=torch.float32, device=device),
        torch.arange(H, dtype=torch.float32, device=device),
        indexing="xy"
    )
    dirs = torch.stack([
        (i - W * 0.5) / focal,
        -(j - H * 0.5) / focal,
        -torch.ones_like(i)
    ], dim=-1)

    R = c2w[:3, :3]
    t = c2w[:3, 3]
    rd = dirs @ R.T
    rd = F.normalize(rd, dim=-1)
    ro = t.expand_as(rd)
    return ro, rd


@torch.no_grad()
def build_ray_bank(imgs, poses, H, W, focal):
    N = imgs.shape[0]
    rays_o, rays_d = [], []
    for n in range(N):
        ro, rd = get_rays(H, W, focal, poses[n])
        rays_o.append(ro)
        rays_d.append(rd)
    rays_o = torch.stack(rays_o, 0).reshape(-1, 3)
    rays_d = torch.stack(rays_d, 0).reshape(-1, 3)
    rgb = imgs[..., :3].reshape(-1, 3)
    return rays_o, rays_d, rgb


def sample_along_rays(rays_o, rays_d, near, far, S=64, stratified=True):
    B = rays_o.shape[0]
    z_lin = torch.linspace(0., 1., S, device=rays_o.device)
    z = near * (1. - z_lin) + far * z_lin
    z = z.unsqueeze(0).expand(B, S)

    if stratified:
        mids = 0.5 * (z[:, :-1] + z[:, 1:])
        lows = torch.cat([z[:, :1], mids], dim=1)
        highs = torch.cat([mids, z[:, -1:]], dim=1)
        z = lows + (highs - lows) * torch.rand_like(z)

    pts = rays_o[:, None, :] + rays_d[:, None, :] * z[..., None]
    return z, pts


def volume_render(rgb, sigma, z, rays_d, white_bkgd=True):
    B, S, _ = rgb.shape
    deltas = (z[:, 1:] - z[:, :-1])
    deltas = torch.cat([deltas, torch.full((B, 1), 1e10, device=z.device)], dim=1)[..., None]

    dist = deltas * torch.linalg.norm(rays_d, dim=-1, keepdim=True).unsqueeze(1)
    alpha = 1. - torch.exp(-sigma * dist)

    T = torch.cumprod(torch.cat(
        [torch.ones(B, 1, 1, device=z.device), 1. - alpha + 1e-10], dim=1
    ), dim=1)[:, :-1]

    weights = alpha * T
    rgb_map = torch.sum(weights * rgb, dim=1)
    acc_map = torch.sum(weights, dim=1)

    if white_bkgd:
        rgb_map = rgb_map + (1. - acc_map)

    return rgb_map, acc_map, weights


def sample_pdf(bins, weights, N_importance, deterministic=False, eps=1e-5):
    weights = weights + eps
    pdf = weights / torch.sum(weights, dim=-1, keepdim=True)
    cdf = torch.cumsum(pdf, dim=-1)
    cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], dim=-1)

    if deterministic:
        u = torch.linspace(0., 1., N_importance, device=bins.device)
        u = u.unsqueeze(0).expand(bins.shape[0], N_importance)
    else:
        u = torch.rand(bins.shape[0], N_importance, device=bins.device)

    inds = torch.searchsorted(cdf, u, right=True)
    below = torch.clamp(inds - 1, min=0)
    above = torch.clamp(inds, max=cdf.shape[-1] - 1)

    cdf_b = torch.gather(cdf, 1, below)
    cdf_a = torch.gather(cdf, 1, above)
    bins_b = torch.gather(bins, 1, torch.clamp(below - 1, min=0))
    bins_a = torch.gather(bins, 1, torch.clamp(above - 1, min=0))

    denom = (cdf_a - cdf_b).clamp_min(eps)
    t = (u - cdf_b) / denom
    return bins_b + t * (bins_a - bins_b)


class PosEnc(nn.Module):
    def __init__(self, num_freqs):
        super().__init__()
        self.num_freqs = num_freqs
        self.register_buffer("freq_bands", (2. ** torch.arange(num_freqs).float()) * math.pi)
        self.register_buffer("freq_weights", torch.ones(num_freqs))

    @torch.no_grad()
    def set_window_by_iter(self, step, end_step):
        if end_step is None or end_step <= 0:
            w = torch.ones(self.num_freqs, device=self.freq_weights.device)
        else:
            p = max(0.0, min(1.0, float(step) / float(end_step)))
            k = torch.arange(self.num_freqs, device=self.freq_weights.device).float()
            w = (p * self.num_freqs - k).clamp_(0.0, 1.0)
        self.freq_weights.copy_(w)

    def forward(self, x):
        freqs = self.freq_bands.to(x.device)
        xb = x[..., None, :] * freqs[:, None]
        s = torch.sin(xb)
        c = torch.cos(xb)
        w = self.freq_weights.to(x.device)[:, None]
        s = s * w
        c = c * w
        enc = torch.cat([s, c], dim=-1)
        return enc.view(*x.shape[:-1], -1)


class TinyNeRF_PE(nn.Module):
    def __init__(self, Lx=10, Ld=4, hidden=256):
        super().__init__()
        in_xyz = 3 + 3 * 2 * Lx
        in_dir = 3 * 2 * Ld

        self.pe_xyz = PosEnc(Lx)
        self.pe_dir = PosEnc(Ld)

        self.fc1 = nn.Linear(in_xyz, hidden)
        self.fc2 = nn.Linear(hidden, hidden)
        self.fc3 = nn.Linear(hidden, hidden)
        self.fc4 = nn.Linear(hidden + in_xyz, hidden)

        self.fc_sigma = nn.Linear(hidden, 1)
        self.fc_feat  = nn.Linear(hidden, hidden)

        self.fc_dir = nn.Linear(hidden + in_dir, hidden // 2)
        self.fc_rgb = nn.Linear(hidden // 2, 3)

    def forward(self, pts, dirs=None):
        pe_xyz = self.pe_xyz(pts)
        x_in = torch.cat([pts, pe_xyz], dim=-1)

        h = F.relu(self.fc1(x_in))
        h = F.relu(self.fc2(h))
        h = F.relu(self.fc3(h))
        h = torch.cat([h, x_in], dim=-1)
        h = F.relu(self.fc4(h))

        sigma = F.softplus(self.fc_sigma(h))
        feat  = F.relu(self.fc_feat(h))

        if dirs is not None:
            pe_d = self.pe_dir(dirs)
            h_col = torch.cat([feat, pe_d], dim=-1)
        else:
            h_col = feat

        h_col = F.relu(self.fc_dir(h_col))
        rgb = torch.sigmoid(self.fc_rgb(h_col))

        return rgb, sigma


def mse2psnr(x):
    return -10.0 * torch.log10(x)


In [None]:
torch.set_float32_matmul_precision("high")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
basedir = "../data/nerf_synthetic/lego"

# Load dataset
imgs_np, poses_np, hwf, i_split = load_blender_data_raw(basedir)
H, W, focal = hwf

imgs_torch  = torch.from_numpy(imgs_np).float().to(device)
poses_torch = torch.from_numpy(poses_np).float().to(device)

# Handle alpha channel (RGBA → RGB on white)
if imgs_torch.shape[-1] == 4:
    rgb, a = imgs_torch[..., :3], imgs_torch[..., 3:4]
    imgs_torch = rgb * a + (1.0 - a)

# Downsample
down = 2
if down > 1:
    imgs_chw = imgs_torch.permute(0, 3, 1, 2)
    imgs_chw = F.interpolate(imgs_chw, size=(H // down, W // down), mode="area")
    imgs_torch = imgs_chw.permute(0, 2, 3, 1).contiguous()
    H //= down; W //= down; focal /= down

# Split train/test
i_train = i_split[0]
i_test  = i_split[2]

imgs_train, poses_train = imgs_torch[i_train], poses_torch[i_train]
imgs_test,  poses_test  = imgs_torch[i_test],  poses_torch[i_test]

# Ray banks
with torch.no_grad():
    rays_o_all, rays_d_all, rgb_all = build_ray_bank(imgs_train, poses_train, H, W, focal)
    rays_o_all = rays_o_all.to(device).contiguous()
    rays_d_all = rays_d_all.to(device).contiguous()
    rgb_all    = rgb_all.to(device).contiguous()
    rays_d_all_unit = F.normalize(rays_d_all, dim=-1)

    rays_o_test, rays_d_test, rgb_test = build_ray_bank(imgs_test, poses_test, H, W, focal)
    rays_o_test = rays_o_test.to(device).contiguous()
    rays_d_test = rays_d_test.to(device).contiguous()
    rgb_test    = rgb_test.to(device).contiguous()
    rays_d_test_unit = F.normalize(rays_d_test, dim=-1)

# Fixed evaluation subset
torch.manual_seed(0)
E = min(8192, rays_o_test.shape[0])
perm = torch.randperm(rays_o_test.shape[0], device=device)
eval_idx = perm[:E]
rays_o_eval     = rays_o_test[eval_idx].contiguous()
rays_d_eval     = rays_d_test[eval_idx].contiguous()
rays_d_eval_unit= rays_d_test_unit[eval_idx].contiguous()
rgb_eval        = rgb_test[eval_idx].contiguous()

# Model
model = TinyNeRF_PE(Lx=10, Ld=4, hidden=256).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)

pe_pos_end = 8000
pe_dir_end = 4000
near, far = 2.0, 6.0
S_coarse, S_fine = 64, 128
B = 2048
iters = 25000

# Logging
hist_iter, hist_loss_f, hist_loss_c = [], [], []
hist_psnr_f, hist_psnr_c = [], []
hist_pos_on, hist_dir_on = [], []
hist_iter_eval, hist_test_loss_f, hist_test_psnr_f = [], [], []

def _log(it, loss_main, loss_aux, psnr_f, psnr_c, npos, ndir):
    hist_iter.append(it)
    hist_loss_f.append(float(loss_main.detach().cpu()))
    hist_loss_c.append(float(loss_aux.detach().cpu()))
    hist_psnr_f.append(float(psnr_f))
    hist_psnr_c.append(float(psnr_c))
    hist_pos_on.append(int(npos))
    hist_dir_on.append(int(ndir))

@torch.no_grad()
def eval_fine(model, rays_o, rays_d, rays_d_unit, target,
              near, far, S_coarse=32, S_fine=64, chunk=1024):

    model.eval()
    N = rays_o.shape[0]
    err_sum = 0.0

    for s in range(0, N, chunk):
        e = min(N, s + chunk)
        ro, rd = rays_o[s:e], rays_d[s:e]
        rdu, tgt = rays_d_unit[s:e], target[s:e]

        z_c, pts_c = sample_along_rays(ro, rd, near, far, S=S_coarse, stratified=False)
        dirs_c = rdu[:, None, :].expand(-1, S_coarse, -1)
        rgb_c, sigma_c = model(pts_c, dirs_c)
        _, _, w_c = volume_render(rgb_c, sigma_c, z_c, rd, white_bkgd=True)

        bins  = 0.5 * (z_c[:, 1:] + z_c[:, :-1])
        w_mid = w_c[:, 1:-1, 0]
        L = min(bins.shape[-1], w_mid.shape[-1])

        z_f = sample_pdf(bins[:, :L].contiguous(),
                         w_mid[:, :L].contiguous(),
                         N_importance=S_fine,
                         deterministic=True)

        z_all, _ = torch.sort(torch.cat([z_c, z_f], dim=-1), dim=-1)
        pts_all  = ro[:, None, :] + rd[:, None, :] * z_all[..., None]
        dirs_all = rdu[:, None, :].expand(-1, z_all.shape[1], -1)

        rgb_f, sigma_f = model(pts_all, dirs_all)
        rgb_map_f, _, _ = volume_render(rgb_f, sigma_f, z_all, rd, white_bkgd=True)

        err_sum += F.mse_loss(rgb_map_f, tgt, reduction='sum').item()

    model.train()
    mse = err_sum / N
    psnr = 10.0 * math.log10(1.0 / max(mse, 1e-12))
    return mse, psnr

# Training
model.train()
t0 = time.time()
total_rays = rays_o_all.shape[0]
test_every = 100

for it in range(1, iters + 1):

    model.pe_xyz.set_window_by_iter(it, pe_pos_end)
    model.pe_dir.set_window_by_iter(it, pe_dir_end)

    idx = torch.randint(0, total_rays, (B,), device=device)
    rays_o = rays_o_all[idx]
    rays_d = rays_d_all[idx]
    rays_d_unit = rays_d_all_unit[idx]
    target = rgb_all[idx]

    z_c, pts_c = sample_along_rays(rays_o, rays_d, near, far, S=S_coarse, stratified=True)
    dirs_c = rays_d_unit[:, None, :].expand(-1, S_coarse, -1)
    rgb_c, sigma_c = model(pts_c, dirs_c)
    rgb_map_c, _, w_c = volume_render(rgb_c, sigma_c, z_c, rays_d, white_bkgd=True)

    with torch.no_grad():
        bins  = 0.5 * (z_c[:, 1:] + z_c[:, :-1])
        w_mid = w_c[:, 1:-1, 0]
        L = min(bins.shape[-1], w_mid.shape[-1])
        z_f = sample_pdf(bins[:, :L].contiguous(),
                         w_mid[:, :L].contiguous(),
                         N_importance=S_fine,
                         deterministic=False)

    z_all, _ = torch.sort(torch.cat([z_c, z_f], dim=-1), dim=-1)
    pts_all  = rays_o[:, None, :] + rays_d[:, None, :] * z_all[..., None]
    dirs_all = rays_d_unit[:, None, :].expand(-1, z_all.shape[1], -1)
    rgb_f, sigma_f = model(pts_all, dirs_all)
    rgb_map_f, _, _ = volume_render(rgb_f, sigma_f, z_all, rays_d, white_bkgd=True)

    loss_main = F.mse_loss(rgb_map_f, target)
    loss_aux  = F.mse_loss(rgb_map_c, target)
    loss = loss_main + 0.1 * loss_aux

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if it <= 10 or it % 50 == 0:
        psnr_f = mse2psnr(loss_main).item()
        psnr_c = mse2psnr(loss_aux).item()
        npos = int(model.pe_xyz.freq_weights.gt(0.999).sum().item())
        ndir = int(model.pe_dir.freq_weights.gt(0.999).sum().item())
        _log(it, loss_main, loss_aux, psnr_f, psnr_c, npos, ndir)

        dt = time.time() - t0
        print(f"Iter {it:5d}  L={loss.item():.6f}  PSNR_f={psnr_f:.2f}  PSNR_c={psnr_c:.2f}  "
              f"[pos={npos}/{model.pe_xyz.num_freqs}, dir={ndir}/{model.pe_dir.num_freqs}]  ({dt:.1f}s)")
        t0 = time.time()

    if it % test_every == 0:
        test_loss, test_psnr = eval_fine(
            model,
            rays_o_eval, rays_d_eval, rays_d_eval_unit, rgb_eval,
            near=near, far=far, S_coarse=64, S_fine=128, chunk=2048
        )
        hist_iter_eval.append(it)
        hist_test_loss_f.append(test_loss)
        hist_test_psnr_f.append(test_psnr)
        print(f"[TEST] iter {it:5d} | fine: L={test_loss:.6f}, PSNR={test_psnr:.2f}")

print("Finished training.")

# Smooth
def _smooth(x, k=51):
    x = pd.Series(x, dtype=float)
    k = min(k, max(3, (len(x)//5)*2 + 1))
    if k % 2 == 0: k += 1
    return x.rolling(window=k, center=True, min_periods=1).mean().to_numpy()

# Export logs
df = pd.DataFrame({
    "iter": hist_iter,
    "loss_f": hist_loss_f,
    "loss_c": hist_loss_c,
    "psnr_f": hist_psnr_f,
    "psnr_c": hist_psnr_c,
    "pos_on": hist_pos_on,
    "dir_on": hist_dir_on,
})
df_eval = pd.DataFrame({
    "iter_eval": hist_iter_eval,
    "test_loss_f": hist_test_loss_f,
    "test_psnr_f": hist_test_psnr_f,
})
df.to_csv("training_metrics.csv", index=False)
df_eval.to_csv("training_metrics_eval.csv", index=False)

# Plots
plt.figure(figsize=(6,4))
plt.plot(hist_iter, _smooth(hist_loss_f), label="train loss_f")
plt.plot(hist_iter_eval, _smooth(hist_test_loss_f), label="test  loss_f")
plt.xlabel("iter"); plt.ylabel("MSE (fine)"); plt.legend(); plt.tight_layout()
plt.savefig("train_vs_test_loss_f.png", dpi=200); plt.close()

plt.figure(figsize=(6,4))
plt.plot(hist_iter, _smooth(hist_psnr_f), label="train psnr_f")
plt.plot(hist_iter_eval, _smooth(hist_test_psnr_f), label="test  psnr_f")
plt.xlabel("iter"); plt.ylabel("PSNR (dB)"); plt.legend(); plt.tight_layout()
plt.savefig("train_vs_test_psnr_f.png", dpi=200); plt.close()

plt.figure(figsize=(6,4))
plt.plot(hist_iter, _smooth(hist_loss_f), label="loss_f (train)")
plt.plot(hist_iter, _smooth(hist_loss_c), label="loss_c (train)")
plt.xlabel("iter"); plt.ylabel("loss"); plt.legend(); plt.tight_layout()
plt.savefig("train_loss.png", dpi=200); plt.close()

plt.figure(figsize=(6,4))
plt.plot(hist_iter, _smooth(hist_psnr_f), label="psnr_f (train)")
plt.plot(hist_iter, _smooth(hist_psnr_c), label="psnr_c (train)")
plt.xlabel("iter"); plt.ylabel("PSNR (dB)"); plt.legend(); plt.tight_layout()
plt.savefig("train_psnr.png", dpi=200); plt.close()

plt.figure(figsize=(6,4))
plt.plot(hist_iter, hist_pos_on, label="pos_on")
plt.plot(hist_iter, hist_dir_on, label="dir_on")
plt.xlabel("iter"); plt.ylabel("#active bands"); plt.legend(); plt.tight_layout()
plt.savefig("pe_active_bands.png", dpi=200); plt.close()

torch.save({"model": model.state_dict()}, "nerf_tiny_pe.pt")
if torch.cuda.is_available():
    peak_mem_gb = torch.cuda.max_memory_allocated() / (1024**3)
    print(f"Peak GPU memory: {peak_mem_gb:.2f} GB")

print("Saved:")
print("  training_metrics.csv, training_metrics_eval.csv")
print("  train_vs_test_loss_f.png, train_vs_test_psnr_f.png")
print("  train_loss.png, train_psnr.png, pe_active_bands.png")
print("  nerf_tiny_pe.pt")

Iter     1  L=0.216700  PSNR_f=7.06  PSNR_c=7.06  [pos=0/10, dir=0/4]  (0.2s)
Iter     2  L=0.217752  PSNR_f=7.03  PSNR_c=7.03  [pos=0/10, dir=0/4]  (0.1s)
Iter     3  L=0.209801  PSNR_f=7.20  PSNR_c=7.20  [pos=0/10, dir=0/4]  (0.1s)
Iter     4  L=0.215056  PSNR_f=7.09  PSNR_c=7.09  [pos=0/10, dir=0/4]  (0.1s)
Iter     5  L=0.208091  PSNR_f=7.23  PSNR_c=7.23  [pos=0/10, dir=0/4]  (0.1s)
Iter     6  L=0.206412  PSNR_f=7.27  PSNR_c=7.27  [pos=0/10, dir=0/4]  (0.1s)
Iter     7  L=0.208828  PSNR_f=7.22  PSNR_c=7.22  [pos=0/10, dir=0/4]  (0.1s)
Iter     8  L=0.201330  PSNR_f=7.37  PSNR_c=7.38  [pos=0/10, dir=0/4]  (0.1s)
Iter     9  L=0.198896  PSNR_f=7.43  PSNR_c=7.43  [pos=0/10, dir=0/4]  (0.1s)
Iter    10  L=0.199033  PSNR_f=7.42  PSNR_c=7.43  [pos=0/10, dir=0/4]  (0.1s)
Iter    50  L=0.093318  PSNR_f=10.71  PSNR_c=10.71  [pos=0/10, dir=0/4]  (3.0s)
Iter   100  L=0.058600  PSNR_f=12.74  PSNR_c=12.73  [pos=0/10, dir=0/4]  (3.8s)
[TEST] iter   100 | fine: L=0.163508, PSNR=7.86


  inds = torch.searchsorted(cdf, u, right=True)


Iter   150  L=0.039333  PSNR_f=14.47  PSNR_c=14.46  [pos=0/10, dir=0/4]  (3.8s)
Iter   200  L=0.026927  PSNR_f=16.11  PSNR_c=16.13  [pos=0/10, dir=0/4]  (3.8s)
[TEST] iter   200 | fine: L=0.073858, PSNR=11.32
Iter   250  L=0.023049  PSNR_f=16.79  PSNR_c=16.76  [pos=0/10, dir=0/4]  (3.8s)
Iter   300  L=0.020590  PSNR_f=17.28  PSNR_c=17.22  [pos=0/10, dir=0/4]  (3.8s)
[TEST] iter   300 | fine: L=0.057830, PSNR=12.38
Iter   350  L=0.019808  PSNR_f=17.45  PSNR_c=17.36  [pos=0/10, dir=0/4]  (3.8s)
Iter   400  L=0.017669  PSNR_f=17.95  PSNR_c=17.84  [pos=0/10, dir=0/4]  (3.8s)
[TEST] iter   400 | fine: L=0.049355, PSNR=13.07
Iter   450  L=0.016744  PSNR_f=18.19  PSNR_c=18.08  [pos=0/10, dir=0/4]  (3.8s)
Iter   500  L=0.016151  PSNR_f=18.35  PSNR_c=18.17  [pos=0/10, dir=0/4]  (3.8s)
[TEST] iter   500 | fine: L=0.046167, PSNR=13.36
Iter   550  L=0.015645  PSNR_f=18.47  PSNR_c=18.42  [pos=0/10, dir=0/4]  (3.8s)
Iter   600  L=0.014951  PSNR_f=18.68  PSNR_c=18.58  [pos=0/10, dir=0/4]  (3.8s)
[TES

In [7]:
ckpt_path = "nerf_lego_pe.pt"
torch.save({
    "model": model.state_dict(),
    "H": H, "W": W, "focal": focal,
    "near": near, "far": far,
}, ckpt_path)
print(f"Saved checkpoint to {ckpt_path}")

Saved checkpoint to nerf_lego_pe.pt


In [None]:
@torch.no_grad()
def render_single_view(
    basedir="../data/nerf_synthetic/lego",
    ckpt="nerf_lego_pe.pt",
    split="test",
    idx=0,
    down=2,
    S_coarse=128,
    S_fine=128,
    near=2.0,
    far=6.0,
    chunk=8192,
    white_bkgd=True,
    out_path="render_example.png"
):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    imgs_np, poses_np, hwf, i_split = load_blender_data_raw(basedir)
    H, W, focal = hwf

    imgs_torch  = torch.from_numpy(imgs_np).float()
    poses_torch = torch.from_numpy(poses_np).float()

    # Convert RGBA to RGB (white background)
    if imgs_torch.shape[-1] == 4:
        rgb = imgs_torch[..., :3]
        a   = imgs_torch[..., 3:4]
        imgs_torch = rgb * a + (1.0 - a)

    # Match training resolution
    if down > 1:
        imgs_chw = imgs_torch.permute(0, 3, 1, 2)
        imgs_chw = F.interpolate(imgs_chw, size=(H // down, W // down), mode="area")
        imgs_torch = imgs_chw.permute(0, 2, 3, 1).contiguous()
        H //= down
        W //= down
        focal /= down

    split_map = {"train": 0, "val": 1, "test": 2}
    ids = i_split[split_map[split]]
    vid = ids[idx % len(ids)]

    gt  = imgs_torch[vid]
    c2w = poses_torch[vid].to(device)

    # Load model
    model = TinyNeRF_PE(Lx=10, Ld=4, hidden=256).to(device)
    state = torch.load(ckpt, map_location=device)
    model.load_state_dict(state["model"])
    model.eval()

    # Rays
    ro, rd = get_rays(H, W, focal, c2w)
    ro = ro.reshape(-1, 3).to(device)
    rd = rd.reshape(-1, 3).to(device)

    # Rendering
    rgb_out = []
    for i in range(0, H * W, chunk):
        rays_o = ro[i:i+chunk]
        rays_d = rd[i:i+chunk]
        dirs_norm = F.normalize(rays_d, dim=-1)

        # Coarse
        z_c, pts_c = sample_along_rays(rays_o, rays_d, near, far, S=S_coarse, stratified=False)
        dirs_c = dirs_norm[:, None, :].expand(-1, S_coarse, -1)
        rgb_c, sigma_c = model(pts_c, dirs_c)
        rgb_map_c, _, w_c = volume_render(rgb_c, sigma_c, z_c, rays_d, white_bkgd=white_bkgd)

        # Fine
        bins  = 0.5 * (z_c[:, 1:] + z_c[:, :-1])
        w_mid = w_c[:, 1:-1, 0]
        L = min(bins.shape[-1], w_mid.shape[-1])

        z_f = sample_pdf(bins[:, :L], w_mid[:, :L],
                         N_importance=S_fine, deterministic=True)

        z_all, _ = torch.sort(torch.cat([z_c, z_f], dim=-1), dim=-1)
        pts_all  = rays_o[:, None, :] + rays_d[:, None, :] * z_all[..., None]
        dirs_all = dirs_norm[:, None, :].expand(-1, z_all.shape[1], -1)

        rgb_f, sigma_f = model(pts_all, dirs_all)
        rgb_map_f, _, _ = volume_render(rgb_f, sigma_f, z_all, rays_d, white_bkgd=white_bkgd)

        rgb_out.append(rgb_map_f.clamp(0, 1).cpu())

    rgb_img = torch.cat(rgb_out, dim=0).view(H, W, 3).cpu().numpy()

    import imageio.v2 as imageio
    imageio.imwrite(out_path, (rgb_img * 255.0).astype(np.uint8))
    print(f"Saved to {out_path}")

    # PSNR
    mse = F.mse_loss(torch.from_numpy(rgb_img).float(), gt.cpu()).item()
    psnr = -10.0 * math.log10(max(mse, 1e-12))
    print(f"PSNR (vs {split} #{idx}): {psnr:.2f} dB")


# Example
render_single_view(
    basedir="../data/nerf_synthetic/lego",
    ckpt="nerf_lego_pe.pt",
    split="test",
    idx=0,
    down=2,
    out_path="render_example.png"
)

  state = torch.load(ckpt, map_location=device)


Saved to render_example.png
PSNR (vs test #0): 28.98 dB


In [None]:
import torch
import imageio.v2 as imageio
import numpy as np
from torch.nn import functional as F

basedir = "../data/nerf_synthetic/lego"

imgs_np, poses_np, hwf, i_split = load_blender_data_raw(basedir)
H, W, focal = hwf

imgs_torch = torch.from_numpy(imgs_np).float()

# RGBA → RGB (white background)
if imgs_torch.shape[-1] == 4:
    rgb = imgs_torch[..., :3]
    a   = imgs_torch[..., 3:4]
    imgs_torch = rgb * a + (1.0 - a)

# downsample (match training)
down = 2
if down > 1:
    imgs_chw = imgs_torch.permute(0, 3, 1, 2)
    imgs_chw = F.interpolate(imgs_chw, size=(H // down, W // down), mode="area")
    imgs_torch = imgs_chw.permute(0, 2, 3, 1).contiguous()
    H //= down
    W //= down
    focal /= down

# pick first test image
i_test = i_split[2]
idx = i_test[0]
gt = imgs_torch[idx].cpu()

gt_np = (gt.numpy() * 255).astype(np.uint8)
imageio.imwrite("gt_example.png", gt_np)

print("Saved: gt_example.png")

[OK] Saved ground truth to gt_example.png


In [None]:
import math, numpy as np, torch, imageio.v2 as imageio
from torch.nn import functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def rodrigues(axis, theta):
    axis = axis / (np.linalg.norm(axis) + 1e-8)
    x, y, z = axis
    c, s = math.cos(theta), math.sin(theta)
    C = 1 - c
    return np.array([
        [c + x*x*C,     x*y*C - z*s, x*z*C + y*s],
        [y*x*C + z*s,   c + y*y*C,   y*z*C - x*s],
        [z*x*C - y*s,   z*y*C + x*s, c + z*z*C]
    ], dtype=np.float32)

def orbit_pitch_keep_pose(c2w_ref, center, pitch_deg, world_up=np.array([0,1,0], np.float32)):
    eye = c2w_ref[:3, 3].astype(np.float32)
    R_ref = c2w_ref[:3, :3].astype(np.float32)
    center = center.astype(np.float32)

    r = eye - center
    right = np.cross(world_up, r)
    if np.linalg.norm(right) < 1e-6:
        right = np.array([1, 0, 0], np.float32)
    right = right / (np.linalg.norm(right) + 1e-8)

    R_axis = rodrigues(right, math.radians(pitch_deg))
    t_new = center + (R_axis @ r)
    R_new = R_axis @ R_ref

    out = c2w_ref.copy().astype(np.float32)
    out[:3, :3] = R_new
    out[:3, 3] = t_new
    return out

@torch.no_grad()
def render_image_whitebg(model, H, W, focal, c2w_np,
                         near, far, n_samples=96, chunk=8192):
    dev = next(model.parameters()).device
    c2w = torch.from_numpy(c2w_np).to(dev).float()
    rays_o, rays_d = get_rays(H, W, focal, c2w)
    rays_o = rays_o.reshape(-1, 3)
    rays_d = rays_d.reshape(-1, 3)

    out_chunks = []
    for i in range(0, H * W, chunk):
        ro = rays_o[i:i+chunk]
        rd = rays_d[i:i+chunk]
        z, pts = sample_along_rays(ro, rd, near, far, S=n_samples, stratified=False)
        dirs = rd[:, None, :].expand(-1, pts.shape[1], -1)
        rgb, sigma = model(pts, dirs)
        comp, _, _ = volume_render(rgb, sigma, z, rd, white_bkgd=True)
        out_chunks.append(comp)

    img = torch.cat(out_chunks, 0).view(H, W, 3).clamp(0, 1).cpu().numpy()
    return img

CKPT_PATH = "nerf_lego_pe.pt"
BASEDIR = "../data/nerf_synthetic/lego"
REF_IDX = 0
PITCH_DEG = -15.0
NEAR, FAR = 2.0, 6.0
SAMPLES = 96
CHUNK = 8192

state = torch.load(CKPT_PATH, map_location=device)
model = TinyNeRF_PE(Lx=10, Ld=4, hidden=256).to(device)
model.load_state_dict(state["model"])
model.eval()

if hasattr(model, "pe_xyz"):
    model.pe_xyz.freq_weights.fill_(1.0)
if hasattr(model, "pe_dir"):
    model.pe_dir.freq_weights.fill_(1.0)

imgs_np, poses_np, hwf, i_split = load_blender_data_raw(BASEDIR)
H, W, focal = hwf

gt_ref = imgs_np[REF_IDX][..., :3]
c2w_ref = poses_np[REF_IDX].astype(np.float32)

center = poses_np[..., :3, 3].mean(axis=0).astype(np.float32)
c2w_top = orbit_pitch_keep_pose(c2w_ref, center, pitch_deg=PITCH_DEG)

novel_white = render_image_whitebg(
    model, H, W, focal, c2w_top,
    near=NEAR, far=FAR,
    n_samples=SAMPLES, chunk=CHUNK
)

pad = 40
panel = np.ones((H, W * 2 + pad, 3), dtype=np.float32)
panel[:, :W] = gt_ref
panel[:, W + pad:] = novel_white

imageio.imwrite("novel_pitch_up_white.png",
                (np.clip(novel_white, 0, 1) * 255).astype(np.uint8))
imageio.imwrite("novel_pitch_up_panel_white.png",
                (np.clip(panel, 0, 1) * 255).astype(np.uint8))

print("Saved: novel_pitch_up_white.png, novel_pitch_up_panel_white.png")


  state = torch.load(CKPT_PATH, map_location=device)


[OK] saved → novel_pitch_up_white.png / novel_pitch_up_panel_white.png


In [None]:
@torch.no_grad()
def nerf_extract_mesh(
    ckpt_path="nerf_lego_pe.pt",
    basedir="../data/nerf_synthetic/lego",
    grid_res=256,
    sigma_thresh=None,         
    aabb_mode="cams",           
    aabb_fixed=(-1.5, 1.5),
    down=2,
    chunk=65536,
    out_ply="mesh_example.ply",
    do_preview=True
):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    state = torch.load(ckpt_path, map_location=device)
    model = TinyNeRF_PE(Lx=10, Ld=4, hidden=256).to(device)
    model.load_state_dict(state["model"])
    model.eval()
    if hasattr(model, "pe_xyz"): model.pe_xyz.freq_weights.fill_(1.0)
    if hasattr(model, "pe_dir"): model.pe_dir.freq_weights.fill_(1.0)

    if aabb_mode == "cams":
        imgs_np, poses_np, hwf, _ = load_blender_data_raw(basedir)
        H, W, focal = hwf
        if down > 1:
            H//=down; W//=down; focal/=down
        cams = torch.from_numpy(poses_np).float()[..., :3, 3].to(device)  # [N,3]
        aabb_min = cams.min(0).values - 1.0
        aabb_max = cams.max(0).values + 1.0
    else:
        lo, hi = aabb_fixed
        aabb_min = torch.tensor([lo, lo, lo], device=device)
        aabb_max = torch.tensor([hi, hi, hi], device=device)

    xs = torch.linspace(aabb_min[0], aabb_max[0], grid_res, device=device)
    ys = torch.linspace(aabb_min[1], aabb_max[1], grid_res, device=device)
    zs = torch.linspace(aabb_min[2], aabb_max[2], grid_res, device=device)
    X, Y, Z = torch.meshgrid(xs, ys, zs, indexing="ij")
    grid = torch.stack([X, Y, Z], dim=-1).reshape(-1, 3)  

    def make_dummy_dirs(n):
        d = torch.zeros(n, 1, 3, device=device)
        d[..., 2] = -1.0
        return d

    sigmas = []
    for i in range(0, grid.shape[0], chunk):
        pts  = grid[i:i+chunk].unsqueeze(1)      # [B,1,3]
        dirs = make_dummy_dirs(pts.shape[0])     # [B,1,3]
        _, sigma = model(pts, dirs)              
        sigmas.append(sigma.squeeze(1).detach().cpu())
    sigma_grid = torch.cat(sigmas, 0).view(grid_res, grid_res, grid_res).numpy()

    chosen = sigma_thresh
    if sigma_thresh is None:
        scans = [1, 2, 4, 6, 8, 12, 16, 24, 32, 48, 64, 96, 128, 256]
        stats = [(th, (sigma_grid > th).mean()) for th in scans]
        if do_preview:
            print("[INFO] sigma occupancy scan:")
            for th, occ in stats:
                print(f"  th={th:<3}  occupied={occ*100:5.2f}%")
        target = 0.01
        chosen = min(stats, key=lambda t: abs(t[1] - target))[0]
        print(f"[INFO] auto-picked sigma_thresh={chosen} (target~1%)")

    from skimage.measure import marching_cubes
    verts, faces, norms, _ = marching_cubes(sigma_grid, level=chosen)

    scale = (aabb_max - aabb_min).detach().cpu().numpy() / (grid_res - 1)
    verts_world = verts * scale + aabb_min.detach().cpu().numpy()

    import trimesh
    mesh = trimesh.Trimesh(
        verts_world,
        faces,
        vertex_normals=norms,
        process=True   
    )
    mesh.export(out_ply)
    print(f"[OK] Mesh exported → {out_ply}")

    return mesh, chosen


In [None]:
mesh, th = nerf_extract_mesh(
    ckpt_path="nerf_lego_pe.pt",
    basedir="../data/nerf_synthetic/lego",
    aabb_mode="fixed", aabb_fixed=(-1.5, 1.5),   
    grid_res=128,
    sigma_thresh=None,                       
    down=2, chunk=65536, out_ply="mesh_example.ply",
    do_preview=True
)


  state = torch.load(ckpt_path, map_location=device)


[INFO] sigma occupancy scan:
  th=1    occupied= 2.77%
  th=2    occupied= 2.63%
  th=4    occupied= 2.47%
  th=6    occupied= 2.34%
  th=8    occupied= 2.24%
  th=12   occupied= 2.09%
  th=16   occupied= 1.97%
  th=24   occupied= 1.80%
  th=32   occupied= 1.67%
  th=48   occupied= 1.47%
  th=64   occupied= 1.28%
  th=96   occupied= 0.97%
  th=128  occupied= 0.72%
  th=256  occupied= 0.07%
[INFO] auto-picked sigma_thresh=96 (target~1%)
[OK] Mesh exported → mesh_example.ply


In [None]:
import open3d as o3d

def save_mesh_png(ply_path="mesh_example.ply", out_png="mesh_example.png",
                  w=1000, h=750, bg=[1,1,1], azim=35, elev=-15, zoom=0.85):
    mesh = o3d.io.read_triangle_mesh(ply_path)
    mesh.compute_vertex_normals()

    vis = o3d.visualization.Visualizer()
    vis.create_window(visible=False, width=w, height=h)
    vis.add_geometry(mesh)

    opt = vis.get_render_option()
    opt.background_color = bg
    opt.mesh_show_back_face = True
    opt.light_on = True

    ctr = vis.get_view_control()
    ctr.change_field_of_view(step=5.0)     
    ctr.rotate(azim, elev)                 
    ctr.set_zoom(zoom)

    vis.poll_events(); vis.update_renderer()
    vis.capture_screen_image(out_png)
    vis.destroy_window()
    print(f"[OK] saved → {out_png}")

save_mesh_png("mesh_example.ply", "mesh_example.png")

[OK] saved → mesh_example.png
