In [None]:
import os
import time
import numpy as np
import torch
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.animation as animation
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.patches import FancyArrowPatch
from mpl_toolkits.mplot3d import proj3d
from IPython.display import HTML

In [None]:
if torch.cuda.is_available():
    print(torch.cuda.get_device_name(torch.cuda.current_device()))
    dev = f"cuda:{torch.cuda.current_device()}"
else:
    dev = "cpu"
print(dev)

In [None]:
loc_enc_dim = 4
N_samples1 = 64
N_samples2 = 64
near = 2.0
far = 6.0

In [None]:
class Arrow3D(FancyArrowPatch):
    def __init__(self, xs, ys, zs, *args, **kwargs):
        FancyArrowPatch.__init__(self, (0,0), (0,0), *args, **kwargs)
        self._verts3d = xs, ys, zs

    def draw(self, renderer):
        xs3d, ys3d, zs3d = self._verts3d
        xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, renderer.M)
        self.set_positions((xs[0],ys[0]),(xs[1],ys[1]))
        FancyArrowPatch.draw(self, renderer)

def plot_poses(poses):
    plt.close()
    fig = plt.figure()
    ax = fig.add_subplot(projection="3d")
    poses_ = poses.cpu().numpy()
    ax.scatter(poses_[..., 0, 3], poses_[..., 1, 3], poses_[..., 2, 3], marker="o")
    for pose in poses_:
        d = np.matmul(pose[0:3, 0:3], [0, 0, -1])
        a = Arrow3D([pose[0, 3], pose[0, 3] + d[0]], [pose[1, 3], pose[1, 3] + d[1]], 
                    [pose[2, 3], pose[2, 3] + d[2]], mutation_scale=20, 
                    lw=1, arrowstyle="-|>", color="r")
        ax.add_artist(a)
        d = np.matmul(pose[0:3, 0:3], [1, 0, 0])
        a = Arrow3D([pose[0, 3], pose[0, 3] + d[0]], [pose[1, 3], pose[1, 3] + d[1]], 
                    [pose[2, 3], pose[2, 3] + d[2]], mutation_scale=20, 
                    lw=1, arrowstyle="-|>", color="g")
        ax.add_artist(a)
        d = np.matmul(pose[0:3, 0:3], [0, -1, 0])
        a = Arrow3D([pose[0, 3], pose[0, 3] + d[0]], [pose[1, 3], pose[1, 3] + d[1]], 
                    [pose[2, 3], pose[2, 3] + d[2]], mutation_scale=20, 
                    lw=1, arrowstyle="-|>", color="b")
        ax.add_artist(a)
    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.set_zlabel("z")
    plt.show()
    plt.close()

In [None]:
def load_data(dev):
    # http://cseweb.ucsd.edu/~viscomp/projects/LF/papers/ECCV20/nerf/tiny_nerf_data.npz
    data = np.load('tiny_nerf_data.npz')
    images = data['images']
    poses = data['poses']
    focal = data['focal']
    H, W = images.shape[1:3]
    print(images.shape, poses.shape, focal)

    testimg, testpose = images[101], poses[101]
    images = images[:100,...,:3]
    poses = poses[:100]

    plt.imshow(testimg)
    plt.show()
    images_ = torch.tensor(images, dtype=torch.float32, device=dev)
    poses_ = torch.tensor(poses, dtype=torch.float32, device=dev)
    return images_, poses_, focal, H, W

images, poses, focal, H, W = load_data(dev)
plot_poses(poses)
print(torch.cuda.memory_allocated())

In [None]:
def ray_origins_directions(H, W, focal, pose, rand=True):
    if rand:
        x, y = torch.meshgrid(
            torch.arange(W, dtype=torch.float32, device=dev) + np.random.uniform(),
            torch.arange(H, dtype=torch.float32, device=dev) + np.random.uniform(),
            indexing="xy")
    else:
        x, y = torch.meshgrid(
            torch.arange(W, dtype=torch.float32, device=dev) + 0.5,
            torch.arange(H, dtype=torch.float32, device=dev) + 0.5,
            indexing="xy")
    pix_x, pix_y = torch.meshgrid(
            torch.arange(W, dtype=torch.float32, device=dev),
            torch.arange(H, dtype=torch.float32, device=dev),
            indexing="xy")
    dirs = [(x - W * 0.5) * (1 / focal), -(y - H * 0.5) * (1 / focal), -torch.ones_like(x)]
    dirs = torch.stack(dirs, -1)
    dirs = torch.sum(dirs[..., None, :] * pose[:3,:3], -1)
    origs = pose[:3,-1].expand(dirs.shape)
    ray_idx = torch.arange(H * W, dtype=torch.float32, device=dev)
    return torch.cat((origs.view(-1, 3), dirs.view(-1, 3), pix_x.reshape(-1, 1), pix_y.reshape(-1, 1), ray_idx.reshape(-1, 1)), 1)

def plot_origs_dirs(rays):
    plt.close()
    r_ = rays.cpu().detach().numpy()
    fig = plt.figure()
    ax = fig.add_subplot(projection="3d")
    ax.scatter(r_[:, 0], r_[:, 1], r_[:, 2], marker="o")
    for r in r_:
        a = Arrow3D([r[0], r[0] + r[3]], [r[1], r[1] + r[4]], 
                    [r[2], r[2] + r[5]], mutation_scale=20, 
                    lw=1, arrowstyle="-|>", color="r")
        ax.add_artist(a)
    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.set_zlabel("z")
    ax.set_xlim3d(r_[0, 0] - 1, r_[0, 0] + 1)
    ax.set_ylim3d(r_[0, 1] - 1, r_[0, 1] + 1)
    ax.set_zlim3d(r_[0, 2] - 1, r_[0, 2] + 1)
    plt.show()
    plt.close()

test_pose = torch.tensor([
    [0, 0, 1, 0],
    [0, 1, 0, 0],
    [1, 0, 0, 0],
    [0, 0, 0, 1]
], dtype=torch.float32, device=dev)
test_rays = ray_origins_directions(2, 3, 5, test_pose)
print(test_rays.shape)
print(test_rays)  # x, y, z, dx, dy, dz, pix_x, pix_y, ray_idx
plot_origs_dirs(test_rays)

In [None]:
def sample_rays(rays, z_vals):
    origs = rays[:, 0:3].repeat(z_vals.shape[0], 1, 1)
    dirs = rays[:, 3:6].repeat(z_vals.shape[0], 1, 1)
    ray_idx = rays[:, 8:9].repeat(z_vals.shape[0], 1, 1)
    sample_points = torch.cat((origs + torch.mul(dirs, z_vals), dirs, z_vals, ray_idx), 2).permute(1, 0, 2)
    return sample_points.reshape(-1, 8)

def plot_samples(samples):
    plt.close()
    s_ = samples.cpu().detach().numpy()
    fig = plt.figure()
    ax = fig.add_subplot(projection="3d")
    ax.scatter(s_[:, 0], s_[:, 1], s_[:, 2], marker="o")
    for s in s_:
        a = Arrow3D([s[0], s[0] + s[3]], [s[1], s[1] + s[4]], 
                    [s[2], s[2] + s[5]], mutation_scale=20, 
                    lw=1, arrowstyle="-|>", color="r")
        ax.add_artist(a)
    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.set_zlabel("z")
    plt.show()
    plt.close()

test_N_samples = 2
test_z_vals = torch.linspace(near, far, test_N_samples, dtype=torch.float32, device=dev).repeat(test_rays.shape[0], 1, 1).permute(2, 0, 1)
test_samples = sample_rays(test_rays, test_z_vals)
print(test_samples.shape)
print(test_samples)  # x, y, z, dx, dy, dz, z, ray_idx
plot_samples(test_samples)

In [None]:
def corners_hashes(samples, Ns, table_size, loc_enc_dim, dev):
    add = torch.tensor([[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1], [1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 1, 1]], dtype=torch.int64, device=dev)
    centers = samples[:, 0:3].repeat(loc_enc_dim, 1, 1) * Ns[:, None, None]
    corners = torch.floor(centers).repeat(8, 1, 1, 1) + add[:, None, None, :]
    del centers, add
    hashes = torch.remainder(torch.bitwise_xor(
        torch.bitwise_xor(corners[..., 0].to(torch.int64), corners[..., 1].to(torch.int64) * 2654435761), corners[..., 2].to(torch.int64) * 805459861
    ), table_size)
    return torch.div(corners, Ns[:, None, None]), hashes

def plot_locs_hashes(samples, locs, hashes):
    plt.close()
    samples_ = samples.cpu()
    locs_ = locs.cpu().view(-1, 3)
    hashes_ = hashes.cpu().view(-1, 1)
    fig = plt.figure()
    ax = fig.add_subplot(projection="3d")
    ax.scatter(samples_[..., 0], samples_[..., 1], samples_[..., 2], marker="^")
    ax.scatter(locs_[..., 0], locs_[..., 1], locs_[..., 2], marker="o")
    for loc, h in zip(locs_, hashes_):
        ax.text(loc[0], loc[1], loc[2], f"{int(h)}", None)
    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.set_zlabel("z")
    plt.show()
    plt.close()

test_all_corners, test_all_hashes = corners_hashes(test_samples.detach(), Ns=torch.tensor([2, 4], dtype=torch.float32, device=dev), table_size=1024, loc_enc_dim=2, dev=dev)
print(test_all_corners.shape)
print(test_all_hashes.shape)
plot_locs_hashes(test_samples, test_all_corners[:, 0], test_all_hashes[:, 0])
plot_locs_hashes(test_samples, test_all_corners[:, 1], test_all_hashes[:, 1])

In [None]:
class Model(torch.nn.Module):
    def __init__(self, in_dim, hidden_dim, table_size, loc_enc_dim, nf=2, N=16):
        super().__init__()
        self.in_dim = in_dim
        self.table_size = table_size
        self.loc_enc_dim = loc_enc_dim
        self.nf = nf
        self.N = N
        self.hashtable = torch.nn.Parameter(torch.empty((self.table_size, self.nf)).uniform_(-1, 1) * 10e-4)
        self.linear1 = torch.nn.Linear(self.in_dim, hidden_dim)
        self.activation1 = torch.nn.ReLU()
        self.linear2 = torch.nn.Linear(hidden_dim, 2)
        self.activation2 = torch.nn.ReLU()

    def forward(self, samples, dev):
        Ns = torch.tensor([self.N * (1.5 ** i) for i in range(self.loc_enc_dim)], dtype=torch.float32, device=dev)
        corners, hashes = corners_hashes(samples.detach(), Ns, self.table_size, self.loc_enc_dim, dev)
        values = self.hashtable[hashes]
        x0 = (samples[None, ..., 0:1] - corners[0, ..., 0:1]) * Ns[:, None,  None]
        x1 = (corners[-1, ..., 0:1] - samples[None, ..., 0:1]) * Ns[:, None,  None]
        y0 = (samples[None, ..., 1:2] - corners[0, ..., 1:2]) * Ns[:, None, None]
        y1 = (corners[-1, ..., 1:2] - samples[None, ..., 1:2]) * Ns[:, None,  None]
        z0 = (samples[None, ..., 2:3] - corners[0, ..., 2:3]) * Ns[:, None,  None]
        z1 = (corners[-1, ..., 2:3] - samples[None, ..., 2:3]) * Ns[:, None, None]
        interpolated_values = x1*y1*z1*values[0] + x1*y1*z0*values[1] + x1*y0*z1*values[2] + x1*y0*z0*values[3] \
                              + x0*y1*z1*values[4] + x0*y1*z0*values[5] + x0*y0*z1*values[6] + x0*y0*z0*values[7]
        interpolated_values = interpolated_values.permute(1, 0, 2).reshape(samples.shape[0], self.nf * self.loc_enc_dim)
        x = torch.cat((samples[:, 0:6], interpolated_values), 1)
        x = self.linear1(x)
        x = self.activation1(x)
        x = self.linear2(x)
        x = self.activation2(x)
        return x

test_model = Model(6 + 2 * loc_enc_dim, 10, 32, loc_enc_dim).to(dev)
test_output = test_model.forward(test_samples, dev)
print(test_output.shape)

In [None]:
def render_pixels(model_output, samples):
    dists = torch.cat((samples[:, 1:, 6] - samples[:, :-1, 6], torch.ones_like(samples[:, 0:1, 6]) * 1e10), 1)
    densities = 1. - torch.exp(-model_output[:, :, 0] * dists)
    weights = densities * torch.cumprod(1. - densities + 1e-10, 1)
    ret = torch.sum(weights * model_output[:, :, 1], 1)
    return ret, samples[:, 0, 7]

test_luminances, test_ray_idxs = render_pixels(test_output.view(-1, test_N_samples, 2), test_samples.view(-1, test_N_samples, 8))
print(test_luminances)
print(test_luminances.shape)
print(test_ray_idxs)

In [None]:
def draw_image(pixels, ray_idxs, rays, H, W, dev):
    img = torch.zeros((H, W), dtype=torch.float32, device=dev)
    idx = torch.arange(ray_idxs.shape[0], dtype=torch.int64, device=dev)
    r = ray_idxs.long()
    img[rays[r, 7].long(), rays[r, 6].long()] = pixels[idx]
    return img

def pred_pixels_hierarchical(rays, near, far, N_samples1, N_samples2, loc_enc_dim, model, dev, rand=True):
    # initial samples
    z_vals1 = torch.linspace(near, far, N_samples1, dtype=torch.float32, device=dev).repeat(rays.shape[0], 1, 1).permute(2, 0, 1)
    samples1 = sample_rays(rays, z_vals1)
    
    # first preds
    model_output1 = model(samples1, dev)
    model_output1 = model_output1.view(rays.shape[0], N_samples1, 2)
    samples1 = samples1.view(rays.shape[0], N_samples1, 8)
    
    # resampling distribution
    dists1 = torch.cat((samples1[:, 1:, 6] - samples1[:, :-1, 6], torch.ones_like(samples1[:, 0:1, 6]) * 1e10), 1)
    densities1 = 1. - torch.exp(-model_output1[:, :, 0] * dists1)
    weights1 = densities1 * torch.cumprod(1. - densities1 + 1e-10, 1)
    weights1_plus = weights1 + torch.mean(weights1, 1)[:, None] * 0.1 + 0.01
    probs = torch.div(weights1_plus, torch.sum(weights1_plus, 1)[:, None])
    distribution = torch.cumsum(probs, 1)
    distribution = torch.cat((torch.zeros_like(distribution[:, 0:1]), distribution), 1)
    
    # resampling
    if rand:
        u = torch.empty((rays.shape[0], N_samples2), dtype=torch.float32, device=dev).uniform_(0, 1)
        u, _ = torch.sort(u, 1)
    else:
        u = torch.linspace(0, 1, N_samples2, dtype=torch.float32, device=dev).repeat(rays.shape[0], 1)
    ind = torch.searchsorted(distribution, u, right=True)
    low_idxs = torch.maximum(torch.zeros_like(ind), torch.minimum(torch.ones_like(ind) * (z_vals1.shape[0] - 1), ind - 1))
    high_idxs = torch.maximum(torch.zeros_like(ind), torch.minimum(torch.ones_like(ind) * (z_vals1.shape[0] - 1), ind))
    low_vals = torch.gather(z_vals1.squeeze().permute(1, 0), 1, low_idxs)
    high_vals = torch.gather(z_vals1.squeeze().permute(1, 0), 1, high_idxs)
    low_cdfs = torch.gather(distribution, 1, low_idxs)
    high_cdfs = torch.gather(distribution, 1, high_idxs)
    proportions = torch.div(u - low_cdfs, torch.maximum(torch.ones_like(low_cdfs) * 0.01, high_cdfs - low_cdfs))
    proportions = torch.maximum(torch.zeros_like(proportions), torch.minimum(torch.ones_like(proportions), proportions))
    z_vals2 = (proportions * (high_vals - low_vals) + low_vals).permute(1, 0)[:, :, None]
    samples2 = sample_rays(rays, z_vals2)
    
    # second preds
    model_output2 = model(samples2, dev)
    model_output2 = model_output2.view(rays.shape[0], N_samples2, 2)
    samples2 = samples2.view(rays.shape[0], N_samples2, 8)
    
    # merge samples
    model_outputs = torch.cat((model_output1, model_output2), 1)
    samples = torch.cat((samples1, samples2), 1)
    z_vals = torch.cat((z_vals1, z_vals2), 0)
    z_vals_sorted, sort_idx = torch.sort(z_vals, 0)
    sort_idx = sort_idx.squeeze().permute(1, 0)
    model_outputs = torch.gather(model_outputs, 1, sort_idx.unsqueeze(2).repeat(1, 1, 2))
    samples = torch.gather(samples, 1, sort_idx.unsqueeze(2).repeat(1, 1, 8))
    
    # render
    pixels, ray_idxs = render_pixels(model_outputs, samples)
    return pixels, ray_idxs

def pred_image_hierarchical(pose, H, W, focal, near, far, N_samples1, N_samples2, loc_enc_dim, model, dev, rand=True):
    rays = ray_origins_directions(H, W, focal, pose, rand)
    pixels, ray_idxs = pred_pixels_hierarchical(rays, near, far, N_samples1, N_samples2, loc_enc_dim, model, dev, rand)
    pred = draw_image(pixels, ray_idxs, rays, H, W, dev)
    return pred

test_pred = pred_image_hierarchical(test_pose, H, W, focal, near, far, N_samples1, N_samples2, loc_enc_dim, test_model, dev, rand=False)
plt.imshow(test_pred.cpu().detach().numpy(), cmap=cm.Greys_r)
plt.show()
test_pred = pred_image_hierarchical(test_pose, H, W, focal, near, far, N_samples1, N_samples2, loc_enc_dim, test_model, dev, rand=True)
plt.imshow(test_pred.cpu().detach().numpy(), cmap=cm.Greys_r)
plt.show()

In [None]:
del test_rays, test_samples, test_all_corners, test_all_hashes, test_model, test_output, test_luminances, test_pose, test_pred, test_ray_idxs, test_z_vals, test_N_samples
print(torch.cuda.memory_allocated())

In [None]:
model = Model(6 + 2 * loc_enc_dim, 64, 16384, loc_enc_dim).to(dev)
all_losses = []
model.load_state_dict(torch.load("./checkpoint_ingp_nerf.pt"))
pred = pred_image_hierarchical(poses[1], H, W, focal, near, far, N_samples1, N_samples2, loc_enc_dim, model, dev)
plt.imshow(pred.cpu().detach().numpy(), cmap=cm.Greys_r)
plt.show()
del pred

In [None]:
def create_dataset(images, poses, H, W, focal, rand):
    all_rays = []
    for idx in range(images.shape[0]):
        target = torch.amax(images[idx], axis=2).view(-1, 1)
        pose = poses[idx]
        rays = ray_origins_directions(H, W, focal, pose, rand)
        rays = torch.cat((rays, target), 1)
        all_rays.append(rays)
    all_rays = torch.cat(all_rays, 0)
    if rand:
        return all_rays[torch.randperm(all_rays.shape[0])]
    else:
        return all_rays

def train(model, dev, epochs, images, poses, H, W, focal, near, far, N_samples1, N_samples2, loc_enc_dim, batch_size=2**14):
    assert batch_size % (N_samples1 + N_samples2) == 0
    optimizer = torch.optim.Adam(model.parameters(), lr=10e-3, betas=[0.9, 0.99], eps=10e-15)
    avg_losses = torch.zeros(0, dtype=torch.float32, device=dev)
    for e in range(epochs):
        t_start = time.time()
        losses = torch.zeros(0, dtype=torch.float32, device=dev)
        dataset = create_dataset(images, poses, H, W, focal, rand=True)
        for idx in range(dataset.shape[0] // batch_size):
            rays = dataset[idx:idx + batch_size]
            optimizer.zero_grad()
            pred, _ = pred_pixels_hierarchical(rays, near, far, N_samples1, N_samples2, loc_enc_dim, model, dev)
            loss = torch.mean(torch.square(rays[:, 9] - pred))
            loss.backward()
            optimizer.step()
            losses = torch.cat((losses, loss[None]), 0)
            print(".", end="")
        avg_loss = torch.mean(losses)
        avg_losses = torch.cat((avg_losses, avg_loss[None]), 0)
        print(f" {e + 1} @ {time.time() - t_start:.1f}s")
    return avg_losses.cpu().detach().numpy()

epochs = 1
all_losses.extend(train(model, dev, epochs, images, poses, H, W, focal, near, far, N_samples1, N_samples2, loc_enc_dim))
plt.plot(-np.log(all_losses))
plt.show()
pred = pred_image_hierarchical(poses[1], H, W, focal, near, far, N_samples1, N_samples2, loc_enc_dim, model, dev)
target = torch.amax(images[1], 2)
plt.imshow(pred.cpu().detach().numpy(), cmap=cm.Greys_r)
plt.show()
plt.imshow(target.cpu().detach().numpy(), cmap=cm.Greys_r)
plt.show()
del pred, target
print(torch.cuda.memory_allocated())

In [None]:
torch.save(model.state_dict(), "./checkpoint_ingp_nerf.pt")

In [None]:
def generate_camera_trajectory(x0, y0, z0, dist, n_points):
    poses = []
    for i in range(n_points):
        dx = dist * np.cos(i * np.pi * 2 / n_points)
        dy = dist * np.sin(i * np.pi * 2 / n_points)
        dz = 0
        x = x0 + dx
        y = y0 + dy
        z = z0 + dz
        phi_x = np.pi / 2
        phi_y = 0
        phi_z = np.pi / 2 + np.pi * 2 * i / n_points
        cx = np.cos(phi_x)
        sx = np.sin(phi_x)
        cy = np.cos(phi_y)
        sy = np.sin(phi_y)
        cz = np.cos(phi_z)
        sz = np.sin(phi_z)
        poses.append([
            [cz*cy, cz*sy*sx-sz*cx, cz*sy*cx+sz*sx, x],
            [sz*cy, sz*sy*sx+cz*cx, sz*sy*cx-cz*sx, y],
            [-sy,   cy*sx,          cy*cx,          z],
            [0,     0,              0,              1]
        ])
    return np.array(poses)

def generate_video(poses, model, dev, H, W, focal, near, far, N_samples1, N_samples2, loc_enc_dim, samples=1):
    frames = []
    fig = plt.figure()
    for idx in range(len(poses)):
        pose = torch.tensor(poses[idx], dtype=torch.float32, device=dev)
        img = np.zeros((H, W))
        for s in range(samples):
            if s == 0:
                rand = False
            else:
                rand = True
            pred = pred_image_hierarchical(pose, H, W, focal, near, far, N_samples1, N_samples2, loc_enc_dim, model, dev, rand)
            img += pred.cpu().detach().numpy()
        frames.append([plt.imshow(img / samples, cmap=cm.Greys_r, animated=True)])
        print(".", end="")
    ani = animation.ArtistAnimation(fig, frames, interval=50, blit=True)
    ani.save("movie_ingp_nerf.mp4")
    plt.close()
    print("Done")

with torch.no_grad():
    video_poses = generate_camera_trajectory(x0=0, y0=0, z0=0.5, dist=4, n_points=100)
    plot_poses(torch.tensor(video_poses, dtype=torch.float32, device=dev))
    generate_video(video_poses, model, dev, H, W, focal, near, far, N_samples1, N_samples2 * 2, loc_enc_dim, samples=10)
del video_poses
print(torch.cuda.memory_allocated())

In [None]:
HTML("""
    <video alt="test" controls>
        <source src="movie_ingp_nerf.mp4" type="video/mp4">
    </video>
""")