In [1]:
import numpy as np
import torch
from tsGaussian.torch_tsgaussian import TangentSpaceGaussian
from stable_baselines_utils import TangentSpaceGaussian as TSG
# from pytorch3d.transforms.so3 import (
#     so3_exp_map,
#     so3_relative_angle,
# )

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
tg = TangentSpaceGaussian(None)

# Test liegroup torch

In [3]:
from liegroups.torch import SO3

In [4]:
C = SO3.exp(torch.Tensor([[1,2,3],
                         [0,0,0]]))
C

<liegroups.torch.so3.SO3Matrix>
| tensor([[[-0.6949,  0.7135,  0.0893],
|          [-0.1920, -0.3038,  0.9332],
|          [ 0.6930,  0.6313,  0.3481]],
| 
|         [[ 1.0000,  0.0000,  0.0000],
|          [ 0.0000,  1.0000,  0.0000],
|          [ 0.0000,  0.0000,  1.0000]]])

# Test torch_tsgaussian sample

In [5]:
R_mu = torch.eye(3).reshape((1,3,3))
sigma = torch.ones(3).reshape((1,3))

In [6]:
R_quat, R_x = tg.rsample(R_mu, sigma)

sigma:  tensor([[1., 1., 1.]])
torch.Size([1, 3, 3])


In [7]:
torch.bmm(torch.transpose(R_x, 1, 2), R_x)

tensor([[[1.0000e+00, 2.9802e-08, 0.0000e+00],
         [2.9802e-08, 1.0000e+00, 2.9802e-08],
         [0.0000e+00, 2.9802e-08, 1.0000e+00]]])

# Test torch_tsgaussian normal_term

In [8]:
sigma = torch.ones(3).reshape((1,3))
sigma

tensor([[1., 1., 1.]])

In [9]:
tg.normal_term(sigma)

tensor([15.7496])

# Test torch_tsgaussian log_map

In [10]:
R_1 = torch.eye(3).reshape((1, 3, 3))
R_2 = torch.eye(3).reshape((1, 3, 3))

In [11]:
tg.log_map(R_1, R_2)

tensor([0., 0., 0.])

# Test torch_tsgaussian log_probs

In [12]:
R_x = torch.eye(3).reshape((1,3,3))
R_mu = torch.zeros(3,3).reshape((1,3,3))
R_x = R_x.repeat(5, 1, 1)
R_mu = R_mu.repeat(5, 1, 1)
sigma = torch.ones(3).reshape((1,3))

In [13]:
# tg.log_probs(R_x, R_mu, sigma)

In [14]:
np.e ** (-2.7568)

0.06349462641817973

all codes run for torch_tsgaussian now, need to check it's correctness and make it into batch version.

# Test TangentSpaceGaussian actions_from_params

In [15]:
tsg = TSG(None)

In [16]:
print(tsg.distribution)

<tsGaussian.torch_tsgaussian.TangentSpaceGaussian object at 0x7fb5f410d100>


In [17]:
tsg

<stable_baselines_utils.TangentSpaceGaussian at 0x7fb5f410d0d0>

In [18]:
tsg.actions_from_params(torch.eye(3).reshape((1,3,3)), torch.ones(3).reshape((1,3)))

sigma:  tensor([[1., 1., 1.]])
torch.Size([1, 3, 3])


(tensor([[-0.2309, -0.1942,  0.2587,  0.9176]]),
 tensor([[[ 0.7908, -0.3850, -0.4759],
          [ 0.5644,  0.7595,  0.3233],
          [ 0.2370, -0.5243,  0.8179]]]))

# Test TangentSpaceGaussian log_prob_from_params

In [19]:
torch.eye(3).repeat(2,1,1).size()

torch.Size([2, 3, 3])

In [20]:
torch.ones(3).repeat(2,1).size()

torch.Size([2, 3])

In [21]:
# tsg.log_prob_from_params(torch.eye(3).repeat(2,1,1), torch.ones(3))

In [22]:
x = torch.randn(2, 4, 4)
y = torch.linalg.inv(x)
y

tensor([[[ 0.8972,  0.3311, -2.0802,  0.1954],
         [-0.4544,  0.0150, -0.1455,  0.1318],
         [-0.1425,  0.6359, -1.6084, -1.3160],
         [ 0.6918, -0.0042,  3.0271,  1.8034]],

        [[-0.2175, -0.6450, -0.9562,  0.3594],
         [ 2.2492,  0.1025,  2.0803, -0.1886],
         [ 1.8147, -0.1330,  2.0608, -0.5000],
         [-1.9430, -0.4256, -1.6181,  0.7988]]])

Again, codes can run, but need to check correctness.

# Try to run training

In [23]:
import torch
from absl import app, flags
from stable_baselines3 import SAC, PPO
from envs.wahba import Wahba
from stable_baselines_utils import CustomSACPolicy, \
    CustomCNN

In [26]:
def main(argv):
    env = Wahba()
    device = torch.device('cpu')
    policy_kwargs = dict(
        features_extractor_class = CustomCNN,
        features_extractor_kwargs = dict(features_dim = 256))
    policy_kwargs['n_critics'] = 1
    policy_kwargs['share_features_extractor'] = False
    policy = CustomSACPolicy
    model = SAC(policy, env, verbose = 1, ent_coef = 'auto_0.1',
                policy_kwargs = policy_kwargs, device = device)
    model.learn(total_timesteps = 500, eval_freq = 100, n_eval_episodes = 100)

In [27]:
from torch import autograd
with autograd.detect_anomaly():
    main(None)

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 1        |
|    ep_rew_mean     | -5.48    |
| time/              |          |
|    episodes        | 4        |
|    fps             | 800      |
|    time_elapsed    | 0        |
|    total_timesteps | 4        |
---------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 1        |
|    ep_rew_mean     | -6.01    |
| time/              |          |
|    episodes        | 8        |
|    fps             | 826      |
|    time_elapsed    | 0        |
|    total_timesteps | 8        |
---------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 1        |
|    ep_rew_mean     | -5.91    |
| time/              |          |
|    episodes        | 12       |
|    fps             |

---------------------------------
| rollout/           |          |
|    ep_len_mean     | 1        |
|    ep_rew_mean     | -6.09    |
| time/              |          |
|    episodes        | 100      |
|    fps             | 832      |
|    time_elapsed    | 0        |
|    total_timesteps | 100      |
---------------------------------


  with autograd.detect_anomaly():


tensor([[ 0.0090, -0.0270,  0.0588,  0.0068, -0.0629, -0.0564, -0.0135, -0.0556,
          0.0520, -0.0301,  0.0054,  0.0176]])
sigma:  tensor([[0.0090, 0.0270, 0.0588]])
torch.Size([1, 3, 3])
tensor([[ 0.0088, -0.0280,  0.0590,  ..., -0.0309,  0.0059,  0.0176],
        [ 0.0089, -0.0279,  0.0591,  ..., -0.0309,  0.0058,  0.0174],
        [ 0.0089, -0.0271,  0.0588,  ..., -0.0302,  0.0050,  0.0177],
        ...,
        [ 0.0088, -0.0282,  0.0590,  ..., -0.0306,  0.0055,  0.0179],
        [ 0.0091, -0.0271,  0.0587,  ..., -0.0302,  0.0053,  0.0176],
        [ 0.0089, -0.0275,  0.0588,  ..., -0.0307,  0.0052,  0.0173]],
       grad_fn=<AddmmBackward>)
sigma:  tensor([[0.0088, 0.0280, 0.0590],
        [0.0089, 0.0279, 0.0591],
        [0.0089, 0.0271, 0.0588],
        [0.0087, 0.0276, 0.0588],
        [0.0090, 0.0276, 0.0589],
        [0.0088, 0.0269, 0.0588],
        [0.0089, 0.0272, 0.0586],
        [0.0088, 0.0271, 0.0586],
        [0.0088, 0.0269, 0.0588],
        [0.0088, 0.0276, 0.

torch.Size([256, 3, 3])
actions:  tensor([[ 0.7409, -0.3898, -0.5165, -0.1801],
        [ 0.7498, -0.4128, -0.4879, -0.1713],
        [ 0.7716, -0.3797, -0.4790, -0.1760],
        ...,
        [ 0.7458, -0.3936, -0.5081, -0.1750],
        [ 0.7563, -0.3878, -0.4989, -0.1692],
        [ 0.7554, -0.3983, -0.4927, -0.1669]])
actions_mat:  tensor([[[ 0.1627, -0.7636, -0.6249],
         [-0.3915, -0.6313,  0.6695],
         [-0.9057,  0.1357, -0.4016]],

        [[ 0.1831, -0.7862, -0.5902],
         [-0.4519, -0.6004,  0.6597],
         [-0.8731,  0.1459, -0.4653]],

        [[ 0.2527, -0.7546, -0.6056],
         [-0.4174, -0.6497,  0.6354],
         [-0.8729,  0.0922, -0.4791]],

        ...,

        [[ 0.1738, -0.7650, -0.6202],
         [-0.4092, -0.6289,  0.6611],
         [-0.8957,  0.1389, -0.4223]],

        [[ 0.2014, -0.7554, -0.6235],
         [-0.4178, -0.6420,  0.6429],
         [-0.8860,  0.1310, -0.4449]],

        [[ 0.1971, -0.7663, -0.6115],
         [-0.4373, -0.6270,  0

actions_mat:  tensor([[[ 0.2200, -0.7603, -0.6112],
         [-0.4568, -0.6339,  0.6241],
         [-0.8619,  0.1419, -0.4867]],

        [[ 0.2458, -0.7443, -0.6210],
         [-0.3694, -0.6642,  0.6499],
         [-0.8962,  0.0696, -0.4382]],

        [[ 0.2004, -0.7695, -0.6064],
         [-0.4550, -0.6212,  0.6380],
         [-0.8676,  0.1481, -0.4746]],

        ...,

        [[ 0.2015, -0.7718, -0.6030],
         [-0.4160, -0.6248,  0.6607],
         [-0.8868,  0.1177, -0.4470]],

        [[ 0.2041, -0.7616, -0.6151],
         [-0.4195, -0.6358,  0.6479],
         [-0.8845,  0.1258, -0.4493]],

        [[ 0.3444, -0.7321, -0.5877],
         [-0.3619, -0.6811,  0.6364],
         [-0.8663, -0.0065, -0.4995]]])
tensor([[ 0.0279, -0.0343,  0.0634,  0.0071, -0.0624, -0.0568, -0.0171, -0.0547,
          0.0504, -0.0336,  0.0069,  0.0167]])
sigma:  tensor([[0.0279, 0.0343, 0.0634]])
torch.Size([1, 3, 3])
tensor([[ 0.0280, -0.0342,  0.0635,  ..., -0.0336,  0.0071,  0.0166],
        [ 0.0

tensor([[ 0.0430, -0.0443,  0.0707,  0.0073, -0.0617, -0.0576, -0.0209, -0.0541,
          0.0484, -0.0369,  0.0081,  0.0150]])
sigma:  tensor([[0.0430, 0.0443, 0.0707]])
torch.Size([1, 3, 3])
tensor([[ 0.0427, -0.0443,  0.0708,  ..., -0.0365,  0.0085,  0.0153],
        [ 0.0433, -0.0445,  0.0710,  ..., -0.0373,  0.0085,  0.0149],
        [ 0.0430, -0.0442,  0.0714,  ..., -0.0366,  0.0087,  0.0152],
        ...,
        [ 0.0432, -0.0440,  0.0710,  ..., -0.0367,  0.0087,  0.0151],
        [ 0.0428, -0.0443,  0.0712,  ..., -0.0371,  0.0086,  0.0151],
        [ 0.0428, -0.0443,  0.0709,  ..., -0.0366,  0.0084,  0.0153]],
       grad_fn=<AddmmBackward>)
sigma:  tensor([[0.0427, 0.0443, 0.0708],
        [0.0433, 0.0445, 0.0710],
        [0.0430, 0.0442, 0.0714],
        [0.0427, 0.0443, 0.0708],
        [0.0425, 0.0447, 0.0713],
        [0.0428, 0.0442, 0.0707],
        [0.0429, 0.0443, 0.0712],
        [0.0427, 0.0442, 0.0709],
        [0.0433, 0.0441, 0.0707],
        [0.0425, 0.0448, 0.

tensor([[ 0.0567, -0.0550,  0.0802,  0.0071, -0.0606, -0.0581, -0.0235, -0.0545,
          0.0456, -0.0399,  0.0097,  0.0139]])
sigma:  tensor([[0.0567, 0.0550, 0.0802]])
torch.Size([1, 3, 3])
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 1        |
|    ep_rew_mean     | -6.01    |
| time/              |          |
|    episodes        | 104      |
|    fps             | 41       |
|    time_elapsed    | 2        |
|    total_timesteps | 104      |
| train/             |          |
|    actor_loss      | 0.205    |
|    critic_loss     | 19.8     |
|    ent_coef        | 0.0999   |
|    ent_coef_loss   | -4.94    |
|    learning_rate   | 0.0003   |
|    n_updates       | 3        |
---------------------------------
tensor([[ 0.0570, -0.0551,  0.0809,  ..., -0.0401,  0.0096,  0.0140],
        [ 0.0567, -0.0551,  0.0806,  ..., -0.0401,  0.0097,  0.0140],
        [ 0.0573, -0.0550,  0.0808,  ..., -0.0406,  0.0098,  0.0138],
        ...,
      

tensor([[ 0.0702, -0.0656,  0.0906,  0.0071, -0.0610, -0.0583, -0.0264, -0.0538,
          0.0433, -0.0422,  0.0102,  0.0120]])
sigma:  tensor([[0.0702, 0.0656, 0.0906]])
torch.Size([1, 3, 3])
tensor([[ 0.0698, -0.0655,  0.0905,  ..., -0.0418,  0.0104,  0.0121],
        [ 0.0699, -0.0655,  0.0905,  ..., -0.0419,  0.0105,  0.0120],
        [ 0.0700, -0.0655,  0.0906,  ..., -0.0419,  0.0104,  0.0122],
        ...,
        [ 0.0698, -0.0659,  0.0913,  ..., -0.0427,  0.0105,  0.0117],
        [ 0.0702, -0.0655,  0.0907,  ..., -0.0420,  0.0106,  0.0122],
        [ 0.0703, -0.0659,  0.0915,  ..., -0.0425,  0.0106,  0.0118]],
       grad_fn=<AddmmBackward>)
sigma:  tensor([[0.0698, 0.0655, 0.0905],
        [0.0699, 0.0655, 0.0905],
        [0.0700, 0.0655, 0.0906],
        [0.0700, 0.0655, 0.0906],
        [0.0702, 0.0655, 0.0910],
        [0.0701, 0.0656, 0.0905],
        [0.0702, 0.0656, 0.0906],
        [0.0702, 0.0659, 0.0911],
        [0.0699, 0.0654, 0.0906],
        [0.0705, 0.0655, 0.

tensor([[ 0.0855, -0.0777,  0.1046,  0.0069, -0.0609, -0.0617, -0.0297, -0.0528,
          0.0402, -0.0442,  0.0108,  0.0089]])
sigma:  tensor([[0.0855, 0.0777, 0.1046]])
torch.Size([1, 3, 3])
tensor([[ 0.0842, -0.0779,  0.1044,  ..., -0.0438,  0.0107,  0.0096],
        [ 0.0857, -0.0779,  0.1049,  ..., -0.0444,  0.0107,  0.0088],
        [ 0.0835, -0.0773,  0.1033,  ..., -0.0432,  0.0110,  0.0100],
        ...,
        [ 0.0837, -0.0773,  0.1038,  ..., -0.0437,  0.0107,  0.0095],
        [ 0.0841, -0.0774,  0.1035,  ..., -0.0435,  0.0109,  0.0097],
        [ 0.0836, -0.0771,  0.1033,  ..., -0.0432,  0.0110,  0.0100]],
       grad_fn=<AddmmBackward>)
sigma:  tensor([[0.0842, 0.0779, 0.1044],
        [0.0857, 0.0779, 0.1049],
        [0.0835, 0.0773, 0.1033],
        [0.0840, 0.0773, 0.1038],
        [0.0849, 0.0779, 0.1051],
        [0.0855, 0.0780, 0.1052],
        [0.0854, 0.0780, 0.1052],
        [0.0855, 0.0778, 0.1047],
        [0.0844, 0.0776, 0.1039],
        [0.0835, 0.0773, 0.

tensor([[ 0.0987, -0.0912,  0.1179,  0.0062, -0.0616, -0.0625, -0.0328, -0.0523,
          0.0367, -0.0452,  0.0105,  0.0074]])
sigma:  tensor([[0.0987, 0.0912, 0.1179]])
torch.Size([1, 3, 3])
tensor([[ 0.1012, -0.0919,  0.1202,  ..., -0.0460,  0.0101,  0.0061],
        [ 0.1003, -0.0922,  0.1190,  ..., -0.0460,  0.0101,  0.0066],
        [ 0.0996, -0.0915,  0.1182,  ..., -0.0455,  0.0107,  0.0073],
        ...,
        [ 0.0988, -0.0911,  0.1179,  ..., -0.0452,  0.0105,  0.0075],
        [ 0.0989, -0.0914,  0.1182,  ..., -0.0452,  0.0103,  0.0074],
        [ 0.0998, -0.0917,  0.1191,  ..., -0.0458,  0.0107,  0.0069]],
       grad_fn=<AddmmBackward>)
sigma:  tensor([[0.1012, 0.0919, 0.1202],
        [0.1003, 0.0922, 0.1190],
        [0.0996, 0.0915, 0.1182],
        [0.0992, 0.0908, 0.1180],
        [0.0987, 0.0912, 0.1179],
        [0.0995, 0.0918, 0.1194],
        [0.0985, 0.0912, 0.1178],
        [0.0985, 0.0912, 0.1178],
        [0.0978, 0.0907, 0.1179],
        [0.0989, 0.0913, 0.

tensor([[ 0.1179, -0.1078,  0.1375,  0.0057, -0.0626, -0.0649, -0.0388, -0.0498,
          0.0323, -0.0463,  0.0101,  0.0037]])
sigma:  tensor([[0.1179, 0.1078, 0.1375]])
torch.Size([1, 3, 3])
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 1        |
|    ep_rew_mean     | -5.92    |
| time/              |          |
|    episodes        | 108      |
|    fps             | 19       |
|    time_elapsed    | 5        |
|    total_timesteps | 108      |
| train/             |          |
|    actor_loss      | 0.264    |
|    critic_loss     | 20.4     |
|    ent_coef        | 0.0998   |
|    ent_coef_loss   | -7       |
|    learning_rate   | 0.0003   |
|    n_updates       | 7        |
---------------------------------
tensor([[ 0.1166, -0.1074,  0.1360,  ..., -0.0461,  0.0104,  0.0044],
        [ 0.1157, -0.1070,  0.1354,  ..., -0.0461,  0.0102,  0.0045],
        [ 0.1163, -0.1073,  0.1357,  ..., -0.0461,  0.0104,  0.0045],
        ...,
      

tensor([[ 0.1357, -0.1251,  0.1563,  0.0056, -0.0630, -0.0664, -0.0415, -0.0484,
          0.0271, -0.0463,  0.0098,  0.0025]])
sigma:  tensor([[0.1357, 0.1251, 0.1563]])
torch.Size([1, 3, 3])
tensor([[ 0.1376, -0.1255,  0.1588,  ..., -0.0463,  0.0101,  0.0023],
        [ 0.1368, -0.1253,  0.1577,  ..., -0.0462,  0.0100,  0.0025],
        [ 0.1367, -0.1254,  0.1573,  ..., -0.0464,  0.0097,  0.0026],
        ...,
        [ 0.1384, -0.1265,  0.1590,  ..., -0.0465,  0.0096,  0.0018],
        [ 0.1361, -0.1252,  0.1567,  ..., -0.0463,  0.0099,  0.0026],
        [ 0.1376, -0.1255,  0.1588,  ..., -0.0463,  0.0101,  0.0023]],
       grad_fn=<AddmmBackward>)
sigma:  tensor([[0.1376, 0.1255, 0.1588],
        [0.1368, 0.1253, 0.1577],
        [0.1367, 0.1254, 0.1573],
        [0.1385, 0.1265, 0.1593],
        [0.1358, 0.1251, 0.1564],
        [0.1364, 0.1252, 0.1568],
        [0.1376, 0.1254, 0.1576],
        [0.1358, 0.1249, 0.1561],
        [0.1398, 0.1268, 0.1611],
        [0.1369, 0.1253, 0.

tensor([[ 0.1606, -0.1477,  0.1819,  0.0065, -0.0642, -0.0695, -0.0457, -0.0446,
          0.0203, -0.0465,  0.0089,  0.0004]])
sigma:  tensor([[0.1606, 0.1477, 0.1819]])
torch.Size([1, 3, 3])
tensor([[ 1.6177e-01, -1.4818e-01,  1.8343e-01,  ..., -4.6406e-02,
          8.6166e-03, -1.2286e-04],
        [ 1.6278e-01, -1.4874e-01,  1.8456e-01,  ..., -4.6302e-02,
          8.5501e-03, -2.5399e-04],
        [ 1.6021e-01, -1.4690e-01,  1.8160e-01,  ..., -4.7029e-02,
          8.4875e-03, -4.0079e-04],
        ...,
        [ 1.6155e-01, -1.4816e-01,  1.8263e-01,  ..., -4.6542e-02,
          8.7415e-03,  2.4999e-04],
        [ 1.6323e-01, -1.4870e-01,  1.8434e-01,  ..., -4.6566e-02,
          8.7697e-03, -9.5097e-05],
        [ 1.6019e-01, -1.4758e-01,  1.8164e-01,  ..., -4.6539e-02,
          8.8663e-03,  3.0803e-04]], grad_fn=<AddmmBackward>)
sigma:  tensor([[0.1618, 0.1482, 0.1834],
        [0.1628, 0.1487, 0.1846],
        [0.1602, 0.1469, 0.1816],
        [0.1652, 0.1503, 0.1876],
      

tensor([[ 0.1937, -0.1774,  0.2149,  0.0081, -0.0643, -0.0754, -0.0501, -0.0397,
          0.0105, -0.0466,  0.0069, -0.0020]])
sigma:  tensor([[0.1937, 0.1774, 0.2149]])
torch.Size([1, 3, 3])
tensor([[ 0.1915, -0.1758,  0.2135,  ..., -0.0464,  0.0073, -0.0018],
        [ 0.1906, -0.1753,  0.2119,  ..., -0.0465,  0.0072, -0.0015],
        [ 0.1900, -0.1746,  0.2127,  ..., -0.0466,  0.0069, -0.0027],
        ...,
        [ 0.1896, -0.1748,  0.2114,  ..., -0.0465,  0.0072, -0.0017],
        [ 0.1881, -0.1732,  0.2105,  ..., -0.0468,  0.0071, -0.0024],
        [ 0.1947, -0.1783,  0.2165,  ..., -0.0465,  0.0066, -0.0026]],
       grad_fn=<AddmmBackward>)
sigma:  tensor([[0.1915, 0.1758, 0.2135],
        [0.1906, 0.1753, 0.2119],
        [0.1900, 0.1746, 0.2127],
        [0.1905, 0.1753, 0.2123],
        [0.1901, 0.1751, 0.2115],
        [0.1960, 0.1789, 0.2179],
        [0.1905, 0.1753, 0.2118],
        [0.1900, 0.1750, 0.2113],
        [0.1904, 0.1752, 0.2129],
        [0.1902, 0.1751, 0.

tensor([[ 0.2307, -0.2120,  0.2517,  0.0082, -0.0654, -0.0817, -0.0532, -0.0354,
         -0.0013, -0.0464,  0.0039, -0.0045]])
sigma:  tensor([[0.2307, 0.2120, 0.2517]])
torch.Size([1, 3, 3])
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 1        |
|    ep_rew_mean     | -5.91    |
| time/              |          |
|    episodes        | 112      |
|    fps             | 12       |
|    time_elapsed    | 8        |
|    total_timesteps | 112      |
| train/             |          |
|    actor_loss      | 0.57     |
|    critic_loss     | 17.2     |
|    ent_coef        | 0.0997   |
|    ent_coef_loss   | -8.54    |
|    learning_rate   | 0.0003   |
|    n_updates       | 11       |
---------------------------------
tensor([[ 0.2335, -0.2141,  0.2554,  ..., -0.0459,  0.0038, -0.0055],
        [ 0.2286, -0.2110,  0.2492,  ..., -0.0465,  0.0045, -0.0042],
        [ 0.2371, -0.2169,  0.2591,  ..., -0.0458,  0.0034, -0.0061],
        ...,
      

tensor([[ 0.2810, -0.2597,  0.3000,  0.0077, -0.0679, -0.0922, -0.0580, -0.0309,
         -0.0154, -0.0463, -0.0008, -0.0090]])
sigma:  tensor([[0.2810, 0.2597, 0.3000]])
torch.Size([1, 3, 3])
tensor([[ 0.2819, -0.2605,  0.3014,  ..., -0.0462, -0.0010, -0.0093],
        [ 0.2775, -0.2574,  0.2965,  ..., -0.0466, -0.0004, -0.0082],
        [ 0.2780, -0.2578,  0.2970,  ..., -0.0465, -0.0005, -0.0083],
        ...,
        [ 0.2789, -0.2583,  0.2984,  ..., -0.0466, -0.0011, -0.0087],
        [ 0.2801, -0.2593,  0.2991,  ..., -0.0464, -0.0008, -0.0086],
        [ 0.2784, -0.2581,  0.2973,  ..., -0.0465, -0.0005, -0.0083]],
       grad_fn=<AddmmBackward>)
sigma:  tensor([[0.2819, 0.2605, 0.3014],
        [0.2775, 0.2574, 0.2965],
        [0.2780, 0.2578, 0.2970],
        [0.2773, 0.2572, 0.2977],
        [0.2815, 0.2602, 0.3002],
        [0.2803, 0.2590, 0.3008],
        [0.2856, 0.2636, 0.3056],
        [0.2847, 0.2630, 0.3047],
        [0.2819, 0.2600, 0.3015],
        [0.2796, 0.2592, 0.

tensor([[ 0.3399, -0.3159,  0.3570,  0.0050, -0.0694, -0.1051, -0.0617, -0.0268,
         -0.0310, -0.0463, -0.0061, -0.0138]])
sigma:  tensor([[0.3399, 0.3159, 0.3570]])
torch.Size([1, 3, 3])
tensor([[ 0.3422, -0.3176,  0.3595,  ..., -0.0461, -0.0064, -0.0142],
        [ 0.3450, -0.3194,  0.3638,  ..., -0.0461, -0.0068, -0.0158],
        [ 0.3360, -0.3119,  0.3543,  ..., -0.0470, -0.0062, -0.0142],
        ...,
        [ 0.3369, -0.3129,  0.3552,  ..., -0.0469, -0.0064, -0.0140],
        [ 0.3403, -0.3162,  0.3579,  ..., -0.0465, -0.0063, -0.0145],
        [ 0.3392, -0.3153,  0.3577,  ..., -0.0462, -0.0062, -0.0146]],
       grad_fn=<AddmmBackward>)
sigma:  tensor([[0.3422, 0.3176, 0.3595],
        [0.3450, 0.3194, 0.3638],
        [0.3360, 0.3119, 0.3543],
        [0.3373, 0.3134, 0.3551],
        [0.3417, 0.3173, 0.3598],
        [0.3431, 0.3183, 0.3602],
        [0.3529, 0.3261, 0.3718],
        [0.3399, 0.3159, 0.3573],
        [0.3484, 0.3223, 0.3670],
        [0.3405, 0.3164, 0.

tensor([[ 0.4327, -0.3995,  0.4475,  0.0032, -0.0683, -0.1280, -0.0676, -0.0216,
         -0.0505, -0.0476, -0.0116, -0.0236]])
sigma:  tensor([[0.4327, 0.3995, 0.4475]])
torch.Size([1, 3, 3])
tensor([[ 0.4207, -0.3901,  0.4357,  ..., -0.0477, -0.0105, -0.0222],
        [ 0.4185, -0.3884,  0.4340,  ..., -0.0477, -0.0103, -0.0222],
        [ 0.4264, -0.3941,  0.4430,  ..., -0.0475, -0.0112, -0.0242],
        ...,
        [ 0.4230, -0.3916,  0.4379,  ..., -0.0477, -0.0108, -0.0226],
        [ 0.4179, -0.3879,  0.4336,  ..., -0.0476, -0.0101, -0.0221],
        [ 0.4159, -0.3856,  0.4327,  ..., -0.0478, -0.0104, -0.0229]],
       grad_fn=<AddmmBackward>)
sigma:  tensor([[0.4207, 0.3901, 0.4357],
        [0.4185, 0.3884, 0.4340],
        [0.4264, 0.3941, 0.4430],
        [0.4227, 0.3917, 0.4377],
        [0.4262, 0.3938, 0.4418],
        [0.4159, 0.3857, 0.4319],
        [0.4208, 0.3902, 0.4362],
        [0.4249, 0.3933, 0.4400],
        [0.4370, 0.4030, 0.4533],
        [0.4216, 0.3900, 0.

tensor([[ 0.5344, -0.4906,  0.5455, -0.0027, -0.0707, -0.1474, -0.0698, -0.0108,
         -0.0756, -0.0535, -0.0227, -0.0280]])
sigma:  tensor([[0.5344, 0.4906, 0.5455]])
torch.Size([1, 3, 3])
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 1        |
|    ep_rew_mean     | -5.89    |
| time/              |          |
|    episodes        | 116      |
|    fps             | 9        |
|    time_elapsed    | 11       |
|    total_timesteps | 116      |
| train/             |          |
|    actor_loss      | 1.6      |
|    critic_loss     | 12.1     |
|    ent_coef        | 0.0996   |
|    ent_coef_loss   | -9.76    |
|    learning_rate   | 0.0003   |
|    n_updates       | 15       |
---------------------------------
tensor([[ 0.5254, -0.4839,  0.5356,  ..., -0.0532, -0.0217, -0.0261],
        [ 0.5426, -0.4979,  0.5532,  ..., -0.0535, -0.0233, -0.0282],
        [ 0.5319, -0.4884,  0.5421,  ..., -0.0534, -0.0224, -0.0272],
        ...,
      

tensor([[ 6.6275e-01, -6.0529e-01,  6.6857e-01, -1.1026e-02, -7.0894e-02,
         -1.6869e-01, -7.1276e-02,  3.8904e-04, -1.0612e-01, -6.3192e-02,
         -3.6722e-02, -3.0026e-02]])
sigma:  tensor([[0.6627, 0.6053, 0.6686]])
torch.Size([1, 3, 3])
tensor([[ 0.6496, -0.5936,  0.6564,  ..., -0.0636, -0.0361, -0.0295],
        [ 0.6856, -0.6243,  0.6916,  ..., -0.0641, -0.0393, -0.0323],
        [ 0.6691, -0.6103,  0.6745,  ..., -0.0634, -0.0374, -0.0307],
        ...,
        [ 0.6618, -0.6044,  0.6675,  ..., -0.0632, -0.0365, -0.0299],
        [ 0.6665, -0.6083,  0.6722,  ..., -0.0633, -0.0371, -0.0304],
        [ 0.6567, -0.6003,  0.6639,  ..., -0.0631, -0.0363, -0.0305]],
       grad_fn=<AddmmBackward>)
sigma:  tensor([[0.6496, 0.5936, 0.6564],
        [0.6856, 0.6243, 0.6916],
        [0.6691, 0.6103, 0.6745],
        [0.6693, 0.6100, 0.6751],
        [0.6749, 0.6153, 0.6807],
        [0.6635, 0.6048, 0.6695],
        [0.6886, 0.6268, 0.6950],
        [0.6665, 0.6081, 0.6729],
    

tensor([[ 0.8408, -0.7611,  0.8416, -0.0227, -0.0692, -0.1989, -0.0752,  0.0172,
         -0.1496, -0.0775, -0.0572, -0.0346]])
sigma:  tensor([[0.8408, 0.7611, 0.8416]])
torch.Size([1, 3, 3])
tensor([[ 0.8304, -0.7528,  0.8315,  ..., -0.0771, -0.0561, -0.0336],
        [ 0.8381, -0.7593,  0.8379,  ..., -0.0773, -0.0561, -0.0339],
        [ 0.8345, -0.7559,  0.8355,  ..., -0.0774, -0.0561, -0.0349],
        ...,
        [ 0.8393, -0.7605,  0.8394,  ..., -0.0773, -0.0562, -0.0339],
        [ 0.8425, -0.7631,  0.8431,  ..., -0.0773, -0.0570, -0.0340],
        [ 0.8669, -0.7835,  0.8668,  ..., -0.0784, -0.0596, -0.0361]],
       grad_fn=<AddmmBackward>)
sigma:  tensor([[0.8304, 0.7528, 0.8315],
        [0.8381, 0.7593, 0.8379],
        [0.8345, 0.7559, 0.8355],
        [0.8478, 0.7677, 0.8479],
        [0.8437, 0.7641, 0.8436],
        [0.8357, 0.7574, 0.8364],
        [0.8368, 0.7583, 0.8368],
        [0.8393, 0.7605, 0.8394],
        [0.8355, 0.7568, 0.8366],
        [0.8489, 0.7683, 0.

tensor([[ 1.0695, -0.9615,  1.0625, -0.0385, -0.0644, -0.2328, -0.0776,  0.0386,
         -0.2064, -0.0978, -0.0812, -0.0384]])
sigma:  tensor([[1.0695, 0.9615, 1.0625]])
torch.Size([1, 3, 3])
tensor([[ 1.0990, -0.9867,  1.0920,  ..., -0.0995, -0.0843, -0.0402],
        [ 1.0947, -0.9821,  1.0871,  ..., -0.0993, -0.0839, -0.0396],
        [ 1.0643, -0.9573,  1.0586,  ..., -0.0975, -0.0806, -0.0382],
        ...,
        [ 1.0644, -0.9573,  1.0583,  ..., -0.0979, -0.0811, -0.0382],
        [ 1.0818, -0.9719,  1.0754,  ..., -0.0984, -0.0822, -0.0390],
        [ 1.0670, -0.9595,  1.0603,  ..., -0.0976, -0.0807, -0.0383]],
       grad_fn=<AddmmBackward>)
sigma:  tensor([[1.0990, 0.9867, 1.0920],
        [1.0947, 0.9821, 1.0871],
        [1.0643, 0.9573, 1.0586],
        [1.0657, 0.9584, 1.0590],
        [1.0657, 0.9584, 1.0590],
        [1.0458, 0.9415, 1.0402],
        [1.0412, 0.9376, 1.0364],
        [1.0613, 0.9540, 1.0551],
        [1.0705, 0.9618, 1.0635],
        [1.0698, 0.9618, 1.

tensor([[ 1.3457, -1.2046,  1.3306, -0.0595, -0.0586, -0.2727, -0.0805,  0.0638,
         -0.2743, -0.1252, -0.1091, -0.0403]])
sigma:  tensor([[1.3457, 1.2046, 1.3306]])
torch.Size([1, 3, 3])
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 1        |
|    ep_rew_mean     | -5.94    |
| time/              |          |
|    episodes        | 120      |
|    fps             | 8        |
|    time_elapsed    | 14       |
|    total_timesteps | 120      |
| train/             |          |
|    actor_loss      | 4.65     |
|    critic_loss     | 4.84     |
|    ent_coef        | 0.0994   |
|    ent_coef_loss   | -8.64    |
|    learning_rate   | 0.0003   |
|    n_updates       | 19       |
---------------------------------
tensor([[ 1.3222, -1.1849,  1.3080,  ..., -0.1241, -0.1064, -0.0401],
        [ 1.3717, -1.2261,  1.3550,  ..., -0.1270, -0.1118, -0.0418],
        [ 1.3138, -1.1776,  1.2998,  ..., -0.1238, -0.1055, -0.0398],
        ...,
      

tensor([[ 1.7451, -1.5580,  1.7183, -0.0721, -0.0420, -0.3263, -0.0771,  0.0996,
         -0.3743, -0.1657, -0.1468, -0.0417]])
sigma:  tensor([[1.7451, 1.5580, 1.7183]])
torch.Size([1, 3, 3])
tensor([[ 1.6992, -1.5176,  1.6739,  ..., -0.1632, -0.1429, -0.0405],
        [ 1.6827, -1.5039,  1.6588,  ..., -0.1616, -0.1410, -0.0406],
        [ 1.7309, -1.5458,  1.7049,  ..., -0.1645, -0.1451, -0.0412],
        ...,
        [ 1.7278, -1.5427,  1.7024,  ..., -0.1645, -0.1451, -0.0416],
        [ 1.7368, -1.5507,  1.7102,  ..., -0.1651, -0.1458, -0.0415],
        [ 1.7278, -1.5427,  1.7024,  ..., -0.1645, -0.1451, -0.0416]],
       grad_fn=<AddmmBackward>)
sigma:  tensor([[1.6992, 1.5176, 1.6739],
        [1.6827, 1.5039, 1.6588],
        [1.7309, 1.5458, 1.7049],
        [1.7098, 1.5268, 1.6849],
        [1.7815, 1.5891, 1.7543],
        [1.7304, 1.5453, 1.7042],
        [1.7473, 1.5597, 1.7203],
        [1.6738, 1.4963, 1.6499],
        [1.7236, 1.5394, 1.6977],
        [1.7347, 1.5491, 1.

tensor([[ 2.2101, -1.9682,  2.1706, -0.0754, -0.0346, -0.4090, -0.0948,  0.1481,
         -0.4820, -0.2250, -0.1778, -0.0348]])
sigma:  tensor([[2.2101, 1.9682, 2.1706]])
torch.Size([1, 3, 3])
tensor([[ 2.2059, -1.9631,  2.1660,  ..., -0.2250, -0.1782, -0.0348],
        [ 2.2657, -2.0163,  2.2247,  ..., -0.2297, -0.1833, -0.0361],
        [ 2.2137, -1.9707,  2.1735,  ..., -0.2254, -0.1783, -0.0347],
        ...,
        [ 2.2316, -1.9851,  2.1909,  ..., -0.2269, -0.1805, -0.0354],
        [ 2.2192, -1.9755,  2.1794,  ..., -0.2260, -0.1787, -0.0353],
        [ 2.2052, -1.9635,  2.1664,  ..., -0.2245, -0.1776, -0.0351]],
       grad_fn=<AddmmBackward>)
sigma:  tensor([[2.2059, 1.9631, 2.1660],
        [2.2657, 2.0163, 2.2247],
        [2.2137, 1.9707, 2.1735],
        [2.1874, 1.9481, 2.1487],
        [2.2909, 2.0371, 2.2489],
        [2.1355, 1.9026, 2.0984],
        [2.2075, 1.9655, 2.1677],
        [2.2259, 1.9818, 2.1854],
        [2.2168, 1.9738, 2.1767],
        [2.2467, 1.9993, 2.

tensor([[ 2.8982, -2.5683,  2.8350, -0.0817, -0.0227, -0.5313, -0.1182,  0.2240,
         -0.6413, -0.3114, -0.2319, -0.0245]])
sigma:  tensor([[2.8982, 2.5683, 2.8350]])
torch.Size([1, 3, 3])
tensor([[ 2.8653, -2.5398,  2.8034,  ..., -0.3083, -0.2291, -0.0241],
        [ 2.8279, -2.5076,  2.7662,  ..., -0.3049, -0.2257, -0.0231],
        [ 2.8653, -2.5398,  2.8034,  ..., -0.3083, -0.2291, -0.0241],
        ...,
        [ 2.8349, -2.5137,  2.7729,  ..., -0.3056, -0.2264, -0.0230],
        [ 2.7232, -2.4162,  2.6656,  ..., -0.2957, -0.2170, -0.0229],
        [ 2.8331, -2.5116,  2.7716,  ..., -0.3055, -0.2261, -0.0235]],
       grad_fn=<AddmmBackward>)
sigma:  tensor([[2.8653, 2.5398, 2.8034],
        [2.8279, 2.5076, 2.7662],
        [2.8653, 2.5398, 2.8034],
        [2.7748, 2.4611, 2.7157],
        [2.8337, 2.5129, 2.7719],
        [2.7167, 2.4108, 2.6590],
        [2.8398, 2.5179, 2.7775],
        [2.9125, 2.5800, 2.8485],
        [2.8162, 2.4976, 2.7550],
        [2.7745, 2.4606, 2.

tensor([[ 3.5645, -3.1540,  3.4767, -0.0786, -0.0124, -0.6546, -0.1401,  0.3008,
         -0.7949, -0.4032, -0.2835, -0.0038]])
sigma:  tensor([[3.5645, 3.1540, 3.4767]])
torch.Size([1, 3, 3])
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 1        |
|    ep_rew_mean     | -6       |
| time/              |          |
|    episodes        | 124      |
|    fps             | 6        |
|    time_elapsed    | 17       |
|    total_timesteps | 124      |
| train/             |          |
|    actor_loss      | 8.32     |
|    critic_loss     | 5.29     |
|    ent_coef        | 0.0993   |
|    ent_coef_loss   | -15.6    |
|    learning_rate   | 0.0003   |
|    n_updates       | 23       |
---------------------------------
tensor([[ 3.4179e+00, -3.0248e+00,  3.3349e+00,  ..., -3.8887e-01,
         -2.7153e-01, -4.1345e-03],
        [ 3.5828e+00, -3.1700e+00,  3.4941e+00,  ..., -4.0487e-01,
         -2.8527e-01, -3.5089e-03],
        [ 3.5852e+00, -

tensor([[ 4.4869e+00, -3.9703e+00,  4.3630e+00, -7.3512e-02,  5.4182e-04,
         -8.2738e-01, -1.6943e-01,  4.0947e-01, -1.0156e+00, -5.2637e-01,
         -3.5921e-01,  2.8417e-02]])
sigma:  tensor([[4.4869, 3.9703, 4.3630]])
torch.Size([1, 3, 3])
tensor([[ 4.3379, -3.8398,  4.2200,  ..., -0.5105, -0.3468,  0.0271],
        [ 4.3591, -3.8581,  4.2399,  ..., -0.5126, -0.3484,  0.0275],
        [ 4.5166, -3.9947,  4.3911,  ..., -0.5300, -0.3623,  0.0283],
        ...,
        [ 4.4702, -3.9554,  4.3474,  ..., -0.5248, -0.3577,  0.0278],
        [ 4.6084, -4.0757,  4.4806,  ..., -0.5400, -0.3695,  0.0280],
        [ 4.2915, -3.7978,  4.1745,  ..., -0.5061, -0.3436,  0.0264]],
       grad_fn=<AddmmBackward>)
sigma:  tensor([[4.3379, 3.8398, 4.2200],
        [4.3591, 3.8581, 4.2399],
        [4.5166, 3.9947, 4.3911],
        [4.5556, 4.0289, 4.4293],
        [4.4575, 3.9440, 4.3345],
        [4.3557, 3.8544, 4.2363],
        [4.5092, 3.9899, 4.3846],
        [4.4935, 3.9761, 4.3694],
    

tensor([[ 5.6508, -5.0040,  5.4816, -0.0631,  0.0160, -1.0440, -0.2049,  0.5480,
         -1.2999, -0.6847, -0.4545,  0.0737]])
sigma:  tensor([[5.6508, 5.0040, 5.4816]])
torch.Size([1, 3, 3])
tensor([[ 5.5719, -4.9361,  5.4059,  ..., -0.6752, -0.4476,  0.0731],
        [ 5.4752, -4.8506,  5.3131,  ..., -0.6650, -0.4402,  0.0715],
        [ 5.4768, -4.8515,  5.3142,  ..., -0.6646, -0.4402,  0.0713],
        ...,
        [ 5.4739, -4.8497,  5.3121,  ..., -0.6646, -0.4400,  0.0717],
        [ 5.5721, -4.9362,  5.4060,  ..., -0.6752, -0.4476,  0.0731],
        [ 5.3371, -4.7293,  5.1800,  ..., -0.6488, -0.4285,  0.0696]],
       grad_fn=<AddmmBackward>)
sigma:  tensor([[5.5719, 4.9361, 5.4059],
        [5.4752, 4.8506, 5.3131],
        [5.4768, 4.8515, 5.3142],
        [5.5721, 4.9359, 5.4060],
        [5.5859, 4.9483, 5.4193],
        [5.5302, 4.8994, 5.3659],
        [5.5708, 4.9349, 5.4047],
        [5.3147, 4.7090, 5.1589],
        [5.5843, 4.9465, 5.4175],
        [5.5126, 4.8826, 5.

tensor([[ 6.8052, -6.0386,  6.5966, -0.0408,  0.0326, -1.2606, -0.2406,  0.6854,
         -1.5847, -0.8512, -0.5399,  0.1338]])
sigma:  tensor([[6.8052, 6.0386, 6.5966]])
torch.Size([1, 3, 3])
tensor([[ 6.8777, -6.1028,  6.6659,  ..., -0.8597, -0.5459,  0.1357],
        [ 6.9521, -6.1677,  6.7380,  ..., -0.8695, -0.5521,  0.1366],
        [ 6.8533, -6.0811,  6.6424,  ..., -0.8572, -0.5436,  0.1348],
        ...,
        [ 6.5253, -5.7912,  6.3272,  ..., -0.8184, -0.5175,  0.1275],
        [ 6.8418, -6.0700,  6.6317,  ..., -0.8565, -0.5434,  0.1340],
        [ 6.8577, -6.0851,  6.6466,  ..., -0.8574, -0.5443,  0.1353]],
       grad_fn=<AddmmBackward>)
sigma:  tensor([[6.8777, 6.1028, 6.6659],
        [6.9521, 6.1677, 6.7380],
        [6.8533, 6.0811, 6.6424],
        [6.8424, 6.0706, 6.6323],
        [6.8117, 6.0441, 6.6023],
        [6.7682, 6.0058, 6.5614],
        [6.7716, 6.0089, 6.5636],
        [6.8721, 6.0978, 6.6605],
        [6.9978, 6.2077, 6.7823],
        [6.7746, 6.0116, 6.

tensor([[ 8.2103, -7.3000,  7.9544, -0.0166,  0.0669, -1.5206, -0.2913,  0.8597,
         -1.9318, -1.0502, -0.6268,  0.2157]])
sigma:  tensor([[8.2103, 7.3000, 7.9544]])
torch.Size([1, 3, 3])
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 1        |
|    ep_rew_mean     | -6.01    |
| time/              |          |
|    episodes        | 128      |
|    fps             | 6        |
|    time_elapsed    | 20       |
|    total_timesteps | 128      |
| train/             |          |
|    actor_loss      | 5.45     |
|    critic_loss     | 2.11     |
|    ent_coef        | 0.0992   |
|    ent_coef_loss   | -20.9    |
|    learning_rate   | 0.0003   |
|    n_updates       | 27       |
---------------------------------
tensor([[ 8.3789, -7.4495,  8.1163,  ..., -1.0712, -0.6394,  0.2205],
        [ 8.3693, -7.4413,  8.1075,  ..., -1.0696, -0.6386,  0.2205],
        [ 8.3927, -7.4619,  8.1298,  ..., -1.0727, -0.6404,  0.2210],
        ...,
      

  File "/usr/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/home/fantasticoven/.local/lib/python3.8/site-packages/ipykernel_launcher.py", line 17, in <module>
    app.launch_new_instance()
  File "/home/fantasticoven/.local/lib/python3.8/site-packages/traitlets/config/application.py", line 846, in launch_instance
    app.start()
  File "/home/fantasticoven/.local/lib/python3.8/site-packages/ipykernel/kernelapp.py", line 707, in start
    self.io_loop.start()
  File "/home/fantasticoven/.local/lib/python3.8/site-packages/tornado/platform/asyncio.py", line 199, in start
    self.asyncio_loop.run_forever()
  File "/usr/lib/python3.8/asyncio/base_events.py", line 570, in run_forever
    self._run_once()
  File "/usr/lib/python3.8/asyncio/base_events.py", line 1859, in _run_once
    handle._run()
  File "/usr/lib/python3.8/asyncio/events.

RuntimeError: Function 'MulBackward0' returned nan values in its 1th output.

In [None]:
%load_ext tensorboard
import tensorflow as tf
import numpy as np
import datetime
import matplotlib.pyplot as plt
from tensorboard.backend.event_processing import event_accumulator

In [None]:
%tensorboard --logdir ./sac

# Experiments for batch operations

In [None]:
sigma = torch.ones(1, 3)
omiga = torch.normal(torch.zeros(1, 3), sigma)
omiga

In [None]:
def transfer(omiga):
    omiga_0, omiga_1, omiga_2 = omiga[0], omiga[1], omiga[2]
    omiga_hat = torch.tensor([[0, -omiga_2, omiga_1],
                                [omiga_2, 0, -omiga_0],
                                [-omiga_1, omiga_0, 0]])
    return omiga_hat

In [None]:
from functorch import vmap
batch_transfer = vmap(transfer)
batch_transfer(omiga)

In [None]:
from liegroups.torch import SO3
C = SO3.exp(torch.Tensor([[1,2,3],
                          [0,0,0]]))
print(torch.Tensor([[1,2,3],
                          [0,0,0]]).size())
SO3.log(C)

In [None]:
np.log(1)

# Question to ask: the original wahba problem action is (4,), in our case actions are (3,3).