In [2]:
import torch
from pytorch3d.transforms import quaternion_invert, quaternion_apply, axis_angle_to_quaternion

In [4]:
# Let's assume A and B are 3D points
A = torch.tensor([1.0, 5.0, 3.0])
B = torch.tensor([4.0, 3.0, 6.0])

print("A", A)
print("B", B)

normalized_A, normalized_B = A / torch.norm(A), B / torch.norm(B)

# Calculate the axis of rotation
axis = torch.cross(normalized_A, normalized_B)
axis = axis / torch.norm(axis)

# The angle of rotation is the arccosine of the dot product of the normalized vectors
angle_of_rotation = torch.acos(torch.dot(normalized_A, normalized_B))

# The axis-angle representation is the axis of rotation scaled by the angle of rotation
axis_angle = axis * angle_of_rotation

# Convert the axis-angle representation to a quaternion
rotation = axis_angle_to_quaternion(axis_angle)

# Apply the rotation to point A
A_rotated = quaternion_apply(rotation, A)

# Compute the translation vector
translation = B - A_rotated

# Apply the translation to point A
A_transformed = A_rotated + translation

print("A transformed into B:", A_transformed)

# Compute the inverse rotation and translation
inverse_rotation = quaternion_invert(rotation)
inverse_translation = -translation

# Apply the inverse transformation: from B, recover A
B_transformed = B.clone()
B_transformed += inverse_translation # Apply inverse translation
B_transformed = quaternion_apply(inverse_rotation, B_transformed)  # Apply inverse rotation

print("B transformed into A:", B_transformed)

A tensor([1., 5., 3.])
B tensor([4., 3., 6.])
A transformed into B: tensor([4., 3., 6.])
B transformed into A: tensor([1.0000, 5.0000, 3.0000])
