import libraries

In [4]:
import torch
import torch.nn as nn

NeRF network class

In [13]:
class NeRF(torch.nn.Module):
    def __init__(self, positional_encoding_pos=10, positional_encoding_dir=4, hidden_dim=256):
        super(NeRF, self).__init__()

        self.positional_encoding_pos = positional_encoding_pos
        self.positional_encoding_dir = positional_encoding_dir

        self.block1 = torch.nn.Sequential(
            nn.Linear(in_features=positional_encoding_pos * 6,out_features=hidden_dim),
            nn.ReLU(),
            nn.Linear(in_features=hidden_dim,out_features=hidden_dim),
            nn.ReLU(),
            nn.Linear(in_features=hidden_dim,out_features=hidden_dim),
            nn.ReLU(),
            nn.Linear(in_features=hidden_dim,out_features=hidden_dim),
            nn.ReLU(),
            nn.Linear(in_features=hidden_dim,out_features=hidden_dim),
            nn.ReLU()
        )

        self.block2 = torch.nn.Sequential(
            nn.Linear(in_features=hidden_dim + positional_encoding_pos * 6,out_features=hidden_dim),
            nn.ReLU(),
            nn.Linear(in_features=hidden_dim,out_features=hidden_dim),
            nn.ReLU(),
            nn.Linear(in_features=hidden_dim,out_features=hidden_dim),
            nn.ReLU(),
            nn.Linear(in_features=hidden_dim,out_features=hidden_dim + 1)
        )

        self.block3 = torch.nn.Sequential(
            nn.Linear(in_features=hidden_dim + positional_encoding_dir * 6,out_features=128),
            nn.ReLU(),
            nn.Linear(in_features=128,out_features=3),
            nn.Sigmoid()
        )

        self.relu = nn.ReLU()

    def positional_encoding(self, x, L):
        out = []
        for i in range(L):
            out.append(torch.sin(2 ** i * torch.pi * x))
            out.append(torch.cos(2 ** i * torch.pi * x))
        return torch.cat(out, dim=-1)

    def forward(self, pos, dir):
        pos_enc_pos = self.positional_encoding(pos, self.positional_encoding_pos)
        pos_enc_dir = self.positional_encoding(dir, self.positional_encoding_dir)
        x = self.block1(pos_enc_pos)
        x = self.block2(torch.cat((x, pos_enc_pos), dim=-1))
        sigma = self.relu(x[:, -1])
        c = self.block3(torch.cat((x[:, :-1], pos_enc_dir), dim=-1))
        return c, sigma

Forward checking

In [15]:
x = torch.rand((64, 3))
d = torch.rand((64, 3))
net = NeRF()
c, sigma = net(x, d)
print(c.shape)
print(sigma.shape)

torch.Size([64, 3])
torch.Size([64])
