In [1]:
import os, math
import numpy as np
import torch

data = np.load(f"lego_200x200.npz")

images_train = data["images_train"] / 255.0
c2ws_train = data["c2ws_train"]
images_val = data["images_val"] / 255.0
c2ws_val = data["c2ws_val"]
c2ws_test = data["c2ws_test"]
focal = data["focal"]  

N_train, H, W, _ = images_train.shape
K = np.array([
    [focal,    0.0,  W/2.0],
    [0.0,      focal, H/2.0],
    [0.0,      0.0,  1.0]
], dtype=np.float32)

print("Train set:", images_train.shape, "K[0,0] (fx) =", K[0,0])

Train set: (100, 200, 200, 3) K[0,0] (fx) = 277.77777


In [3]:
def _to_hom(x):
    x = np.asarray(x, dtype=np.float32)
    if x.ndim == 1:
        x = x[None, :]  
    ones = np.ones((*x.shape[:-1], 1), dtype=x.dtype)
    return np.concatenate([x, ones], axis=-1)


def transform(T_4x4, x):
    xh = _to_hom(x)                                    
    y  = xh @ T_4x4.T                                  
    return y[..., :3] / np.clip(y[..., 3:4], 1e-8, None)

def pixel_to_camera(K, uv, s):
    uv = np.asarray(uv, dtype=np.float32)
    s  = np.asarray(s,  dtype=np.float32)

    fx, fy = K[0,0], K[1,1]
    cx, cy = K[0,2], K[1,2]

    u = uv[..., 0]
    v = uv[..., 1]

    
    if s.shape == ():
        s = np.broadcast_to(s, u.shape)
    else:
        s = np.broadcast_to(s, u.shape)

    X = (u - cx) / fx * s
    Y = (v - cy) / fy * s
    Z = s

    return np.stack([X, Y, Z], axis=-1).astype(np.float32)


def pixel_to_ray(K, c2w, uv):
    
    origin = c2w[:3, 3].astype(np.float32)

    
    x_c = pixel_to_camera(K, uv, s=1.0)                
    
    x_w = transform(c2w, x_c)                          

    
    d = x_w - origin
    norm = np.linalg.norm(d, axis=-1, keepdims=True)
    d = d / np.clip(norm, 1e-8, None)
    
    o = np.broadcast_to(origin, d.shape).astype(np.float32)
    return o, d.astype(np.float32)


In [4]:


idx = np.random.randint(0, N_train)
c2w = c2ws_train[idx]
w2c = np.linalg.inv(c2w)

pts_c = np.random.randn(5, 3).astype(np.float32)
pts_w = transform(c2w, pts_c)
pts_c_back = transform(w2c, pts_w)
print("Max |x - inv(inv(x))|:", np.max(np.abs(pts_c - pts_c_back)))


Max |x - inv(inv(x))|: 2.0438067682704286e-07


In [5]:
class RaysData:
    def __init__(self, images, K, c2ws):
        """
        images: (N, H, W, 3)
        K: (3, 3)
        c2ws: (N, 4, 4)
        """
        self.images = images.astype(np.float32)
        self.K = K.astype(np.float32)
        self.c2ws = c2ws.astype(np.float32)

        self.N, self.H, self.W, _ = images.shape

        xs = np.arange(self.W, dtype=np.int32)
        ys = np.arange(self.H, dtype=np.int32)
        grid_x, grid_y = np.meshgrid(xs, ys, indexing="xy")  

        uvs = np.stack([grid_x, grid_y], axis=-1) 
        uvs = np.tile(uvs[None,...], (self.N,1,1,1))  
        self.uvs = uvs.reshape(-1, 2)  

        
        self.pixels = images.reshape(-1, 3)  

        
        img_ids = np.arange(self.N)
        img_ids = np.repeat(img_ids, self.H*self.W)
        self.img_ids = img_ids

        self.total = len(self.uvs)
                
        self.rays_o = np.empty((self.total, 3), dtype=np.float32)
        self.rays_d = np.empty((self.total, 3), dtype=np.float32)

        
        per_img_uvs = self.uvs.reshape(self.N, self.H * self.W, 2).astype(np.float32) + 0.5

        offset = 0
        for i in range(self.N):
            
            o, d = pixel_to_ray(self.K, self.c2ws[i], per_img_uvs[i])  
            n = self.H * self.W
            self.rays_o[offset:offset + n] = o
            self.rays_d[offset:offset + n] = d
            offset += n


    def sample_rays(self, B):
        idxs = np.random.randint(0, self.total, size=(B,))

        
        uv_int = self.uvs[idxs]  
        uvs = uv_int.astype(np.float32) + 0.5

        img_ids = self.img_ids[idxs]
        c2w_batch = self.c2ws[img_ids]

        rays_o = np.zeros((B, 3), dtype=np.float32)
        rays_d = np.zeros((B, 3), dtype=np.float32)

        for k in range(B):
            o, d = pixel_to_ray(self.K, c2w_batch[k], uvs[k])
            rays_o[k] = o
            rays_d[k] = d

        return rays_o, rays_d, self.pixels[idxs], uv_int

def sample_along_rays(
    rays_o,
    rays_d,
    n_samples=64,
    near=2.0,
    far=6.0,
    perturb=True,
    rng=None,
    random=None,  
):
    """
    Uniform samples along rays in [near, far].

    Backward + teacher-compatible behavior:
    - If `random` is None (old calls): use `perturb` and RETURN (points, t_vals).
    - If `random` is True/False (teacher calls): map to perturb and RETURN points only.

    Args:
      rays_o: (B,3)
      rays_d: (B,3) normalized
      n_samples, near, far: sampling config
      perturb: (old API) stratified sampling when True
      rng: np.random.Generator
      random: (teacher API) if not None, overrides `perturb` and changes return type to only `points`

    Returns:
      - Old API (random is None): (points (B,n,3), t_vals (B,n))
      - Teacher API (random is not None): points (B,n,3)
    """
    B = rays_o.shape[0]

    
    use_perturb = perturb if random is None else bool(random)

    if use_perturb:
        
        t_edges = np.linspace(near, far, n_samples + 1, dtype=np.float32)  
        if rng is None:
            rng = np.random.default_rng()
        t_lower = t_edges[:-1]  
        t_upper = t_edges[1:]   
        r = rng.random((B, n_samples), dtype=np.float32)
        t_vals = t_lower[None, :] + r * (t_upper - t_lower)[None, :]
    else:
        t_vals = np.linspace(near, far, n_samples, dtype=np.float32)[None, :].repeat(B, axis=0)

    points = rays_o[:, None, :].astype(np.float32) + t_vals[..., None].astype(np.float32) * rays_d[:, None, :].astype(np.float32)

    
    if random is not None:
        return points.astype(np.float32)

    
    return points.astype(np.float32), t_vals.astype(np.float32)


In [6]:
dataset = RaysData(images_train, K, c2ws_train)
print(dataset.uvs)
B = 100
rays_o, rays_d, pixels, uv_int = dataset.sample_rays(B)  



points, t_vals = sample_along_rays(rays_o, rays_d, n_samples=64, near=2.0, far=6.0, perturb=True)

print("rays_o:", rays_o.shape, "rays_d:", rays_d.shape, "points:", points.shape)


[[  0   0]
 [  1   0]
 [  2   0]
 ...
 [197 199]
 [198 199]
 [199 199]]
rays_o: (100, 3) rays_d: (100, 3) points: (100, 64, 3)


In [7]:
import viser
import time


server = viser.ViserServer(share=True)
print("0000000000000")

for i, (image, c2w) in enumerate(zip(images_train, c2ws_train)):
    print(i)
    server.add_camera_frustum(
        f"/cameras/{i}",
        fov=2 * np.arctan2(H / 2, K[0, 0]),
        aspect=W / H,
        scale=0.15,
        wxyz=viser.transforms.SO3.from_matrix(c2w[:3, :3]).wxyz,
        position=c2w[:3, 3],
        image=image,
    )
print("1")

for i, (o, d) in enumerate(zip(rays_o, rays_d)):
    print(i)
    server.add_spline_catmull_rom(
        f"/rays/{i}", positions=np.stack((o, o + d * 6.0), axis=0),
    )


server.add_point_cloud(
    "/samples",
    colors=np.zeros_like(points, dtype=np.float32).reshape(-1, 3),
    points=points.reshape(-1, 3),
    point_size=0.02,
)

print("Viser server running. Interrupt the cell to stop.")
while True:
    time.sleep(0.1)


0000000000000
0


  server.add_camera_frustum(


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
1
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
Viser server running. Interrupt the cell to stop.


  server.add_spline_catmull_rom(
  server.add_point_cloud(


KeyboardInterrupt: 

In [8]:
import viser, time
import numpy as np


dataset = RaysData(images_train, K, c2ws_train)


uvs_start = 0
uvs_end = 40_000
sample_uvs = dataset.uvs[uvs_start:uvs_end] 

assert np.all(images_train[0, sample_uvs[:,1], sample_uvs[:,0]] == dataset.pixels[uvs_start:uvs_end])





indices_x = np.random.randint(low=100, high=200, size=100)
indices_y = np.random.randint(low=0, high=100, size=100)
indices = indices_x + (indices_y * 200)

data = {"rays_o": dataset.rays_o[indices], "rays_d": dataset.rays_d[indices]}
points = sample_along_rays(data["rays_o"], data["rays_d"], random=True)


server = viser.ViserServer(share=True)
for i, (image, c2w) in enumerate(zip(images_train, c2ws_train)):
  print(i)
  server.add_camera_frustum(
    f"/cameras/{i}",
    fov=2 * np.arctan2(H / 2, K[0, 0]),
    aspect=W / H,
    scale=0.15,
    wxyz=viser.transforms.SO3.from_matrix(c2w[:3, :3]).wxyz,
    position=c2w[:3, 3],
    image=image
  )
for i, (o, d) in enumerate(zip(data["rays_o"], data["rays_d"])):
  positions = np.stack((o, o + d * 6.0))
  server.add_spline_catmull_rom(
      f"/rays/{i}", positions=positions,
  )
server.add_point_cloud(
    f"/samples",
    colors=np.zeros_like(points).reshape(-1, 3),
    points=points.reshape(-1, 3),
    point_size=0.03,
)

while True:
    time.sleep(0.1)  

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99


  server.add_camera_frustum(
  server.add_spline_catmull_rom(
  server.add_point_cloud(


KeyboardInterrupt: 

In [15]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F


class PosEnc(nn.Module):
    """
    Positional encoding for D-dimensional inputs.
    Returns: [x, sin(2^k pi x), cos(2^k pi x)] for k=0..L-1 (applied per-dimension).
    """
    def __init__(self, L: int, D: int):
        super().__init__()
        self.L = L
        self.D = D
        
        freqs = (2.0 ** torch.arange(L)) * math.pi
        self.register_buffer("freqs", freqs.float())

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: (..., D)
        returns: (..., D + 2*D*L)
        """
        assert x.shape[-1] == self.D, f"Expected last dim {self.D}, got {x.shape[-1]}"
        if self.L == 0:
            return x
        
        xb = x.unsqueeze(-2) * self.freqs.view(-1, 1)   
        s = torch.sin(xb)                               
        c = torch.cos(xb)                               
        
        return torch.cat([x, s.reshape(*x.shape[:-1], -1), c.reshape(*x.shape[:-1], -1)], dim=-1)

def pe_dim(D: int, L: int) -> int:
    """Output dimension of PE for D-dim input and L frequencies."""
    return D + 2 * D * L



class NeRFMLP(nn.Module):
    """
    NeRF-style network:
      - Trunk on PE(x) with skip at the middle.
      - Density head: Linear(256->1) + ReLU.
      - Color head: concat(feature, PE(d)) -> [Linear(256+dir_pe -> 128), ReLU, Linear(128->3), Sigmoid].
    """
    def __init__(self,
                 Lx: int = 10,     
                 Ld: int = 4,      
                 width: int = 256, 
                 depth: int = 8    
                 ):
        super().__init__()
        self.Lx = Lx
        self.Ld = Ld
        self.width = width
        self.depth = depth

        self.pe_x = PosEnc(L=Lx, D=3)
        self.pe_d = PosEnc(L=Ld, D=3)

        in_pos = pe_dim(3, Lx)   
        in_dir = pe_dim(3, Ld)   

        
        layers1 = []
        d = in_pos
        half = depth // 2
        for _ in range(half):
            layers1 += [nn.Linear(d, width), nn.ReLU(inplace=True)]
            d = width
        self.trunk1 = nn.Sequential(*layers1)

        
        layers2 = []
        d = width + in_pos  
        for _ in range(depth - half):
            layers2 += [nn.Linear(d, width), nn.ReLU(inplace=True)]
            d = width
        self.trunk2 = nn.Sequential(*layers2)

        
        self.sigma_head = nn.Sequential(
            nn.Linear(width, 1),
            nn.ReLU(inplace=True)
        )

        
        self.feature_256 = nn.Linear(width, width)

        
        self.color_head = nn.Sequential(
            nn.Linear(width + in_dir, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 3),
            nn.Sigmoid()
        )

    def forward(self, x_w: torch.Tensor, d_w: torch.Tensor):
        """
        x_w: (..., 3) world coords
        d_w: (..., 3) view/ray direction (will be normalized)
        Returns:
           rgb:   (..., 3) in [0,1]
           sigma: (..., 1) >= 0
        """
        
        d_w = F.normalize(d_w, dim=-1)

        
        x_enc = self.pe_x(x_w)   
        d_enc = self.pe_d(d_w)   

        
        h1 = self.trunk1(x_enc)              
        h  = torch.cat([h1, x_enc], dim=-1)  
        h  = self.trunk2(h)                  

        
        sigma = self.sigma_head(h)           

        
        feat = self.feature_256(h)           
        h_rgb = torch.cat([feat, d_enc], dim=-1)
        rgb = self.color_head(h_rgb)         

        return rgb, sigma


In [16]:
nerf = NeRFMLP(Lx=10, Ld=4, width=256, depth=8)
B, S = 32, 64
x = torch.randn(B, S, 3)         
d = torch.randn(B, S, 3)         
rgb, sigma = nerf(x, d)
print(rgb.shape, sigma.shape)    


torch.Size([32, 64, 3]) torch.Size([32, 64, 1])


In [24]:
import torch

def volrend(sigmas: torch.Tensor, rgbs: torch.Tensor, step_size: float):
    """
    Discrete volume rendering with constant step size (delta = step_size).

    Args:
      sigmas: (B, S, 1) density σ_i
      rgbs:   (B, S, 3) color c_i in [0,1]
      step_size: float Δ

    Returns:
      C: (B, 3) rendered RGB
    """
    
    alpha = 1.0 - torch.exp(-sigmas * step_size)            

    
    
    cum = torch.cumsum((sigmas * step_size).squeeze(-1), dim=1)
    
    cum_exclusive = torch.cat(
        [torch.zeros(cum.size(0), 1, device=cum.device, dtype=cum.dtype),
         cum[:, :-1]], dim=1
    )
    T = torch.exp(-cum_exclusive).unsqueeze(-1)             

    
    weights = T * alpha                                     

    
    C = torch.sum(weights * rgbs, dim=1)                    
    return C


In [26]:
torch.manual_seed(42)
sigmas = torch.rand((10, 64, 1))
rgbs = torch.rand((10, 64, 3))
step_size = (6.0 - 2.0) / 64
rendered_colors = volrend(sigmas, rgbs, step_size)

correct = torch.tensor([
    [0.5006, 0.3728, 0.4728],
    [0.4322, 0.3559, 0.4134],
    [0.4027, 0.4394, 0.4610],
    [0.4514, 0.3829, 0.4196],
    [0.4002, 0.4599, 0.4103],
    [0.4471, 0.4044, 0.4069],
    [0.4285, 0.4072, 0.3777],
    [0.4152, 0.4190, 0.4361],
    [0.4051, 0.3651, 0.3969],
    [0.3253, 0.3587, 0.4215]
])
assert torch.allclose(rendered_colors, correct, rtol=1e-4, atol=1e-4)


In [35]:
def render_rays(model, rays_o, rays_d, n_samples=64, near=2.0, far=6.0, device="cuda", white_bg=True):
    """
    Args:
      rays_o, rays_d: (B,3) numpy or torch
    Returns:
      C: (B,3) torch float32 on device
    """
    if not torch.is_tensor(rays_o):
        rays_o = torch.from_numpy(rays_o)
    if not torch.is_tensor(rays_d):
        rays_d = torch.from_numpy(rays_d)
    rays_o = rays_o.to(device).float()
    rays_d = rays_d.to(device).float()

    
    t_vals = torch.linspace(near, far, n_samples, device=device).view(1, n_samples)  
    step_size = (far - near) / n_samples

    
    pts = rays_o[:, None, :] + t_vals[..., None] * F.normalize(rays_d, dim=-1)[:, None, :]  

    
    dirs = F.normalize(rays_d, dim=-1)[:, None, :].expand_as(pts)  

    
    rgb, sigma = model(pts, dirs)  

    
    C = volrend(sigma, rgb, step_size=step_size)  
    return C


In [36]:
def render_image(model, K, c2w, H, W, n_samples=64, near=2.0, far=6.0, chunk=8192, device="cuda"):
    
    xs = np.arange(W, dtype=np.float32)
    ys = np.arange(H, dtype=np.float32)
    grid_x, grid_y = np.meshgrid(xs, ys, indexing="xy")  
    uvs = np.stack([grid_x + 0.5, grid_y + 0.5], axis=-1).reshape(-1, 2)  

    
    o_all, d_all = pixel_to_ray(K, c2w, uvs)  

    
    colors = []
    for i in range(0, H*W, chunk):
        co = render_rays(model,
                         o_all[i:i+chunk],
                         d_all[i:i+chunk],
                         n_samples=n_samples, near=near, far=far, device=device)
        colors.append(co)
    C = torch.cat(colors, dim=0).clamp(0, 1)  
    img = (C.view(H, W, 3).cpu().numpy() * 255.0).astype(np.uint8)
    return img
def mse2psnr(mse: float) -> float:
    return -10.0 * np.log10(max(mse, 1e-10))


In [40]:
import time, os
from PIL import Image
import matplotlib.pyplot as plt

def train_nerf(
    model,
    dataset: RaysData,              
    images_val, c2ws_val,           
    K,
    iters=1000,
    rays_per_batch=10_000,
    n_samples=64,
    near=2.0,
    far=6.0,
    lr=5e-4,
    snaps=(0,100,250,500,750,1000),
    outdir="nerf_out",
    device="cuda",
):
    os.makedirs(outdir, exist_ok=True)
    model = model.to(device)
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    mse_fn = torch.nn.MSELoss()

    H, W = dataset.H, dataset.W
    t0 = time.time()
    val_psnr_hist = []

    for it in range(iters + 1):

        
        rays_o_np, rays_d_np, rgbs_np, _ = dataset.sample_rays(rays_per_batch)  
        target = torch.from_numpy(rgbs_np).to(device).float()  
        
        pred = render_rays(model, rays_o_np, rays_d_np,
                           n_samples=n_samples, near=near, far=far, device=device)
        loss = mse_fn(pred, target)

        opt.zero_grad(set_to_none=True)
        loss.backward()
        opt.step()

        
        if it % 10 == 0:
            psnr = mse2psnr(float(loss.item()))
            print(f"[it {it:04d}] train MSE={float(loss):.6f} PSNR={psnr:.2f} dB")

            
            val_psnrs = []
            with torch.no_grad():
                for i in range(min(len(images_val), 3)):  
                    gt = (images_val[i] * 255.0).astype(np.uint8)
                    pred_img = render_image(model, K, c2ws_val[i], H, W,
                                            n_samples=n_samples, near=near, far=far, device=device)
                    mse = np.mean(
                        (pred_img.astype(np.float32)/255.0 - gt.astype(np.float32)/255.0) ** 2
                    )
                    val_psnrs.append(mse2psnr(mse))

            mean_psnr = float(np.mean(val_psnrs))
            val_psnr_hist.append((it, mean_psnr))

            
            xs = [t for t,_ in val_psnr_hist]
            ys = [p for _,p in val_psnr_hist]
            plt.figure(figsize=(5,3))
            plt.plot(xs, ys, marker='o')
            plt.xlabel("iter"); plt.ylabel("Val PSNR (dB)")
            plt.tight_layout()
            plt.savefig(os.path.join(outdir, "val_psnr_curve.png"), dpi=150)
            plt.close()

        
        if it % 100 == 0 and it != 0:
            torch.save({
                "it": it,
                "model": model.state_dict(),
                "opt": opt.state_dict(),
                "val_psnr_hist": val_psnr_hist,
            }, os.path.join(outdir, f"ckpt_{it:06d}.pth"))
            torch.save({
                "it": it,
                "model": model.state_dict(),
                "opt": opt.state_dict(),
                "val_psnr_hist": val_psnr_hist,
            }, os.path.join(outdir, f"ckpt_latest.pth"))

        
        if it in snaps:
            with torch.no_grad():
                img_pred = render_image(model, K, c2ws_val[0], H, W,
                                        n_samples=n_samples, near=near, far=far, device=device)
            Image.fromarray(img_pred).save(os.path.join(outdir, f"val0_it{it}.png"))
    print(f"Done. time={time.time()-t0:.1f}s")
    return val_psnr_hist

In [38]:
def visualize_training_step_viser(dataset, model=None, K=None, n_rays=100, n_samples=64, near=2.0, far=6.0):
    import viser
    server = viser.ViserServer(share=True)

    
    for i, (image, c2w) in enumerate(zip(dataset.images, dataset.c2ws)):
        server.add_camera_frustum(
            f"/cameras/{i}",
            fov=2 * np.arctan2(dataset.H / 2, dataset.K[0, 0]),
            aspect=dataset.W / dataset.H,
            scale=0.15,
            wxyz=viser.transforms.SO3.from_matrix(c2w[:3,:3]).wxyz,
            position=c2w[:3,3],
            image=image,
        )

    
    rays_o, rays_d, _, _ = dataset.sample_rays(n_rays)

    
    t_vals = np.linspace(near, far, n_samples, dtype=np.float32)  
    points = rays_o[:, None, :] + t_vals[None, :, None] * (rays_d[:, None, :] / np.linalg.norm(rays_d, axis=-1, keepdims=True))
    
    for i, (o, d) in enumerate(zip(rays_o, rays_d)):
        server.add_spline_catmull_rom(f"/rays/{i}", positions=np.stack([o, o + d * 6.0], axis=0))
    server.add_point_cloud("/samples", points=points.reshape(-1,3), colors=np.zeros_like(points.reshape(-1,3)), point_size=0.02)

    print("Viser running; Ctrl/Cmd+C to stop.")
    import time
    while True:
        time.sleep(0.1)


In [32]:
import imageio.v2 as imageio

@torch.no_grad()
def render_test_sweep(model, K, c2ws_test, H, W, out_path="lego_sweep.gif", fps=12, n_samples=64, near=2.0, far=6.0, device="cuda"):
    frames = []
    for i, c2w in enumerate(c2ws_test):
        img = render_image(model, K, c2w, H, W, n_samples=n_samples, near=near, far=far, device=device)
        frames.append(img)
    imageio.mimsave(out_path, frames, fps=fps)
    print(f"Saved sweep to {out_path}")


In [42]:
device = "cpu"


dataset = RaysData(images_train, K, c2ws_train)


nerf = NeRFMLP(Lx=10, Ld=4, width=256, depth=8)


val_hist = train_nerf(
    model=nerf,
    dataset=dataset,
    images_val=images_val,
    c2ws_val=c2ws_val,
    K=K,
    iters=1000,
    rays_per_batch=10_000,  
    n_samples=64,
    near=2.0, far=6.0,
    lr=5e-4,
    snaps=(0,5,10,15,50,100,200,300,400,500,600,700,800,900,1000),
    outdir="nerf_out",
    device=device
)





H, W = dataset.H, dataset.W
pred0 = render_image(nerf, K, c2ws_val[0], H, W, n_samples=64, device=device)
Image.fromarray(pred0).save("nerf_out/val0_final.png")

render_test_sweep(nerf, K, c2ws_test, H, W, out_path="nerf_out/lego_sweep.gif", fps=12, n_samples=64, device=device)


0
[it 0000] train MSE=0.085812 PSNR=10.66dB
1
2
3
4
5
6
7
8
9
10
[it 0010] train MSE=0.067687 PSNR=11.69dB
11
12
13
14
15
16
17
18
19
20
[it 0020] train MSE=0.064556 PSNR=11.90dB
21
22
23
24
25
26
27
28
29
30
[it 0030] train MSE=0.061440 PSNR=12.12dB
31
32
33
34
35
36
37
38
39
40
[it 0040] train MSE=0.054455 PSNR=12.64dB
41
42
43
44
45
46
47
48
49
50
[it 0050] train MSE=0.049760 PSNR=13.03dB
51
52
53
54
55
56
57
58
59
60
[it 0060] train MSE=0.041847 PSNR=13.78dB
61
62
63
64
65
66
67
68
69
70
[it 0070] train MSE=0.030773 PSNR=15.12dB
71
72
73
74
75
76
77
78
79
80
[it 0080] train MSE=0.025348 PSNR=15.96dB
81
82
83
84
85
86
87
88
89
90
[it 0090] train MSE=0.022865 PSNR=16.41dB
91
92
93
94
95
96
97
98
99
100
[it 0100] train MSE=0.019940 PSNR=17.00dB
101
102
103
104
105
106
107
108
109
110
[it 0110] train MSE=0.018684 PSNR=17.29dB
111
112
113
114
115
116
117
118
119
120
[it 0120] train MSE=0.016642 PSNR=17.79dB
121
122
123
124
125
126
127
128
129
130
[it 0130] train MSE=0.016450 PSNR=17.84d