In [1]:
import numpy as np
import torch
from tsGaussian.torch_tsgaussian import TangentSpaceGaussian
from stable_baselines_utils import TangentSpaceGaussian as TSG

In [2]:
tg = TangentSpaceGaussian(None)

# Test torch_tsgaussian sample

In [3]:
R_mu = torch.Tensor([[1, 0, 0]])
sigma = torch.ones(3).reshape((1,3))

In [4]:
tg.rsample(R_mu, sigma)

tensor([[[ 0.0000,  0.4505,  0.7278],
         [-0.4505,  0.0000, -0.5170],
         [-0.7278,  0.5170,  0.0000]]])

# Test torch_tsgaussian normal_term

In [5]:
sigma = torch.ones(3).reshape((1,3))
sigma

tensor([[1., 1., 1.]])

In [6]:
tg.normal_term(sigma)

tensor([15.7496])

# Test torch_tsgaussian log_map

In [3]:
R_1 = torch.eye(3).reshape((1, 3, 3))
R_2 = torch.eye(3).reshape((1, 3, 3))

In [4]:
tg.log_map(R_1, R_2)

torch.Size([1, 3, 3])


tensor([0., 0., 0.])

In [5]:
# Not orthogonal matrices ?
R_1 = torch.randn(3,3)

In [6]:
tg.log_map(R_1, R_2)

IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)

# Test torch_tsgaussian log_probs

In [3]:
R_x = torch.eye(3).reshape((1,3,3))
R_mu = torch.zeros(3,3).reshape((1,3,3))
R_x = R_x.repeat(5, 1, 1)
R_mu = R_mu.repeat(5, 1, 1)
sigma = torch.ones(3).reshape((1,3))

In [4]:
tg.log_probs(R_x, R_mu, sigma)

torch.Size([5, 3, 3])
log size:  torch.Size([5, 3])
5
torch.Size([5, 1, 3])
torch.Size([5, 3, 3])
torch.Size([5, 3])


tensor([-2.7568, -2.7568, -2.7568, -2.7568, -2.7568])

In [5]:
np.e ** (-2.7568)

0.06349462641817973

all codes run for torch_tsgaussian now, need to check it's correctness and make it into batch version.

# Test TangentSpaceGaussian actions_from_params

In [None]:
tsg = TSG(None)

In [None]:
print(tsg.distribution)

In [None]:
tsg

In [None]:
tsg.actions_from_params(torch.eye(3), torch.eye(3).reshape((1,3,3)))

# Test TangentSpaceGaussian log_prob_from_params

In [None]:
tsg.log_prob_from_params(torch.eye(3), torch.eye(3).reshape((1,3,3)))

Again, codes can run, but need to check correctness.

# Try to run training

In [None]:
import torch
from absl import app, flags
from stable_baselines3 import SAC, PPO
from envs.wahba import Wahba
from stable_baselines_utils import CustomSACPolicy, \
    CustomCNN

In [None]:
def main(argv):
    env = Wahba()
    device = torch.device('cpu')
    policy_kwargs = dict(
        features_extractor_class = CustomCNN,
        features_extractor_kwargs = dict(features_dim = 256))
    policy_kwargs['n_critics'] = 1
    policy_kwargs['share_features_extractor'] = False
    policy = CustomSACPolicy
    model = SAC(policy, env, verbose = 1, ent_coef = 'auto_0.1',
                policy_kwargs = policy_kwargs, device = device)
    model.learn(total_timesteps = 110, eval_freq = 5, n_eval_episodes = 5)

In [None]:
main(None)

# Experiments for batch operations

In [None]:
sigma = torch.ones(1, 3)
omiga = torch.normal(torch.zeros(1, 3), sigma)
omiga

In [None]:
def transfer(omiga):
    omiga_0, omiga_1, omiga_2 = omiga[0], omiga[1], omiga[2]
    omiga_hat = torch.tensor([[0, -omiga_2, omiga_1],
                                [omiga_2, 0, -omiga_0],
                                [-omiga_1, omiga_0, 0]])
    return omiga_hat

In [None]:
from functorch import vmap
batch_transfer = vmap(transfer)
batch_transfer(omiga)

In [5]:
from liegroups.torch import SO3
C = SO3.exp(torch.Tensor([[1,2,3],
                          [0,0,0]]))
print(torch.Tensor([[1,2,3],
                          [0,0,0]]).size())
SO3.log(C)

torch.Size([2, 3])


tensor([[-0.6793, -1.3585, -2.0378],
        [ 0.0000,  0.0000,  0.0000]])

In [12]:
np.log(1)

0.0