In [None]:

import os, json
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

DATASET_PATH = "/kaggle/input/scene"            # folder with images/ + transforms_train.json

VAL_IMG_DIR = "val_images"

H_RES, W_RES = 800, 800
L_POS, L_DIR = 10, 4
USE_WHITE_BG = True

N_ITERS = 50000
BATCH_SIZE = 2048
LR = 5e-4

N_COARSE = 64
N_FINE = 128

VAL_INTERVAL = 200
SAVE_INTERVAL = 1000
NOISE_DECAY_ITERS = 2000

DEVICE = torch.device(
    "cuda" if torch.cuda.is_available()
    else "mps" if torch.mps.is_available()
    else "cpu"
)

print("Using device:", DEVICE)
print("GPU count:", torch.cuda.device_count())


os.makedirs(VAL_IMG_DIR, exist_ok=True)



def posenc(x, L):
    out = [x]
    for i in range(L):
        out.append(torch.sin((2**i) * np.pi * x))
        out.append(torch.cos((2**i) * np.pi * x))
    return torch.cat(out, -1)



def load_data(basedir, H, W):
    with open(os.path.join(basedir, "transforms_train.json")) as f:
        meta = json.load(f)

    images, poses = [], []

    for frame in meta["frames"]:
        fp = frame["file_path"].replace("./", "")
        # if not fp.endswith(".png"):
        #     fp += ".png"

        img_path = os.path.join(basedir, fp)
        print(img_path)
        if not os.path.exists(img_path):
            continue

        img = Image.open(img_path).resize((W, H), Image.LANCZOS)
        img = np.array(img).astype(np.float32) / 255.0

        if img.shape[-1] == 4:
            alpha = img[..., 3:4]
            if USE_WHITE_BG:
                img = img[..., :3] * alpha + (1 - alpha)
            else:
                img = img[..., :3] * alpha
                
        images.append(img)
        poses.append(np.array(frame["transform_matrix"], dtype=np.float32))

    images = torch.tensor(np.array(images), device=DEVICE)
    poses = torch.tensor(np.array(poses), device=DEVICE)

    print("poses shape:", poses.shape)


    poses[:, :3, 1] *= -1   # flip Y axis
    poses[:, :3, 2] *= -1   # flip Z axis




    # Center + scale (LLFF style)

    centers = poses[:, :3, 3]
    center = centers.mean(0)
    poses[:, :3, 3] -= center
    scale = 2.0 / torch.norm(poses[:, :3, 3], dim=1).mean()
    poses[:, :3, 3] *= scale

    focal = 0.5 * W / np.tan(0.5 * meta["camera_angle_x"])

    return images, poses, focal

def get_rays(H, W, focal, c2w):
    i, j = torch.meshgrid(
        torch.arange(W, device=DEVICE),
        torch.arange(H, device=DEVICE),
        indexing="xy"
    )
    dirs = torch.stack([
        (i - W * 0.5) / focal,
        -(j - H * 0.5) / focal,
        -torch.ones_like(i)
    ], -1)

    rays_d = (dirs[..., None, :] * c2w[:3, :3]).sum(-1)
    rays_o = c2w[:3, 3].expand_as(rays_d)
    return rays_o, rays_d


def sample_pdf(bins, weights, N):
    weights = weights + 1e-5
    pdf = weights / torch.sum(weights, -1, keepdim=True)
    cdf = torch.cumsum(pdf, -1)
    cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)

    u = torch.rand(list(cdf.shape[:-1]) + [N], device=DEVICE)
    inds = torch.searchsorted(cdf, u, right=True)

    below = torch.clamp(inds - 1, min=0)
    above = torch.clamp(inds, max=cdf.shape[-1] - 1)
    inds = torch.stack([below, above], -1)

    cdf_g = torch.gather(cdf.unsqueeze(-2).expand(*inds.shape[:-1], cdf.shape[-1]), -1, inds)
    bins_g = torch.gather(bins.unsqueeze(-2).expand(*inds.shape[:-1], bins.shape[-1]), -1, inds)

    t = (u - cdf_g[..., 0]) / (cdf_g[..., 1] - cdf_g[..., 0] + 1e-5)
    return bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])

def raw2outputs(raw, z_vals, rays_d, noise_std):
    dists = z_vals[..., 1:] - z_vals[..., :-1]
    dists = torch.cat([dists, 1e10 * torch.ones_like(dists[..., :1])], -1)
    dists *= torch.norm(rays_d[..., None, :], dim=-1)

    rgb = torch.sigmoid(raw[..., :3])
    noise = torch.randn_like(raw[..., 3]) * noise_std if noise_std > 0 else 0
    sigma = F.softplus(raw[..., 3] + noise)

    alpha = 1. - torch.exp(-sigma * dists)
    T = torch.cumprod(
        torch.cat([torch.ones_like(alpha[..., :1]), 1. - alpha + 1e-10], -1),
        -1
    )[..., :-1]

    weights = alpha * T
    rgb_map = torch.sum(weights[..., None] * rgb, -2)

    if USE_WHITE_BG:
        acc = torch.sum(weights, -1, keepdim=True)
        rgb_map = rgb_map + (1 - acc)

    return rgb_map, weights


class NeRF(nn.Module):
    def __init__(self, W=256):
        super().__init__()
        self.fc = nn.ModuleList(
            [nn.Linear(63, W)] + [nn.Linear(W, W) for _ in range(7)]
        )
        self.sigma = nn.Linear(W, 1)
        self.rgb = nn.Sequential(
            nn.Linear(W + 27, W),
            nn.ReLU(),
            nn.Linear(W, 3)
        )

    def forward(self, x):
        pts, views = torch.split(x, [63, 27], -1)
        h = pts
        for l in self.fc:
            h = F.relu(l(h))
        sigma = self.sigma(h)
        rgb = self.rgb(torch.cat([h, views], -1))
        return torch.cat([rgb, sigma], -1)




def render_rays(rays_o, rays_d, model_c, model_f, near, far, noise_std):
    N = rays_o.shape[0]

    z_vals = torch.linspace(near, far, N_COARSE, device=DEVICE).expand(N, N_COARSE)
    mids = 0.5 * (z_vals[:, 1:] + z_vals[:, :-1])
    z_vals = torch.cat([
        z_vals[:, :1],
        mids + torch.rand_like(mids) * (z_vals[:, 1:] - z_vals[:, :-1]),
        z_vals[:, -1:]
    ], -1)

    pts = rays_o[:, None] + rays_d[:, None] * z_vals[..., None]
    viewdirs = rays_d / torch.norm(rays_d, dim=-1, keepdim=True)
    viewdirs = viewdirs[:, None].expand_as(pts)

    raw_c = model_c(torch.cat([
        posenc(pts.reshape(-1, 3), L_POS),
        posenc(viewdirs.reshape(-1, 3), L_DIR)
    ], -1)).reshape(N, -1, 4)

    rgb_c, weights = raw2outputs(raw_c, z_vals, rays_d, noise_std)

    z_mid = 0.5 * (z_vals[:, 1:] + z_vals[:, :-1])
    z_fine = sample_pdf(z_mid, weights[:, 1:-1], N_FINE)
    z_vals = torch.sort(torch.cat([z_vals, z_fine], -1), -1)[0]

    pts = rays_o[:, None] + rays_d[:, None] * z_vals[..., None]
    viewdirs = viewdirs[:, :1].expand_as(pts)

    raw_f = model_f(torch.cat([
        posenc(pts.reshape(-1, 3), L_POS),
        posenc(viewdirs.reshape(-1, 3), L_DIR)
    ], -1)).reshape(N, -1, 4)

    rgb_f, _ = raw2outputs(raw_f, z_vals, rays_d, noise_std)
    return rgb_f

@torch.no_grad()
def render_full_image(H, W, focal, pose, model_c, model_f, near, far):
    rays_o, rays_d = get_rays(H, W, focal, pose)
    rays_o, rays_d = rays_o.reshape(-1, 3), rays_d.reshape(-1, 3)

    outputs = []
    for i in range(0, rays_o.shape[0], 4096):
        outputs.append(
            render_rays(
                rays_o[i:i+4096],
                rays_d[i:i+4096],
                model_c, model_f,
                near, far, 0.0
            )
        )

    return torch.cat(outputs).reshape(H, W, 3)

def save_checkpoint(iteration, model_c, model_f, optimizer):
    ckpt = {
        "iter": iteration,
        "model_c": model_c.module.state_dict() if isinstance(model_c, nn.DataParallel) else model_c.state_dict(),
        "model_f": model_f.module.state_dict() if isinstance(model_f, nn.DataParallel) else model_f.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(ckpt, f"ckpt_{iteration:05d}.pth")


images, poses, focal = load_data(DATASET_PATH, H_RES, W_RES)

cam_dists = torch.norm(poses[:, :3, 3], dim=1)
val_idx = torch.argmin(cam_dists)

val_img = images[val_idx]
val_pose = poses[val_idx]

near = cam_dists.min().item() * 0.8
far = cam_dists.max().item() * 1.2

print(near)
print(far)

model_c = NeRF().to(DEVICE)
model_f = NeRF().to(DEVICE)

if torch.cuda.device_count() > 1:
    model_c = nn.DataParallel(model_c)
    model_f = nn.DataParallel(model_f)

optimizer = torch.optim.Adam(
    list(model_c.parameters()) + list(model_f.parameters()), lr=LR
)

for i in range(1, N_ITERS + 1):
    noise_std = max(0.0, 1.0 - i / NOISE_DECAY_ITERS)

    img_i = np.random.randint(images.shape[0])
    rays_o, rays_d = get_rays(H_RES, W_RES, focal, poses[img_i])
    coords = torch.randint(0, H_RES, (BATCH_SIZE, 2), device=DEVICE)

    rgb = render_rays(
        rays_o[coords[:, 0], coords[:, 1]],
        rays_d[coords[:, 0], coords[:, 1]],
        model_c, model_f,
        near, far, noise_std
    )

    gt = images[img_i][coords[:, 0], coords[:, 1]]
    loss = F.mse_loss(rgb, gt)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if i % 100 == 0:
        psnr = -10 * torch.log10(loss)
        print(f"Iter {i:05d} | PSNR {psnr.item():.2f}")

    if i % VAL_INTERVAL == 0:
        rgb_val = render_full_image(
            H_RES, W_RES, focal, val_pose, model_c, model_f, near, far
        )
        val_psnr = -10 * torch.log10(F.mse_loss(rgb_val, val_img))
        print(f"[VAL] Iter {i:05d} | PSNR {val_psnr.item():.2f}")

        plt.imsave(
            f"{VAL_IMG_DIR}/val_{i:05d}.png",
            rgb_val.clamp(0, 1).cpu().numpy()
        )

    if i % SAVE_INTERVAL == 0:
        save_checkpoint(
            iteration=i,
            model_c=model_c,
            model_f=model_f,
            optimizer=optimizer,
        )

print("Training complete.")