In [33]:
import torch
import numpy as np
from liegroups.torch import SO3 as SO3_torch


In [34]:
    def _gen_sim_data_fast(N_rotations, N_matches_per_rotation, sigma,
                           max_rotation_angle=None, dtype=torch.double):
        axis = torch.randn(N_rotations, 3, dtype=dtype)
        axis = axis / axis.norm(dim=1, keepdim=True)
        if max_rotation_angle:
            max_angle = max_rotation_angle * np.pi / 180.
        else:
            max_angle = np.pi
        angle = max_angle * torch.rand(N_rotations, 1)
        C = SO3_torch.exp(angle * axis).as_matrix()
        if N_rotations == 1:
            C = C.unsqueeze(dim=0)
        x_1 = torch.randn(N_rotations, 3, N_matches_per_rotation, dtype=dtype)
        x_1 = x_1 / x_1.norm(dim=1, keepdim=True)
        noise = sigma * torch.randn_like(x_1)
        x_2 = C.bmm(x_1) + noise
        return C, x_1, x_2


In [35]:
  C_train, x_1_train, x_2_train = _gen_sim_data_fast(
            1, 100, 1e-2, max_rotation_angle=180)

In [39]:
C_train.shape

torch.Size([1, 3, 3])

In [37]:
x_1_train

tensor([[[ 0.3402, -0.9309,  0.6898,  0.5228, -0.9545, -0.4781, -0.9815,
           0.4148,  0.4894, -0.6773,  0.7755,  0.6138,  0.3074,  0.2942,
          -0.4360,  0.5251, -0.3710, -0.0499,  0.9509, -0.3317, -0.2121,
          -0.4438,  0.2411, -0.9650, -0.7968,  0.9108, -0.4822, -0.2444,
          -0.6511, -0.0523, -0.3095, -0.8972,  0.8859,  0.0371,  0.0175,
          -0.3477, -0.8084,  0.5250, -0.6148,  0.9740, -0.6856,  0.1311,
           0.9333, -0.0994,  0.4950, -0.8403, -0.5314,  0.5981,  0.6386,
           0.3567,  0.6086, -0.1647, -0.6566, -0.0921, -0.2016,  0.7307,
           0.8676, -0.7044,  0.4890,  0.7262,  0.2844, -0.4600,  0.1916,
          -0.7927,  0.9064, -0.0176,  0.1414, -0.4594, -0.8091,  0.5763,
          -0.7327,  0.1932, -0.0453, -0.0439, -0.0351,  0.5287, -0.2476,
          -0.9055, -0.2612,  0.2420,  0.8752,  0.2047, -0.9789,  0.9138,
          -0.0453,  0.0333, -0.3026, -0.5875, -0.0534, -0.6159, -0.8502,
           0.6325,  0.8676, -0.5422, -0.7465, -0.78

In [40]:
C_train[0].T @ C_train[0]

tensor([[ 1.0000e+00, -6.2200e-18,  6.5229e-18],
        [-6.2200e-18,  1.0000e+00,  5.1607e-17],
        [ 6.5229e-18,  5.1607e-17,  1.0000e+00]], dtype=torch.float64)