In [None]:
import pandas as pd
from ast import literal_eval
import pytz
from imitation.data.types import Trajectory

data = pd.read_csv('data/data.csv')
data['datetime'] = pd.to_datetime(
    data['timestamp'], unit='ms', utc=True).dt.tz_convert(pytz.timezone('US/Mountain'))
data.drop(['timestamp', 'scene name', 'Unnamed: 19'], axis=1, inplace=True)
data.set_index(['datetime'], inplace=True)
temp = ['gaze origin', 'target position']

data[temp] = data[temp].applymap(literal_eval, na_action='ignore')
dataset = pd.DataFrame()
# dataset['Is Eye Tracking Enabled and Valid'] = data['Is Eye Tracking Enabled and Valid'].resample('0.1S').mean().interpolate('time', limit_direction='both', limit=len(data['Is Eye Tracking Enabled and Valid'].index))
for col in temp:
  col_data = data[col]
  dataset[f'{col}_x'] = col_data.apply(lambda x: x if not x == x else x[0]).resample(
      '0.1S').mean().interpolate('time', limit_direction='both', limit=len(col_data.index))
  dataset[f'{col}_y'] = col_data.apply(lambda x: x if not x == x else x[1]).resample(
      '0.1S').mean().interpolate('time', limit_direction='both', limit=len(col_data.index))
  dataset[f'{col}_z'] = col_data.apply(lambda x: x if not x == x else x[2]).resample(
      '0.1S').mean().interpolate('time', limit_direction='both', limit=len(col_data.index))
  if col.endswith('rotation'):
    dataset[f'{col}_w'] = col_data.apply(lambda x: x if not x == x else x[3]).resample(
        '0.1S').mean().interpolate('time', limit_direction='both', limit=len(col_data.index))


events = pd.read_csv('data/events.csv', usecols=['timestamp', 'event'])
events['datetime'] = pd.to_datetime(
    events['timestamp'], unit='ms', utc=True).dt.tz_convert(pytz.timezone('US/Mountain'))
events.set_index(['datetime'], inplace=True)
collision_events = events[events['event'] == 'Left IndexTip']
target_events = events[events['event'].str.match(r'^target')]
target_events.tail()
del events

target_found = target_events[target_events['event'] == 'target_found']
target_lost = target_events[target_events['event'] == 'target_lost']

final_res = []
for found, row in target_found.iterrows():
  lost = target_lost[target_lost.index > found].iloc[0].name
  mask = ((dataset.index >= found) & (dataset.index <= lost))
  masked = dataset[mask]
  obs = masked.iloc[:, :].values
  obs[:, :3] = obs[:, :3] - obs[:, -3:]  # gaze
  obs = obs[:, :-3]  # removing target position
  acs = masked.iloc[:, -3:].values
  acs = acs[1:, :] - acs[:-1, :]
  # for idx, item in enumerate(acs):
  #   v = directions[idx]
  #   v_norm = np.sqrt(sum(v**2))
  #   if (v_norm == 0):
  #     print('wierd')
  #     acs[idx] = v
  #     continue
  #   u = item
  #   acs[idx] = (np.dot(u, v)/v_norm**2)*v
  final_res.append(Trajectory(obs, acs, None, True))


In [None]:
import gym
from gym import spaces
import numpy as np


class CustomEnv(gym.Env):

  def __init__(self, rng):
    super().__init__()
    self.observation_space = spaces.Box(-1, 1, shape=(3,), dtype=np.float32)
    self.action_space = spaces.Box(-1, 1, shape=(3,), dtype=np.float32)
    self.rng = rng

  def reset(self):
    # We need the following line to seed self.np_random
    # super().reset()

    # Choose the agent's location uniformly at random
    self._state = (self.rng.random(3).astype('f') -0.5) * 2

    return self._state

  # def _get_info(self):
  #   return None
  
  # def _get_obs(self):
  #   return self._state

  # def render(self):
  #   pass

  # def _render_frame(self):
  #   pass

  def step(self, action):
    self._state[:3] -= action
    # An episode is done iff the agent has reached the target

    return self._state, 0.0, False, {}


In [None]:
from stable_baselines3.common.vec_env import SubprocVecEnv
from stable_baselines3.common.monitor import Monitor
rng = np.random.default_rng(12345)
venv = SubprocVecEnv([lambda: Monitor(CustomEnv(rng))]*16)
# venv = Monitor(CustomEnv())


In [None]:
from imitation.algorithms import bc
from stable_baselines3 import PPO
from stable_baselines3.ppo import MlpPolicy

learner = PPO(env=venv, policy=MlpPolicy)
bc_trainer = bc.BC(
    observation_space=venv.observation_space,
    action_space=venv.action_space,
    demonstrations=final_res,
    rng=rng,
    policy=learner.policy,
)


In [None]:
bc_trainer.train(n_epochs=1000)


In [None]:
from imitation.algorithms.adversarial.gail import GAIL
from imitation.algorithms.adversarial.airl import AIRL

from imitation.rewards.reward_nets import BasicRewardNet

reward_net = BasicRewardNet(
    venv.observation_space,
    venv.action_space,
)
gail_trainer = AIRL(
    demonstrations=final_res,
    demo_batch_size=1024,
    venv=venv,
    gen_algo=learner,
    reward_net=reward_net
)


In [None]:
gail_trainer.train(32768*10)


In [None]:
import torch as th


class OnnxablePolicy(th.nn.Module):
  def __init__(self, extractor, action_net, value_net):
    super().__init__()
    self.extractor = extractor
    self.action_net = action_net
    self.value_net = value_net

  def forward(self, observation):
    # NOTE: You may have to process (normalize) observation in the correct
    #       way before using this. See `common.preprocessing.preprocess_obs`
    action_hidden, value_hidden = self.extractor(observation)
    return self.action_net(action_hidden), self.value_net(value_hidden)


onnxable_model = OnnxablePolicy(
    bc_trainer.policy.mlp_extractor, bc_trainer.policy.action_net, bc_trainer.policy.value_net
)

observation_size = bc_trainer.observation_space.shape
dummy_input = th.randn(*observation_size,)
# dummy_input.device = bc_trainer.policy.device
th.onnx.export(
    onnxable_model,
    th.ones(dummy_input.shape, dtype=th.float32, device=bc_trainer.policy.device),
    "test.onnx",
    opset_version=9,
    input_names=["input"]
)


In [None]:
import onnx
import onnxruntime as ort
import numpy as np

onnx_path = "test.onnx"
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)

ort_sess = ort.InferenceSession(onnx_path)
for idx, i in enumerate(final_res[3].obs):
  if idx == len(final_res[3].obs) - 1:
    continue
  observations = i.astype(np.float32)
  # action_p, _ = learner.policy.predict(observations, deterministic=True)
  action, value = ort_sess.run(None, {"input": observations})
  print(action)
  print(np.linalg.norm(action - final_res[3].acts[idx]))
  # print(np.linalg.norm(action_p - final_res[1].acts[idx]))
  print('-'*100)
