In [1]:
import torch
import numpy as np

from raytorch.structures import learnable_meshes
from raytorch.ops.ray_intersect import ray_triangle_intersect_iter
from raytorch.LiDAR import LiDAR_base
from raytorch.visualization import visualize_LiDAR, visualize_point_clouds

from pytorch3d.utils import ico_sphere
from pytorch3d.vis.plotly_vis import plot_batch_individually, plot_scene
from pytorch3d.transforms import Translate

In [2]:
ball = ico_sphere(level = 1)
translate = Translate(x=0, y=10, z=0)
inverse_translate = translate.inverse()
obj = learnable_meshes(ball)

In [3]:
lidar = LiDAR_base(torch.tensor([0.0, 0.0, 0.0]),
                    azi_range = [0, 180],
                    polar_range = [80, 100],
                    res = [2.5, 2.5],
                    min_range = 1.5,
                    max_range = 20)

In [4]:
intersection = lidar.scan_triangles(obj.get_deformed_meshes(translate))
print(intersection)
plot_scene({
    "test": {"ray": visualize_LiDAR(lidar),
             "points": visualize_point_clouds(intersection)},
},
    raybundle_ray_line_width=2.0,
    pointcloud_marker_size=2.0
)

tensor([[ 8.4904e-01,  9.7046e+00,  4.2533e-01],
        [ 8.2652e-01,  9.4471e+00,  5.8068e-16],
        [ 8.4904e-01,  9.7046e+00, -4.2533e-01],
        [ 4.2492e-01,  9.7322e+00,  8.5227e-01],
        [ 4.0136e-01,  9.1927e+00,  4.0174e-01],
        [ 3.9699e-01,  9.0926e+00,  5.5729e-16],
        [ 4.0136e-01,  9.1927e+00, -4.0174e-01],
        [ 4.2492e-01,  9.7322e+00, -8.5226e-01],
        [ 5.7846e-16,  9.4470e+00,  8.2650e-01],
        [ 5.5723e-16,  9.1003e+00,  3.9733e-01],
        [ 5.5109e-16,  9.0000e+00,  5.5109e-16],
        [ 5.5723e-16,  9.1003e+00, -3.9733e-01],
        [ 5.7846e-16,  9.4470e+00, -8.2650e-01],
        [-4.2492e-01,  9.7322e+00,  8.5226e-01],
        [-4.0136e-01,  9.1927e+00,  4.0174e-01],
        [-3.9699e-01,  9.0926e+00,  5.5729e-16],
        [-4.0136e-01,  9.1927e+00, -4.0174e-01],
        [-4.2492e-01,  9.7322e+00, -8.5227e-01],
        [-8.4904e-01,  9.7046e+00,  4.2533e-01],
        [-8.2652e-01,  9.4471e+00,  5.8068e-16],
        [-8.4904e-01

In [5]:
intersection = lidar.scan_triangles(obj.get_deformed_meshes(translate),
                                               method="single_ray")
print(intersection) # [N, 3]
print(inverse_translate.transform_points(intersection))

tensor([[ 8.4904e-01,  9.7046e+00,  4.2533e-01],
        [ 8.2652e-01,  9.4471e+00,  5.8068e-16],
        [ 8.4904e-01,  9.7046e+00, -4.2533e-01],
        [ 4.2492e-01,  9.7322e+00,  8.5227e-01],
        [ 4.0136e-01,  9.1927e+00,  4.0174e-01],
        [ 3.9699e-01,  9.0926e+00,  5.5729e-16],
        [ 4.0136e-01,  9.1927e+00, -4.0174e-01],
        [ 4.2492e-01,  9.7322e+00, -8.5226e-01],
        [ 5.7846e-16,  9.4470e+00,  8.2650e-01],
        [ 5.5723e-16,  9.1003e+00,  3.9733e-01],
        [ 5.5109e-16,  9.0000e+00,  5.5109e-16],
        [ 5.5723e-16,  9.1003e+00, -3.9733e-01],
        [ 5.7846e-16,  9.4470e+00, -8.2650e-01],
        [-4.2492e-01,  9.7322e+00,  8.5226e-01],
        [-4.0136e-01,  9.1927e+00,  4.0174e-01],
        [-3.9699e-01,  9.0926e+00,  5.5729e-16],
        [-4.0136e-01,  9.1927e+00, -4.0174e-01],
        [-4.2492e-01,  9.7322e+00, -8.5227e-01],
        [-8.4904e-01,  9.7046e+00,  4.2533e-01],
        [-8.2652e-01,  9.4471e+00,  5.8068e-16],
        [-8.4904e-01

In [6]:
eta = 0.1
iters = 10

for i in range(iters):
    intersection = lidar.scan_triangles(obj.get_deformed_meshes(translate))
    l2_loss = inverse_translate.transform_points(
        intersection) - torch.zeros_like(intersection)
    l2_loss = torch.norm(l2_loss, dim=-1).sum()

    parameter = obj.get_parameters()
    if parameter.grad is not None:
        parameter.grad.zero_()
    l2_loss.backward()

    grad = obj.get_gradient()
    obj.update_parameters(parameter - eta * grad)
    print(f"l2 loss:{l2_loss}")
    
plot_scene({
    "original": {"mesh": obj.get_meshes()},
    "deformed": {"mesh": obj.get_deformed_meshes()},
}, ncols=2)

l2 loss:20.808931350708008
l2 loss:20.02924156188965
l2 loss:19.51750946044922
l2 loss:19.007099151611328
l2 loss:18.45610809326172
l2 loss:18.071189880371094
l2 loss:17.718372344970703
l2 loss:17.503564834594727
l2 loss:17.22599983215332
l2 loss:16.897844314575195
