In [1]:
import numpy as np
import torch
from tsGaussian.torch_tsgaussian import TangentSpaceGaussian
from stable_baselines_utils import TangentSpaceGaussian as TSG
# from pytorch3d.transforms.so3 import (
#     so3_exp_map,
#     so3_relative_angle,
# )

In [2]:
tg = TangentSpaceGaussian(None)

# Test liegroup torch

In [3]:
from liegroups.torch import SO3

In [4]:
C = SO3.exp(torch.Tensor([[1,2,3],
                         [0,0,0]]))
C

<liegroups.torch.so3.SO3Matrix>
| tensor([[[-0.6949,  0.7135,  0.0893],
|          [-0.1920, -0.3038,  0.9332],
|          [ 0.6930,  0.6313,  0.3481]],
| 
|         [[ 1.0000,  0.0000,  0.0000],
|          [ 0.0000,  1.0000,  0.0000],
|          [ 0.0000,  0.0000,  1.0000]]])

# Test torch_tsgaussian sample

In [5]:
R_mu = torch.eye(3).reshape((1,3,3))
sigma = torch.ones(3).reshape((1,3))

In [6]:
R_quat, R_x = tg.rsample(R_mu, sigma)

sigma:  tensor([[1., 1., 1.]])
torch.Size([1, 3, 3])


In [7]:
torch.bmm(torch.transpose(R_x, 1, 2), R_x)

tensor([[[ 1.0000e+00, -8.9407e-08, -2.2352e-08],
         [-8.9407e-08,  1.0000e+00, -3.7253e-08],
         [-2.2352e-08, -3.7253e-08,  1.0000e+00]]])

# Test torch_tsgaussian normal_term

In [8]:
sigma = torch.ones(3).reshape((1,3))
sigma

tensor([[1., 1., 1.]])

In [9]:
tg.normal_term(sigma)

tensor([15.7496])

# Test torch_tsgaussian log_map

In [10]:
R_1 = torch.eye(3).reshape((1, 3, 3))
R_2 = torch.eye(3).reshape((1, 3, 3))

In [11]:
tg.log_map(R_1, R_2)

tensor([0., 0., 0.])

# Test torch_tsgaussian log_probs

In [12]:
R_x = torch.eye(3).reshape((1,3,3))
R_mu = torch.zeros(3,3).reshape((1,3,3))
R_x = R_x.repeat(5, 1, 1)
R_mu = R_mu.repeat(5, 1, 1)
sigma = torch.ones(3).reshape((1,3))

In [13]:
# tg.log_probs(R_x, R_mu, sigma)

In [14]:
np.e ** (-2.7568)

0.06349462641817973

all codes run for torch_tsgaussian now, need to check it's correctness and make it into batch version.

# Test TangentSpaceGaussian actions_from_params

In [15]:
tsg = TSG(None)

In [16]:
print(tsg.distribution)

<tsGaussian.torch_tsgaussian.TangentSpaceGaussian object at 0x7fb1fc697160>


In [17]:
tsg

<stable_baselines_utils.TangentSpaceGaussian at 0x7fb1fc697370>

In [18]:
tsg.actions_from_params(torch.eye(3).reshape((1,3,3)), torch.ones(3).reshape((1,3)))

sigma:  tensor([[1., 1., 1.]])
torch.Size([1, 3, 3])


(tensor([[ 0.5602, -0.3354,  0.1288,  0.7464]]),
 tensor([[[ 0.7418, -0.5681, -0.3563],
          [-0.1835,  0.3391, -0.9227],
          [ 0.6450,  0.7498,  0.1473]]]))

# Test TangentSpaceGaussian log_prob_from_params

In [19]:
torch.eye(3).repeat(2,1,1).size()

torch.Size([2, 3, 3])

In [20]:
torch.ones(3).repeat(2,1).size()

torch.Size([2, 3])

In [21]:
# tsg.log_prob_from_params(torch.eye(3).repeat(2,1,1), torch.ones(3))

In [22]:
x = torch.randn(2, 4, 4)
y = torch.linalg.inv(x)
y

tensor([[[ 1.0942,  0.3810, -1.5905,  0.6362],
         [-1.7363, -0.7394,  0.2128, -0.2090],
         [ 2.6925,  2.3252, -1.4282,  0.6790],
         [ 5.7760,  4.4176, -1.8317,  1.6719]],

        [[ 0.2663,  0.2239, -1.1285, -0.3589],
         [ 0.2848, -0.1602, -0.4384, -1.0846],
         [-0.0658,  0.4471,  0.0916, -0.2542],
         [ 0.5830,  0.1490, -0.7578, -0.3658]]])

Again, codes can run, but need to check correctness.

# Try to run training

In [23]:
import torch
from absl import app, flags
from stable_baselines3 import SAC, PPO
from envs.wahba import Wahba
from stable_baselines_utils import CustomSACPolicy, \
    CustomCNN

In [24]:
def main(argv):
    env = Wahba()
    device = torch.device('cpu')
    policy_kwargs = dict(
        features_extractor_class = CustomCNN,
        features_extractor_kwargs = dict(features_dim = 256))
    policy_kwargs['n_critics'] = 1
    policy_kwargs['share_features_extractor'] = False
    policy = CustomSACPolicy
    model = SAC(policy, env, verbose = 1, ent_coef = 0.0,
                policy_kwargs = policy_kwargs, device = device)
    model.learn(total_timesteps = 500, eval_freq = 100, n_eval_episodes = 100)

In [32]:
from torch import autograd
with autograd.detect_anomaly():
    main(None)

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 1        |
|    ep_rew_mean     | -6.18    |
| time/              |          |
|    episodes        | 4        |
|    fps             | 428      |
|    time_elapsed    | 0        |
|    total_timesteps | 4        |
---------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 1        |
|    ep_rew_mean     | -6.36    |
| time/              |          |
|    episodes        | 8        |
|    fps             | 484      |
|    time_elapsed    | 0        |
|    total_timesteps | 8        |
---------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 1        |
|    ep_rew_mean     | -5.67    |
| time/              |          |
|    episodes        | 12       |
|    fps             |

---------------------------------
| rollout/           |          |
|    ep_len_mean     | 1        |
|    ep_rew_mean     | -5.88    |
| time/              |          |
|    episodes        | 100      |
|    fps             | 805      |
|    time_elapsed    | 0        |
|    total_timesteps | 100      |
---------------------------------


  with autograd.detect_anomaly():


tensor([[-0.0461, -0.0411,  0.0463,  0.0229,  0.0401, -0.0423,  0.0588, -0.0149,
         -0.0440, -0.0059, -0.0116,  0.0519]])
sigma:  tensor([[0.0461, 0.0411, 0.0463]])
torch.Size([1, 3, 3])
tensor([[-0.0462, -0.0409,  0.0467,  ..., -0.0061, -0.0115,  0.0517],
        [-0.0461, -0.0406,  0.0467,  ..., -0.0058, -0.0116,  0.0519],
        [-0.0456, -0.0410,  0.0467,  ..., -0.0058, -0.0115,  0.0520],
        ...,
        [-0.0451, -0.0410,  0.0465,  ..., -0.0053, -0.0115,  0.0524],
        [-0.0455, -0.0410,  0.0463,  ..., -0.0060, -0.0116,  0.0520],
        [-0.0458, -0.0412,  0.0463,  ..., -0.0063, -0.0117,  0.0517]],
       grad_fn=<AddmmBackward0>)
sigma:  tensor([[0.0462, 0.0409, 0.0467],
        [0.0461, 0.0406, 0.0467],
        [0.0456, 0.0410, 0.0467],
        [0.0462, 0.0407, 0.0466],
        [0.0461, 0.0412, 0.0462],
        [0.0458, 0.0412, 0.0463],
        [0.0458, 0.0413, 0.0464],
        [0.0458, 0.0408, 0.0465],
        [0.0452, 0.0409, 0.0465],
        [0.0458, 0.0412, 0

tensor([[-0.0462, -0.0409,  0.0467,  ..., -0.0061, -0.0115,  0.0517],
        [-0.0461, -0.0406,  0.0467,  ..., -0.0058, -0.0116,  0.0519],
        [-0.0456, -0.0410,  0.0467,  ..., -0.0058, -0.0115,  0.0520],
        ...,
        [-0.0451, -0.0410,  0.0465,  ..., -0.0053, -0.0115,  0.0524],
        [-0.0455, -0.0410,  0.0463,  ..., -0.0060, -0.0116,  0.0520],
        [-0.0458, -0.0412,  0.0463,  ..., -0.0063, -0.0117,  0.0517]])
sigma:  tensor([[0.0462, 0.0409, 0.0467],
        [0.0461, 0.0406, 0.0467],
        [0.0456, 0.0410, 0.0467],
        [0.0462, 0.0407, 0.0466],
        [0.0461, 0.0412, 0.0462],
        [0.0458, 0.0412, 0.0463],
        [0.0458, 0.0413, 0.0464],
        [0.0458, 0.0408, 0.0465],
        [0.0452, 0.0409, 0.0465],
        [0.0458, 0.0412, 0.0466],
        [0.0461, 0.0408, 0.0464],
        [0.0461, 0.0409, 0.0464],
        [0.0460, 0.0410, 0.0464],
        [0.0456, 0.0409, 0.0464],
        [0.0460, 0.0411, 0.0462],
        [0.0461, 0.0411, 0.0463],
        [0.045

tensor([[-0.0462, -0.0409,  0.0467,  0.0229,  0.0401, -0.0420,  0.0587, -0.0150,
         -0.0440, -0.0059, -0.0115,  0.0520]])
sigma:  tensor([[0.0462, 0.0409, 0.0467]])
torch.Size([1, 3, 3])
tensor([[-0.0461, -0.0409,  0.0464,  ..., -0.0058, -0.0117,  0.0519],
        [-0.0456, -0.0406,  0.0465,  ..., -0.0054, -0.0117,  0.0519],
        [-0.0462, -0.0407,  0.0468,  ..., -0.0060, -0.0115,  0.0518],
        ...,
        [-0.0456, -0.0413,  0.0463,  ..., -0.0061, -0.0116,  0.0518],
        [-0.0462, -0.0409,  0.0467,  ..., -0.0061, -0.0115,  0.0517],
        [-0.0457, -0.0413,  0.0466,  ..., -0.0057, -0.0114,  0.0523]],
       grad_fn=<AddmmBackward0>)
sigma:  tensor([[0.0461, 0.0409, 0.0464],
        [0.0456, 0.0406, 0.0465],
        [0.0462, 0.0407, 0.0468],
        [0.0460, 0.0409, 0.0465],
        [0.0461, 0.0409, 0.0464],
        [0.0461, 0.0408, 0.0464],
        [0.0455, 0.0410, 0.0466],
        [0.0457, 0.0407, 0.0465],
        [0.0461, 0.0409, 0.0465],
        [0.0461, 0.0408, 0

tensor([[-0.0461, -0.0409,  0.0464,  ..., -0.0058, -0.0117,  0.0519],
        [-0.0456, -0.0406,  0.0465,  ..., -0.0054, -0.0117,  0.0519],
        [-0.0462, -0.0407,  0.0468,  ..., -0.0060, -0.0115,  0.0518],
        ...,
        [-0.0456, -0.0413,  0.0463,  ..., -0.0061, -0.0116,  0.0518],
        [-0.0462, -0.0409,  0.0467,  ..., -0.0061, -0.0115,  0.0517],
        [-0.0457, -0.0413,  0.0466,  ..., -0.0057, -0.0114,  0.0523]])
sigma:  tensor([[0.0461, 0.0409, 0.0464],
        [0.0456, 0.0406, 0.0465],
        [0.0462, 0.0407, 0.0468],
        [0.0460, 0.0409, 0.0465],
        [0.0461, 0.0409, 0.0464],
        [0.0461, 0.0408, 0.0464],
        [0.0455, 0.0410, 0.0466],
        [0.0457, 0.0407, 0.0465],
        [0.0461, 0.0409, 0.0465],
        [0.0461, 0.0408, 0.0464],
        [0.0454, 0.0412, 0.0467],
        [0.0461, 0.0408, 0.0464],
        [0.0458, 0.0408, 0.0465],
        [0.0456, 0.0406, 0.0470],
        [0.0454, 0.0412, 0.0469],
        [0.0461, 0.0406, 0.0467],
        [0.045

KeyboardInterrupt: 

In [None]:
%load_ext tensorboard
import tensorflow as tf
import numpy as np
import datetime
import matplotlib.pyplot as plt
from tensorboard.backend.event_processing import event_accumulator

In [None]:
%tensorboard --logdir ./sac

# Experiments for batch operations

In [None]:
sigma = torch.ones(1, 3)
omiga = torch.normal(torch.zeros(1, 3), sigma)
omiga

In [None]:
def transfer(omiga):
    omiga_0, omiga_1, omiga_2 = omiga[0], omiga[1], omiga[2]
    omiga_hat = torch.tensor([[0, -omiga_2, omiga_1],
                                [omiga_2, 0, -omiga_0],
                                [-omiga_1, omiga_0, 0]])
    return omiga_hat

In [None]:
from functorch import vmap
batch_transfer = vmap(transfer)
batch_transfer(omiga)

In [None]:
from liegroups.torch import SO3
C = SO3.exp(torch.Tensor([[1,2,3],
                          [0,0,0]]))
print(torch.Tensor([[1,2,3],
                          [0,0,0]]).size())
SO3.log(C)

In [None]:
np.log(1)

# Question to ask: the original wahba problem action is (4,), in our case actions are (3,3).