In [None]:
import os, json
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image


DATASET_PATH = "/kaggle/input/scene/"
WHITE_BG = True   # True = white background, False = black background

VAL_IMG_DIR = "val_images"

H_RES, W_RES = 800, 800

L_POS = 6                  # Tiny-NeRF uses L=6
N_ITERS = 50000
BATCH_SIZE = 2048
LR = 5e-4

N_SAMPLES = 64              # single sampling only
VAL_INTERVAL = 100
SAVE_INTERVAL = 1000

DEVICE = torch.device(
    "cuda" if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available()
    else "cpu"
)


os.makedirs(VAL_IMG_DIR, exist_ok=True)

print("Device:", DEVICE)

def posenc(x, L):
    out = [x]
    for i in range(L):
        out.append(torch.sin((2**i) * x))
        out.append(torch.cos((2**i) * x))
    return torch.cat(out, -1)



def load_data(basedir, H, W):
    with open(os.path.join(basedir, "transforms_train.json")) as f:
        meta = json.load(f)

    images, poses = [], []

    for frame in meta["frames"]:
        img_path = os.path.join(basedir, frame["file_path"])
        if not os.path.exists(img_path):
            continue

        img = Image.open(img_path).resize((W, H), Image.LANCZOS)
        img = np.array(img).astype(np.float32) / 255.0

        if img.shape[-1] == 4:
            alpha = img[..., 3:4]
            if WHITE_BG:
                img = img[..., :3] * alpha + (1 - alpha)
            else:
                img = img[..., :3] * alpha

        images.append(img)
        poses.append(np.array(frame["transform_matrix"], dtype=np.float32))

    images = torch.tensor(np.array(images), device=DEVICE)
    poses = torch.tensor(np.array(poses), device=DEVICE)

    poses[:, :3, 1] *= -1
    poses[:, :3, 2] *= -1

    # --- Center + scale (LLFF) ---
    centers = poses[:, :3, 3]
    center = centers.mean(0)
    poses[:, :3, 3] -= center
    scale = 2.0 / torch.norm(poses[:, :3, 3], dim=1).mean()
    poses[:, :3, 3] *= scale

    focal = 0.5 * W / np.tan(0.5 * meta["camera_angle_x"])

    return images, poses, focal


def get_rays(H, W, focal, c2w):
    i, j = torch.meshgrid(
        torch.arange(W, device=DEVICE),
        torch.arange(H, device=DEVICE),
        indexing="xy"
    )
    dirs = torch.stack([
        (i - W * 0.5) / focal,
        -(j - H * 0.5) / focal,
        -torch.ones_like(i)
    ], -1)

    rays_d = (dirs[..., None, :] * c2w[:3, :3]).sum(-1)
    rays_o = c2w[:3, 3].expand_as(rays_d)
    return rays_o, rays_d


def raw2outputs(rgb, sigma, z_vals, rays_d, white_bg=True):
    dists = z_vals[..., 1:] - z_vals[..., :-1]
    dists = torch.cat(
        [dists, 1e10 * torch.ones_like(dists[..., :1])], -1
    )
    dists *= torch.norm(rays_d[..., None, :], dim=-1)

    alpha = 1.0 - torch.exp(-sigma * dists)

    T = torch.cumprod(
        torch.cat(
            [torch.ones_like(alpha[..., :1]),
             1. - alpha + 1e-10],
            -1
        ),
        -1
    )[..., :-1]

    weights = alpha * T
    rgb_map = torch.sum(weights[..., None] * rgb, -2)

    acc = torch.sum(weights, -1)

    if white_bg:
        rgb_map = rgb_map + (1.0 - acc)[..., None]

    return rgb_map

class TinyNeRF(nn.Module):
    def __init__(self, D=8, W=256, input_ch=39, skips=[4]):
        super().__init__()
        self.D = D
        self.W = W
        self.input_ch = input_ch
        self.skips = skips

        self.pts_linears = nn.ModuleList(
            [nn.Linear(input_ch, W)] +
            [
                nn.Linear(W + input_ch, W) if i in skips
                else nn.Linear(W, W)
                for i in range(1, D)
            ]
        )

        self.output_linear = nn.Linear(W, 4)  # RGB + sigma

    def forward(self, x):
        h = x
        for i, l in enumerate(self.pts_linears):
            if i in self.skips:
                h = torch.cat([x, h], dim=-1)  # 256 + 39 = 295
            h = F.relu(l(h))

        out = self.output_linear(h)
        rgb = torch.sigmoid(out[..., :3])
        sigma = F.relu(out[..., 3])
        return rgb, sigma


def render_rays(rays_o, rays_d, model, near, far):
    N = rays_o.shape[0]

    z_vals = torch.linspace(near, far, N_SAMPLES, device=DEVICE)
    z_vals = z_vals.expand(N, N_SAMPLES)

    pts = rays_o[:, None] + rays_d[:, None] * z_vals[..., None]
    pts_enc = posenc(pts.reshape(-1, 3), L_POS)

    rgb, sigma = model(pts_enc)
    rgb = rgb.view(N, N_SAMPLES, 3)
    sigma = sigma.view(N, N_SAMPLES)

    return raw2outputs(rgb, sigma, z_vals, rays_d,white_bg=WHITE_BG)


@torch.no_grad()
def render_full_image(H, W, focal, pose, model, near, far, chunk=2048):
    rays_o, rays_d = get_rays(H, W, focal, pose)
    rays_o = rays_o.reshape(-1, 3)
    rays_d = rays_d.reshape(-1, 3)

    outputs = []
    for i in range(0, rays_o.shape[0], chunk):
        outputs.append(
            render_rays(
                rays_o[i:i+chunk],
                rays_d[i:i+chunk],
                model, near, far
            ).cpu()
        )

    return torch.cat(outputs).reshape(H, W, 3)


def save_checkpoint(i, model, optimizer):
    torch.save({
        "iter": i,
        "model": model.state_dict(),
        "optimizer": optimizer.state_dict()
    }, f"tiny_nerf_{i:05d}.pth")


images, poses, focal = load_data(DATASET_PATH, H_RES, W_RES)

cam_dists = torch.norm(poses[:, :3, 3], dim=1)
near = cam_dists.min().item()
far = cam_dists.max().item()

val_idx = torch.argmin(cam_dists)
val_img = images[val_idx]
val_pose = poses[val_idx]

model = TinyNeRF().to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)

for i in range(1, N_ITERS + 1):
    img_i = np.random.randint(images.shape[0])

    rays_o, rays_d = get_rays(H_RES, W_RES, focal, poses[img_i])
    coords = torch.randint(0, H_RES, (BATCH_SIZE, 2), device=DEVICE)

    rgb = render_rays(
        rays_o[coords[:, 0], coords[:, 1]],
        rays_d[coords[:, 0], coords[:, 1]],
        model, near, far
    )

    gt = images[img_i][coords[:, 0], coords[:, 1]]
    loss = F.mse_loss(rgb, gt)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if i % 100 == 0:
        psnr = -10. * torch.log10(loss)
        print(f"Iter {i:05d} | PSNR {psnr.item():.2f}")

    if i % VAL_INTERVAL == 0:
        rgb_val = render_full_image(
            H_RES, W_RES, focal, val_pose, model, near, far
        )

        rgb_val = rgb_val.to("cuda")
        val_img = val_img.to("cuda")
        

        val_psnr = -10. * torch.log10(F.mse_loss(rgb_val, val_img))
        print(f"[VAL] Iter {i:05d} | PSNR {val_psnr.item():.2f}")

        plt.imsave(
            f"{VAL_IMG_DIR}/val_{i:05d}.png",
            rgb_val.clamp(0,1).cpu().numpy()
        )

    if i % SAVE_INTERVAL == 0:
        save_checkpoint(i, model, optimizer)

print("Tiny-NeRF training complete.")