import libraries

In [2]:
import torch
import torch.nn as nn

NeRF network class

In [3]:
class NeRF(torch.nn.Module):
    def __init__(self, positional_encoding_pos=10, positional_encoding_dir=4, hidden_dim=256):
        super(NeRF, self).__init__()

        self.positional_encoding_pos = positional_encoding_pos
        self.positional_encoding_dir = positional_encoding_dir

        self.block1 = torch.nn.Sequential(
            nn.Linear(in_features=positional_encoding_pos * 6,out_features=hidden_dim),
            nn.ReLU(),
            nn.Linear(in_features=hidden_dim,out_features=hidden_dim),
            nn.ReLU(),
            nn.Linear(in_features=hidden_dim,out_features=hidden_dim),
            nn.ReLU(),
            nn.Linear(in_features=hidden_dim,out_features=hidden_dim),
            nn.ReLU(),
            nn.Linear(in_features=hidden_dim,out_features=hidden_dim),
            nn.ReLU()
        )

        self.block2 = torch.nn.Sequential(
            nn.Linear(in_features=hidden_dim + positional_encoding_pos * 6,out_features=hidden_dim),
            nn.ReLU(),
            nn.Linear(in_features=hidden_dim,out_features=hidden_dim),
            nn.ReLU(),
            nn.Linear(in_features=hidden_dim,out_features=hidden_dim),
            nn.ReLU(),
            nn.Linear(in_features=hidden_dim,out_features=hidden_dim + 1)
        )

        self.block3 = torch.nn.Sequential(
            nn.Linear(in_features=hidden_dim + positional_encoding_dir * 6,out_features=128),
            nn.ReLU(),
            nn.Linear(in_features=128,out_features=3),
            nn.Sigmoid()
        )

        self.relu = nn.ReLU()

    def positional_encoding(self, x, L):
        out = []
        for i in range(L):
            out.append(torch.sin(2 ** i * torch.pi * x))
            out.append(torch.cos(2 ** i * torch.pi * x))
        return torch.cat(out, dim=-1)

    def forward(self, pos, dir):
        pos_enc_pos = self.positional_encoding(pos, self.positional_encoding_pos)
        pos_enc_dir = self.positional_encoding(dir, self.positional_encoding_dir)
        x = self.block1(pos_enc_pos)
        x = self.block2(torch.cat((x, pos_enc_pos), dim=-1))
        sigma = self.relu(x[:, -1])
        c = self.block3(torch.cat((x[:, :-1], pos_enc_dir), dim=-1))
        return c, sigma

Forward checking

In [4]:
x = torch.rand((64, 3))
d = torch.rand((64, 3))
net = NeRF()
c, sigma = net(x, d)
print(c.shape)
print(sigma.shape)

torch.Size([64, 3])
torch.Size([64])


Cumulative Transmittance

In [8]:
def calculate_T(sigdel):
    T = torch.cumprod(sigdel, dim=1)
    T = torch.concat((torch.ones((T.shape[0], 1)).to(T.device), T[:, :-1]), dim=1)
    return T

In [12]:
T = calculate_T(torch.rand((2, 5)))

In [11]:
def calculate_C(T, alphas, colors):
    weights = T * alphas
    weights = weights.unsqueeze(2)
    C = (weights * colors).sum(1)
    return C

In [14]:
C = calculate_C(
    T,
    torch.rand((2, 5)),
    torch.rand((2, 5, 3))
)

In [15]:
C

tensor([[0.9653, 0.6755, 1.0447],
        [0.5151, 0.3976, 0.7528]])

In [31]:
def create_t_samples(hn, hf, n_bins, n_rays, device):
    t = torch.linspace(hn, hf, n_bins+1, device=device).expand(n_rays, n_bins+1)
    u = torch.rand(t.shape)
    t = t + (u * ((hf-hn) / n_bins))
    return t[:, :-1]


In [32]:
create_t_samples(0.0, 0.5, 10, 4, T.device)

tensor([[0.0306, 0.0689, 0.1096, 0.1600, 0.2407, 0.2562, 0.3108, 0.3849, 0.4138,
         0.4718],
        [0.0103, 0.0802, 0.1071, 0.1927, 0.2112, 0.2782, 0.3496, 0.3914, 0.4244,
         0.4724],
        [0.0349, 0.0564, 0.1247, 0.1633, 0.2183, 0.2981, 0.3358, 0.3936, 0.4411,
         0.4550],
        [0.0469, 0.0642, 0.1196, 0.1875, 0.2226, 0.2639, 0.3171, 0.3715, 0.4320,
         0.4860]])

In [None]:
def create_ray_points(ray_origins, ray_directions, t):
    x = ray_origins + t * ray_directions
    return x