In [53]:
import math
import numpy as np
import torch
from pyquaternion import Quaternion
from pytorch3d.transforms import matrix_to_quaternion, quaternion_to_axis_angle, axis_angle_to_matrix, Rotate

from bc.se3 import flow2pose

In [25]:
axis = [0, math.sqrt(2) / 2, math.sqrt(2) / 2]
angle = math.pi / 2

In [26]:
quat = Quaternion(axis=axis, angle=angle)
points = np.vstack(np.meshgrid(np.linspace(-1, 1, 10), np.linspace(-1, 1, 10), np.linspace(-1, 1, 10))).reshape(3, -1).T

In [27]:
quat.rotation_matrix

array([[ 1.11022302e-16, -7.07106781e-01,  7.07106781e-01],
       [ 7.07106781e-01,  5.00000000e-01,  5.00000000e-01],
       [-7.07106781e-01,  5.00000000e-01,  5.00000000e-01]])

In [28]:
flow = np.zeros_like(points)
for i in range(points.shape[0]):
    pt = points[i]
    rot = quat.rotate(pt)
    flow[i] += rot - pt

points = torch.from_numpy(points.astype(np.float32))
flow = torch.from_numpy(flow.astype(np.float32))

In [29]:
trfm = flow2pose(
    xyz=points[None, :],
    flow=flow[None, :],
    weights=None,
    return_transform3d=True,
    return_quaternions=False,
)

In [30]:
pred_flow = trfm.transform_points(points).squeeze(0) - points
torch.allclose(flow, pred_flow, atol=1e-4)

True

In [31]:
rot_matrices, trans = flow2pose(
    xyz=points[None, :],
    flow=flow[None, :],
    weights=None,
    return_transform3d=False,
    return_quaternions=False,
    world_frameify=False,
)

In [32]:
rot_matrices

tensor([[[ 0.0000,  0.7071, -0.7071],
         [-0.7071,  0.5000,  0.5000],
         [ 0.7071,  0.5000,  0.5000]]])

In [33]:
quats = matrix_to_quaternion(matrix=rot_matrices.transpose(1, 2))
axis_ang = quaternion_to_axis_angle(quaternions=quats)

In [34]:
pred_axis = axis_ang / torch.linalg.norm(axis_ang)
pred_angle = torch.linalg.norm(axis_ang)

In [35]:
print(pred_axis)
print(pred_angle)

tensor([[5.9605e-08, 7.0711e-01, 7.0711e-01]])
tensor(1.5708)


In [36]:
pred_flow = torch.bmm(points.unsqueeze(0), rot_matrices).squeeze(0) - points
torch.allclose(flow, pred_flow, atol=1e-4)

True

In [37]:
trans.shape

torch.Size([1, 3])

In [38]:
torch.mean(flow, axis=0, keepdims=True).shape

torch.Size([1, 3])

Dense Trfm Testing
==================

In [61]:
axis_ang = torch.tensor([axis]) * angle
rot_matrix = axis_angle_to_matrix(axis_ang).transpose(1, 2)

In [62]:
pred_flow = torch.bmm(points.unsqueeze(0), rot_matrix).squeeze(0) - points
torch.allclose(flow, pred_flow, atol=1e-4)

True

In [63]:
trfm = Rotate(rot_matrix)
pred_flow = trfm.transform_points(points).squeeze(0) - points
torch.allclose(flow, pred_flow, atol=1e-4)

True