In [1]:
import numpy as np
import torch
from tsGaussian.torch_tsgaussian import TangentSpaceGaussian
from stable_baselines_utils import TangentSpaceGaussian as TSG
# from pytorch3d.transforms.so3 import (
#     so3_exp_map,
#     so3_relative_angle,
# )

In [2]:
tg = TangentSpaceGaussian(None)

# Test torch_tsgaussian sample

In [3]:
R_mu = torch.eye(3).reshape((1,3,3))
sigma = torch.ones(3).reshape((1,3))

In [4]:
R_quat, R_x = tg.rsample(R_mu, sigma)

In [5]:
torch.bmm(torch.transpose(R_x, 1, 2), R_x)

tensor([[[ 1.0000e+00, -1.4901e-08, -1.4901e-08],
         [-1.4901e-08,  1.0000e+00,  5.9605e-08],
         [-1.4901e-08,  5.9605e-08,  1.0000e+00]]])

# Test torch_tsgaussian normal_term

In [6]:
sigma = torch.ones(3).reshape((1,3))
sigma

tensor([[1., 1., 1.]])

In [7]:
tg.normal_term(sigma)

tensor([15.7496])

# Test torch_tsgaussian log_map

In [8]:
R_1 = torch.eye(3).reshape((1, 3, 3))
R_2 = torch.eye(3).reshape((1, 3, 3))

In [9]:
tg.log_map(R_1, R_2)

tensor([[0., 0., 0.]])

# Test torch_tsgaussian log_probs

In [3]:
tg = TangentSpaceGaussian(None)

In [16]:
R_x = torch.eye(3).reshape((1,3,3))
R_mu = torch.eye(3).reshape((1,3,3))
# R_x = R_x.repeat(5, 1, 1)
# R_mu = R_mu.repeat(5, 1, 1)
sigma = torch.Tensor([1, 1, 1]).reshape((1,3))
sigma_mat = torch.diag_embed(sigma)

In [17]:
R_x1 = torch.eye(3).reshape((1,3,3))
_, R_x2 = tg.rsample(R_mu, sigma)
print(R_x2.squeeze().T @ R_x2.squeeze())

tensor([[1.0000e+00, 1.4901e-08, 1.1409e-08],
        [1.4901e-08, 1.0000e+00, 8.3819e-08],
        [1.1409e-08, 8.3819e-08, 1.0000e+00]])


In [18]:
lp_1 = tg.log_probs(R_x1, R_mu, sigma)
lp_2 = tg.log_probs(R_x2, R_mu, sigma)

In [19]:
for _ in range(1000):
    lp_2 = tg.log_probs(R_x2, R_mu, sigma)
    if lp_1.item() < lp_2.item():
        print('Wrong!')
        break

tensor([2.7568]) tensor([-0.3911])


In [20]:
print(np.e ** (lp_1), np.e ** (lp_2))

tensor([15.7496]) tensor([0.6763])


In [21]:
log_1 = tg.log_map(R_mu, R_x1)
log_2 = tg.log_map(R_mu, R_x2)

In [10]:
torch.bmm(torch.bmm(log_1.reshape((1, 1, 3)), torch.linalg.inv(sigma_mat)), \
                    log_1.reshape(1, 3, 1))

tensor([[[0.]]])

In [11]:
torch.bmm(torch.bmm(log_2.reshape((1, 1, 3)), torch.linalg.inv(sigma_mat)), \
                    log_2.reshape(1, 3, 1))

tensor([[[2.9278]]])

all codes run for torch_tsgaussian now, need to check it's correctness and make it into batch version.

# Test TangentSpaceGaussian actions_from_params

In [None]:
tsg = TSG(None)

In [None]:
print(tsg.distribution)

In [None]:
tsg

In [None]:
tsg.actions_from_params(torch.eye(3).reshape((1,3,3)), torch.ones(3).reshape((1,3)))

# Test TangentSpaceGaussian log_prob_from_params

In [None]:
torch.eye(3).repeat(2,1,1).size()

In [None]:
torch.ones(3).repeat(2,1).size()

In [None]:
# tsg.log_prob_from_params(torch.eye(3).repeat(2,1,1), torch.ones(3))

In [None]:
x = torch.randn(2, 4, 4)
y = torch.linalg.inv(x)
y

Again, codes can run, but need to check correctness.

# Try to run training

In [1]:
import torch
from absl import app, flags
from stable_baselines3 import SAC, PPO
from envs.wahba import Wahba
from stable_baselines_utils import CustomSACPolicy, \
    CustomCNN

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def main(argv):
    env = Wahba()
    device = torch.device('cpu')
    policy_kwargs = dict(
        features_extractor_class = CustomCNN,
        features_extractor_kwargs = dict(features_dim = 256))
    policy_kwargs['n_critics'] = 1
    policy_kwargs['share_features_extractor'] = False
    policy = CustomSACPolicy
    model = SAC(policy, env, verbose = 1, ent_coef = 'auto_0.1',
                policy_kwargs = policy_kwargs, device = device)
    model.learn(total_timesteps = 500, eval_freq = 100, n_eval_episodes = 100)
    

In [4]:
from torch import autograd
with autograd.detect_anomaly():
    env = Wahba()
    device = torch.device(
        "cuda:0") if torch.cuda.is_available() else torch.device("cpu")
    policy_kwargs = dict(
        features_extractor_class = CustomCNN,
        features_extractor_kwargs = dict(features_dim = 256))
    policy_kwargs['n_critics'] = 1
    policy_kwargs['share_features_extractor'] = False
    policy = CustomSACPolicy
    model = SAC(policy, env, verbose = 1, ent_coef = 'auto_0.1',
                policy_kwargs = policy_kwargs, device = device)
    model.learn(total_timesteps = 50000, eval_freq = 100, n_eval_episodes = 100)

  with autograd.detect_anomaly():


Using cuda:0 device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 1        |
|    ep_rew_mean     | -7.3     |
| time/              |          |
|    episodes        | 4        |
|    fps             | 789      |
|    time_elapsed    | 0        |
|    total_timesteps | 4        |
---------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 1        |
|    ep_rew_mean     | -6.89    |
| time/              |          |
|    episodes        | 8        |
|    fps             | 855      |
|    time_elapsed    | 0        |
|    total_timesteps | 8        |
---------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 1        |
|    ep_rew_mean     | -6.28    |
| time/              |          |
|    episodes        | 12       |
|    fps           

---------------------------------
| rollout/           |          |
|    ep_len_mean     | 1        |
|    ep_rew_mean     | -5.84    |
| time/              |          |
|    episodes        | 100      |
|    fps             | 774      |
|    time_elapsed    | 0        |
|    total_timesteps | 100      |
---------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 1        |
|    ep_rew_mean     | -5.68    |
| time/              |          |
|    episodes        | 104      |
|    fps             | 166      |
|    time_elapsed    | 0        |
|    total_timesteps | 104      |
| train/             |          |
|    actor_loss      | 0.263    |
|    critic_loss     | 18.8     |
|    ent_coef        | 0.0999   |
|    ent_coef_loss   | -5.03    |
|    learning_rate   | 0.0003   |
|    n_updates       | 3        |
---------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_me

---------------------------------
| rollout/           |          |
|    ep_len_mean     | 1        |
|    ep_rew_mean     | -5.77    |
| time/              |          |
|    episodes        | 160      |
|    fps             | 16       |
|    time_elapsed    | 9        |
|    total_timesteps | 160      |
| train/             |          |
|    actor_loss      | 5.15     |
|    critic_loss     | 2.32     |
|    ent_coef        | 0.0979   |
|    ent_coef_loss   | -35.5    |
|    learning_rate   | 0.0003   |
|    n_updates       | 59       |
---------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 1        |
|    ep_rew_mean     | -5.7     |
| time/              |          |
|    episodes        | 164      |
|    fps             | 15       |
|    time_elapsed    | 10       |
|    total_timesteps | 164      |
| train/             |          |
|    actor_loss      | 4.78     |
|    critic_loss     | 2.43     |
|    ent_coef 

---------------------------------
| rollout/           |          |
|    ep_len_mean     | 1        |
|    ep_rew_mean     | -5.75    |
| time/              |          |
|    episodes        | 220      |
|    fps             | 11       |
|    time_elapsed    | 19       |
|    total_timesteps | 220      |
| train/             |          |
|    actor_loss      | 4.21     |
|    critic_loss     | 1.97     |
|    ent_coef        | 0.0955   |
|    ent_coef_loss   | -50.9    |
|    learning_rate   | 0.0003   |
|    n_updates       | 119      |
---------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 1        |
|    ep_rew_mean     | -5.79    |
| time/              |          |
|    episodes        | 224      |
|    fps             | 11       |
|    time_elapsed    | 20       |
|    total_timesteps | 224      |
| train/             |          |
|    actor_loss      | 4.2      |
|    critic_loss     | 2.35     |
|    ent_coef 

---------------------------------
| rollout/           |          |
|    ep_len_mean     | 1        |
|    ep_rew_mean     | -5.76    |
| time/              |          |
|    episodes        | 280      |
|    fps             | 9        |
|    time_elapsed    | 29       |
|    total_timesteps | 280      |
| train/             |          |
|    actor_loss      | 3.87     |
|    critic_loss     | 2.52     |
|    ent_coef        | 0.0932   |
|    ent_coef_loss   | -61.2    |
|    learning_rate   | 0.0003   |
|    n_updates       | 179      |
---------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 1        |
|    ep_rew_mean     | -5.71    |
| time/              |          |
|    episodes        | 284      |
|    fps             | 9        |
|    time_elapsed    | 29       |
|    total_timesteps | 284      |
| train/             |          |
|    actor_loss      | 4.15     |
|    critic_loss     | 2.64     |
|    ent_coef 

---------------------------------
| rollout/           |          |
|    ep_len_mean     | 1        |
|    ep_rew_mean     | -5.66    |
| time/              |          |
|    episodes        | 340      |
|    fps             | 8        |
|    time_elapsed    | 38       |
|    total_timesteps | 340      |
| train/             |          |
|    actor_loss      | 3.53     |
|    critic_loss     | 2.4      |
|    ent_coef        | 0.091    |
|    ent_coef_loss   | -68.7    |
|    learning_rate   | 0.0003   |
|    n_updates       | 239      |
---------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 1        |
|    ep_rew_mean     | -5.46    |
| time/              |          |
|    episodes        | 344      |
|    fps             | 8        |
|    time_elapsed    | 39       |
|    total_timesteps | 344      |
| train/             |          |
|    actor_loss      | 3.66     |
|    critic_loss     | 2.75     |
|    ent_coef 

---------------------------------
| rollout/           |          |
|    ep_len_mean     | 1        |
|    ep_rew_mean     | -5.31    |
| time/              |          |
|    episodes        | 400      |
|    fps             | 8        |
|    time_elapsed    | 47       |
|    total_timesteps | 400      |
| train/             |          |
|    actor_loss      | 3.1      |
|    critic_loss     | 2.72     |
|    ent_coef        | 0.0889   |
|    ent_coef_loss   | -74.8    |
|    learning_rate   | 0.0003   |
|    n_updates       | 299      |
---------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 1        |
|    ep_rew_mean     | -5.27    |
| time/              |          |
|    episodes        | 404      |
|    fps             | 8        |
|    time_elapsed    | 48       |
|    total_timesteps | 404      |
| train/             |          |
|    actor_loss      | 3.48     |
|    critic_loss     | 2.33     |
|    ent_coef 

---------------------------------
| rollout/           |          |
|    ep_len_mean     | 1        |
|    ep_rew_mean     | -5.43    |
| time/              |          |
|    episodes        | 460      |
|    fps             | 8        |
|    time_elapsed    | 57       |
|    total_timesteps | 460      |
| train/             |          |
|    actor_loss      | 3.19     |
|    critic_loss     | 2.22     |
|    ent_coef        | 0.0869   |
|    ent_coef_loss   | -79.8    |
|    learning_rate   | 0.0003   |
|    n_updates       | 359      |
---------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 1        |
|    ep_rew_mean     | -5.52    |
| time/              |          |
|    episodes        | 464      |
|    fps             | 8        |
|    time_elapsed    | 57       |
|    total_timesteps | 464      |
| train/             |          |
|    actor_loss      | 3.11     |
|    critic_loss     | 2.47     |
|    ent_coef 

---------------------------------
| rollout/           |          |
|    ep_len_mean     | 1        |
|    ep_rew_mean     | -5.46    |
| time/              |          |
|    episodes        | 520      |
|    fps             | 7        |
|    time_elapsed    | 66       |
|    total_timesteps | 520      |
| train/             |          |
|    actor_loss      | 2.99     |
|    critic_loss     | 2.38     |
|    ent_coef        | 0.0849   |
|    ent_coef_loss   | -84.4    |
|    learning_rate   | 0.0003   |
|    n_updates       | 419      |
---------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 1        |
|    ep_rew_mean     | -5.48    |
| time/              |          |
|    episodes        | 524      |
|    fps             | 7        |
|    time_elapsed    | 67       |
|    total_timesteps | 524      |
| train/             |          |
|    actor_loss      | 3.2      |
|    critic_loss     | 2.21     |
|    ent_coef 

---------------------------------
| rollout/           |          |
|    ep_len_mean     | 1        |
|    ep_rew_mean     | -5.57    |
| time/              |          |
|    episodes        | 580      |
|    fps             | 7        |
|    time_elapsed    | 75       |
|    total_timesteps | 580      |
| train/             |          |
|    actor_loss      | 3.45     |
|    critic_loss     | 2.31     |
|    ent_coef        | 0.083    |
|    ent_coef_loss   | -88.4    |
|    learning_rate   | 0.0003   |
|    n_updates       | 479      |
---------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 1        |
|    ep_rew_mean     | -5.64    |
| time/              |          |
|    episodes        | 584      |
|    fps             | 7        |
|    time_elapsed    | 76       |
|    total_timesteps | 584      |
| train/             |          |
|    actor_loss      | 2.59     |
|    critic_loss     | 2.4      |
|    ent_coef 

---------------------------------
| rollout/           |          |
|    ep_len_mean     | 1        |
|    ep_rew_mean     | -5.83    |
| time/              |          |
|    episodes        | 640      |
|    fps             | 7        |
|    time_elapsed    | 85       |
|    total_timesteps | 640      |
| train/             |          |
|    actor_loss      | 2.74     |
|    critic_loss     | 2.21     |
|    ent_coef        | 0.0812   |
|    ent_coef_loss   | -92      |
|    learning_rate   | 0.0003   |
|    n_updates       | 539      |
---------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 1        |
|    ep_rew_mean     | -5.82    |
| time/              |          |
|    episodes        | 644      |
|    fps             | 7        |
|    time_elapsed    | 85       |
|    total_timesteps | 644      |
| train/             |          |
|    actor_loss      | 3.27     |
|    critic_loss     | 2.12     |
|    ent_coef 

---------------------------------
| rollout/           |          |
|    ep_len_mean     | 1        |
|    ep_rew_mean     | -5.81    |
| time/              |          |
|    episodes        | 700      |
|    fps             | 7        |
|    time_elapsed    | 94       |
|    total_timesteps | 700      |
| train/             |          |
|    actor_loss      | 3        |
|    critic_loss     | 2.17     |
|    ent_coef        | 0.0794   |
|    ent_coef_loss   | -95.4    |
|    learning_rate   | 0.0003   |
|    n_updates       | 599      |
---------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 1        |
|    ep_rew_mean     | -5.86    |
| time/              |          |
|    episodes        | 704      |
|    fps             | 7        |
|    time_elapsed    | 95       |
|    total_timesteps | 704      |
| train/             |          |
|    actor_loss      | 2.75     |
|    critic_loss     | 2.34     |
|    ent_coef 

---------------------------------
| rollout/           |          |
|    ep_len_mean     | 1        |
|    ep_rew_mean     | -6.02    |
| time/              |          |
|    episodes        | 760      |
|    fps             | 7        |
|    time_elapsed    | 103      |
|    total_timesteps | 760      |
| train/             |          |
|    actor_loss      | 2.84     |
|    critic_loss     | 2.06     |
|    ent_coef        | 0.0777   |
|    ent_coef_loss   | -98.8    |
|    learning_rate   | 0.0003   |
|    n_updates       | 659      |
---------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 1        |
|    ep_rew_mean     | -6.14    |
| time/              |          |
|    episodes        | 764      |
|    fps             | 7        |
|    time_elapsed    | 104      |
|    total_timesteps | 764      |
| train/             |          |
|    actor_loss      | 3.37     |
|    critic_loss     | 2.38     |
|    ent_coef 

---------------------------------
| rollout/           |          |
|    ep_len_mean     | 1        |
|    ep_rew_mean     | -6.04    |
| time/              |          |
|    episodes        | 820      |
|    fps             | 7        |
|    time_elapsed    | 113      |
|    total_timesteps | 820      |
| train/             |          |
|    actor_loss      | 2.94     |
|    critic_loss     | 2.12     |
|    ent_coef        | 0.076    |
|    ent_coef_loss   | -102     |
|    learning_rate   | 0.0003   |
|    n_updates       | 719      |
---------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 1        |
|    ep_rew_mean     | -6.04    |
| time/              |          |
|    episodes        | 824      |
|    fps             | 7        |
|    time_elapsed    | 114      |
|    total_timesteps | 824      |
| train/             |          |
|    actor_loss      | 2.94     |
|    critic_loss     | 1.97     |
|    ent_coef 

---------------------------------
| rollout/           |          |
|    ep_len_mean     | 1        |
|    ep_rew_mean     | -5.85    |
| time/              |          |
|    episodes        | 880      |
|    fps             | 7        |
|    time_elapsed    | 123      |
|    total_timesteps | 880      |
| train/             |          |
|    actor_loss      | 2.82     |
|    critic_loss     | 2.09     |
|    ent_coef        | 0.0744   |
|    ent_coef_loss   | -105     |
|    learning_rate   | 0.0003   |
|    n_updates       | 779      |
---------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 1        |
|    ep_rew_mean     | -5.72    |
| time/              |          |
|    episodes        | 884      |
|    fps             | 7        |
|    time_elapsed    | 124      |
|    total_timesteps | 884      |
| train/             |          |
|    actor_loss      | 3.54     |
|    critic_loss     | 2.12     |
|    ent_coef 

KeyboardInterrupt: 

In [None]:
model

In [None]:
random_action = env.action_space.sample()
random_action

In [None]:
obs = env.step(random_action)[0]
obs.shape

In [None]:
act = model.predict(obs)

In [None]:
act

In [None]:
%load_ext tensorboard
import tensorflow as tf
import numpy as np
import datetime
import matplotlib.pyplot as plt
from tensorboard.backend.event_processing import event_accumulator

In [None]:
%tensorboard --logdir ./sac

# Experiments for batch operations

In [None]:
sigma = torch.ones(1, 3)
omiga = torch.normal(torch.zeros(1, 3), sigma)
omiga

In [None]:
def transfer(omiga):
    omiga_0, omiga_1, omiga_2 = omiga[0], omiga[1], omiga[2]
    omiga_hat = torch.tensor([[0, -omiga_2, omiga_1],
                                [omiga_2, 0, -omiga_0],
                                [-omiga_1, omiga_0, 0]])
    return omiga_hat

In [None]:
from functorch import vmap
batch_transfer = vmap(transfer)
batch_transfer(omiga)

In [None]:
from liegroups.torch import SO3
C = SO3.exp(torch.Tensor([[1,2,3],
                          [0,0,0]]))
print(torch.Tensor([[1,2,3],
                          [0,0,0]]).size())
SO3.log(C)

In [None]:
np.log(1)

# Question to ask: the original wahba problem action is (4,), in our case actions are (3,3).

In [1]:
import torch
import numpy as np
# from liegroups.torch import SO3
# from scipy.spatial.transform import Rotation as R
from pytorch3d.transforms.so3 import so3_exp_map, so3_log_map
from pytorch3d.transforms import matrix_to_quaternion, quaternion_to_matrix

In [2]:
def normal_term(sigma):
        """ Compute normalization term in the pdf of tangent space Gaussian
            Return a scalar
        """
        return torch.sqrt((2 * np.pi) ** 3 * torch.det(sigma))
    
def log_map(R_1, R_2):
        """ Log map term in pdf of tangent space Gaussian
            Return a 3d vector.
        """
        rot_mat = torch.bmm(torch.transpose(R_1, 1, 2), R_2)
        return so3_log_map(torch.bmm(torch.transpose(R_1, 1, 2), R_2), eps = 0.0001)

def log_probs(R_x, R_mu, sigma):
        """ Log probability of a given R_x with mean R_mu
            Return a probability
        """

        log_term = log_map(R_mu, R_x)
        batch_size = R_x.shape[0]
        sigma_mat = torch.diag_embed(sigma)
        log_prob = torch.bmm(torch.bmm(log_term.reshape((batch_size, 1, 3)), torch.linalg.inv(sigma_mat)), \
                    log_term.reshape(batch_size, 3, 1)).reshape((batch_size,)) - torch.log(normal_term(sigma_mat))
        return log_prob

In [3]:
R_x = torch.eye(3).reshape((1,3,3))
R_mu = torch.eye(3).reshape((1,3,3))
sigma = torch.Tensor([1e-4, 1e-4, 1e-4]).reshape((1,3))

In [4]:
lp = log_probs(R_x, R_mu, sigma)
print(lp)
torch.exp(lp)

tensor([11.0587])


tensor([63493.6289])

In [5]:
log_map(R_x, R_mu)

tensor([[0., 0., 0.]])

In [6]:
so3_log_map(torch.eye(3).reshape((1,3,3)))

tensor([[0., 0., 0.]])

In [7]:
sigma_mat = torch.diag_embed(sigma)
sigma_mat

tensor([[[1.0000e-04, 0.0000e+00, 0.0000e+00],
         [0.0000e+00, 1.0000e-04, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 1.0000e-04]]])

In [8]:
normal_term(sigma_mat)

tensor([1.5750e-05])

In [9]:
torch.det(sigma_mat)

tensor([1.0000e-12])

In [10]:
torch.linalg.inv(sigma_mat)

tensor([[[10000.,     0.,    -0.],
         [    0., 10000.,    -0.],
         [    0.,     0., 10000.]]])

In [11]:
log_term = torch.Tensor([[0., 0., 0.]])
batch_size = 1
torch.bmm(torch.bmm(log_term.reshape((batch_size, 1, 3)), torch.linalg.inv(sigma_mat)), \
                    log_term.reshape(batch_size, 3, 1)).reshape((batch_size,))

tensor([0.])

In [12]:
-np.log((1e-2)*np.sqrt(2 * np.pi)) 

3.6862316527834187

In [13]:
np.exp(3.6862)

39.89296529676823

In [14]:
1 / normal_term(sigma_mat)

tensor([63493.6406])