In [None]:
import os
import time
import numpy as np
import torch
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.animation as animation
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.patches import FancyArrowPatch
from mpl_toolkits.mplot3d import proj3d
from IPython.display import HTML

In [None]:
if torch.cuda.is_available():
    print(torch.cuda.get_device_name(torch.cuda.current_device()))
    dev = f"cuda:{torch.cuda.current_device()}"
else:
    dev = "cpu"
print(dev)

In [None]:
loc_enc_dim = 5
N_samples = 96
near = 2.0
far = 6.0

In [None]:
class Arrow3D(FancyArrowPatch):
    def __init__(self, xs, ys, zs, *args, **kwargs):
        FancyArrowPatch.__init__(self, (0,0), (0,0), *args, **kwargs)
        self._verts3d = xs, ys, zs

    def draw(self, renderer):
        xs3d, ys3d, zs3d = self._verts3d
        xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, renderer.M)
        self.set_positions((xs[0],ys[0]),(xs[1],ys[1]))
        FancyArrowPatch.draw(self, renderer)

def plot_poses(poses):
    plt.close()
    fig = plt.figure()
    ax = fig.add_subplot(projection="3d")
    poses_ = poses.cpu().numpy()
    ax.scatter(poses_[..., 0, 3], poses_[..., 1, 3], poses_[..., 2, 3], marker="o")
    for pose in poses_:
        d = np.matmul(pose[0:3, 0:3], [0, 0, -1])
        a = Arrow3D([pose[0, 3], pose[0, 3] + d[0]], [pose[1, 3], pose[1, 3] + d[1]], 
                    [pose[2, 3], pose[2, 3] + d[2]], mutation_scale=20, 
                    lw=1, arrowstyle="-|>", color="r")
        ax.add_artist(a)
        d = np.matmul(pose[0:3, 0:3], [1, 0, 0])
        a = Arrow3D([pose[0, 3], pose[0, 3] + d[0]], [pose[1, 3], pose[1, 3] + d[1]], 
                    [pose[2, 3], pose[2, 3] + d[2]], mutation_scale=20, 
                    lw=1, arrowstyle="-|>", color="g")
        ax.add_artist(a)
        d = np.matmul(pose[0:3, 0:3], [0, -1, 0])
        a = Arrow3D([pose[0, 3], pose[0, 3] + d[0]], [pose[1, 3], pose[1, 3] + d[1]], 
                    [pose[2, 3], pose[2, 3] + d[2]], mutation_scale=20, 
                    lw=1, arrowstyle="-|>", color="b")
        ax.add_artist(a)
    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.set_zlabel("z")
    plt.show()
    plt.close()

In [None]:
def load_data():
    # http://cseweb.ucsd.edu/~viscomp/projects/LF/papers/ECCV20/nerf/tiny_nerf_data.npz
    data = np.load('tiny_nerf_data.npz')
    images = data['images']
    poses = data['poses']
    focal = data['focal']
    H, W = images.shape[1:3]
    print(images.shape, poses.shape, focal)

    testimg, testpose = images[101], poses[101]
    images = images[:100,...,:3]
    poses = poses[:100]

    plt.imshow(testimg)
    plt.show()
    images_ = torch.tensor(images, dtype=torch.float32, device=dev)
    poses_ = torch.tensor(poses, dtype=torch.float32, device=dev)
    return images_, poses_, focal, H, W
    
images, poses, focal, H, W = load_data()
plot_poses(poses)
print(torch.cuda.memory_allocated())

In [None]:
def ray_origins_directions(H, W, focal, pose, rand=True):
    if rand:
        x, y = torch.meshgrid(
            torch.arange(W, dtype=torch.float32, device=dev) + np.random.uniform(),
            torch.arange(H, dtype=torch.float32, device=dev) + np.random.uniform(),
            indexing="xy")
    else:
        x, y = torch.meshgrid(
            torch.arange(W, dtype=torch.float32, device=dev) + 0.5,
            torch.arange(H, dtype=torch.float32, device=dev) + 0.5,
            indexing="xy")
    dirs = [(x - W * 0.5) * (1 / focal), -(y - H * 0.5) * (1 / focal), -torch.ones_like(x)]
    dirs = torch.stack(dirs, -1)
    dirs = torch.sum(dirs[..., None, :] * pose[:3,:3], -1)
    origs = pose[:3,-1].expand(dirs.shape)
    return origs, dirs

def plot_origs_dirs(origs, dirs):
    plt.close()
    origs_ = origs.cpu().view(-1, 3)
    dirs_ = dirs.cpu().view(-1, 3)
    fig = plt.figure()
    ax = fig.add_subplot(projection="3d")
    ax.scatter(origs_[..., 0], origs_[..., 1], origs_[..., 2], marker="o")
    for o, d in zip(origs_, dirs_):
        a = Arrow3D([o[0], o[0] + d[0]], [o[1], o[1] + d[1]], 
                    [o[2], o[2] + d[2]], mutation_scale=20, 
                    lw=1, arrowstyle="-|>", color="r")
        ax.add_artist(a)
    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.set_zlabel("z")
    ax.set_xlim3d(origs_[0, 0] - 1, origs_[0, 0] + 1)
    ax.set_ylim3d(origs_[0, 1] - 1, origs_[0, 1] + 1)
    ax.set_zlim3d(origs_[0, 2] - 1, origs_[0, 2] + 1)
    plt.show()
    plt.close()

test_pose = torch.tensor([
    [0, 0, 1, 0],
    [0, 1, 0, 0],
    [1, 0, 0, 0],
    [0, 0, 0, 1]
], dtype=torch.float32, device=dev)
test_origs, test_dirs = ray_origins_directions(2, 3, 5, test_pose)
print(test_origs.shape)
print(test_dirs.shape)
plot_origs_dirs(test_origs, test_dirs)

In [None]:
def sample_rays(origs, dirs, near, far, N_samples, dev, rand=True):
    z_vals = torch.linspace(near, far, N_samples, dtype=torch.float32, device=dev)
    if rand:
        z_vals += torch.empty(z_vals.shape, dtype=torch.float32, device=dev).uniform_(0, 1) * (far-near)/N_samples
    sample_points = []
    for z_val in z_vals:
        sample_points.append(origs + dirs * z_val)
    return torch.stack(sample_points), dirs[None, ...].repeat(N_samples, 1, 1, 1), z_vals

def plot_rays(locs, dirs):
    plt.close()
    locs_ = locs.cpu().view(-1, 3)
    dirs_ = dirs.cpu().view(-1, 3)
    fig = plt.figure()
    ax = fig.add_subplot(projection="3d")
    ax.scatter(locs_[..., 0], locs_[..., 1], locs_[..., 2], marker="o")
    for l, d in zip(locs_, dirs_):
        a = Arrow3D([l[0], l[0] + d[0]], [l[1], l[1] + d[1]], 
                    [l[2], l[2] + d[2]], mutation_scale=20, 
                    lw=1, arrowstyle="-|>", color="r")
        ax.add_artist(a)
    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.set_zlabel("z")
    plt.show()
    plt.close()

test_N_samples = 2
test_sample_locations, test_sample_dirs, test_z_vals = sample_rays(test_origs, test_dirs, near, far, test_N_samples, dev)
print(test_sample_locations.shape)
print(test_sample_dirs.shape)
print(test_z_vals.shape)
plot_rays(test_sample_locations, test_sample_dirs)

In [None]:
def encode_input(locations, dirs, L):
    ret = torch.cat((dirs, locations), 3)
    for i in range(L):
        for fn in [torch.sin, torch.cos]:
            ret = torch.cat((ret, fn(2.**i * locations)), 3)
    return ret

test_model_input = encode_input(test_sample_locations, test_sample_dirs, loc_enc_dim)
print(test_model_input.shape)

In [None]:
class Model(torch.nn.Module):
    def __init__(self, in_dim, hidden_dim):
        super().__init__()
        self.in_dim = in_dim
        self.linear1 = torch.nn.Linear(self.in_dim, hidden_dim)
        self.activation1 = torch.nn.ReLU()
        self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim)
        self.activation2 = torch.nn.ReLU()
        self.linear3 = torch.nn.Linear(hidden_dim, hidden_dim)
        self.activation3 = torch.nn.ReLU()
        self.linear4 = torch.nn.Linear(hidden_dim, 2)
        self.activation4 = torch.nn.ReLU()

    def forward(self, x):
        orig_shape = tuple(x.shape)
        x = x.reshape(-1, self.in_dim)
        x = self.linear1(x)
        x = self.activation1(x)
        x = self.linear2(x)
        x = self.activation2(x)
        x = self.linear3(x)
        x = self.activation3(x)
        x = self.linear4(x)
        x = self.activation4(x)
        x = x.reshape((orig_shape[0], orig_shape[1], orig_shape[2], -1))
        return x

test_model = Model(6 + 6 * loc_enc_dim, 10).to(dev)
test_output = test_model(test_model_input)
print(test_output.shape)

In [None]:
def render_pixels(model_output, z_vals):
    dists = torch.cat((z_vals[1:] - z_vals[:-1], torch.ones_like(z_vals[-1, None]) * 1e10))
    densities = 1. - torch.exp(-model_output[..., 0] * dists[:, None, None])
    weights = densities * torch.cumprod(1. - densities + 1e-10, -1)
    ret = torch.sum(weights * model_output[..., 1], axis=0)
    return ret
    
test_luminances = render_pixels(test_output, test_z_vals)
print(test_luminances)
print(test_luminances.size())

In [None]:
def pred_image(pose, H, W, focal, near, far, N_samples, loc_enc_dim, model, dev, rand=True):
    origs, dirs = ray_origins_directions(H, W, focal, pose, rand)
    sample_locations, sample_dirs, z_vals = sample_rays(origs, dirs, near, far, N_samples, dev, rand)
    model_input = encode_input(sample_locations, sample_dirs, loc_enc_dim)
    model_output = model(model_input)
    return render_pixels(model_output, z_vals)

test_pred = pred_image(test_pose, H, W, focal, 1, 5, N_samples, loc_enc_dim, test_model, dev)
print(test_pred.cpu().detach().numpy())
plt.imshow(test_pred.cpu().detach().numpy(), cmap=cm.Greys_r)
plt.show()

In [None]:
del test_pose, test_origs, test_dirs, test_sample_locations, test_sample_dirs, test_z_vals, test_model_input, test_model, test_output, test_luminances, test_pred, test_N_samples
print(torch.cuda.memory_allocated())

In [None]:
model = Model(6 + 6 * loc_enc_dim, 128).to(dev)
all_losses = []
model.load_state_dict(torch.load("./checkpoint_nerf_simple.pt"))
pred = pred_image(poses[1], H, W, focal, near, far, N_samples, loc_enc_dim, model, dev)
plt.imshow(pred.cpu().detach().numpy(), cmap=cm.Greys_r)
plt.show()
del pred

In [None]:
def train(model, dev, epochs, images, poses, H, W, focal, near, far, N_samples, loc_enc_dim):
    optimizer = torch.optim.Adam(model.parameters(), lr=10e-5)
    avg_losses = torch.zeros(0, dtype=torch.float32, device=dev)
    for e in range(epochs):
        t_start = time.time()
        losses = torch.zeros(0, dtype=torch.float32, device=dev)
        for idx in range(images.shape[0]):
            image = images[idx]
            pose = poses[idx]
            optimizer.zero_grad()
            pred = pred_image(pose, H, W, focal, near, far, N_samples, loc_enc_dim, model, dev)
            target = torch.amax(image, axis=2)
            img_loss = torch.mean(torch.square(target - pred))
            img_loss.backward()
            optimizer.step()
            losses = torch.cat((losses, img_loss[None]), 0)
            print(".", end="")
        avg_loss = torch.mean(losses)
        avg_losses = torch.cat((avg_losses, avg_loss[None]), 0)
        print(f" {e + 1} @ {time.time() - t_start:.1f}s")
    return avg_losses.cpu().detach().numpy()

epochs = 1
all_losses.extend(train(model, dev, epochs, images, poses, H, W, focal, near, far, N_samples, loc_enc_dim))
plt.plot(-np.log(all_losses))
plt.show()
pred = pred_image(poses[1], H, W, focal, near, far, N_samples, loc_enc_dim, model, dev, rand=False)
target = torch.mean(images[1], 2)
plt.imshow(pred.cpu().detach().numpy(), cmap=cm.Greys_r)
plt.show()
plt.imshow(target.cpu().detach().numpy(), cmap=cm.Greys_r)
plt.show()
del pred, target
print(torch.cuda.memory_allocated())

In [None]:
torch.save(model.state_dict(), "./checkpoint_nerf_simple.pt")

In [None]:
def generate_camera_trajectory(x0, y0, z0, dist, n_points):
    poses = []
    for i in range(n_points):
        dx = dist * np.cos(i * np.pi * 2 / n_points)
        dy = dist * np.sin(i * np.pi * 2 / n_points)
        dz = 0
        x = x0 + dx
        y = y0 + dy
        z = z0 + dz
        phi_x = np.pi / 2
        phi_y = 0
        phi_z = np.pi / 2 + np.pi * 2 * i / n_points
        cx = np.cos(phi_x)
        sx = np.sin(phi_x)
        cy = np.cos(phi_y)
        sy = np.sin(phi_y)
        cz = np.cos(phi_z)
        sz = np.sin(phi_z)
        poses.append([
            [cz*cy, cz*sy*sx-sz*cx, cz*sy*cx+sz*sx, x],
            [sz*cy, sz*sy*sx+cz*cx, sz*sy*cx-cz*sx, y],
            [-sy,   cy*sx,          cy*cx,          z],
            [0,     0,              0,              1]
        ])
    return np.array(poses)

def generate_video(poses, model, dev, H, W, focal, near, far, N_samples, loc_enc_dim, samples=1):
    frames = []
    fig = plt.figure()
    for idx in range(len(poses)):
        pose = torch.tensor(poses[idx], dtype=torch.float32, device=dev)
        img = np.zeros((H, W))
        for s in range(samples):
            if s == 0:
                rand = False
            else:
                rand = True
            pred = pred_image(pose, H, W, focal, near, far, N_samples, loc_enc_dim, model, dev, rand)
            img += pred.cpu().detach().numpy()
        frames.append([plt.imshow(img / samples, cmap=cm.Greys_r, animated=True)])
        print(".", end="")
    ani = animation.ArtistAnimation(fig, frames, interval=50, blit=True)
    ani.save("movie_nerf_simple.mp4")
    plt.close()
    print("Done")

with torch.no_grad():
    video_poses = generate_camera_trajectory(x0=0, y0=0, z0=0.5, dist=4, n_points=100)
    plot_poses(torch.tensor(video_poses, dtype=torch.float32, device=dev))
    generate_video(video_poses, model, dev, H, W, focal, near, far, N_samples, loc_enc_dim, samples=10)
del video_poses
print(torch.cuda.memory_allocated())

In [None]:
HTML("""
    <video alt="test" controls>
        <source src="movie_nerf_simple.mp4" type="video/mp4">
    </video>
""")