In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from torchmetrics import Accuracy
import numpy as np

In [3]:
class NerfModel(nn.Module):
    def __init__(self, input_shape, output_dim: int = 3, num_layers: int = 4, num_channels: int = 256):      
        super(NerfModel, self).__init__()
        self.num_layers = num_layers
        self.input_channels = input_shape[0]
        self.conv_layers = nn.ModuleList()

        # Create the layers
        for i in range(num_layers - 1):
            self.conv_layers.append(nn.Conv2d(self.input_channels if i == 0 else num_channels, num_channels, kernel_size=3, padding=0))
            self.conv_layers.append(nn.BatchNorm2d(num_channels))
        
        # Output layer
        self.output_layer = nn.Conv2d(num_channels, output_dim, kernel_size=1, padding=0)

    def forward(self, x):
        for i in range(self.num_layers - 1):
            x = F.relu(self.conv_layers[2*i](x))
            x = self.conv_layers[2*i+1](x)
        x = torch.sigmoid(self.output_layer(x))
        return x

In [4]:
class NerfTrainer(pl.LightningModule):
    def __init__(self, input_shape, output_dim: int = 3, num_layers: int = 4, num_channels: int = 256, learning_rate: float = 1e-3):
        super(NerfTrainer, self).__init__()
        self.model = NerfModel(input_shape, output_dim, num_layers, num_channels)
        self.loss_fn = torch.nn.MSELoss()
        self.accuracy = Accuracy()
        self.learning_rate = learning_rate

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss_fn(y_hat, y)
        acc = self.accuracy(y_hat, y)
        self.log('train_loss', loss, prog_bar=True)
        self.log('train_acc', acc, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss_fn(y_hat, y)
        acc = self.accuracy(y_hat, y)
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss_fn(y_hat, y)
        acc = self.accuracy(y_hat, y)
        self.log('test_loss', loss, prog_bar=True)
        self.log('test_acc', acc, prog_bar=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

In [6]:
# Fourier Encoded Features
class FourierEncoding(object):
    def __init__(self, num_input_channels: int, mapping_size: int = 256, scale: int = 10):
        super(FourierEncoding, self).__init__()
        self._num_input_channels = num_input_channels
        self._mapping_size = mapping_size
        self._B = np.random.randn(num_input_channels, mapping_size) * scale

    def __call__(self, x: np.ndarray) -> np.ndarray:
        assert len(x.shape) == 4, 'Expected 4D input (got {}D input)'.format(len(x.shape))
        batches, height, width, channels = x.shape
        assert channels == self._num_input_channels, \
            "Expected input to have {} channels (got {} channels)".format(self._num_input_channels, channels)
        # Make shape compatible for matmul with _B.
        # From [B, H, W, C] to [(B*H*W), C].
        x = x.reshape(batches * height * width, channels)
        x = x @ self._B
        # From [(B*H*W), C] to [B, H, W, C]
        x = x.reshape(batches, height, width, self._mapping_size)
        x = 2 * np.pi * x
        return np.concatenate((np.sin(x), np.cos(x)), axis=-1).astype('float32')

In [7]:
# Wavelet Encoded Features
class WaveletEncoding(object):
    def __init__(self):
        super(WaveletEncoding, self).__init__()
        
    def __call__(self, x: np.ndarray) -> np.ndarray:
        pass

In [None]:
# DataLoader

In [None]:
# Metrics

In [None]:
# Output