In [5]:
import pandas as pd
from ast import literal_eval
import pytz
from imitation.data.types import Trajectory, TransitionsMinimal
import numpy as np

data = pd.read_csv('data/obs_acs.csv')
data['datetime'] = pd.to_datetime(
    data['timestamp'], unit='ms', utc=True).dt.tz_convert(pytz.timezone('US/Mountain'))
data.set_index(['datetime'], inplace=True)
data.drop(['timestamp'], axis=1, inplace=True)
data[data.columns] = data[data.columns].applymap(literal_eval, na_action='ignore')
dataset = pd.DataFrame()

for col in data.columns:
  col_data = data[col]
  dataset[f'{col}_x'] = col_data.apply(lambda x: x if not x == x else x[0])
  dataset[f'{col}_y'] = col_data.apply(lambda x: x if not x == x else x[1])
  dataset[f'{col}_z'] = col_data.apply(lambda x: x if not x == x else x[2])
  if col.endswith('rot'):
    dataset[f'{col}_w'] = col_data.apply(lambda x: x if not x == x else x[3])

prev_idx = dataset.index[0]
batches = []

for i in dataset[dataset['act_pos_x'].isna()].index:
  if len(batches) == 0:
    mask = (dataset.index >= prev_idx) & (dataset.index <= i)
  else:
    mask = (dataset.index > prev_idx) & (dataset.index <= i)
  prev_idx = i
  obs = dataset[mask].values[:,:7].astype(np.float32)
  acs = dataset[mask].values[:,7:].astype(np.float32)[:-1,:]
  batches.append(Trajectory(obs, acs, None, False))

transition_minimal = TransitionsMinimal(dataset.dropna().values[:,:7], dataset.dropna().values[:,7:], np.zeros(shape=(dataset.dropna().values.shape[0],)))


In [6]:
import gym
from gym import spaces
import numpy as np
from scipy.spatial.transform import Rotation as R

class RPMEnv(gym.Env):
  def __init__(self, rng):
    super().__init__()
    self.observation_space = spaces.Box(-1, 1, shape=(7,), dtype=np.float32)
    self.action_space = spaces.Box(-1, 1, shape=(7,), dtype=np.float32)
    self.rng = rng
    self.counter = 0

  def get_matrix(self, flat: np.array):
    pos = flat[:3]
    rot = flat[3:]
    rot_matrix = R.from_quat(rot).as_matrix()
    res = np.zeros(shape=(4, 4), dtype=np.float32)
    res[:3, :3] = rot_matrix
    res[:3, 3] = pos
    res[3, 3] = 1
    return res

  def destruct_matrix(self, matrix: np.array):
    res = np.zeros(shape=(7,), dtype=np.float32)
    res[:3] = matrix[:3, 3]
    rot = R.from_matrix(matrix[:3, :3]).as_quat();
    res[3:] = rot
    return res

  def reset(self):
    self._state = (self.rng.random(7).astype('f') -0.5) * 2
    self.counter = 0
    self._info = self.get_matrix(self._state)

    return self._state


  def step(self, action):
    act_matrix = self.get_matrix(action)
    self.counter += 1
    try:
      inversed = np.linalg.inv(act_matrix)
      self._info = np.matmul(self._info, inversed)
      self._state = self.destruct_matrix(self._info)
      return self._state, 0.0, self.counter > 10, {}
    except:
      print("singular matrix")
      return self._state, 0.0, True, {}


In [11]:
from stable_baselines3.common.vec_env import SubprocVecEnv, DummyVecEnv
from stable_baselines3.common.monitor import Monitor
rng = np.random.default_rng(0)
try:
  venv.close()
except:
  pass
if False:
  venv = DummyVecEnv([lambda: Monitor(RPMEnv(rng))]*1)
else:
  venv = SubprocVecEnv([lambda: Monitor(RPMEnv(rng))]*1)


In [8]:
from imitation.algorithms import bc
from stable_baselines3 import PPO
from stable_baselines3.ppo import MlpPolicy

learner = PPO(env=venv, policy=MlpPolicy)
bc_trainer = bc.BC(
    observation_space=venv.observation_space,
    action_space=venv.action_space,
    demonstrations=batches[:],
    rng=rng,
    policy=learner.policy,
)


In [None]:
bc_trainer.train(n_epochs=100)


In [None]:
max_batch_len = 0
total_len = 0
for i in batches:
  max_batch_len = max(max_batch_len, len(i))
  total_len += len(i)

In [9]:
from imitation.algorithms.adversarial.gail import GAIL
from imitation.rewards.reward_nets import BasicRewardNet

reward_net = BasicRewardNet(
    venv.observation_space,
    venv.action_space,
)
gail_trainer = GAIL(
    demonstrations=transition_minimal,
    demo_batch_size=32,
    venv=venv,
    gen_algo=learner,
    reward_net=reward_net
)


In [10]:
gail_trainer.train(2048)


round:   0%|          | 0/1 [00:02<?, ?it/s]


EOFError: 

In [None]:
import torch as th


class OnnxablePolicy(th.nn.Module):
  def __init__(self, extractor, action_net, value_net):
    super().__init__()
    self.extractor = extractor
    self.action_net = action_net
    self.value_net = value_net

  def forward(self, observation):
    # NOTE: You may have to process (normalize) observation in the correct
    #       way before using this. See `common.preprocessing.preprocess_obs`
    action_hidden, value_hidden = self.extractor(observation)
    return self.action_net(action_hidden), self.value_net(value_hidden)


onnxable_model = OnnxablePolicy(
    bc_trainer.policy.mlp_extractor, bc_trainer.policy.action_net, bc_trainer.policy.value_net
)

observation_size = bc_trainer.observation_space.shape
dummy_input = th.randn(*observation_size,)
# dummy_input.device = bc_trainer.policy.device
th.onnx.export(
    onnxable_model,
    th.ones(dummy_input.shape, dtype=th.float32, device=bc_trainer.policy.device),
    "test.onnx",
    opset_version=15,
    input_names=["input"]
)


In [None]:
import onnx
import onnxruntime as ort
import numpy as np

onnx_path = "test.onnx"
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)

ort_sess = ort.InferenceSession(onnx_path)
for idx, i in enumerate(batches[0].obs):
  if idx == len(batches[0].obs) - 1:
    continue
  observations = i.astype(np.float32)
  action, value = ort_sess.run(None, {"input": observations})
  print(np.linalg.norm(action - batches[0].acts[idx]))
  print('-'*100)
