In [25]:
%load_ext autoreload
%autoreload 2

from nerfstudio.model_components.nesf_components import FieldTransformerConfig, TranformerEncoderModelConfig
import torch
import lovely_tensors as lt
import time
from nerfstudio.utils.nesf_utils import visualize_point_batch
lt.monkey_patch()

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [37]:
imput_size = 96
DEVICE = "cuda:0"
model_config = FieldTransformerConfig(
    knn=128,
    transformer=TranformerEncoderModelConfig(
        num_layers=2,
        num_heads=2,
    )
)
model = model_config.setup(input_size=imput_size)
model = model.to(DEVICE)

# parameter count
print("Parameter count: ", sum(p.numel() for p in model.parameters() if p.requires_grad))

Parameter count:  88387


In [39]:
N = 10000

neural_points = torch.rand(N, 3, device=DEVICE)
neural_features = torch.rand(N, imput_size, device=DEVICE)
query_points = torch.rand(20000, 3, device=DEVICE)

model.config.knn = 64
time1 = time.time()
K = 1
for i in range(K):
    outs = model(query_points, neural_features, neural_points)

time2 = time.time()

print("Max memory allocated: ", torch.cuda.max_memory_allocated() / 1024 / 1024, "MB")
print("Time: ", (time2 - time1) / K, "s")
print(outs.shape)
print(outs)

# reset torch stats
torch.cuda.reset_peak_memory_stats()

Max memory allocated:  12002.15380859375 MB
Time:  0.38204002380371094 s
torch.Size([20000, 96])
tensor[20000, 96] n=1920000 x∈[-2.572, 2.309] μ=-0.003 σ=0.579 grad AddmmBackward0 cuda:0


In [4]:
points = torch.arange(0, 1, 0.01, device=DEVICE)
points = torch.stack(torch.meshgrid(points, points, points), dim=-1).reshape(-1, 3)

ind = model.get_k_closest_points(torch.tensor([0.5, 0.5, 0.5], device=DEVICE).unsqueeze(0), points)


closest_points = points[ind]
print(closest_points)
print(closest_points.p)


tensor[1, 128, 3] [38;2;127;127;127mall_zeros[0m cuda:0
tensor([[[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
