In [None]:
# Allow importing from src
import sys
sys.path.insert(0, '../src/')

## Imports

In [None]:
import numpy as np
import torch
from torch import nn, Tensor
from torch.utils.data import TensorDataset, DataLoader, WeightedRandomSampler
import torch.nn.functional as F
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

COMPUTE_DEVICE = torch.device('cpu')
if torch.cuda.is_available():
    COMPUTE_DEVICE = torch.device('cuda:0')
elif torch.mps.is_available():
    COMPUTE_DEVICE = torch.device('mps')
print(f"{COMPUTE_DEVICE=}")

## Example data

In [None]:
data =  np.load("../data/tiny_nerf_data.npz")
images, c2ws, focal = data["images"], data["poses"], data["focal"]

print(
    f"Shapes:",
    f"{images.shape=}",
    f"{c2ws.shape=}",
    f"{focal.shape=}",
    sep='\n  ',
)

plt.imshow(images[2])
plt.title("Image at index 2")

print(f"C2W transform at index 2: \n", c2ws[2])

print(f"Focal length: {focal:.4f}")

images = torch.from_numpy(images)
c2ws = torch.from_numpy(c2ws)
focal = torch.from_numpy(focal)

## Utility functions

To be exported into `ROOT_DIR/src/utils/`

In [None]:
# Lightning conversion proposal: STANDALONE
def create_rays(height: int, width: int, intrinsic: Tensor, c2w: Tensor):
    focal_x = intrinsic[0, 0]
    focal_y = intrinsic[1, 1]
    # cx and cy handle the misalignement of the principal point with the center of the image
    cx = intrinsic[0, 2]
    cy = intrinsic[1, 2]

    # Index each point on the image, determine ray directions to them
    i, j = torch.meshgrid(torch.arange(width, dtype=torch.float32), torch.arange(height, dtype=torch.float32), indexing='xy')
    directions = torch.stack((
        (i - cx) / focal_x,
        -(j - cy) / focal_y,
        -torch.ones(i.shape, dtype=torch.float32)  # -1 since ray is cast away from camera
    ), -1)

    # Transform ray directions to World, origins just need to be broadcasted accordingly
    ray_directions = F.normalize(directions @ c2w[:3, :3].T, "fro", -1)
    ray_origins = torch.broadcast_to(c2w[:3, -1], ray_directions.shape)  # c2w last column determines position
    
    return ray_origins, ray_directions


# Test on real data
ex_index = 2
ex_img = images[ex_index]
ex_intr = torch.tensor([
    [focal.item(), 0, ex_img.shape[1] // 2],
    [0, focal.item(), ex_img.shape[0] // 2],
    [0, 0, 1],
], dtype=torch.float32)

origins, directions = create_rays(ex_img.shape[0], ex_img.shape[1], ex_intr, c2ws[ex_index])

"""
# Test on example data
ex_intr = torch.tensor([
    [4, 0, 5 // 2],
    [0, 4, 5 // 2],
    [0, 0, 1],
], dtype=torch.float32)

ex_c2w = torch.tensor([
    [1, 0, 0, 0],
    [0, 1, 0, 0],
    [0, 0, 1, 1],
    [0, 0, 0, 1],
], dtype=torch.float32)

o, d = create_rays(5, 5, ex_intr, ex_c2w)
"""

print(f"{origins.shape=} | {directions.shape=}")

In [None]:
# Lightning conversion proposal: DATAMODULE
def compute_near_far_planes(c2ws: Tensor):
    scene_bounds_min = torch.tensor([-1, -1, -1], dtype=torch.float32)
    scene_bounds_max = torch.tensor([1, 1, 1], dtype=torch.float32)

    # Transform bounding box corners to camera coordinates
    corners = torch.tensor([
        [scene_bounds_min[0], scene_bounds_min[1], scene_bounds_min[2]],
        [scene_bounds_min[0], scene_bounds_min[1], scene_bounds_max[2]],
        [scene_bounds_min[0], scene_bounds_max[1], scene_bounds_min[2]],
        [scene_bounds_min[0], scene_bounds_max[1], scene_bounds_max[2]],
        [scene_bounds_max[0], scene_bounds_min[1], scene_bounds_min[2]],
        [scene_bounds_max[0], scene_bounds_min[1], scene_bounds_max[2]],
        [scene_bounds_max[0], scene_bounds_max[1], scene_bounds_min[2]],
        [scene_bounds_max[0], scene_bounds_max[1], scene_bounds_max[2]],
    ])
    
    nears, fars = [], []
    for c2w in c2ws:
        corners_camera = (c2w[:3, :3] @ corners.T).T + c2w[:3, -1]
        distances = torch.norm(corners_camera, "fro", dim=1)
        nears.append(torch.min(distances))
        fars.append(torch.max(distances))
    
    near_plane = min(distances) * 0.9  # Slightly smaller than the closest point
    far_plane = max(distances) * 1.1  # Slightly larger than the farthest point

    return near_plane.item(), far_plane.item()

near, far = compute_near_far_planes(c2ws)
print(f"Near and far planes: {near=} | {far=}")

In [None]:
# Lightning conversion proposal: STANDALONE
@torch.no_grad()
def sobel_filter(images: Tensor):
    # Sobel-Feldman operator
    filter = nn.Conv2d(in_channels=1, out_channels=2, kernel_size=3, stride=1, padding=1, padding_mode='zeros', bias=False, dtype=torch.float32)
    gx = torch.tensor([
        [3.0, 0.0, -3.0],
        [10.0, 0.0, -10.0],
        [3.0, 0.0, -3.0],
    ], dtype=torch.float32)
    gy = torch.tensor([
        [3.0,  10.0,  3.0],
        [0.0,  0.0,  0.0],
        [-3.0, -10.0, -3.0],
    ], dtype=torch.float32)
    weights = torch.stack([gx, gy], 0).unsqueeze(1)
    filter.weight = nn.Parameter(weights, requires_grad=False)

    edges = filter(images.mean(dim=-1).unsqueeze(1))
    edges = torch.sqrt(torch.sum(edges ** 2, dim=1))
    return edges

edges = sobel_filter(images)
plt.hist(edges.flatten()[edges.flatten() != 0], bins=30)
plt.show()

plt.imshow(edges[5], cmap="gray")

In [None]:
# Lightning conversion proposal: DATAMODULE
def create_nerf_data(images: Tensor, c2ws: Tensor, focal: Tensor, weight_epsilon: float = 0.33):
    origins = []
    directions = []
    colors = []

    # Collecting to list then concat for ease
    for image, c2w in zip(images, c2ws):
        intrinsic = torch.tensor([
            [focal.item(), 0, image.shape[1] // 2],
            [0, focal.item(), image.shape[0] // 2],
            [0, 0, 1],
        ], dtype=torch.float32)

        o, d = create_rays(image.shape[0], image.shape[1], intrinsic, c2w)

        origins.append(o.flatten(0, 1))
        directions.append(d.flatten(0, 1))
        colors.append(image.flatten(0, 1))

    origins = torch.cat(origins, dim=0)
    directions = torch.cat(directions, dim=0)
    colors = torch.cat(colors, dim=0)

    pixel_weights = sobel_filter(images).flatten() + weight_epsilon
    return origins, directions, colors, pixel_weights


origins, directions, colors, pixel_weights = create_nerf_data(images, c2ws, focal)

print(f"{origins.shape=} | {directions.shape=} | {colors.shape=}")
print(f"Training origins' size:    {origins.element_size() * origins.nelement() / (1024**2):.2f} MB")
print(f"Training directions' size: {directions.element_size() * directions.nelement() / (1024**2):.2f} MB")
print(f"Training images' size:     {colors.element_size() * colors.nelement() / (1024**2):.2f} MB")
print(f"Pixel weights' size:       {pixel_weights.element_size() * pixel_weights.nelement() / (1024**2):.2f} MB")

In [None]:
# Lightning conversion proposal: STANDALONE, used inside DATAMODULE
class ImportantPixelSampler(WeightedRandomSampler):
    def __init__(self, weights: Tensor, num_samples: int, replacement: bool = True, swap_strategy_iter: int = 100):
        super(ImportantPixelSampler, self).__init__(weights=weights, num_samples=num_samples, replacement=replacement, generator=None)
        weights = weights.to(torch.float32)
        self.pixel_weights = weights / weights.max()
        self.weights = self.pixel_weights
        self.swap_strategy_iter = swap_strategy_iter
        self.num_iters = 0
        self.squared_errors = torch.ones(self.pixel_weights.shape, dtype=torch.float32)

    def __iter__(self):
        self.num_iters += 1
        yield from super(ImportantPixelSampler, self).__iter__()

    def update_errors(self, idxs: Tensor, errors: Tensor):
        errors = errors.clone().cpu().detach()
        # Epsilon evaluates to 5e-3 as the prev value contributes 20%
        self.squared_errors[idxs] = self.squared_errors[idxs] * 0.2 + errors * 0.8  + 4e-3  # Discounted error update
        pxw = torch.clamp(torch.tensor([1.0], dtype=torch.float32) - (self.num_iters) / self.swap_strategy_iter, 0.0, 1.0)
        self.weights[idxs] = self.pixel_weights[idxs] * pxw + self.squared_errors[idxs] * (1 - pxw)


# Samplers with DataLoaders: Sampler num_samples decides dataloader "length" (like how Dataset length usually works)
# DataLoader batch size decided the batch size as per usual
sampler = ImportantPixelSampler(pixel_weights, 16, False)
for _ in range(10):
    for sample in sampler:
        pass
    sampler.update_errors(torch.arange(origins.shape[0]), torch.tensor([2.0], dtype=torch.float32))


for sample in sampler:
    print(f"Sample: {sample:07d} | Weight: {sampler.weights[sample]:5.2f}")

ex_data = TensorDataset(origins, directions, colors)
ex_loader = DataLoader(ex_data, batch_size=4, sampler=sampler)
for o, d, c in ex_loader:
    print(o.shape, d.shape, c.shape)

ex_loader.sampler.weights
sampler.num_iters

In [None]:
# Lightning conversion proposal: STANDALONE
class PositionalEncoding(nn.Module):
    # Appended Positional Encoding Module

    def __init__(self, max_freq: int):
        super(PositionalEncoding, self).__init__()
        self._max_freq = max_freq
        freq_bands = 2.0 ** torch.linspace(0.0, max_freq - 1, steps=max_freq, dtype=torch.float32)
        self._freq_bands = nn.parameter.Buffer(freq_bands)
    
    def forward(self, x: Tensor):
        encs = (x[..., None] * self._freq_bands).flatten(-2, -1)
        # Encoding to (x, sin parts, cos parts) of shape(N, M+M*max_freq*2) if x is of shape(N,M)
        return torch.cat([x, encs.sin(), encs.cos()], dim=-1)
    
    def get_out_dim(self, in_dim: int):
        return in_dim + in_dim * self._max_freq * 2
    
    
enc = PositionalEncoding(10).to(COMPUTE_DEVICE)
print(f"Output dimension: {enc.get_out_dim(3)}")
enc(origins[:1000, :].to(COMPUTE_DEVICE)).shape

In [None]:
# Lightning conversion proposal: STANDALONE
def sample_ray_uniformally(origins: Tensor, directions: Tensor, near: float, far: float, num_samples: int, perturb=True):
    device = origins.device
    depths = torch.linspace(near, far, num_samples, dtype=torch.float32, device=device).expand(origins.shape[0], -1)

    if perturb:
        # Noise is at most half of step size, this ensures sorted depths, required for volume rendering
        noise = (torch.rand(depths.shape, device=device) - 0.5) * (far - near) / num_samples / 2
        # Clamping to stay between near and far
        depths = (depths + noise).clamp(near, far)

    points = origins[..., None, :] + directions[..., None, :] * depths[..., :, None]
    # Expand directions to make NeRF input
    directions = directions[..., None, :].expand(points.shape)
    return points, directions, depths


# Lightning conversion proposal: STANDALONE
def plot_ray_sampling(points: Tensor, origin: Tensor, cartesian_direction: Tensor, title: str):
    points = points.cpu()
    origin = origin.cpu()
    cartesian_direction = cartesian_direction.cpu()

    fig, axes = plt.subplots(1, 2, figsize=(16,8), subplot_kw={"projection": "3d"})
    axes = axes.flatten()
    fig.suptitle(title)
    plt.tight_layout()
    # Adding the origin so it always starts from the camera position
    points = torch.cat([origin.expand((points.shape[0], 1, -1)), points], 1)

    # Convert to spherical coordinates
    X, Y, Z = -cartesian_direction  # Taking the negative as view_init specifies direction outward
    R = torch.sqrt(X**2 + Y**2 + Z**2)
    X, Y, Z = X/R, Y/R, Z/R  # normalization
    azim = torch.rad2deg(torch.atan2(Y, X))
    elev = torch.rad2deg(torch.arcsin(Z))
    # Multiple angles to understand better
    for ax, (mod_elev, mod_azim) in zip(axes, [[0,0],[-10, 60]]):
        ax.view_init(elev + mod_elev, azim + mod_azim, 0)
        ax.plot(points[:, :, 0], points[:, :, 1], points[:, :, 2], linewidth=0.2, markersize=2, marker='o')
    plt.show()


points, dirs, depths = sample_ray_uniformally(
    origins[:10000:71].to(COMPUTE_DEVICE),
    directions[:10000:71].to(COMPUTE_DEVICE),
    near, far, 7, perturb=True
)

print(f"{points.shape=} | {dirs.shape=} | {depths.shape=}")
plot_ray_sampling(points, origins[0], dirs[:, 0, :].mean(0), "Example of perturbed uniform samples along ray")        

In [None]:
# Lightning conversion proposal: STANDALONE
def sample_pdf(bins: Tensor, weights: Tensor, num_samples: int, deterministic: bool = False):
    device = weights.device

    weights = weights + 1e-5  # avoid nans later
    pdf = weights / torch.sum(weights, -1, keepdim=True)  # Normalize PDF
    cdf = torch.cumsum(pdf, -1)
    cdf = torch.cat([torch.zeros_like(cdf[..., :1], device=device), cdf], dim=-1)  # Prepend 0 to have cdf->[0,1]

    if deterministic:
        u = torch.linspace(0.0, 1.0, steps=num_samples, device=device)
        u = u.expand(list(cdf.shape[:-1]) + [num_samples])
    else:
        u = torch.rand(list(cdf.shape[:-1]) + [num_samples], device=device)

    # Inverting the CDF
    u = u.contiguous()  # Need contigous memory layout for further operations
    indexes = torch.searchsorted(cdf, u, right=True)  # Finding bins
    # Need to ensure below and above don't leave bounds of bins
    below = torch.max(torch.zeros_like(indexes-1), indexes-1)
    above = torch.min((cdf.shape[-1]-1) * torch.ones_like(indexes), indexes)
    indexes = torch.stack([below, above], dim=-1)

    # Gathering sampled bins and bound probabilites
    shape = [indexes.shape[0], indexes.shape[1], cdf.shape[-1]]
    cdf = torch.gather(cdf.unsqueeze(1).expand(shape), dim=2, index=indexes)
    bins = torch.gather(bins.unsqueeze(1).expand(shape), dim=2, index=indexes)

    # denominator is the size of the bins
    denominator = cdf[..., 1] - cdf[..., 0]
    denominator = torch.where(denominator < 1e-5, torch.ones_like(denominator), denominator)
    denominator[denominator < 1e-5] = 1.0
    # t gives the relative position inside the bins 
    t = (u - cdf[..., 0]) / denominator

    samples = bins[..., 0] + t * (bins[..., 1] - bins[..., 0])
    return samples


bins = torch.linspace(0, 1, steps=11) # 10 intervals (bins)
weights = torch.tensor([0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.3, 0.25, 0.2, 0.1]) # Arbitrary weights
weights = weights / torch.sum(weights, -1)

# Expand the inputs to include a batch dimension
bins = bins.unsqueeze(0).to(COMPUTE_DEVICE) # Shape: [1, 11]
weights = weights.unsqueeze(0).to(COMPUTE_DEVICE) # Shape: [1, 10]
samples = sample_pdf(bins, weights, 5)

print(f"{samples=}")

In [None]:
# Lightning conversion proposal: STANDALONE
def sample_ray_hierarchically(origins: Tensor, directions: Tensor, num_samples: int, bins: Tensor, weights: Tensor, deterministic: bool = False):
    depths = sample_pdf(bins, weights, num_samples, deterministic=deterministic)

    points = origins[..., None, :] + directions[..., None, :] * depths[..., :, None]
    # Expand directions to make NeRF input
    directions = directions[..., None, :].expand(points.shape)
    return points, directions, depths

bins = torch.linspace(0, 1, 65).expand((10000 // 41 + 1, 65))  # Computed after uniform sampling
weights = torch.randn((10000 // 41 + 1, 64))  # Retrieved from NeRF's sigma values
points, dirs, depths = sample_ray_hierarchically(
    origins[:10000:41].to(COMPUTE_DEVICE),
    directions[:10000:41].to(COMPUTE_DEVICE),
    32, bins.to(COMPUTE_DEVICE), weights.to(COMPUTE_DEVICE)
)

print(f"{points.shape=} | {dirs.shape=} | {depths.shape=}")
plot_ray_sampling(points, origins[0], dirs[:, 0, :].mean(0), "Example of hierarchical samples along ray")    

In [None]:
# Lightning conversion proposal: NeRF() STANDALONE, compute_along_rays LIGHTNING MODULE
class NeRF(nn.Module):
    def __init__(self, num_layers: int = 8, hidden_size: int = 256, in_coordinates: int = 3, in_directions: int = 3,
                 skips: list[int] = [4], coord_encode_freq: int = 10, dir_encode_freq: int = 4):
        super(NeRF, self).__init__()
        self.in_coordinates = in_coordinates
        self.in_directions = in_directions
        self.skips = tuple(skips)

        self.coordinate_encoder = PositionalEncoding(coord_encode_freq)
        self.direction_encoder = PositionalEncoding(dir_encode_freq)

        coord_dim = self.coordinate_encoder.get_out_dim(self.in_coordinates)
        self.feature_mlp = nn.ModuleList([nn.Linear(coord_dim, hidden_size)])
        # go until num_layers -1 as we already have the initial layer
        for i in range(num_layers - 1):
            self.feature_mlp.append(nn.Sequential(
                # skip with +1 as we already have the initial layer
                nn.Linear(hidden_size + (coord_dim if i+1 in self.skips else 0), hidden_size),
                nn.ReLU(inplace=True),
            ))

        self.sigma_fc = nn.Linear(hidden_size, 1)

        self.color_preproc = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(inplace=True),
        )
        dir_dim = self.direction_encoder.get_out_dim(self.in_directions)
        self.rgb_mlp = nn.Sequential(
            nn.Linear(hidden_size + dir_dim, hidden_size // 2),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_size // 2, 3),
            nn.Sigmoid(),
        )

    def forward(self, coordinates: Tensor, directions: Tensor, skip_colors: bool = False):
        coordinates = self.coordinate_encoder(coordinates)
        features = coordinates
        for i, fc in enumerate(self.feature_mlp):
            if i in self.skips:
                features = torch.cat([features, coordinates], -1)
            features = fc(features)

        sigma = self.sigma_fc(features)

        if skip_colors:
            return sigma

        directions = self.direction_encoder(directions)
        features = self.color_preproc(features)
        features = torch.cat([features, directions], dim=-1)
        rgb = self.rgb_mlp(features)

        return torch.cat([rgb, sigma], dim=-1)
    
    def compute_along_rays(self, origins: Tensor, directions: Tensor, near: float, far: float,
                           coarse_samples: int, fine_samples: int, deterministic: bool = True):
        device = origins.device
        # This function deviates from the original NeRF paper as the coarse and fine samples are processed by the same model
        points, expanded_directions, coarse_depths = sample_ray_uniformally(origins, directions, near, far, coarse_samples)
        coarse_out = self(points, expanded_directions)
        
        # Bin bounds are halfway between sampled coordinates + near + far plane
        bins = torch.cat([
            torch.tensor(near, dtype=torch.float32, device=device).expand(origins.shape[0], 1),
            (coarse_depths[..., 1:] + coarse_depths[..., :-1]) / 2,
            torch.tensor(far, dtype=torch.float32, device=device).expand(origins.shape[0], 1),
        ], -1)

        points, expanded_directions, fine_depths = sample_ray_hierarchically(origins, directions, fine_samples, bins,
                                                                             coarse_out[..., -1], deterministic=deterministic)
        fine_out = self(points, expanded_directions)

        # deterministic ensures depth sorted output, if non-deterministic, sort manually as sortedness is required for volume rendering
        if not deterministic:
            fine_depths, idxs = torch.sort(fine_depths, dim=-1)
            fine_out = fine_out[torch.arange(idxs.shape[0]).unsqueeze(1), idxs]

        return coarse_out, coarse_depths, fine_out, fine_depths
    

model = NeRF().to(COMPUTE_DEVICE)
print(f"Output shape for origins + directions: {model(origins[:50].to(COMPUTE_DEVICE), directions[:50].to(COMPUTE_DEVICE)).shape}")

coarse_rgbs, coarse_depths, fine_rgbs, fine_depths = model.compute_along_rays(
    origins[:500].to(COMPUTE_DEVICE),
    directions[:500].to(COMPUTE_DEVICE),
    near, far, 64, 5, False
)
print(f"Coarse output shape for hierarchical sampling: {coarse_rgbs.shape=} | {coarse_depths.shape=}")
print(f"Fine output shape for hierarchical sampling:   {fine_rgbs.shape=} | {fine_depths.shape=}")

In [None]:
# Lightning conversion proposal: STANDALONE
def render_rays(rgbs: Tensor, depths: Tensor):
    device = rgbs.device

    distances = depths[..., 1:] - depths[..., :-1]
    # 1e10 ensures the last color is rendered no matter what
    distances = torch.cat([distances, torch.tensor([1e10], device=device).expand(distances[...,:1].shape)], -1)
    # directions already normalized at ray calculation, so distances correspond to world already

    alpha = 1.0 - torch.exp(-F.relu(rgbs[..., 3]) * distances)
    # 1e10 ensures the last color is rendered no matter what
    weights = alpha * torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1), device=device), 1. - alpha + 1e-10], -1), -1)[:, :-1]
    rgb = torch.sum(weights[..., None] * rgbs[..., :3], dim=-2)
    depth = torch.sum(weights * depths, dim=-1)

    return rgb, depth

rendered_rgb, rendered_depth = render_rays(coarse_rgbs, coarse_depths)
print(f"Color and depth shapes after render: {rendered_rgb.shape=} | {rendered_depth.shape=}")

In [None]:
# Lightning conversion proposal: LIGHTNING MODULE
@torch.no_grad()
def render_image(model: NeRF, height: int, width: int, c2w: Tensor, focal: Tensor,
                 near: int, far: int, batch_size=512):
    device = next(iter(model.parameters())).device
    
    intrinsic = torch.tensor([
        [focal.item(), 0, width // 2],
        [0, focal.item(), height // 2],
        [0, 0, 1],
    ], dtype=torch.float32)
    origins, directions = create_rays(height, width, intrinsic, c2w)

    origins = origins.to(device).flatten(0, -2)
    directions = directions.to(device).flatten(0, -2)
    data = TensorDataset(origins, directions)
    data = DataLoader(data, batch_size=batch_size, shuffle=False)

    image = []
    for o, d in data:
        _, _, rgbs, depths = model.compute_along_rays(o, d, near, far, 64, 64)
        rgb, _ = render_rays(rgbs, depths)
        image.append(rgb)

    image = torch.cat(image, 0).reshape(height, width, -1)
    
    return image

render = render_image(NeRF().to(COMPUTE_DEVICE), 50, 50, c2ws[-1], focal, near, far).cpu()
plt.imshow(render)

## Model training

In [None]:
data =  np.load("../data/tiny_nerf_data.npz")
images, c2ws, focal = torch.from_numpy(data["images"]), torch.from_numpy(data["poses"]), torch.from_numpy(data["focal"])
near, far = compute_near_far_planes(c2ws)

test_img, test_c2w = images[-5].unsqueeze(0), c2ws[-5].unsqueeze(0)
train_imgs, train_c2ws = torch.cat([images[:-5], images[-4:]], 0), torch.cat([images[:-5], images[-4:]], 0)
train_origins, train_directions, train_colors, pixel_weights = create_nerf_data(images, c2ws, focal, weight_epsilon=0.33)

print(f"{train_origins.shape=} | {train_directions.shape=} | {train_colors.shape=} | {pixel_weights.shape=}")

plt.title("Testing image")
plt.imshow(test_img[0])

batch_size = 1024

# Batch size determines iteration per epoch, swap strategy is based on epochs
train_sampler = ImportantPixelSampler(pixel_weights, batch_size * 20, False, 400)
train_data = TensorDataset(torch.arange(train_origins.shape[0]), train_origins, train_directions, train_colors)
train_loader = DataLoader(train_data, batch_size, sampler=train_sampler, num_workers=2, pin_memory=True, persistent_workers=True)

In [None]:
model: NeRF = NeRF().to(COMPUTE_DEVICE)
print(f"Trainable params count: {sum(p.numel() for p in model.parameters() if p.requires_grad):_}")

coarse_samples = 64
fine_samples = 64
epochs = 5000

optim = torch.optim.Adam(model.parameters(), lr=5e-5)
loss_func = torch.nn.MSELoss(reduction='none')

train_losses = []

# RTX 3050Ti Max-Q, Ryzen 5 5600H => ~4.3s/it at batch_size=1024, num_samples=batch_size*20
with tqdm(range(1, epochs + 1), desc=f"Epoch", position=0, leave=True) as epoch_progress:
    for epoch in epoch_progress:

        inter_train_losses = []
        for i, origins, directions, colors in train_loader:
            origins = origins.to(COMPUTE_DEVICE)
            directions = directions.to(COMPUTE_DEVICE)
            colors = colors.to(COMPUTE_DEVICE)

            optim.zero_grad()
            coarse_rgbs, coarse_depths, fine_rgbs, fine_depths = model.compute_along_rays(origins, directions, near, far, coarse_samples, fine_samples)

            coarse_colors, _ = render_rays(coarse_rgbs, coarse_depths)
            fine_colors, _ = render_rays(fine_rgbs, fine_depths)

            loss = loss_func(coarse_colors, colors) + loss_func(fine_colors, colors)
            loss = loss.mean(-1)
            train_loader.sampler.update_errors(i, loss)
            loss = loss.mean()
            loss.backward()
            optim.step()

            inter_train_losses.append(loss.item() * colors.shape[0])
        
        train_losses.append(sum(inter_train_losses) / train_sampler.num_samples)
        epoch_progress.set_postfix(loss=train_losses[-1])

        if epoch % 50 == 0:
            render = render_image(model, 100, 100, test_c2w.squeeze(0), focal, near, far)

            fig, axes = plt.subplots(1, 2, figsize=(12,4), width_ratios=(0.4, 0.6))
            axes = axes.flatten()

            axes[0].set_title(f"Render at epoch {epoch}")
            axes[0].imshow(render.cpu())

            axes[1].set_title(f"Losses up to epoch {epoch}")
            axes[1].plot(train_losses)
            axes[1].set_ylim(bottom=0.0, top=0.35)

            plt.show()

In [None]:
torch.save(model.state_dict(), "tiny_nerf2.pth")

In [None]:
print(torch.cuda.memory_summary())