In [59]:
import pandas as pd
from ast import literal_eval
import pytz
from imitation.data.types import Trajectory
import numpy as np

data = pd.read_csv('data/obs_acs.csv')
data['datetime'] = pd.to_datetime(
    data['timestamp'], unit='ms', utc=True).dt.tz_convert(pytz.timezone('US/Mountain'))
data.set_index(['datetime'], inplace=True)
data.drop(['timestamp'], axis=1, inplace=True)
data[data.columns] = data[data.columns].applymap(literal_eval, na_action='ignore')
dataset = pd.DataFrame()

for col in data.columns:
  col_data = data[col]
  dataset[f'{col}_x'] = col_data.apply(lambda x: x if not x == x else x[0])
  dataset[f'{col}_y'] = col_data.apply(lambda x: x if not x == x else x[1])
  dataset[f'{col}_z'] = col_data.apply(lambda x: x if not x == x else x[2])
  if col.endswith('rot'):
    dataset[f'{col}_w'] = col_data.apply(lambda x: x if not x == x else x[3])

prev_idx = dataset.index[0]
batches = []
for i in dataset[dataset['act_pos_x'].isna()].index:
  if len(batches) == 0:
    mask = (dataset.index >= prev_idx) & (dataset.index <= i)
  else:
    mask = (dataset.index > prev_idx) & (dataset.index <= i)
  prev_idx = i
  obs = dataset[mask].values[:,:7].astype(np.float32)
  acs = dataset[mask].values[:,7:].astype(np.float32)[:-1,:]
  print(obs.shape)
  print(acs.shape)
  batches.append(Trajectory(obs, acs, None, False))


(67, 7)
(66, 7)
(41, 7)
(40, 7)
(52, 7)
(51, 7)


In [24]:
import gym
from gym import spaces
import numpy as np
from scipy.spatial.transform import Rotation as R

class CustomEnv(gym.Env):

  def __init__(self, rng):
    super().__init__()
    self.observation_space = spaces.Box(-1, 1, shape=(7,), dtype=np.float32)
    self.action_space = spaces.Box(-1, 1, shape=(7,), dtype=np.float32)
    self.rng = rng

  def get_matrix(self, flat: np.array):
    pos = flat[:3]
    rot = flat[3:]
    rot_matrix = R.from_quat(rot).as_matrix()
    res = np.zeros(shape=(4, 4), dtype=np.float32)
    res[:3, :3] = rot_matrix
    res[:3, 3] = pos
    res[3, 3] = 1
    return res

  def destruct_matrix(self, matrix: np.array):
    res = np.zeros(shape=(7,), dtype=np.float32)
    res[:3] = matrix[:3, 3]
    rot = R.from_matrix(matrix[:3, :3]).as_quat();
    res[3:] = rot
    return res

  def reset(self):
    self._state = (self.rng.random(7).astype('f') -0.5) * 2

    self._info = self.get_matrix(self._state)

    return self._state


  def step(self, action):
    act_matrix = self.get_matrix(action)
    try:
      inversed = np.linalg.inv(act_matrix)
      self._info = np.matmul(self._info, inversed)
      self._state = self.destruct_matrix(self._info)
      return self._state, 0.0, False, {}
    except:
      print("singular matrix")
      return self._state, 0.0, True, {}


In [33]:
from stable_baselines3.common.vec_env import SubprocVecEnv
from stable_baselines3.common.monitor import Monitor
rng = np.random.default_rng(12345)
venv = SubprocVecEnv([lambda: Monitor(CustomEnv(rng))]*1)
# venv = Monitor(CustomEnv())


In [67]:
from imitation.algorithms import bc
from stable_baselines3 import PPO
from stable_baselines3.ppo import MlpPolicy

learner = PPO(env=venv, policy=MlpPolicy)
bc_trainer = bc.BC(
    observation_space=venv.observation_space,
    action_space=venv.action_space,
    demonstrations=batches[:],
    rng=rng,
    policy=learner.policy,
)


In [69]:
bc_trainer.train(n_epochs=1000)


0batch [00:00, ?batch/s]

--------------------------------
| batch_size        | 32       |
| bc/               |          |
|    batch          | 0        |
|    ent_loss       | -0.00711 |
|    entropy        | 7.11     |
|    epoch          | 0        |
|    l2_loss        | 0        |
|    l2_norm        | 147      |
|    loss           | 3.6      |
|    neglogp        | 3.61     |
|    prob_true_act  | 0.0271   |
|    samples_so_far | 32       |
--------------------------------


494batch [00:01, 256.34batch/s]

--------------------------------
| batch_size        | 32       |
| bc/               |          |
|    batch          | 500      |
|    ent_loss       | -0.0036  |
|    entropy        | 3.6      |
|    epoch          | 125      |
|    l2_loss        | 0        |
|    l2_norm        | 150      |
|    loss           | 0.106    |
|    neglogp        | 0.109    |
|    prob_true_act  | 0.897    |
|    samples_so_far | 16032    |
--------------------------------


978batch [00:03, 287.72batch/s]

---------------------------------
| batch_size        | 32        |
| bc/               |           |
|    batch          | 1000      |
|    ent_loss       | -0.000106 |
|    entropy        | 0.106     |
|    epoch          | 250       |
|    l2_loss        | 0         |
|    l2_norm        | 158       |
|    loss           | -3.38     |
|    neglogp        | -3.38     |
|    prob_true_act  | 29.3      |
|    samples_so_far | 32032     |
---------------------------------


1481batch [00:05, 296.99batch/s]

--------------------------------
| batch_size        | 32       |
| bc/               |          |
|    batch          | 1500     |
|    ent_loss       | 0.00339  |
|    entropy        | -3.39    |
|    epoch          | 375      |
|    l2_loss        | 0        |
|    l2_norm        | 165      |
|    loss           | -6.86    |
|    neglogp        | -6.86    |
|    prob_true_act  | 958      |
|    samples_so_far | 48032    |
--------------------------------


1985batch [00:07, 235.86batch/s]

--------------------------------
| batch_size        | 32       |
| bc/               |          |
|    batch          | 2000     |
|    ent_loss       | 0.00686  |
|    entropy        | -6.86    |
|    epoch          | 500      |
|    l2_loss        | 0        |
|    l2_norm        | 178      |
|    loss           | -10.2    |
|    neglogp        | -10.2    |
|    prob_true_act  | 2.96e+04 |
|    samples_so_far | 64032    |
--------------------------------


2479batch [00:09, 308.27batch/s]

--------------------------------
| batch_size        | 32       |
| bc/               |          |
|    batch          | 2500     |
|    ent_loss       | 0.0103   |
|    entropy        | -10.3    |
|    epoch          | 625      |
|    l2_loss        | 0        |
|    l2_norm        | 193      |
|    loss           | -13.7    |
|    neglogp        | -13.7    |
|    prob_true_act  | 9.42e+05 |
|    samples_so_far | 80032    |
--------------------------------


2972batch [00:10, 276.20batch/s]

--------------------------------
| batch_size        | 32       |
| bc/               |          |
|    batch          | 3000     |
|    ent_loss       | 0.0137   |
|    entropy        | -13.7    |
|    epoch          | 750      |
|    l2_loss        | 0        |
|    l2_norm        | 206      |
|    loss           | -17      |
|    neglogp        | -17.1    |
|    prob_true_act  | 2.64e+07 |
|    samples_so_far | 96032    |
--------------------------------


3494batch [00:12, 305.85batch/s]

--------------------------------
| batch_size        | 32       |
| bc/               |          |
|    batch          | 3500     |
|    ent_loss       | 0.0169   |
|    entropy        | -16.9    |
|    epoch          | 875      |
|    l2_loss        | 0        |
|    l2_norm        | 220      |
|    loss           | -19.2    |
|    neglogp        | -19.2    |
|    prob_true_act  | 5.65e+08 |
|    samples_so_far | 112032   |
--------------------------------


4000batch [00:14, 276.80batch/s]


In [70]:
import torch as th


class OnnxablePolicy(th.nn.Module):
  def __init__(self, extractor, action_net, value_net):
    super().__init__()
    self.extractor = extractor
    self.action_net = action_net
    self.value_net = value_net

  def forward(self, observation):
    # NOTE: You may have to process (normalize) observation in the correct
    #       way before using this. See `common.preprocessing.preprocess_obs`
    action_hidden, value_hidden = self.extractor(observation)
    return self.action_net(action_hidden), self.value_net(value_hidden)


onnxable_model = OnnxablePolicy(
    bc_trainer.policy.mlp_extractor, bc_trainer.policy.action_net, bc_trainer.policy.value_net
)

observation_size = bc_trainer.observation_space.shape
dummy_input = th.randn(*observation_size,)
# dummy_input.device = bc_trainer.policy.device
th.onnx.export(
    onnxable_model,
    th.ones(dummy_input.shape, dtype=th.float32, device=bc_trainer.policy.device),
    "test.onnx",
    opset_version=9,
    input_names=["input"]
)


In [30]:
import onnx
import onnxruntime as ort
import numpy as np

onnx_path = "test.onnx"
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)

ort_sess = ort.InferenceSession(onnx_path)
for idx, i in enumerate(batches[0].obs):
  if idx == len(batches[0].obs) - 1:
    continue
  observations = i.astype(np.float32)
  # action_p, _ = learner.policy.predict(observations, deterministic=True)
  action, value = ort_sess.run(None, {"input": observations})
  print(np.linalg.norm(action - batches[0].acts[idx]))
  # print(np.linalg.norm(action_p - final_res[1].acts[idx]))
  print('-'*100)


0.03027010770383833
----------------------------------------------------------------------------------------------------
0.012832886637383979
----------------------------------------------------------------------------------------------------
0.02168983446296189
----------------------------------------------------------------------------------------------------
0.02759699900934911
----------------------------------------------------------------------------------------------------
0.02558587882249591
----------------------------------------------------------------------------------------------------
0.025872564846956318
----------------------------------------------------------------------------------------------------
0.017833979950414646
----------------------------------------------------------------------------------------------------
0.011649364508439349
----------------------------------------------------------------------------------------------------
0.014342902146064125
-------