In [183]:
%load_ext autoreload
%autoreload 2

import os

import pickle
import numpy as np
import cvxpy as cp
import scipy.linalg
import time

import matplotlib.pyplot as plt 
%matplotlib inline

import open3d as o3d
import pypose as pp
import torch
from torch import nn

from IPython.display import clear_output

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [184]:
coord_mesh = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1)

In [206]:
mesh = o3d.geometry.TriangleMesh.create_sphere(radius=0.1, resolution=20)
mesh.compute_vertex_normals()
o3d.visualization.draw_geometries([mesh])

In [207]:
mesh_pts = np.asarray(mesh.vertices)
mesh_tri = np.asarray(mesh.triangles)

In [208]:
edges01 = np.stack([mesh_tri[:, 0], mesh_tri[:, 1]]).T
edges12 = np.stack([mesh_tri[:, 1], mesh_tri[:, 2]]).T
edges20 = np.stack([mesh_tri[:, 2], mesh_tri[:, 0]]).T

mesh_edges = np.concatenate([edges01, np.flip(edges01, axis=1), 
                             edges12, np.flip(edges12, axis=1),
                             edges20, np.flip(edges20, axis=1)], axis=0)

In [199]:
class NodeGraph(nn.Module):
    def __init__(self, rest_pts, edges) -> None:
        super().__init__()
        
        self.num_pts = rest_pts.shape[0]
        self.rest_pts = torch.tensor(rest_pts, dtype=torch.double)
        self.num_edges = edges.shape[0]
        self.edges = torch.tensor(edges, dtype=torch.long)

        self.R_tsr = pp.Parameter(pp.identity_SO3(self.num_pts, dtype=torch.double))
        self.t_tsr = nn.Parameter(torch.zeros((self.num_pts, 3), dtype=torch.double))
        
        self.w_reg = 1.0
        self.w_handle = 1e3
    
    def get_deform_pts(self):
        deform_pts = self.R_tsr.Act(self.rest_pts) + self.t_tsr
        return deform_pts

    def forward(self, handle_idx, handle_tgt_pts):
        res_handle = torch.zeros((handle_idx.shape[0], 3)).double()
        handle_rot = self.R_tsr[handle_idx, :]
        handle_t = self.t_tsr[handle_idx, :]
        handle_src_pts = self.rest_pts[handle_idx, :]
        res_handle = handle_tgt_pts - handle_rot.Act(handle_src_pts) - handle_t

        res_reg = torch.zeros((self.num_edges, 3)).double()

        # read points
        edge0_pts = self.rest_pts[self.edges[:, 0], :]
        edge1_pts = self.rest_pts[self.edges[:, 1], :]
        # read rotations
        edge0_rot = self.R_tsr[self.edges[:, 0], :]
        # read translations
        edge0_t = self.t_tsr[self.edges[:, 0], :]
        edge1_t = self.t_tsr[self.edges[:, 1], :]

        res_reg = edge0_rot.Act(edge0_pts-edge1_pts) - (edge0_pts+edge0_t) + (edge1_pts+edge1_t)
        return torch.cat([res_handle, res_reg])

In [210]:
node_graph = NodeGraph(mesh_pts, mesh_edges)

In [211]:
handle_idx = torch.tensor([0, 41], dtype=torch.long)
handle_tgt_pts = torch.tensor([mesh_pts[0, :]+0.01, mesh_pts[41, :]+0.01], dtype=torch.double)

ball_lst = []
for i in range(handle_idx.shape[0]):
    ball = o3d.geometry.TriangleMesh.create_sphere(radius=0.01, resolution=20)
    ball.translate(handle_tgt_pts[i, :])
    ball.compute_vertex_normals()
    ball.paint_uniform_color([1, 0, 0])

    ball_lst.append(ball)
o3d.visualization.draw_geometries([coord_mesh, mesh]+ball_lst)

vis_pcd = o3d.geometry.PointCloud()

optimizer = pp.optim.GaussNewton(node_graph)
for i in range(10):
    err = optimizer.step((handle_idx, handle_tgt_pts))
    print(f"step {i}:", err)
    
    deform_pts = node_graph.get_deform_pts().detach().numpy()

    update_ball_lst = []
    for i in range(handle_idx.shape[0]):
        ball = o3d.geometry.TriangleMesh.create_sphere(radius=0.01, resolution=20)
        ball.translate(deform_pts[handle_idx[i], :])
        ball.compute_vertex_normals()
        ball.paint_uniform_color([0, 1, 0])

        update_ball_lst.append(ball)

    vis_pcd.points = o3d.utility.Vector3dVector(deform_pts)
    line_set = o3d.geometry.LineSet.create_from_point_cloud_correspondences(
        vis_pcd, vis_pcd, o3d.utility.Vector2iVector(mesh_edges))
    o3d.visualization.draw_geometries([coord_mesh, vis_pcd, line_set]+ball_lst+update_ball_lst)


: 

: 

In [80]:
res = ng(None)

In [81]:
res

tensor([[ 0.0000,  0.0000,  0.0000],
        [ 0.9493,  0.4947, -1.1476],
        [-0.9493, -0.4947,  1.1476]], dtype=torch.float64,
       grad_fn=<SubBackward0>)

In [60]:
class TestNet(nn.Module):
    def __init__(self, n, dtype=torch.double):
        super().__init__()
        self.weight = pp.Parameter(pp.randn_SE3(n, dtype=dtype))

    def forward(self, src_pts):
        return self.weight.Act(src_pts)

In [61]:
tn = TestNet(6)

In [62]:
optimizer = pp.optim.GaussNewton(tn)

In [65]:
tn.weight.Act(src_pts)

tensor([[ 1.1396, -0.0567,  0.4078],
        [ 1.7431,  0.9849, -0.1012],
        [-0.0565, -0.3222, -0.3441],
        [-0.3964, -0.3177,  0.9354],
        [-0.0024,  0.6332,  0.3844],
        [ 0.6178, -0.7093,  2.3403]], dtype=torch.float64,
       grad_fn=<ViewBackward0>)

In [66]:
tgt_pts

tensor([[ 1.1396, -0.0567,  0.4078],
        [ 1.7431,  0.9849, -0.1012],
        [-0.0565, -0.3222, -0.3441],
        [-0.3964, -0.3177,  0.9354],
        [-0.0024,  0.6332,  0.3844],
        [ 0.6178, -0.7093,  2.3403]], dtype=torch.float64)

In [64]:
for i in range(100):
    err = optimizer.step(src_pts, target = tgt_pts)
    print(f"step {i}:", err)

step 0: tensor(0.0722, dtype=torch.float64)
step 1: tensor(3.5399e-05, dtype=torch.float64)
step 2: tensor(2.2810e-11, dtype=torch.float64)
step 3: tensor(9.8528e-24, dtype=torch.float64)
step 4: tensor(0., dtype=torch.float64)
step 5: tensor(0., dtype=torch.float64)
step 6: tensor(0., dtype=torch.float64)
step 7: tensor(0., dtype=torch.float64)
step 8: tensor(0., dtype=torch.float64)
step 9: tensor(0., dtype=torch.float64)
step 10: tensor(0., dtype=torch.float64)
step 11: tensor(0., dtype=torch.float64)
step 12: tensor(0., dtype=torch.float64)
step 13: tensor(0., dtype=torch.float64)
step 14: tensor(0., dtype=torch.float64)
step 15: tensor(0., dtype=torch.float64)
step 16: tensor(0., dtype=torch.float64)
step 17: tensor(0., dtype=torch.float64)
step 18: tensor(0., dtype=torch.float64)
step 19: tensor(0., dtype=torch.float64)
step 20: tensor(0., dtype=torch.float64)
step 21: tensor(0., dtype=torch.float64)
step 22: tensor(0., dtype=torch.float64)
step 23: tensor(0., dtype=torch.float64

In [41]:
optimizer = torch.optim.SGD([X], lr = 0.2, momentum=0.9)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[2,4], gamma=0.5)

for i in range(100):
    optimizer.zero_grad()
    outputs = tgt_pts - X.Act(src_pts)
    loss = torch.norm(outputs, p=2)
    loss.backward()
    optimizer.step()
    scheduler.step()
    print(f"iteration {i}", loss.item())

iteration 0 0.7971451096979688
iteration 1 0.37133094360597485
iteration 2 0.48577792425153027
iteration 3 0.22135161734040165
iteration 4 0.528213574882948
iteration 5 0.5793738113395689
iteration 6 0.40886933898723066
iteration 7 0.21180537428187404
iteration 8 0.363594062162387
iteration 9 0.3675338429540494
iteration 10 0.15612600501735485
iteration 11 0.307457456231587
iteration 12 0.45184362093047553
iteration 13 0.3540842145489769
iteration 14 0.13217830431881675
iteration 15 0.30537883520598974
iteration 16 0.3243795211338801
iteration 17 0.11460742623143173
iteration 18 0.3314173399889008
iteration 19 0.4637362487490478
iteration 20 0.34114497235551383
iteration 21 0.10772936172793006
iteration 22 0.28204802109237187
iteration 23 0.2503406174388113
iteration 24 0.06420061813792652
iteration 25 0.17605387358126579
iteration 26 0.08038880027515018
iteration 27 0.21979147416215458
iteration 28 0.22313042250372908
iteration 29 0.07836775695118217
iteration 30 0.16530797774314646
i