In [1]:
%load_ext autoreload
%autoreload 2

from nerfstudio.model_components.nesf_components import FieldTransformerConfig, TranformerEncoderModelConfig
import torch
import lovely_tensors as lt
import time
from nerfstudio.utils.nesf_utils import visualize_point_batch
lt.monkey_patch()

  warn(f"Failed to load image Python extension: {e}")


In [2]:
FEATURE_DIM = 48
DEVICE = "cuda:0"
model_config = FieldTransformerConfig(
    knn=128,
    mode = "transformer",
    transformer=TranformerEncoderModelConfig(
        num_layers=2,
        num_heads=2,
    )
)
model = model_config.setup(input_size=FEATURE_DIM)
model = model.to(DEVICE)

# parameter count
print("Parameter count: ", sum(p.numel() for p in model.parameters() if p.requires_grad))

Parameter count:  82147


In [3]:
# NP=49152
NP=4915
# QP=24576
QP=2457
CHUNK_SIZE=4096

neural_points = torch.rand(NP, 3, device=DEVICE)
neural_features = torch.rand(NP, FEATURE_DIM, device=DEVICE, requires_grad=True)
query_points = torch.rand(QP, 3, device=DEVICE)

In [4]:
print("Max cuda memory allocated: ", torch.cuda.max_memory_allocated(DEVICE) / 1024 / 1024, "MB")
k_idx1, k_dist = model.get_k_closest_points_deprecated(query_points, neural_points)
print("Max cuda memory allocated: ", torch.cuda.max_memory_allocated(DEVICE) / 1024 / 1024, "MB")
k_idx2, k_dist2 = model.get_k_closest_points(query_points, neural_points)

print("Is identical: ", torch.all(k_idx1 == k_idx2))
print("Is identical: ", torch.allclose(k_dist, k_dist2))
print("Max cuda memory allocated: ", torch.cuda.max_memory_allocated(DEVICE) / 1024 / 1024, "MB")

Max cuda memory allocated:  1.3056640625 MB


Max cuda memory allocated:  52.96435546875 MB
Is identical:  tensor bool cuda:0 True
Is identical:  True
Max cuda memory allocated:  52.96435546875 MB


In [7]:


model.config.knn = 64
time1 = time.time()

#process in chunks
K=1
# for i in range(0, QP, CHUNK_SIZE):
#     K+=1
#     qp = query_points[i:i+CHUNK_SIZE]
outs = model(query_points, neural_features, neural_points)

time2 = time.time()

print("Max memory allocated: ", torch.cuda.max_memory_allocated() / 1024 / 1024, "MB")
print("Time: ", (time2 - time1) / K, "s")
print(outs.shape)
print(outs)

# reset torch stats
torch.cuda.reset_peak_memory_stats()

shapes torch.Size([24576, 3]) torch.Size([49152, 48]) torch.Size([49152, 3])


closest_ind torch.Size([24576, 64])
closest ind tensor[24576, 64] i64 [38;2;127;127;127mall_zeros[0m cuda:0
rel_pos_feat torch.Size([24576, 65, 51])
Max memory allocated:  26940.9140625 MB
Time:  0.02077317237854004 s
torch.Size([24576, 48])
tensor[24576, 48] n=1179648 x∈[-2.136, 1.849] μ=-0.129 σ=0.633 grad AddmmBackward0 cuda:0


In [4]:
points = torch.arange(0, 1, 0.01, device=DEVICE)
points = torch.stack(torch.meshgrid(points, points, points), dim=-1).reshape(-1, 3)

ind = model.get_k_closest_points(torch.tensor([0.5, 0.5, 0.5], device=DEVICE).unsqueeze(0), points)


closest_points = points[ind]
print(closest_points)
print(closest_points.p)


tensor[1, 128, 3] [38;2;127;127;127mall_zeros[0m cuda:0
tensor([[[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
