In [None]:
import torch
import numpy as np
import time

from raytorch.structures import learnable_meshes
from raytorch.ops.ray_intersect import ray_triangle_intersect_iter
from raytorch.LiDAR import LiDAR_base
from raytorch.visualization import visualize_LiDAR, visualize_point_clouds

from pytorch3d.utils import ico_sphere
from pytorch3d.vis.plotly_vis import plot_batch_individually, plot_scene
from pytorch3d.transforms import Translate

Create a learnable meshes

In [None]:
ball = ico_sphere(level = 2)
translate = Translate(x=0, y=10, z=0)

obj = learnable_meshes(ball)

Create a lidar

In [None]:
lidar = LiDAR_base(torch.tensor([0.0, 0.0, 0.0]),
                    azi_range = [0, 180])

Check different scaning methods

In [None]:
res_intersection = []

for method in ["single_ray", "batch_ray", "iter"]:
    for aabb_test in [True, False]:  
        start_time = time.time()
        intersection = lidar.scan_triangles(obj.get_deformed_meshes(translate),
                                            method=method,
                                            aabb_test=aabb_test)
        end_time = time.time()
        res_intersection.append(intersection)
        print(f"[method = {method}\t, aabb_test = {aabb_test}\t] time: {end_time - start_time}\t point size: {intersection.size()}")

error_flag = False
for i in range(res_intersection.__len__()):
    for j in range(res_intersection.__len__()):
        if i != j:
            if not torch.allclose(res_intersection[i],
                              res_intersection[j]):
                error_flag = True
                print(f"i = {i}\t, j = {j}\t is not equivalent.")
                
assert not error_flag, "ray intersection test failed."

In [None]:
grad_input = torch.rand_like(res_intersection[0])
grad_output = []

for i in range(res_intersection.__len__()):
    if obj.get_gradient() is not None:
        obj.get_gradient().zero_()
    
    res_intersection[i].backward(grad_input)
    grad_output.append(obj.get_gradient())

error_flag = False
for i in range(res_intersection.__len__()):
    for j in range(res_intersection.__len__()):
        if i != j:
            if not torch.allclose(grad_output[i],
                                  grad_output[j]):
                error_flag = True
                print(f"i = {i}\t, j = {j}\t gradient is not equivalent.")
assert not error_flag, "ray gradient test failed."