In [15]:
import numpy as np
import torch
from pytorch3d.transforms import quaternion_invert, quaternion_apply, axis_angle_to_quaternion

# Two points only

In [16]:
# Let's assume A and B are 3D points
A = torch.tensor([1.0, 5.0, 3.0])
B = torch.tensor([4.0, 3.0, 6.0])

print("A", A)
print("B", B)

normalized_A, normalized_B = A / torch.norm(A), B / torch.norm(B)

# Calculate the axis of rotation
axis = torch.cross(normalized_A, normalized_B)
axis = axis / torch.norm(axis)

# The angle of rotation is the arccosine of the dot product of the normalized vectors
angle_of_rotation = torch.acos(torch.dot(normalized_A, normalized_B))

# The axis-angle representation is the axis of rotation scaled by the angle of rotation
axis_angle = axis * angle_of_rotation

# Convert the axis-angle representation to a quaternion
rotation = axis_angle_to_quaternion(axis_angle)

# Apply the rotation to point A
A_rotated = quaternion_apply(rotation, A)

# Compute the translation vector
translation = B - A_rotated

# Apply the translation to point A
A_transformed = A_rotated + translation

print("A transformed into B:", A_transformed)

# Compute the inverse rotation and translation
inverse_rotation = quaternion_invert(rotation)
inverse_translation = -translation

# Apply the inverse transformation: from B, recover A
B_transformed = B.clone()
B_transformed += inverse_translation # Apply inverse translation
B_transformed = quaternion_apply(inverse_rotation, B_transformed)  # Apply inverse rotation

print("B transformed into A:", B_transformed)

A tensor([1., 5., 3.])
B tensor([4., 3., 6.])
A transformed into B: tensor([4., 3., 6.])
B transformed into A: tensor([1.0000, 5.0000, 3.0000])


In [17]:
print(translation.shape)

torch.Size([3])


# Mimicking the data structure

In [18]:
import numpy as np
import torch
from pytorch3d.transforms import quaternion_apply, axis_angle_to_quaternion

n_obs = 4
n_joints = 2

all_A = torch.from_numpy(np.random.randint(0, 10, size=((n_obs, 3)))).float()
all_B = torch.from_numpy(np.random.randint(0, 10, size=((n_obs, n_joints, 3)))).float()

print("A", all_A)
print("B", all_B)

transformations = torch.zeros((n_obs, n_joints, 7))

for i in range(n_obs):
    A = all_A[i]
    for j in range(n_joints):
        B = all_B[i, j]
        normalized_A, normalized_B = A / torch.norm(A), B / torch.norm(B)
        
        # Calculate the axis of rotation
        axis = torch.cross(normalized_A, normalized_B)
        axis = axis / torch.norm(axis)
        
        # The angle of rotation is the arccosine of the dot product of the normalized vectors
        angle_of_rotation = torch.acos(torch.dot(normalized_A, normalized_B))
        
        # The axis-angle representation is the axis of rotation scaled by the angle of rotation
        axis_angle = axis * angle_of_rotation
        
        # Convert the axis-angle representation to a quaternion
        rotation = axis_angle_to_quaternion(axis_angle)
        
        # Apply the rotation to point A
        A_rotated = quaternion_apply(rotation, A)
        
        # Compute the translation vector
        translation = B - A_rotated
        
        # Apply the translation to point A
        A_transformed = A_rotated + translation
        
        print("A transformed into B:", A_transformed)
        
        transformations[i, j] = torch.cat((rotation, translation))

A tensor([[5., 9., 0.],
        [6., 8., 0.],
        [6., 5., 4.],
        [9., 0., 5.]])
B tensor([[[0., 3., 7.],
         [2., 2., 6.]],

        [[4., 7., 1.],
         [8., 8., 7.]],

        [[8., 3., 4.],
         [7., 8., 9.]],

        [[5., 8., 8.],
         [3., 1., 5.]]])
A transformed into B: tensor([0., 3., 7.])
A transformed into B: tensor([2., 2., 6.])
A transformed into B: tensor([4., 7., 1.])
A transformed into B: tensor([8., 8., 7.])
A transformed into B: tensor([8., 3., 4.])
A transformed into B: tensor([7., 8., 9.])
A transformed into B: tensor([5., 8., 8.])
A transformed into B: tensor([3., 1., 5.])


# Vectorising the operations

In [19]:
# Normalize A and B
normalized_A = all_A / torch.norm(all_A, dim=-1, keepdim=True)
normalized_B = all_B / torch.norm(all_B, dim=-1, keepdim=True)

# Calculate the axis of rotation
axis = torch.cross(normalized_A.unsqueeze(1), normalized_B)
axis = axis / torch.norm(axis, dim=-1, keepdim=True)

# The angle of rotation is the arccosine of the dot product of the normalized vectors
angle_of_rotation = torch.acos((normalized_A.unsqueeze(1) * normalized_B).sum(-1))

# The axis-angle representation is the axis of rotation scaled by the angle of rotation
axis_angle = axis * angle_of_rotation.unsqueeze(-1)

# Convert the axis-angle representation to a quaternion
rotation = axis_angle_to_quaternion(axis_angle)

# Apply the rotation to point A
A_rotated = quaternion_apply(rotation, all_A.unsqueeze(1).expand(-1, n_joints, -1))

# Compute the translation vector
translation = all_B - A_rotated

# Apply the translation to point A
A_transformed = A_rotated + translation

print("A transformed into B:")
print(A_transformed)

transformations = torch.cat((rotation, translation), dim=-1)

A transformed into B:
tensor([[[0., 3., 7.],
         [2., 2., 6.]],

        [[4., 7., 1.],
         [8., 8., 7.]],

        [[8., 3., 4.],
         [7., 8., 9.]],

        [[5., 8., 8.],
         [3., 1., 5.]]])


In [20]:
print(all_A.shape)

torch.Size([4, 3])
