In [1]:
%load_ext autoreload
%autoreload 2

from tqdm import tqdm
from nerfstudio.model_components.nesf_components import FieldTransformerConfig, TranformerEncoderModelConfig
import torch
import lovely_tensors as lt
import time
from nerfstudio.utils.nesf_utils import visualize_point_batch
lt.monkey_patch()

  warn(f"Failed to load image Python extension: {e}")


In [2]:
imput_size = 48
DEVICE = "cuda:0"
model_config = FieldTransformerConfig(
    knn=64,
    mode="transformer",
    transformer=TranformerEncoderModelConfig(
        num_layers=2,
        num_heads=4,
        dim_feed_forward=64,
    )
)
model = model_config.setup(input_size=imput_size)
model = model.to(DEVICE)

# parameter count
print("Parameter count: ", sum(p.numel() for p in model.parameters() if p.requires_grad))
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)


Parameter count:  82147


In [3]:
N=16384
Q=16384

neural_points = torch.rand(N, 3, device=DEVICE)
neural_features = torch.rand(N, imput_size, device=DEVICE, requires_grad=False)
query_points = torch.rand(Q, 3, device=DEVICE)
print("Max memory allocated: ", torch.cuda.max_memory_allocated() / 1024 / 1024, "MB")
torch.cuda.reset_peak_memory_stats()

Max memory allocated:  3.6953125 MB


In [4]:
# warmup
for _ in tqdm(range(50)):
    with torch.no_grad():
        outs = model(query_points, neural_features, neural_points)

total_time = 0
repetitions = 10
for _ in tqdm(range(repetitions)):
    with torch.no_grad():
        starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
        starter.record()
        timea = time.time()
        outs = model(query_points, neural_features, neural_points)
        ender.record()
        torch.cuda.synchronize()  # synchronize CUDA operations
        timeb = time.time()
        curr_time = starter.elapsed_time(ender)
        time.sleep(0.05)
        total_time += curr_time

print("CUDA: Forward - feature transformer: ", total_time/repetitions)
print("Max memory allocated: ", torch.cuda.max_memory_allocated() / 1024 / 1024, "MB")

# reset torch stats
torch.cuda.reset_peak_memory_stats()

  0%|          | 0/50 [00:00<?, ?it/s]

100%|██████████| 50/50 [00:08<00:00,  6.05it/s]
100%|██████████| 10/10 [00:02<00:00,  4.81it/s]

CUDA: Forward - feature transformer:  131.27601928710936
Max memory allocated:  4379.2578125 MB





In [14]:
print("CUDA: Forward - feature transformer: ", total_time/repetitions)


CUDA: Forward - feature transformer:  183.86016326904297


In [25]:
try:
    del neural_features
except:
    pass
try:
    del outs
except:
    pass
model.zero_grad()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

In [8]:
torch.cuda.empty_cache()

In [None]:
points = torch.arange(0, 1, 0.01, device=DEVICE)
points = torch.stack(torch.meshgrid(points, points, points), dim=-1).reshape(-1, 3)

ind = model.get_k_closest_points(torch.tensor([0.5, 0.5, 0.5], device=DEVICE).unsqueeze(0), points)


closest_points = points[ind]
print(closest_points)
print(closest_points.p)


tensor[1, 128, 3] [38;2;127;127;127mall_zeros[0m cuda:0
tensor([[[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
