In [1]:
import torch
import numpy as np
from liegroups.torch import SO3 as SO3_torch


In [2]:
    def _gen_sim_data_fast(N_rotations, N_matches_per_rotation, sigma,
                           max_rotation_angle=None, dtype=torch.double):
        axis = torch.randn(N_rotations, 3, dtype=dtype)
        axis = axis / axis.norm(dim=1, keepdim=True)
        if max_rotation_angle:
            max_angle = max_rotation_angle * np.pi / 180.
        else:
            max_angle = np.pi
        angle = max_angle * torch.rand(N_rotations, 1)
        C = SO3_torch.exp(angle * axis).as_matrix()
        if N_rotations == 1:
            C = C.unsqueeze(dim=0)
        x_1 = torch.randn(N_rotations, 3, N_matches_per_rotation, dtype=dtype)
        x_1 = x_1 / x_1.norm(dim=1, keepdim=True)
        noise = sigma * torch.randn_like(x_1)
        noise = 0
        x_2 = C.bmm(x_1) + noise
        return C, x_1, x_2


In [3]:
  C_train, x_1_train, x_2_train = _gen_sim_data_fast(
            1, 100, 1e-2, max_rotation_angle=180)

In [4]:
C_train.shape

torch.Size([1, 3, 3])

In [5]:
x_1_train

tensor([[[ 0.3053,  0.6244,  0.2937, -0.9063, -0.1695, -0.4656, -0.5402,
          -0.0780, -0.8359, -0.6967,  0.1997,  0.0347, -0.9701, -0.9556,
           0.4213, -0.0215, -0.4273,  0.1677,  0.6987, -0.3830, -0.2244,
           0.4382,  0.8924,  0.6496, -0.3979,  0.5987, -0.9194, -0.5785,
          -0.8368,  0.4039, -0.4326,  0.7687,  0.4843, -0.9241,  0.9616,
           0.2663, -0.8944, -0.2188,  0.4762,  0.3523,  0.2355,  0.2505,
          -0.8710,  0.5697, -0.5144,  0.6730,  0.5800, -0.1740,  0.8164,
           0.4041,  0.8586, -0.4566, -0.3879,  0.7196,  0.0103, -0.4749,
          -0.0356, -0.7182, -0.4690, -0.8559, -0.9762,  0.1618, -0.8765,
          -0.4276, -0.2244, -0.1034,  0.0316, -0.0194,  0.1685, -0.6864,
          -0.5981, -0.6690, -0.9581, -0.4430, -0.4956,  0.0242,  0.7041,
          -0.3003,  0.8333,  0.3636, -0.4876,  0.0204, -0.4209, -0.9446,
          -0.2941, -0.3419, -0.8138,  0.8123, -0.4954, -0.1863,  0.9688,
           0.6794,  0.8479,  0.7337, -0.3770, -0.96

In [6]:
C_train[0].T @ C_train[0]

tensor([[ 1.0000e+00,  2.0817e-17, -1.3878e-17],
        [ 2.0817e-17,  1.0000e+00,  0.0000e+00],
        [-1.3878e-17,  0.0000e+00,  1.0000e+00]], dtype=torch.float64)

In [8]:
C_train @ x_1_train == x_2_train

tensor([[[False, False, False, False, False, False, False, False, False, False,
          False, False, False, False, False, False, False, False, False, False,
          False, False, False, False, False, False, False, False, False, False,
          False, False, False, False, False, False, False, False, False, False,
          False, False, False, False, False, False, False, False, False, False,
          False, False, False, False, False, False, False, False, False, False,
          False, False, False, False, False, False, False, False, False, False,
          False, False, False, False, False, False, False, False, False, False,
          False, False, False, False, False, False, False, False, False, False,
          False, False, False, False, False, False, False, False, False, False],
         [False, False, False, False, False, False, False, False, False, False,
          False, False, False, False, False, False, False, False, False, False,
          False, False, False, False, F