In [1]:
import numpy as np
import dataclasses
import enact

def normalize_angle(theta: np.ndarray) -> np.ndarray:
  return (theta + np.pi) % (2 * np.pi) - np.pi

@enact.register
@dataclasses.dataclass
class Action(enact.Resource):
  array: np.ndarray = np.zeros((1, 2))

  @property
  def torque(self) -> np.ndarray:
    return self.array[..., 0]

  @property
  def thrust(self) -> np.ndarray:
    return self.array[..., 1]

@enact.register
@dataclasses.dataclass
class State(enact.Resource):
  """Represents a game state.

  The state is represented internally as a numpy array of length 10.
  State components can be accessed with setters and getters:
  * Agent position
  * Agent orientation [-pi, pi]
  * Agent velocity
  * Agent angular speed
  * Goal position
  * Has been at goal

  A state object may represent an arbitrary batch of states.
  """

  # Indices and slices into the underlying arrays.
  POSITION_SLICE = slice(0, 2)
  ROTATION_INDEX = 2
  VELOCITY_SLICE = slice(3, 5)
  ANGULAR_VELOCITY_INDEX = 5
  GOAL1_POSITION_SLICE = slice(6, 8)
  HAS_BEEN_AT_GOAL1_INDEX = 8
  GOAL2_POSITION_SLICE = slice(9, 11)
  HAS_BEEN_AT_GOAL2_INDEX = 11
  # Total dimensionality of the state space array.
  ARRAY_SIZE = 12

  # Game configuration.
  BOARD_SIZE = 25
  REACHED_EPSILON = 1
  TORQUE_FORCE = 0.5
  THRUST_FORCE = 0.5
  MAX_SPEED = 2
  FRICTION_COEFFICIENT = 0.99

  array: np.ndarray = np.zeros((1, ARRAY_SIZE))

  def randomize(self):
    batch_shape = self.array.shape[:-1]
    self.position = np.random.uniform(0, State.BOARD_SIZE, batch_shape + (2,))
    self.rotation = np.random.uniform(-np.pi, np.pi, batch_shape)
    self.goal1_position = np.random.uniform(
      0, State.BOARD_SIZE, batch_shape + (2,))
    self.goal2_position = np.random.uniform(
      0, State.BOARD_SIZE, batch_shape + (2,))
    self.success = np.zeros_like(self.rotation)

  @property
  def position(self) -> np.ndarray:
    """Return the position of the agent."""
    return self.array[..., State.POSITION_SLICE]

  @position.setter
  def position(self, value: np.ndarray):
    """Set the position of the agent."""
    self.array[..., State.POSITION_SLICE] = value

  @property
  def rotation(self) -> np.ndarray:
    """Return the rotation of the agent. Positive is counter-clockwise."""
    return self.array[..., State.ROTATION_INDEX]

  @rotation.setter
  def rotation(self, value: np.ndarray):
    """Set the rotation of the agent (normalized to [-pi, pi])."""
    self.array[..., State.ROTATION_INDEX] = normalize_angle(value)

  @property
  def velocity(self) -> np.ndarray:
    """Two dimensional velocity vector of the agent."""
    return self.array[..., State.VELOCITY_SLICE]

  @velocity.setter
  def velocity(self, value: np.ndarray):
    """Set the velocity."""
    self.array[..., State.VELOCITY_SLICE] = value

  @property
  def angular_velocity(self) -> np.ndarray:
    """The angular velocity of the agent."""
    return self.array[..., State.ANGULAR_VELOCITY_INDEX]

  @angular_velocity.setter
  def angular_velocity(self, value: np.ndarray):
    """Set the angular velocity of the agent."""
    self.array[..., State.ANGULAR_VELOCITY_INDEX] = value

  @property
  def goal1_position(self) -> np.ndarray:
    """The position of the goal."""
    return self.array[..., State.GOAL1_POSITION_SLICE]

  @goal1_position.setter
  def goal1_position(self, value: np.ndarray):
    """Set the position of the goal."""
    self.array[..., State.GOAL1_POSITION_SLICE] = value

  @property
  def has_been_at_goal1(self) -> np.ndarray:
    """Whether the agent has been at the goal."""
    return self.array[..., State.HAS_BEEN_AT_GOAL1_INDEX]

  @has_been_at_goal1.setter
  def has_been_at_goal1(self, value: np.ndarray):
    """Set whether the agent has been at the goal."""
    self.array[..., State.HAS_BEEN_AT_GOAL1_INDEX] = value

  @property
  def goal2_position(self) -> np.ndarray:
    """The position of the goal."""
    return self.array[..., State.GOAL2_POSITION_SLICE]

  @goal2_position.setter
  def goal2_position(self, value: np.ndarray):
    """Set the position of the goal."""
    self.array[..., State.GOAL2_POSITION_SLICE] = value

  @property
  def has_been_at_goal2(self) -> np.ndarray:
    """Whether the agent has been at the goal."""
    return self.array[..., State.HAS_BEEN_AT_GOAL2_INDEX]

  @has_been_at_goal2.setter
  def has_been_at_goal2(self, value: np.ndarray):
    """Set whether the agent has been at the goal."""
    self.array[..., State.HAS_BEEN_AT_GOAL2_INDEX] = value

  def forward(self, offset: float=0) -> np.ndarray:
    """The unit vector pointing forward relative to the agent."""
    return np.stack([np.cos(self.rotation + offset),
                     np.sin(self.rotation + offset)], axis=-1)

  def right(self, offset: float=0) -> np.ndarray:
    """The unit vector pointing right relative to the agent."""
    return self.forward(offset - np.pi / 2)

  def at_goal1(self) -> np.ndarray:
    """Whether the agent is currently at the goal."""
    return np.linalg.norm(
      self.position - self.goal1_position, axis=-1) < State.REACHED_EPSILON

  def at_goal2(self) -> np.ndarray:
    """Whether the agent is currently at the goal."""
    return np.linalg.norm(
      self.position - self.goal2_position, axis=-1) < State.REACHED_EPSILON

In [2]:
import abc
import asyncio

import enact

class GameAPI(abc.ABC):
  """API for the policy to interact with the game."""

  @abc.abstractmethod
  async def observations(self):
    """Used by the policy to read observations."""
 
  @abc.abstractmethod
  async def step_game(self, actions):
    """Step the game with the indicated actions."""

   
class PolicyAPI(abc.ABC):
  """API for the game to interact with the policy."""

  @abc.abstractmethod
  async def actions(self):
    """Used by the game to read actions from the policy."""
 
  @abc.abstractmethod
  async def step_policy(self, observations):
    """Step the policy with the indicated observations."""

  @abc.abstractmethod
  async def initialize(self, observations):
    """Initialize the game with a set of observations."""

  @abc.abstractmethod
  async def end(self):
    """Indicates the game has ended."""
 
@enact.register
class GameEnded(enact.ExceptionResource):
  pass

class JointAPI(GameAPI, PolicyAPI):
  def __init__(self):
    self._observations = None
    self._actions = None
    self._observations_event = asyncio.Event()
    self._actions_event = asyncio.Event()
  
  async def observations(self):
    """Used by the policy to read observations."""
    await self._observations_event.wait()
    return self._observations
 
  async def actions(self):
    """Used by the game to read actions."""
    await self._actions_event.wait()
    return self._actions
 
  async def step_policy(self, observations):
    await self._actions_event.wait()
    self._actions_event.clear()
    self._observations = observations
    self._observations_event.set()
 
  async def step_game(self, action_arr):
    await self._observations_event.wait()
    self._observations_event.clear()
    self._actions = action_arr
    self._actions_event.set()
 
  def initialize(self, observations):
    self._observations = observations
    self._observations_event.set()

  def end(self):
    self._observations = None
    self._observations_event.set()


class Game:

  def __init__(self, api: PolicyAPI, observation: State):
    self._api = api
    self._observation = observation

  async def dynamics(self, actions: Action) -> State:
    dt = 0.1
    next_state = State(np.copy(self._observation.array))
    # Update velocity by applying acceleration forces.
    torque = actions.torque
    thrust = actions.thrust
    thrust = np.clip(thrust, 0, 1)
    torque = np.clip(torque, -1, 1)
    next_state.velocity += (
        dt * State.THRUST_FORCE *
        self._observation.forward() * np.expand_dims(thrust, axis=-1))
    next_state.angular_velocity += dt * State.TORQUE_FORCE * torque

    # Clamp max speed and apply friction.
    speed = np.clip(
        np.expand_dims(np.linalg.norm(next_state.velocity, axis=-1), -1),
        1e-10, np.infty)

    at_maxed_out_speed = (next_state.velocity / speed) * State.MAX_SPEED
    next_state.velocity = np.where(
        speed > State.MAX_SPEED, at_maxed_out_speed, next_state.velocity)

    next_state.velocity *= State.FRICTION_COEFFICIENT
    next_state.angular_velocity *= State.FRICTION_COEFFICIENT

    # Update position by applying velocity.
    next_state.position += dt * next_state.velocity
    next_state.rotation += dt * next_state.angular_velocity

    next_state.has_been_at_goal1 += next_state.at_goal1()
    next_state.has_been_at_goal1 = np.clip(next_state.has_been_at_goal1, 0, 1)

    next_state.has_been_at_goal2 += next_state.at_goal2()
    next_state.has_been_at_goal2 = np.clip(next_state.has_been_at_goal2, 0, 1)

    return next_state

  async def run(self, num_frames: int = 100):
    self._api.initialize(self._observation)
    for _ in range(num_frames):
      actions = await self._api.actions()
      # Dynamics update
      self._observation = await self.dynamics(actions)
      await self._api.step_policy(self._observation)
    self._api.end()
    

@enact.typed_invokable(Action, enact.NoneResource)
class TrackedActions(enact.AsyncInvokable):
  def __init__(self, api: GameAPI):
    self._api = api

  async def call(self, actions: Action):
    await self._api.step_game(actions)

  
@enact.typed_invokable(enact.NoneResource, enact.NoneResource)
class Policy(enact.AsyncInvokable):

  def __init__(self, game_api: GameAPI):
    self._api = game_api
    self._tracked_actions = TrackedActions(api)
  
  async def call(self):
    idx = 0
    while True:
      print(f'Loop index: {idx}')
      for _ in range(10):
        if await self._api.observations() is None:
          return  # UGLY
        await self._api.step_game(Action(np.array([[1., 1.]])))
        print(f'Stepped game #1')
      print(f'Finished Loop #1')
      for _ in range(10):
        print(f'In Loop #2')
        if await self._api.observations() is None:
          return
        print(f'  Valid observations.')
        await self._tracked_actions(Action(np.array([[0., 0.]])))
        print(f'Stepped game #2')
      return


api = JointAPI()
game = Game(api, State())
policy = Policy(api)

loop = asyncio.get_running_loop()
g = loop.create_task(game.run())
with enact.InMemoryStore() as store:
  p = loop.create_task(policy.invoke())
  _, invocation = await asyncio.gather(g, p)
  print(store.get_current())
print(store.get_current())
with store:
  enact.pprint(invocation)

Loop index: 0
Stepped game #1
Stepped game #1
Stepped game #1
Stepped game #1
Stepped game #1
Stepped game #1
Stepped game #1
Stepped game #1
Stepped game #1
Stepped game #1
Finished Loop #1
In Loop #2
  Valid observations.
