In [1]:
import torch
import tinycudann as tcnn
import numpy as np
import matplotlib.pyplot as plt
from process_data import process_radar_data

In [2]:
def forward_model(f_est, H):
    ''' Forward model to compute the measurements
        field_estimate = -H.conj().T @ g 
        Therefore, g = -H @ field_estimate'''
    H_pseudo_inv = torch.linalg.pinv(H.conj().T)
    g = -H_pseudo_inv @ f_est
    return g


class InstantNGPFieldRepresentation(torch.nn.Module):
    def __init__(self, input_dim=3, hidden_dim=128, output_dim=2):
        super().__init__()
        
        self.encoding = tcnn.Encoding(
            n_input_dims=input_dim,
            encoding_config={
                "otype": "HashGrid",
                "n_levels": 16,
                "n_features_per_level": 2,
                "log2_hashmap_size": 19,
                "base_resolution": 16,
                "per_level_scale": 1.5
            },
        )
        self.network = tcnn.Network(
            n_input_dims=self.encoding.n_output_dims,
            n_output_dims=output_dim,
            network_config={
                "otype": "CutlassMLP",
                "activation": "Sine",
                "output_activation": "None",
                "n_neurons": hidden_dim,
                "n_hidden_layers": 2
            },
        )
    
    def forward(self, x):
        x_encoded = self.encoding(x)
        out = self.network(x_encoded).float()
        return torch.complex(out[:, 0], out[:, 1])

In [3]:
model = InstantNGPFieldRepresentation(input_dim=3, hidden_dim=128, output_dim=2).cuda()
print(model)

InstantNGPFieldRepresentation(
  (encoding): Encoding(n_input_dims=3, n_output_dims=32, seed=1337, dtype=torch.float16, hyperparams={'base_resolution': 16, 'hash': 'CoherentPrime', 'interpolation': 'Linear', 'log2_hashmap_size': 19, 'n_features_per_level': 2, 'n_levels': 16, 'otype': 'Grid', 'per_level_scale': 1.5, 'type': 'Hash'})
  (network): Network(n_input_dims=32, n_output_dims=2, seed=1337, dtype=torch.float16, hyperparams={'encoding': {'offset': 0.0, 'otype': 'Identity', 'scale': 1.0}, 'network': {'activation': 'Sine', 'n_hidden_layers': 2, 'n_neurons': 128, 'otype': 'CutlassMLP', 'output_activation': 'None'}, 'otype': 'NetworkWithInputEncoding'})
)


In [5]:
#test the model
x = torch.randn(10, 3).cuda()
y = model(x)
print(y)

tensor([ 3.3736e-04-1.4126e-04j, -1.0467e-04+3.4738e-04j,
         6.4969e-05-1.9133e-05j,  5.6028e-05-4.6253e-05j,
         2.8729e-04-1.8823e-04j,  2.2805e-04+1.4877e-04j,
        -6.7353e-05+1.4460e-04j, -5.9187e-05-1.7262e-04j,
         1.0592e-04+8.7619e-05j, -4.1771e-04+5.1069e-04j], device='cuda:0',
       grad_fn=<ComplexBackward0>)
