In [43]:
import read_write_model as colmap_reader
import numpy as np
import torch
from torch import nn
import imageio.v3 as iio
import matplotlib.pyplot as plt
from datautils import read_data
from nerfhelpers import PositionalEncoding, get_rays, batched_get_rays
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [8]:
imagedir = './images'
camFile = 'cameras.bin'
imgFile = 'images.bin'

images, c2w, H, W, F = read_data(camFile, imgFile, imagedir)
images = torch.tensor(images, dtype = torch.float32)
c2w = torch.tensor(c2w, dtype = torch.float32)

Loading Poses...
Done
Reading Images...
Done


In [42]:
rays_o, rays_d = get_rays(H, W, F, c2w[0])

In [84]:
test_o = torch.rand((10, 3))
test_d = torch.rand((10, 3))

model = NeRF()
t = model.forward(test_o, test_d)

In [82]:
class NeRF(nn.Module):
    def __init__(self, L_pos = 10, L_dir = 4):
        super().__init__()
        self.L_pos = L_pos
        self.L_dir = L_dir
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid
        
        self.pos_input = nn.Linear(3 + self.L_pos * 2 * 3, 256)
        self.posnet = nn.ModuleList([nn.Linear(256, 256) for i in range(7)])
        self.alpha_output = nn.Linear(256, 1)

        self.view_input = nn.Linear(256 + 3 + self.L_dir * 2 * 3, 128)
        self.rgb_output = nn.Linear(128, 3)

    def forward(self, pos, view):
        pos, view = self.encode(pos, view)

        out = self.relu(self.pos_input(pos))
        for lin in self.posnet:
            out = self.relu(lin(out))

        alpha = self.alpha_output(out) 
        out = torch.cat((out, view), axis = -1)

        out = self.relu(self.view_input(out))
        rgb = self.rgb_output(out)

        return torch.cat((rgb, alpha), axis = -1)
        

    def encode(self, pos, vdir):
        posenc = [pos] #Try including base pos and dir
        direnc = [vdir]
        
        for i in range(0, self.L_pos):
            for f in [torch.sin, torch.cos]:
                posenc.append(f(pos * np.pi * (2 ** i)))
                              
        for j in range(0, self.L_dir):
            for f in [torch.sin, torch.cos]:
                direnc.append(f(vdir * np.pi * (2 ** i)))
        
        posenc = torch.cat(posenc, -1)
        direnc = torch.cat(direnc, -1)
        
        return posenc, direnc