# Grid Encoding (like Instant NGP)

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import plotly.graph_objects as go

In [2]:
grid_size = 10
device = torch.device("mps")
grid = torch.ones((grid_size, grid_size, grid_size)).to(device)
print(f"Our grid shape is {grid.shape}")

Our grid shape is torch.Size([10, 10, 10])


## Creating a plot to visualize the grid

In [3]:
grid_np = grid.cpu().numpy()

x, y, z = np.indices(grid_np.shape)
values = grid_np.flatten()

fig = go.Figure(
    data=go.Scatter3d(
        x=x.flatten(),
        y=y.flatten(),
        z=z.flatten(),
        mode="markers",
        marker=dict(size=5, color=values, colorscale="Viridis", opacity=0.8),
    )
)

fig.update_layout(
    title="3D Random Grid Visualization",
    scene=dict(xaxis_title="X", yaxis_title="Y", zaxis_title="Z"),
)

fig.show()

### Visualizing a sample 3D point

In [4]:
sample_point = torch.rand(3) * grid_size
sample_point

tensor([2.2381, 5.5359, 5.5115])

In [5]:
fig.add_trace(go.Scatter3d(
    x=[sample_point[0]],
    y=[sample_point[1]],
    z=[sample_point[2]],
    mode='markers',
    marker=dict(
        size=3,
        color='red',
        opacity=0.5
    )
))

fig.show()

### Finding vertices of the nearest voxel grid

In [6]:
def return_corresponding_voxel_vertices(sample_points):
    sample_points = sample_points.to(device)
    lower_1 = torch.floor(sample_points)
    upper_1 = torch.ceil(sample_points)
    lower_2 = lower_1 + torch.tensor([0, 0, 1]).to(device)
    lower_3 = lower_1 + torch.tensor([0, 1, 0]).to(device)
    lower_4 = lower_1 + torch.tensor([1, 0, 0]).to(device)
    upper_2 = upper_1 - torch.tensor([0, 0, 1]).to(device)
    upper_3 = upper_1 - torch.tensor([0, 1, 0]).to(device)
    upper_4 = upper_1 - torch.tensor([1, 0, 0]).to(device)
    return torch.stack(
        [lower_1, upper_1, lower_2, lower_3, lower_4, upper_2, upper_3, upper_4], dim=1
    )

In [7]:
return_corresponding_voxel_vertices(torch.rand((4, 3)) * grid_size).shape

torch.Size([4, 8, 3])

### Visualizing a sample 3D point and its nearest voxel grid

In [8]:
sample_voxel_vertices = (
    return_corresponding_voxel_vertices(sample_point.unsqueeze(0)).cpu().numpy()
)

In [9]:
fig.add_trace(
    go.Scatter3d(
        x=sample_voxel_vertices[0, :, 0],
        y=sample_voxel_vertices[0, :, 1],
        z=sample_voxel_vertices[0, :, 2],
        mode="markers",
        marker=dict(size=5, color="blue", opacity=1),
    )
)
fig.show()

### Creating embeddings for each vertex of voxels in the 3D grid

In [10]:
Feature_dim = 4
embedding = nn.Embedding(grid_size * grid_size * grid_size, Feature_dim).to(device)
embedding.weight.shape

torch.Size([1000, 4])

### Interpolation function
(for getting the embeddings of any point in 3D space as sum of weighted embeddings of the nearest voxel grid)

In [11]:
def interpolate(sample_points, embedding, type="Linear"):
    sample_points = sample_points.to(device)
    batch_size = sample_points.shape[0]
    vertices_of_interest = return_corresponding_voxel_vertices(sample_points)
    flattened_indices = (
        (
            vertices_of_interest[:, :, 0]
            + grid_size * vertices_of_interest[:, :, 1]
            + grid_size * grid_size * vertices_of_interest[:, :, 2]
        )
        .long()
        .to(device)
    )
    embeddings = embedding(flattened_indices).squeeze()
    distances = torch.norm(vertices_of_interest - sample_points.unsqueeze(1), dim=2)
    if type == "Linear":
        weights = 1 / distances
    else:
        raise NotImplementedError

    weights = weights / torch.sum(weights, dim=1, keepdim=True)
    weights = weights.unsqueeze(2)
    return torch.sum(embeddings * weights, dim=1)

### A sample down-stream task (SDF prediction) using the embeddings

#### Creating a dataset of sphere SDFs

In [12]:
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data.dataset import random_split


def sphere_sdf(x, y, z):
    return torch.sqrt((x - 2) ** 2 + (y - 2) ** 2 + (z - 2) ** 2) - 0.5


sample_size = 2**15
random_points_in_grid = torch.rand(sample_size, 3) * grid_size
gt_sdf_vals = sphere_sdf(
    random_points_in_grid[:, 0],
    random_points_in_grid[:, 1],
    random_points_in_grid[:, 2],
).to(device)
random_points_in_grid = random_points_in_grid.to(device)

points_dataset = TensorDataset(random_points_in_grid, gt_sdf_vals)
train_size = int(0.8 * sample_size)
val_size = sample_size - train_size
train_dataset, val_dataset = random_split(points_dataset, [train_size, val_size])

BATCH_SIZE = 64
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

#### Training loop

In [13]:
from tqdm import tqdm

model = nn.Sequential(
    nn.Linear(Feature_dim, 256),
    nn.ReLU(),
    nn.Linear(256, 256),
    nn.ReLU(),
    nn.Linear(256, 1),
).to(device)

optimizer = optim.Adam(list(model.parameters()) + list(embedding.parameters()), lr=1e-3)
EPOCHS = 100
pbar = tqdm(range(int(EPOCHS)))

  0%|          | 0/100 [00:00<?, ?it/s]

In [14]:
for epoch in pbar:
    pbar.set_description(f"Epoch {epoch}")
    for i, (x, sdf_val) in enumerate(train_loader):
        optimizer.zero_grad()
        y_pred = model(interpolate(x, embedding))
        loss = F.mse_loss(y_pred, sdf_val.unsqueeze(1))
        loss.backward()
        optimizer.step()
        pbar.set_postfix({"Loss": loss.item()})

    if epoch % 10 == 0:
        with torch.no_grad():
            val_loss = 0
            for i, (x, y) in enumerate(val_loader):
                y_pred = model(interpolate(x, embedding))
                val_loss += F.mse_loss(y_pred, y.unsqueeze(1))
            print(f"Epoch {epoch} - Val Loss: {val_loss / len(val_loader)}")

Epoch 1:   1%|          | 1/100 [00:05<08:38,  5.23s/it, Loss=3.06]

Epoch 0 - Val Loss: 3.597395896911621


Epoch 11:  11%|█         | 11/100 [00:44<05:50,  3.93s/it, Loss=0.199]

Epoch 10 - Val Loss: 0.2925286293029785


Epoch 21:  21%|██        | 21/100 [01:22<05:11,  3.95s/it, Loss=0.0921]

Epoch 20 - Val Loss: 0.14235101640224457


Epoch 31:  31%|███       | 31/100 [02:01<04:35,  3.99s/it, Loss=0.0463]

Epoch 30 - Val Loss: 0.08255120366811752


Epoch 41:  41%|████      | 41/100 [02:40<04:02,  4.11s/it, Loss=0.0449]

Epoch 40 - Val Loss: 0.06785129010677338


Epoch 51:  51%|█████     | 51/100 [03:21<03:19,  4.06s/it, Loss=0.0537]

Epoch 50 - Val Loss: 0.05200814828276634


Epoch 61:  61%|██████    | 61/100 [04:02<02:48,  4.31s/it, Loss=0.0232]

Epoch 60 - Val Loss: 0.04541991651058197


Epoch 71:  71%|███████   | 71/100 [04:45<02:07,  4.40s/it, Loss=0.026] 

Epoch 70 - Val Loss: 0.04362420365214348


Epoch 81:  81%|████████  | 81/100 [05:24<01:17,  4.08s/it, Loss=0.0228]

Epoch 80 - Val Loss: 0.04999928921461105


Epoch 91:  91%|█████████ | 91/100 [06:05<00:39,  4.37s/it, Loss=0.0171]

Epoch 90 - Val Loss: 0.04161970317363739


Epoch 99: 100%|██████████| 100/100 [06:43<00:00,  4.03s/it, Loss=0.0331]
