In [47]:
import os, json
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import imageio
from tqdm import tqdm


DATASET_DIR = "/kaggle/input/scene"
TRANSFORMS_PATH = os.path.join(DATASET_DIR, "transforms_train.json")
MODEL_PATH = "/kaggle/working/tiny_nerf_01000.pth" #model
OUTPUT_VIDEO = "orbit.mp4"

H, W = 200, 200
N_FRAMES = 60
CHUNK = 1024

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BACKGROUND = "white"   # "white" or "black"

# Tiny-NeRF params
L_POS = 6
N_SAMPLES = 64

print("Using device:", DEVICE)

def posenc(x, L):
    out = [x]
    for i in range(L):
        out.append(torch.sin((2**i) * np.pi * x))
        out.append(torch.cos((2**i) * np.pi * x))
    return torch.cat(out, dim=-1)


class TinyNeRF(nn.Module):
    def __init__(self, D=8, W=256, input_ch=39, skips=[4]):
        super().__init__()
        self.D = D
        self.W = W
        self.input_ch = input_ch
        self.skips = skips

        self.pts_linears = nn.ModuleList(
            [nn.Linear(input_ch, W)] +
            [
                nn.Linear(W + input_ch, W) if i in skips
                else nn.Linear(W, W)
                for i in range(1, D)
            ]
        )

        self.output_linear = nn.Linear(W, 4)  # RGB + sigma

    def forward(self, x):
        h = x
        for i, l in enumerate(self.pts_linears):
            if i in self.skips:
                h = torch.cat([x, h], dim=-1)  # 256 + 39 = 295
            h = F.relu(l(h))

        out = self.output_linear(h)
        rgb = torch.sigmoid(out[..., :3])
        sigma = F.relu(out[..., 3])
        return rgb, sigma

def get_rays(H, W, focal, c2w):
    i, j = torch.meshgrid(
        torch.arange(W, device=DEVICE),
        torch.arange(H, device=DEVICE),
        indexing="xy"
    )
    dirs = torch.stack([
        (i - W * 0.5) / focal,
        -(j - H * 0.5) / focal,
        -torch.ones_like(i)
    ], -1)

    rays_d = (dirs[..., None, :] * c2w[:3, :3]).sum(-1)
    rays_o = c2w[:3, 3].expand_as(rays_d)
    return rays_o, rays_d


def render_rays(model, rays_o, rays_d, near, far):
    z_vals = torch.linspace(near, far, N_SAMPLES, device=DEVICE)
    z_vals = z_vals.expand(rays_o.shape[0], N_SAMPLES)

    # Sample points along rays
    pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., None]

    # Positional encoding (Tiny-NeRF → ONLY xyz)
    pts_enc = posenc(pts.reshape(-1, 3), L_POS)

    # Forward pass
    rgb, sigma = model(pts_enc)

    rgb = rgb.view(-1, N_SAMPLES, 3)
    sigma = sigma.view(-1, N_SAMPLES)

    # Volume rendering
    dists = z_vals[:, 1:] - z_vals[:, :-1]
    dists = torch.cat(
        [dists, 1e10 * torch.ones_like(dists[:, :1])],
        dim=-1
    )

    alpha = 1.0 - torch.exp(-sigma * dists)

    T = torch.cumprod(
        torch.cat(
            [torch.ones_like(alpha[:, :1]), 1.0 - alpha + 1e-10],
            dim=-1
        ),
        dim=-1
    )[:, :-1]

    weights = alpha * T

    rgb_map = torch.sum(weights[..., None] * rgb, dim=1)

    if BACKGROUND == "white":
        rgb_map = rgb_map + (1.0 - weights.sum(dim=1, keepdim=True))

    return rgb_map


@torch.no_grad()
def render_image(model, rays_o, rays_d, near, far):
    H_, W_ = rays_o.shape[:2]
    rays_o = rays_o.reshape(-1, 3)
    rays_d = rays_d.reshape(-1, 3)

    out = []
    for i in range(0, rays_o.shape[0], CHUNK):
        rgb = render_rays(
            model,
            rays_o[i:i+CHUNK],
            rays_d[i:i+CHUNK],
            near, far
        )
        out.append(rgb.cpu())

    return torch.cat(out).reshape(H_, W_, 3)




def pose_spherical(theta_deg, radius, device):
    theta = np.deg2rad(theta_deg)

    cam_pos = torch.tensor(
        [radius * np.sin(theta), 0.0, radius * np.cos(theta)],
        device=device,
        dtype=torch.float32
    )

    target = torch.tensor([0.0, 0.0, 0.0], device=device, dtype=torch.float32)
    up = torch.tensor([0.0, 1.0, 0.0], device=device, dtype=torch.float32)

    forward = target - cam_pos
    forward = forward / torch.norm(forward)

    right = torch.cross(forward, up)
    right = right / torch.norm(right)

    up = torch.cross(right, forward)

    c2w = torch.eye(4, device=device, dtype=torch.float32)
    c2w[:3, 0] = right
    c2w[:3, 1] = up
    c2w[:3, 2] = -forward
    c2w[:3, 3] = cam_pos

    return c2w
   

def main():
    with open(TRANSFORMS_PATH) as f:
        meta = json.load(f)

    focal = 0.5 * W / np.tan(0.5 * meta["camera_angle_x"])

    model = TinyNeRF().to(DEVICE)
    ckpt = torch.load(MODEL_PATH, map_location=DEVICE)
    model.load_state_dict(ckpt["model"])
    model.eval()

    near = 0.8 * min([np.linalg.norm(f["transform_matrix"][0][3]) for f in meta["frames"]])
    far  = 1.2 * max([np.linalg.norm(f["transform_matrix"][0][3]) for f in meta["frames"]])

    frames = []
    for th in tqdm(np.linspace(0, 360, N_FRAMES, endpoint=False)):
        c2w = pose_spherical(th, radius=4.0,device= DEVICE)
        rays_o, rays_d = get_rays(H, W, focal, c2w)
        img = render_image(model, rays_o, rays_d, near, far)
        frames.append((img.numpy().clip(0, 1) * 255).astype(np.uint8))

    imageio.mimwrite(OUTPUT_VIDEO, frames, fps=12)
    print("Saved:", OUTPUT_VIDEO)

if __name__ == "__main__":
    main()

Using device: cuda


100%|██████████| 60/60 [01:05<00:00,  1.09s/it]


Saved: orbit.mp4


In [48]:
!pip install mcubes
!pip install trimesh

Collecting mcubes
  Downloading mcubes-0.1.6-cp310-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.9 kB)
Downloading mcubes-0.1.6-cp310-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (293 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m293.3/293.3 kB[0m [31m5.1 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hInstalling collected packages: mcubes
Successfully installed mcubes-0.1.6
Collecting trimesh
  Downloading trimesh-4.10.1-py3-none-any.whl.metadata (13 kB)
Downloading trimesh-4.10.1-py3-none-any.whl (737 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m737.0/737.0 kB[0m [31m10.8 MB/s[0m eta [36m0:00:00[0m [36m0:00:01[0m
[?25hInstalling collected packages: trimesh
Successfully installed trimesh-4.10.1


In [None]:

import mcubes
from skimage import measure
import trimesh
from scipy.ndimage import gaussian_filter

OUTPUT_MESH = "nerf_mesh.ply"

GRID_RES = 200        
SIGMA_THRESHOLD = 8.0
CHUNK = 65536

print("Generating density grid for mesh...")


lin = torch.linspace(-0.6, 0.6, GRID_RES, device=DEVICE)
X, Y, Z = torch.meshgrid(lin, lin, lin, indexing="ij")
pts = torch.stack([X, Y, Z], dim=-1).reshape(-1, 3)

# Tiny / Full NeRF does NOT use view direction for density
views = torch.zeros_like(pts)


sigmas = []

with torch.no_grad():
    for i in tqdm(range(0, pts.shape[0], CHUNK)):
        p = pts[i:i+CHUNK]
        v = views[i:i+CHUNK]

        p_enc = posenc(p, L_POS)

        # Full NeRF expects view encoding → keep zero
        if "L_DIR" in globals():
            v_enc = posenc(v, L_DIR)
            _, sigma = model(torch.cat([p_enc, v_enc], dim=-1))
        else:
            _, sigma = model(p_enc)

        sigmas.append(sigma.squeeze(-1).cpu())

sigma_grid = torch.cat(sigmas).reshape(GRID_RES, GRID_RES, GRID_RES).numpy()

print("Density grid computed")


sigma_grid = gaussian_filter(sigma_grid, sigma=0.5)


verts, faces, normals, values = measure.marching_cubes(
    sigma_grid,
    level=SIGMA_THRESHOLD
)

verts = verts / (GRID_RES - 1) * 2.4 - 1.2

mesh = trimesh.Trimesh(
    vertices=verts,
    faces=faces,
    vertex_normals=normals,
    process=False
)



mesh.export(OUTPUT_MESH)

print(f"Mesh saved to {OUTPUT_MESH}")