In [23]:
"""Jittable abstract base class based on Gymnax Environments."""

import jax
import numpy as np
import jax.numpy as jnp
from gymnax.environments import environment, spaces
from typing import Tuple, Optional
import chex
from flax import struct
from jax import vmap
from params import EnvParams
from jax.tree_util import tree_flatten, tree_unflatten

In [24]:
@struct.dataclass
class EnvState:
    damage_state: jnp.array
    observation: jnp.array
    # belief: jnp.array
    base_travel_time: jnp.array
    capacity: jnp.array
    timestep: int

In [25]:
class RoadEnvironment(environment.Environment):

    """

    JAX implementation of the Road Environment.

    """

    def __init__(self):
        super().__init__()

    def _get_next(
        self, key: chex.PRNGKey, dam_state: int, action: int, table: jnp.array
    ) -> int:
        # sample
        next_dam_state = jax.random.choice(key, 4, p=table[action, dam_state])

        return next_dam_state

    def _vmap_get_next(self):
        # get next state or observation for all segments
        return vmap(self._get_next, in_axes=(0, 0, 0, None))

    def _get_maintenance_reward(
        self, dam_state: int, action: int, rewards_table: jnp.array
    ) -> float:
        return rewards_table[action, dam_state]

    def _vmap_get_maintenance_reward(self):
        return vmap(self._get_maintenance_reward, in_axes=(0, 0, None))

    def calculate_bpr_travel_time(volume: int, capacity: int, base_time: float, alpha: float, beta:int):
        return base_time * (1 + alpha * (volume / capacity)**beta)

    def _vmap_calculate_bpr_travel_time(self):
        return vmap(self.calculate_bpr_travel_time, in_axes=(0, 0, 0, None, None))

    def compute_edge_base_travel_time(self, state: EnvState):
        # map segments to edges, gather base travel times and sum over segments
        return self._gather(state.base_travel_time).sum(axis=0)

    def compute_edge_travel_time(self, state: EnvState, edge_volumes: jnp.array, params: EnvParams):

        # get edge base travel times
        edge_base_travel_time = self.compute_edge_base_travel_time(state)

        # get edge travel times
        edge_travel_times = self._vmap_calculate_bpr_travel_time()(
            edge_volumes, state.capacity, edge_base_travel_time, params.traffic_alpha, params.traffic_beta
        )

        return edge_travel_times

    def _get_shortest_paths(self, state, action, params):
        #! Cannot use igraph in JAX if we want to jit this function
        # update edge volumes
        pass

    def _get_total_travel_time(self, state, action, params):

        # 0.1 get edge volumes: Initialize with all-or-nothing assignment

        # 0.2 get edge travel times

        # repeat until convergence

            # 1. Recalculate travel times with current volumes

            # 2. Find the shortest paths using updated travel times
            #    (recalculates edge volumes)

            # 3. Check for convergence by comparing volume changes

        return 0.0

    def step_env(
        self, keys: chex.PRNGKey, state: EnvState, action: jnp.array, params: EnvParams
    ) -> Tuple[chex.Array, list, float, bool, dict]:
        # next state
        next_state = self._vmap_get_next()(
            keys, state.damage_state, action, params.deterioration_table
        )

        # observation
        obs = self._vmap_get_next()(keys, next_state, action, params.observation_table)

        # maintenance reward
        maintenance_reward = self._vmap_get_maintenance_reward()(
            state.damage_state, action, params.rewards_table
        ).sum()

        # TODO: belief update

        base_travel_time = params.btt_table[action, state.damage_state]
        capacity = params.capacity_table[action, state.damage_state]

        # TODO: traffic assignment (returning 0.0 for now)
        total_travel_time = self._get_total_travel_time(state, action, params)
        travel_time_reward = params.travel_time_reward_factor * total_travel_time

        # reward
        reward = maintenance_reward + travel_time_reward

        next_state = EnvState(
            damage_state=next_state,
            observation=obs,
            base_travel_time=base_travel_time,
            capacity=capacity,
            timestep=state.timestep + 1,
        )

        # done
        done = self.is_terminal(next_state, params)

        # info
        info = {
            "total_travel_time": total_travel_time,
            "maintenance_reward": maintenance_reward,
            "travel_time_reward": travel_time_reward,
        }

        return obs, reward, done, info, next_state

    def reset_env(self, key: chex.PRNGKey, params: EnvParams) -> Tuple[chex.Array, EnvState]:
        damage_state = [
            {"0": [0, 1]},
            {"1": [0, 3, 0]},
            {"2": [1, 1, 2]},
            {"3": [3, 1, 2]},
        ]

        # flatten pytree and convert to jnp.array
        damage_state = jnp.array(
            jax.tree_util.tree_leaves(damage_state), dtype=jnp.uint8
        )

        # initial base travel times (using pytree)
        initial_btt = jnp.ones(params.total_num_segments) * params.btt_table[0, 0]
        # initial capacity
        initial_capacity = (
            jnp.ones(params.total_num_segments) * params.capacity_table[0, 0]
        )

        env_state = EnvState(
            damage_state=damage_state,
            observation=damage_state,
            base_travel_time=initial_btt,
            capacity=initial_capacity,
            timestep=0,
        )

        return self.get_obs(env_state), env_state

    def _to_pytree(self, x: jnp.array):
        # example pytree
        py_tree = [
            {"0": np.array([0, 1], dtype=np.uint8)},
            {"1": np.array([0, 3, 0], dtype=np.uint8)}, 
            {"2": np.array([1, 1, 2], dtype=np.uint8)}, 
            {"3": np.array([3, 1, 2], dtype=np.uint8)}, 
        ]

        # flatten pytree
        _, treedef = tree_flatten(py_tree)

        # put x into pytree
        py_tree = tree_unflatten(treedef, x.tolist())

        return py_tree
    
    def _gather(self, x: jnp.array):

        # map segment indices to edge indices
        # and pad gathered values with 0.0
        # to make get equal number of columns
        # TODO: precompute idxs_map

        # map of segment indices to segment ids
        idxs_map = jnp.array([[100, 0, 1],
                            [2, 3, 4],
                            [5, 6, 7],
                            [8, 9, 10],
                            ])

        return jnp.take(x, idxs_map, fill_value=0.0)

    def get_obs(self, state: EnvState) -> chex.Array:
        return state.observation

    def is_terminal(self, state: EnvParams, params: EnvParams) -> bool:
        return state.timestep >= params.max_timesteps

    def action_space(self, params: Optional[EnvParams] = None) -> spaces.Discrete:
        pass

    def state_space(self, params: EnvParams):
        pass

    def observation_space(self, params: Optional[EnvParams] = None) -> spaces.Discrete:
        pass

    def belief_space(self, params: EnvParams):
        pass

    @property
    def name(self) -> str:
        pass

    @property
    def num_actions(self) -> int:
        pass

In [26]:
params = EnvParams()
env = RoadEnvironment()

_action = [{"0": [0, 0]}, {"1": [0, 0, 0]}, {"2": [0, 0, 0]}, {"3": [0, 0, 0]}]
__action = jax.tree_util.tree_leaves(_action)
action = jnp.array(__action, dtype=jnp.uint8)

# TODO: function to convert the above to the following

key = jax.random.PRNGKey(442)
keys = jax.random.split(key, params.total_num_segments)  # keys for all segments

# reset
obs, state = env.reset_env(key, params)

total_rewards = 0.0

for _ in range(params.max_timesteps):
    # step
    obs, reward, done, info, state = env.step_env(keys, state, action, params)

    # update total rewards
    total_rewards += reward

In [27]:
print(state)
print(obs)
print(reward)
print(total_rewards)

EnvState(damage_state=Array([0, 1, 0, 3, 0, 1, 3, 2, 3, 1, 2], dtype=int32), observation=Array([1, 1, 0, 3, 1, 2, 3, 1, 3, 0, 0], dtype=int32), base_travel_time=Array([150., 165., 150., 240., 150., 165., 240., 210., 240., 165., 210.],      dtype=float32), capacity=Array([500., 500., 500., 500., 500., 500., 500., 500., 500., 500., 500.],      dtype=float32), timestep=50)
[1 1 0 3 1 2 3 1 3 0 0]
-493.0
-24371.0


In [28]:
state.base_travel_time

Array([150., 165., 150., 240., 150., 165., 240., 210., 240., 165., 210.],      dtype=float32)

In [29]:
jit_step_env = jax.jit(env.step_env)

%timeit env.step_env(keys, state, action, params)

%timeit jit_step_env(keys, state, action, params)

8.58 ms ± 59.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
32 µs ± 470 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
