In [None]:
import pandas as pd
from ast import literal_eval
import pytz
from imitation.data.types import Trajectory
import numpy as np


data = pd.read_csv('data/data.csv')
data['datetime'] = pd.to_datetime(
    data['timestamp'], unit='ms', utc=True).dt.tz_convert(pytz.timezone('US/Mountain'))
data.drop(['timestamp', 'scene name', 'Unnamed: 19'], axis=1, inplace=True)
data.set_index(['datetime'], inplace=True)
temp = ['rIndex position', 'rIndex rotation', 'rIndex velocity',
        'rIndex angular velocity', 'lIndex position', 'lIndex rotation',
        'lIndex velocity', 'lIndex angular velocity',
        'gaze origin', 'gaze direction',
        'head movement direction', 'head velocity', 'target velocity',
        'target angular velocity', 'target position', 'target rotation']

data[temp] = data[temp].applymap(literal_eval, na_action='ignore')
dataset = pd.DataFrame()
# dataset['Is Eye Tracking Enabled and Valid'] = data['Is Eye Tracking Enabled and Valid'].resample('0.1S').mean().interpolate('time', limit_direction='both', limit=len(data['Is Eye Tracking Enabled and Valid'].index))
for col in temp:
  col_data = data[col]
  dataset[f'{col}_x'] = col_data.apply(lambda x: x if not x == x else x[0]).resample(
      '0.1S').mean().interpolate('time', limit_direction='both', limit=len(col_data.index))
  dataset[f'{col}_y'] = col_data.apply(lambda x: x if not x == x else x[1]).resample(
      '0.1S').mean().interpolate('time', limit_direction='both', limit=len(col_data.index))
  dataset[f'{col}_z'] = col_data.apply(lambda x: x if not x == x else x[2]).resample(
      '0.1S').mean().interpolate('time', limit_direction='both', limit=len(col_data.index))
  if col.endswith('rotation'):
    dataset[f'{col}_w'] = col_data.apply(lambda x: x if not x == x else x[3]).resample(
        '0.1S').mean().interpolate('time', limit_direction='both', limit=len(col_data.index))


events = pd.read_csv('data/events.csv', usecols=['timestamp', 'event'])
events['datetime'] = pd.to_datetime(
    events['timestamp'], unit='ms', utc=True).dt.tz_convert(pytz.timezone('US/Mountain'))
events.set_index(['datetime'], inplace=True)
collision_events = events[events['event'] == 'Left IndexTip']
target_events = events[events['event'].str.match(r'^target')]
target_events.tail()
del events

target_found = target_events[target_events['event'] == 'target_found']
target_lost = target_events[target_events['event'] == 'target_lost']

final_res = []
obs_concated = None
acs_concated = None
for found, row in target_found.iterrows():
  lost = target_lost[target_lost.index > found].iloc[0].name
  mask = ((dataset.index >= found) & (dataset.index <= lost))
  masked = dataset[mask]
  obs = masked.iloc[:, :-7].values
  obs_concated = obs if obs_concated is None else np.concatenate([obs_concated, obs])
  acs = masked.iloc[:, -7:].values  # - masked.iloc[:-1, -7:].values
  acs_concated = acs if acs_concated is None else np.concatenate([acs_concated, acs])
  # Transitions(obs, acs, None, False, None)
  final_res.append(Trajectory(obs, acs[1:], None, True))
  # final_res.append(TransitionsMinimal(obs, acs, np.zeros(shape=obs.shape[:1])))


In [None]:
from imitation.data.types import TransitionsMinimal

demonstrations = TransitionsMinimal(obs_concated, acs_concated, np.zeros_like(obs_concated))

In [None]:
import gym
from gym import spaces
import numpy as np


class CustomEnv(gym.Env):

  def __init__(self):
    super().__init__()
    self.observation_space = spaces.Box(-1, 1, shape=(44,), dtype=np.float32)
    self.action_space = spaces.Box(-1, 1, shape=(7,), dtype=np.float32)


  def reset(self):
    # We need the following line to seed self.np_random
    # super().reset()

    # Choose the agent's location uniformly at random
    self._state = np.random.rand(44).astype('f')

    return self._state

  # def _get_info(self):
  #   return None
  
  # def _get_obs(self):
  #   return self._state

  # def render(self):
  #   pass

  # def _render_frame(self):
  #   pass

  def step(self, action):
    self._state[-7:] = action
    # An episode is done iff the agent has reached the target

    return self._state, 0.0, False, {}


In [None]:
from stable_baselines3.common.vec_env import SubprocVecEnv
from stable_baselines3.common.monitor import Monitor

venv = SubprocVecEnv([lambda: Monitor(CustomEnv())]*4)
# venv = Monitor(CustomEnv())


In [None]:
from imitation.algorithms import bc
from stable_baselines3 import PPO
from stable_baselines3.ppo import MlpPolicy

rng = np.random.default_rng(0)
learner = PPO(env=venv, policy=MlpPolicy)
bc_trainer = bc.BC(
    observation_space=venv.observation_space,
    action_space=venv.action_space,
    demonstrations=demonstrations,
    rng=rng,
    policy=learner.policy,
    batch_size=1024
)


In [15]:
bc_trainer.train(n_epochs=20000)


15499batch [05:12, 45.28batch/s]

--------------------------------
| batch_size        | 1024     |
| bc/               |          |
|    batch          | 15500    |
|    ent_loss       | 0.0211   |
|    entropy        | -21.1    |
|    epoch          | 15500    |
|    l2_loss        | 0        |
|    l2_norm        | 1.74e+03 |
|    loss           | -21.2    |
|    neglogp        | -21.3    |
|    prob_true_act  | 2.46e+10 |
|    samples_so_far | 15873024 |
--------------------------------


15997batch [05:24, 46.18batch/s]

--------------------------------
| batch_size        | 1024     |
| bc/               |          |
|    batch          | 16000    |
|    ent_loss       | 0.0211   |
|    entropy        | -21.1    |
|    epoch          | 16000    |
|    l2_loss        | 0        |
|    l2_norm        | 1.78e+03 |
|    loss           | -21      |
|    neglogp        | -21      |
|    prob_true_act  | 2.59e+10 |
|    samples_so_far | 16385024 |
--------------------------------


16495batch [05:34, 50.80batch/s]

--------------------------------
| batch_size        | 1024     |
| bc/               |          |
|    batch          | 16500    |
|    ent_loss       | 0.0212   |
|    entropy        | -21.2    |
|    epoch          | 16500    |
|    l2_loss        | 0        |
|    l2_norm        | 1.81e+03 |
|    loss           | -21.1    |
|    neglogp        | -21.1    |
|    prob_true_act  | 2.75e+10 |
|    samples_so_far | 16897024 |
--------------------------------


16999batch [05:46, 46.63batch/s]

--------------------------------
| batch_size        | 1024     |
| bc/               |          |
|    batch          | 17000    |
|    ent_loss       | 0.0212   |
|    entropy        | -21.2    |
|    epoch          | 17000    |
|    l2_loss        | 0        |
|    l2_norm        | 1.85e+03 |
|    loss           | -21.3    |
|    neglogp        | -21.3    |
|    prob_true_act  | 2.95e+10 |
|    samples_so_far | 17409024 |
--------------------------------


17500batch [05:57, 47.12batch/s]

--------------------------------
| batch_size        | 1024     |
| bc/               |          |
|    batch          | 17500    |
|    ent_loss       | 0.0212   |
|    entropy        | -21.2    |
|    epoch          | 17500    |
|    l2_loss        | 0        |
|    l2_norm        | 1.88e+03 |
|    loss           | -21.3    |
|    neglogp        | -21.3    |
|    prob_true_act  | 3.02e+10 |
|    samples_so_far | 17921024 |
--------------------------------


18000batch [06:09, 45.68batch/s]

--------------------------------
| batch_size        | 1024     |
| bc/               |          |
|    batch          | 18000    |
|    ent_loss       | 0.0213   |
|    entropy        | -21.3    |
|    epoch          | 18000    |
|    l2_loss        | 0        |
|    l2_norm        | 1.91e+03 |
|    loss           | -21.3    |
|    neglogp        | -21.3    |
|    prob_true_act  | 3.21e+10 |
|    samples_so_far | 18433024 |
--------------------------------


18500batch [06:21, 44.48batch/s]

--------------------------------
| batch_size        | 1024     |
| bc/               |          |
|    batch          | 18500    |
|    ent_loss       | 0.0214   |
|    entropy        | -21.4    |
|    epoch          | 18500    |
|    l2_loss        | 0        |
|    l2_norm        | 1.94e+03 |
|    loss           | -21.3    |
|    neglogp        | -21.3    |
|    prob_true_act  | 2.83e+10 |
|    samples_so_far | 18945024 |
--------------------------------


18996batch [06:32, 47.24batch/s]

--------------------------------
| batch_size        | 1024     |
| bc/               |          |
|    batch          | 19000    |
|    ent_loss       | 0.0214   |
|    entropy        | -21.4    |
|    epoch          | 19000    |
|    l2_loss        | 0        |
|    l2_norm        | 1.97e+03 |
|    loss           | -21      |
|    neglogp        | -21      |
|    prob_true_act  | 2.56e+10 |
|    samples_so_far | 19457024 |
--------------------------------


19499batch [06:45, 29.78batch/s]

--------------------------------
| batch_size        | 1024     |
| bc/               |          |
|    batch          | 19500    |
|    ent_loss       | 0.0215   |
|    entropy        | -21.5    |
|    epoch          | 19500    |
|    l2_loss        | 0        |
|    l2_norm        | 2.01e+03 |
|    loss           | -21.4    |
|    neglogp        | -21.4    |
|    prob_true_act  | 3.53e+10 |
|    samples_so_far | 19969024 |
--------------------------------


20000batch [07:01, 47.45batch/s]


In [None]:
from imitation.algorithms.adversarial.gail import GAIL

from imitation.rewards.reward_nets import BasicRewardNet

reward_net = BasicRewardNet(
    venv.observation_space,
    venv.action_space,
)
gail_trainer = GAIL(
    demonstrations=final_res,
    demo_batch_size=1024,
    venv=venv,
    gen_algo=learner,
    reward_net=reward_net
)


In [None]:
gail_trainer.train(16384*10)


In [16]:
import torch as th


class OnnxablePolicy(th.nn.Module):
  def __init__(self, extractor, action_net, value_net):
    super().__init__()
    self.extractor = extractor
    self.action_net = action_net
    self.value_net = value_net

  def forward(self, observation):
    # NOTE: You may have to process (normalize) observation in the correct
    #       way before using this. See `common.preprocessing.preprocess_obs`
    action_hidden, value_hidden = self.extractor(observation)
    return self.action_net(action_hidden), self.value_net(value_hidden)


onnxable_model = OnnxablePolicy(
    bc_trainer.policy.mlp_extractor, bc_trainer.policy.action_net, bc_trainer.policy.value_net
)

observation_size = bc_trainer.observation_space.shape
dummy_input = th.randn(*observation_size,)
# dummy_input.device = bc_trainer.policy.device
th.onnx.export(
    onnxable_model,
    th.ones(dummy_input.shape, dtype=th.float32, device=bc_trainer.policy.device),
    "test.onnx",
    opset_version=9,
    input_names=["input"]
)


In [17]:
import onnx
import onnxruntime as ort
import numpy as np

onnx_path = "test.onnx"
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)

ort_sess = ort.InferenceSession(onnx_path)
for idx, i in enumerate(final_res[1].obs):
  if idx == len(final_res[1].obs) - 1:
    continue
  observations = i.astype(np.float32)
  action_p, _ = learner.policy.predict(observations, deterministic=True)
  action, value = ort_sess.run(None, {"input": observations})
  print(np.linalg.norm(action - final_res[1].acts[idx]))
  print(np.linalg.norm(action_p - final_res[1].acts[idx]))
  print('-'*100)
  # print(action - final_res[0].acts[idx])

0.01591785708543528
0.015917732723449995
----------------------------------------------------------------------------------------------------
0.01354331136069177
0.013543346388593365
----------------------------------------------------------------------------------------------------
0.013888103119320084
0.01388819170061951
----------------------------------------------------------------------------------------------------
0.015233893234105646
0.0152337880161363
----------------------------------------------------------------------------------------------------
0.043331612495682716
0.043331596431267945
----------------------------------------------------------------------------------------------------
0.04245482579786464
0.04245493286528565
----------------------------------------------------------------------------------------------------
0.0310392022230322
0.031039152642724457
----------------------------------------------------------------------------------------------------
0.024468