In [None]:
import os
import numpy as np
import torch
import matplotlib.pyplot as plt

In [None]:
loc_enc_dim = 5
N_samples = 15
near = 2.0
far = 6.0

In [None]:
def load_data():
    # http://cseweb.ucsd.edu/~viscomp/projects/LF/papers/ECCV20/nerf/tiny_nerf_data.npz
    data = np.load('tiny_nerf_data.npz')
    images = data['images']
    poses = data['poses']
    focal = data['focal']
    H, W = images.shape[1:3]
    print(images.shape, poses.shape, focal)

    testimg, testpose = images[101], poses[101]
    images = images[:100,...,:3]
    poses = poses[:100]

    plt.imshow(testimg)
    plt.show()
    print(poses[0])
    return images, poses, focal, H, W
    
images, poses, focal, H, W = load_data()

In [None]:
def ray_origins_directions(H, W, focal, pose):
    x, y = np.meshgrid(np.arange(W), np.arange(H))
    dirs = np.stack([(x-W*.5)/focal, -(y-H*.5)/focal, -np.ones_like(x)], -1)
    dirs = np.matmul(dirs, pose[:3, :3])
    origs = np.tile(pose[0:3, 3:4].T, (dirs.shape[0], dirs.shape[1], 1))
    return origs, dirs

test_origs, test_dirs = ray_origins_directions(2, 2, focal, poses[0])
print(test_origs)
print(test_dirs)

In [None]:
def sample_rays(origs, dirs, near, far, N_samples, rand=True):
    z_vals = np.linspace(near, far, N_samples)
    if rand:
        z_vals += np.random.uniform(z_vals.shape) * (far-near)/N_samples
    sample_points = []
    for z_val in z_vals:
        sample_points.append(origs + dirs * z_val)
    return np.stack(sample_points), np.repeat([dirs], N_samples, axis=0), z_vals

test_sample_locations, test_sample_dirs, test_z_vals = sample_rays(test_origs, test_dirs, near, far, 3)
print(test_sample_locations)
print(test_sample_locations.shape)
print(test_sample_dirs)
print(test_sample_dirs.shape)
print(test_z_vals)

In [None]:
def encode_input(locations, dirs, L):
    ret = [dirs]
    for i in range(L):
        for fn in [np.sin, np.cos]:
            ret.append(fn(2.**i * locations))
    return np.concatenate(ret, -1)

test_model_input = encode_input(test_sample_locations, test_sample_dirs, loc_enc_dim)
print(test_model_input[0])
print(test_model_input.shape)

In [None]:
class Model(torch.nn.Module):
    def __init__(self, in_dim, hidden_dim):
        super().__init__()
        self.in_dim = in_dim
        self.linear1 = torch.nn.Linear(self.in_dim, hidden_dim)
        self.activation1 = torch.nn.ReLU()
        self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim)
        self.activation2 = torch.nn.ReLU()
        self.linear3 = torch.nn.Linear(hidden_dim, 2)
        self.activation3 = torch.nn.ReLU()

    def forward(self, x):
        orig_shape = tuple(x.shape)
        x = x.reshape(-1, self.in_dim)
        #x = 0.001 * torch.mul(torch.ones(2, x.size()[0]), 1/((1.0-x[:, 3])**2 + (0.0-x[:, 4])**2 + (0.0-x[:, 5])**2)).transpose(0, 1)
        x = self.linear1(x)
        x = self.activation1(x)
        x = self.linear2(x)
        x = self.activation2(x)
        x = self.linear3(x)
        x = self.activation3(x)
        x = x.reshape((orig_shape[0], orig_shape[1], orig_shape[2], -1))
        return x

test_model = Model(3 + 6 * loc_enc_dim, 10)
test_output = test_model(torch.tensor(test_model_input, dtype=torch.float32))
print(test_output)
print(test_output.shape)

In [None]:
def render_pixels(model_output, z_vals):
    dists = torch.tensor(np.concatenate([z_vals[1:] - z_vals[:-1], [1e10]]), dtype=torch.float32)
    densities = 1. - torch.exp(-model_output[..., 0] * dists[:, None, None])
    weights = densities * torch.cumprod(1. - densities + 1e-10, -1)
    return torch.sum(weights * model_output[..., 1], axis=0)
    
test_luminances = render_pixels(test_output, test_z_vals)
print(test_luminances)
print(test_luminances.size())

In [None]:
def pred_image(pose, H, W, focal, near, far, N_samples, loc_enc_dim, model, rand=True):
    origs, dirs = ray_origins_directions(H, W, focal, pose)
    sample_locations, sample_dirs, z_vals = sample_rays(origs, dirs, near, far, N_samples, rand)
    model_input = encode_input(sample_locations, sample_dirs, loc_enc_dim)
    model_output = model(torch.tensor(model_input, dtype=torch.float32))
    return render_pixels(model_output, z_vals)
    
test_pose = np.array([
    [1, 0, 0, 1],
    [0, 1, 0, 0],
    [0, 0, 1, 0],
    [0, 0, 0, 1]
])
test_pred = pred_image(test_pose, H, W, focal, 1, 5, N_samples, loc_enc_dim, test_model)
print(test_pred.detach().numpy())
plt.imshow(test_pred.detach().numpy())
plt.show()

In [None]:
model = Model(3 + 6 * loc_enc_dim, 64)
#model.load_state_dict(torch.load("./checkpoint"))

In [None]:
def train(model, epochs, images, poses, H, W, focal, near, far, N_samples, loc_enc_dim):
    optimizer = torch.optim.Adam(model.parameters(), lr=10e-5)
    loss = torch.nn.L1Loss()
    avg_losses = []
    for e in range(epochs):
        losses = []
        for idx in range(images.shape[0]):
            image = images[idx]
            pose = poses[idx]
            optimizer.zero_grad()
            pred = pred_image(pose, H, W, focal, near, far, N_samples, loc_enc_dim, model)
            target = torch.tensor(np.average(image, axis=2))
            img_loss = loss(target, pred)
            img_loss.backward()
            optimizer.step()
            last_loss = float(img_loss)
            losses.append(last_loss)
        avg_loss = np.average(losses)
        print(f"epoch {e+1} average loss {avg_loss}")
        avg_losses.append(avg_loss)
    plt.plot(range(epochs), avg_losses)
    
train(model, 10, images, poses, H, W, focal, near, far, N_samples, loc_enc_dim)
pred = pred_image(poses[0], H, W, focal, near, far, N_samples, loc_enc_dim, model)
target = np.average(images[0], axis=2)
plt.imshow(pred.detach().numpy())
plt.show()
plt.imshow(target)
plt.show()

In [None]:
torch.save(model.state_dict(), "./checkpoint")