# Step 1: Setting up the Colab environment

In [1]:
!pip install torch torchvision matplotlib numpy tqdm



# Step 2: Importing necessary libraries and setting up the environment

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Step 3: Implementing the NeRF model architecture

In [2]:
class PositionalEncoding(nn.Module):
    def __init__(self, num_frequencies, d_in=3):
        super().__init__()
        self.num_frequencies = num_frequencies
        self.d_in = d_in

    def forward(self, x):
        x_proj = (2.0 ** torch.arange(self.num_frequencies, device=x.device)).view(1, -1) * x.unsqueeze(-1)
        return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1).flatten(start_dim=-2)

class NeRF(nn.Module):
    def __init__(self, d_in=3, d_hidden=256, d_out=4, num_layers=8, skip_layers=[4], num_frequencies=10):
        super().__init__()
        self.positional_encoding = PositionalEncoding(num_frequencies, d_in)
        d_in_encoded = d_in * num_frequencies * 2

        self.layers = nn.ModuleList()
        self.layers.append(nn.Linear(d_in_encoded, d_hidden))

        for i in range(1, num_layers):
            if i in skip_layers:
                self.layers.append(nn.Linear(d_hidden + d_in_encoded, d_hidden))
            else:
                self.layers.append(nn.Linear(d_hidden, d_hidden))

        self.output_layer = nn.Linear(d_hidden, d_out)
        self.skip_layers = skip_layers

    def forward(self, x):
        x_encoded = self.positional_encoding(x)
        h = x_encoded

        for i, layer in enumerate(self.layers):
            if i in self.skip_layers:
                h = torch.cat([h, x_encoded], dim=-1)
            h = F.relu(layer(h))

        output = self.output_layer(h)
        rgb = torch.sigmoid(output[..., :3])
        sigma = F.relu(output[..., 3])
        return rgb, sigma

# Step 4: Loading and preprocessing 2D image data

In [3]:
def generate_synthetic_data(num_images=100, image_size=32):
    images = []
    poses = []

    for _ in range(num_images):
        # Generate a simple sphere
        x = np.linspace(-1, 1, image_size)
        y = np.linspace(-1, 1, image_size)
        xx, yy = np.meshgrid(x, y)
        zz = np.sqrt(1 - xx**2 - yy**2)

        # Add some noise
        zz += np.random.normal(0, 0.1, zz.shape)

        # Normalize to [0, 1]
        zz = (zz - zz.min()) / (zz.max() - zz.min())

        images.append(zz)

        # Generate random camera poses
        pose = np.eye(4)
        pose[:3, 3] = np.random.uniform(-1, 1, 3)
        poses.append(pose)

    return np.array(images), np.array(poses)

# Generate synthetic data
images, poses = generate_synthetic_data()

# Convert to PyTorch tensors
images = torch.from_numpy(images).float().to(device)
poses = torch.from_numpy(poses).float().to(device)

  zz = np.sqrt(1 - xx**2 - yy**2)


# Step 5: Implementing the forward pass and ray sampling

In [4]:
def get_rays(H, W, focal, c2w):
    i, j = torch.meshgrid(torch.linspace(0, W-1, W), torch.linspace(0, H-1, H))
    i = i.t()
    j = j.t()
    dirs = torch.stack([(i-W*.5)/focal, -(j-H*.5)/focal, -torch.ones_like(i)], -1)
    rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1)
    rays_o = c2w[:3,-1].expand(rays_d.shape)
    return rays_o, rays_d

def sample_pdf(bins, weights, N_samples, det=False):
    weights = weights + 1e-5
    pdf = weights / torch.sum(weights, -1, keepdim=True)
    cdf = torch.cumsum(pdf, -1)
    cdf = torch.cat([torch.zeros_like(cdf[...,:1]), cdf], -1)

    if det:
        u = torch.linspace(0., 1., steps=N_samples)
        u = u.expand(list(cdf.shape[:-1]) + [N_samples])
    else:
        u = torch.rand(list(cdf.shape[:-1]) + [N_samples])

    u = u.to(weights.device)
    inds = torch.searchsorted(cdf, u, right=True)
    below = torch.max(torch.zeros_like(inds-1), inds-1)
    above = torch.min((cdf.shape[-1]-1) * torch.ones_like(inds), inds)
    inds_g = torch.stack([below, above], -1)

    matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
    cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
    bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)

    denom = (cdf_g[...,1]-cdf_g[...,0])
    denom = torch.where(denom<1e-5, torch.ones_like(denom), denom)
    t = (u-cdf_g[...,0])/denom
    samples = bins_g[...,0] + t * (bins_g[...,1]-bins_g[...,0])

    return samples

# Step 6: Training code

In [5]:
def render_rays(ray_batch, network_fn, N_samples=64, perturb=0, N_importance=0, raw_noise_std=0):
    N_rays = ray_batch.shape[0]
    rays_o, rays_d = ray_batch[:,0:3], ray_batch[:,3:6]
    near, far = ray_batch[:,6:7], ray_batch[:,7:8]

    t_vals = torch.linspace(0., 1., steps=N_samples)
    z_vals = near * (1.-t_vals) + far * (t_vals)

    if perturb > 0.:
        mids = .5 * (z_vals[...,1:] + z_vals[...,:-1])
        upper = torch.cat([mids, z_vals[...,-1:]], -1)
        lower = torch.cat([z_vals[...,:1], mids], -1)
        t_rand = torch.rand(z_vals.shape)
        z_vals = lower + (upper - lower) * t_rand

    pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[
    ...,:,None]

    raw = network_fn(pts)
    rgb, sigma = raw[...,:-1], raw[...,-1]

    dists = z_vals[...,1:] - z_vals[...,:-1]
    dists = torch.cat([dists, torch.Tensor([1e10]).expand(dists[...,:1].shape)], -1)

    alpha = 1. - torch.exp(-sigma * dists)
    weights = alpha * torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1)), 1.-alpha + 1e-10], -1), -1)[:, :-1]

    rgb_map = torch.sum(weights[...,None] * rgb, -2)
    depth_map = torch.sum(weights * z_vals, -1)
    acc_map = torch.sum(weights, -1)

    if N_importance > 0:
        z_vals_mid = .5 * (z_vals[...,1:] + z_vals[...,:-1])
        z_samples = sample_pdf(z_vals_mid, weights[...,1:-1], N_importance, det=(perturb==0.))
        z_samples = z_samples.detach()

        z_vals, _ = torch.sort(torch.cat([z_vals, z_samples], -1), -1)
        pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None]

        raw = network_fn(pts)
        rgb, sigma = raw[...,:-1], raw[...,-1]

        dists = z_vals[...,1:] - z_vals[...,:-1]
        dists = torch.cat([dists, torch.Tensor([1e10]).expand(dists[...,:1].shape)], -1)

        alpha = 1. - torch.exp(-sigma * dists)
        weights = alpha * torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1)), 1.-alpha + 1e-10], -1), -1)[:, :-1]

        rgb_map = torch.sum(weights[...,None] * rgb, -2)
        depth_map = torch.sum(weights * z_vals, -1)
        acc_map = torch.sum(weights, -1)

    return rgb_map, depth_map, acc_map

# Step 7: Evaluation and visualization

In [43]:
def train(images, poses, H, W, focal, num_epochs=100, batch_size=1024, lr=5e-4):
    nerf = NeRF().to(device)
    optimizer = torch.optim.Adam(nerf.parameters(), lr=lr)

    for epoch in tqdm(range(num_epochs)):
        for img_i in range(images.shape[0]):
            img = images[img_i]
            pose = poses[img_i]

            rays_o, rays_d = get_rays(H, W, focal, pose)
            rays_o = rays_o.reshape(-1, 3)
            rays_d = rays_d.reshape(-1, 3)

            select_inds = np.random.choice(rays_o.shape[0], size=[batch_size], replace=False)
            rays_o = rays_o[select_inds]
            rays_d = rays_d[select_inds]

            target_s = img[select_inds]

            rays_o = rays_o.to(device)
            rays_d = rays_d.to(device)
            target_s = target_s.to(device)

            near = 0. * torch.ones_like(rays_d[...,:1])
            far = 1. * torch.ones_like(rays_d[...,:1])

            rays = torch.cat([rays_o, rays_d, near, far], -1)

            rgb, depth, acc = render_rays(rays, nerf)

            optimizer.zero_grad()
            loss = F.mse_loss(rgb, target_s)
            loss.backward()
            optimizer.step()

        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch + 1}, Loss: {loss.item()}")

    return nerf


In [44]:
# Train the model
H, W = images.shape[1], images.shape[2]
focal = 0.5 * W / np.tan(0.5 * np.pi / 3)  # assume 60° FOV
trained_nerf = train(images, poses, H, W, focal)

  0%|                                                   | 0/100 [00:00<?, ?it/s]


IndexError: index 546 is out of bounds for dimension 0 with size 32

In [8]:
def render_novel_view(nerf, H, W, focal, pose):
    rays_o, rays_d = get_rays(H, W, focal, pose)
    rays_o = rays_o.reshape(-1, 3)
    rays_d = rays_d.reshape(-1, 3)

    rays_o = rays_o.to(device)
    rays_d = rays_d.to(device)

    near = 0. * torch.ones_like(rays_d[...,:1])
    far = 1. * torch.ones_like(rays_d[...,:1])

    rays = torch.cat([rays_o, rays_d, near, far], -1)

    rgb, depth, acc = render_rays(rays, nerf)

    rgb = rgb.reshape(H, W, 3).cpu().detach().numpy()
    depth = depth.reshape(H, W).cpu().detach().numpy()

    return rgb, depth

# Generate a novel view
novel_pose = torch.eye(4).to(device)
novel_pose[:3, 3] = torch.tensor([0.5, 0.5, -2.0]).to(device)

rgb, depth = render_novel_view(trained_nerf, H, W, focal, novel_pose)

# Visualize the results
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
ax1.imshow(rgb)
ax1.set_title("RGB")
ax2.imshow(depth, cmap='viridis')
ax2.set_title("Depth")
plt.show()

NameError: name 'trained_nerf' is not defined