In [10]:
import os
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import imageio
from tqdm import tqdm

DATASET_DIR = "/kaggle/input/scene/"  # ------> your path of the dataset
TRANSFORMS_PATH = os.path.join(DATASET_DIR, "transforms_train.json")
MODEL_PATH = "/kaggle/input/model"  # ------> your path of the model
OUTPUT_VIDEO = "orbit.mp4"

H, W = 200, 200          # KEEP LOW for safety
N_FRAMES = 60
CHUNK = 1024             # controls memory
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Background: "black" or "white"
BACKGROUND = "white"


L_POS = 10
L_DIR = 4
N_SAMPLES = 64            # SINGLE sampling → stable & memory-safe
NEAR = 2.0
FAR = 6.0


def posenc(x, L):
    out = [x]
    for i in range(L):
        out.append(torch.sin((2 ** i) * np.pi * x))
        out.append(torch.cos((2 ** i) * np.pi * x))
    return torch.cat(out, -1)


class NeRF(nn.Module):
    def __init__(self, W=256):
        super().__init__()
        self.fc = nn.ModuleList(
            [nn.Linear(63, W)] + [nn.Linear(W, W) for _ in range(7)]
        )
        self.sigma = nn.Linear(W, 1)
        self.rgb = nn.Sequential(
            nn.Linear(W + 27, W),
            nn.ReLU(),
            nn.Linear(W, 3)
        )

    def forward(self, x):
        pts, views = torch.split(x, [63, 27], -1)
        h = pts
        for l in self.fc:
            h = F.relu(l(h))
        sigma = F.relu(self.sigma(h))
        rgb = torch.sigmoid(self.rgb(torch.cat([h, views], -1)))
        return rgb, sigma


def get_rays(H, W, focal, c2w):
    i, j = torch.meshgrid(
        torch.arange(W, device=DEVICE),
        torch.arange(H, device=DEVICE),
        indexing="xy"
    )
    dirs = torch.stack([
        (i - W * 0.5) / focal,
        -(j - H * 0.5) / focal,
        -torch.ones_like(i)
    ], -1)

    rays_d = (dirs[..., None, :] * c2w[:3, :3]).sum(-1)
    rays_o = c2w[:3, 3].expand_as(rays_d)
    return rays_o, rays_d


def render_rays(model, rays_o, rays_d):
    z_vals = torch.linspace(NEAR, FAR, N_SAMPLES, device=DEVICE)
    z_vals = z_vals.expand(rays_o.shape[0], N_SAMPLES)

    pts = rays_o[:, None] + rays_d[:, None] * z_vals[..., None]
    viewdirs = rays_d / torch.norm(rays_d, dim=-1, keepdim=True)
    viewdirs = viewdirs[:, None].expand_as(pts)

    pts_enc = posenc(pts.reshape(-1, 3), L_POS)
    view_enc = posenc(viewdirs.reshape(-1, 3), L_DIR)

    rgb, sigma = model(torch.cat([pts_enc, view_enc], -1))
    rgb = rgb.reshape(-1, N_SAMPLES, 3)
    sigma = sigma.reshape(-1, N_SAMPLES)

    dists = z_vals[:, 1:] - z_vals[:, :-1]
    dists = torch.cat([dists, 1e10 * torch.ones_like(dists[:, :1])], -1)

    alpha = 1.0 - torch.exp(-sigma * dists)
    T = torch.cumprod(
        torch.cat([torch.ones_like(alpha[:, :1]), 1 - alpha + 1e-10], -1),
        -1
    )[:, :-1]
    weights = alpha * T

    rgb_map = (weights[..., None] * rgb).sum(dim=1)

    if BACKGROUND == "white":
        rgb_map = rgb_map + (1 - weights.sum(dim=1, keepdim=True))

    return rgb_map



@torch.no_grad()
def render_image(model, rays_o, rays_d):
    H, W = rays_o.shape[:2]
    rays_o = rays_o.reshape(-1, 3)
    rays_d = rays_d.reshape(-1, 3)

    out = []
    for i in range(0, rays_o.shape[0], CHUNK):
        rgb = render_rays(
            model,
            rays_o[i:i+CHUNK],
            rays_d[i:i+CHUNK]
        )
        out.append(rgb.cpu())

    return torch.cat(out, 0).reshape(H, W, 3)



def pose_spherical(theta_deg, radius, device):
    theta = np.deg2rad(theta_deg)

    cam_pos = torch.tensor(
        [radius * np.sin(theta), 0.0, radius * np.cos(theta)],
        device=device,
        dtype=torch.float32
    )

    target = torch.tensor([0.0, 0.0, 0.0], device=device, dtype=torch.float32)
    up = torch.tensor([0.0, 1.0, 0.0], device=device, dtype=torch.float32)

    forward = target - cam_pos
    forward = forward / torch.norm(forward)

    right = torch.cross(forward, up)
    right = right / torch.norm(right)

    up = torch.cross(right, forward)

    c2w = torch.eye(4, device=device, dtype=torch.float32)
    c2w[:3, 0] = right
    c2w[:3, 1] = up
    c2w[:3, 2] = -forward
    c2w[:3, 3] = cam_pos

    return c2w
   


def main():
    with open(TRANSFORMS_PATH) as f:
        meta = json.load(f)

    focal = 0.5 * W / np.tan(0.5 * meta["camera_angle_x"])

    print(f"Loading checkpoint from {MODEL_PATH}...")
    ckpt = torch.load(MODEL_PATH, map_location=DEVICE)
    
    model = NeRF().to(DEVICE)
    print(sum(p.numel() for p in model.parameters()) / 1e6, "M parameters")
    
    if isinstance(ckpt, dict) and "model_f" in ckpt:
        print("Detected full checkpoint → loading fine model")
        model.load_state_dict(ckpt["model_f"])
    elif isinstance(ckpt, dict) and "state_dict" in ckpt:
        model.load_state_dict(ckpt["state_dict"])
    else:
        print("Detected raw state_dict → loading directly")
        model.load_state_dict(ckpt)
    
    model.eval()

    print("Rendering orbit...")
    frames = []  
    for th in tqdm(np.linspace(0, 360, N_FRAMES, endpoint=False)):
        c2w = pose_spherical(th, radius=4.0,device = "cuda")
        rays_o, rays_d = get_rays(H, W, focal, c2w)
        img = render_image(model, rays_o, rays_d)
        frames.append((img.numpy().clip(0, 1) * 255).astype(np.uint8))

    imageio.mimwrite(OUTPUT_VIDEO, frames, fps=12)
    print(f"Saved {OUTPUT_VIDEO}")

if __name__ == "__main__":
    main()

Loading checkpoint from /kaggle/input/model-3k/pytorch/default/1/ckpt_03000.pth...
0.55066 M parameters
Detected full checkpoint → loading fine model
Rendering orbit...


100%|██████████| 60/60 [01:10<00:00,  1.17s/it]


Saved orbit.mp4
