In [None]:
import torch
import numpy as np

from raytorch.structures import learnable_meshes
from raytorch.ops.ray_intersect import ray_triangle_intersect_iter
from raytorch.LiDAR import LiDAR_base
from raytorch.visualization import visualize_LiDAR, visualize_point_clouds

from pytorch3d.utils import ico_sphere
from pytorch3d.vis.plotly_vis import plot_batch_individually, plot_scene
from pytorch3d.transforms import Translate

In [None]:
ball = ico_sphere(level = 1)
translate = Translate(x=0, y=10, z=0)
inverse_translate = translate.inverse()
obj = learnable_meshes(ball)

In [None]:
lidar = LiDAR_base(torch.tensor([0.0, 0.0, 0.0]),
                    azi_range = [0, 180])

In [None]:
intersection = lidar.scan_triangles(obj.get_deformed_meshes(translate))
print(intersection)
plot_scene({
    "test": {"ray": visualize_LiDAR(lidar),
             "points": visualize_point_clouds(intersection)},
},
    raybundle_ray_line_width=2.0,
    pointcloud_marker_size=2.0
)

In [None]:
intersection = lidar.scan_triangles(obj.get_deformed_meshes(translate),
                                               method="single_ray")
print(intersection) # [N, 3]
print(inverse_translate.transform_points(intersection))

In [None]:
eta = 0.1
iters = 10

for i in range(iters):
    intersection = lidar.scan_triangles(obj.get_deformed_meshes(translate))
    l2_loss = inverse_translate.transform_points(
        intersection) - torch.zeros_like(intersection)
    l2_loss = torch.norm(l2_loss, dim=-1).sum()

    parameter = obj.get_parameters()
    if parameter.grad is not None:
        parameter.grad.zero_()
    l2_loss.backward()

    grad = obj.get_gradient()
    obj.update_parameters(parameter - eta * grad)
    print(f"l2 loss:{l2_loss}")
    
plot_scene({
    "original": {"mesh": obj.get_meshes()},
    "deformed": {"mesh": obj.get_deformed_meshes()},
}, ncols=2)