### Tutorial Env: Defining Optimization Environments for Functions

The `function_interface.ipynb` illustrated how one can create new test functions for evaluating or training algorithms. However, in many practical settings, we don't directly access the function itself during optimization. The optimized function is wrapped around in an arbitrary complex environment that we iteratively poll for evaluation of input queries and receiving its feedback. In essence, the optimized function is part of a larger complex environment. Usually, optimization algorithms like: Bandits, or Evolutionary Algorithms, pre-process alot of this complexity away for the agent. As ambition grows to more complex environments, manual pre-processing can become unsustainable, and more sophisticated methods might be needed.

This notebook illustrates a number of Environments we have implemented using the `dm-env`/ `jumanji` Environment API, that wraps complexity around a `Function` instance to reflect the complexities of real-world Black-Box Optimization. These environments are intended to benchmark and test existing or new Black-Box Optimization algorithms in more than just a Bandit setting.

This optional `env` sub-module to `bbox` can be installed with the command:
```bash
python -m pip install bbox[env]
```
Which will jointly install `jumanji` along with other dependencies.

In [2]:
import jax
from jax import numpy as jnp

import matplotlib.pyplot as plt

from bbox import procgen as px
from bbox import env


key = jax.random.PRNGKey(123)


# Environment instantiation closely follows the core Function API.
rbf = lambda a, b: jnp.exp(-0.5 * jnp.sum(jnp.square(a - b)))
my_env = env.as_bandit(
    base=px.real.GaussianProcessPrior.partial(kernel=rbf, resolution=10),
    wrappers=None,
    dummy_x=jnp.zeros(2)
)

print(repr(my_env))

Bandit(function=PartialGaussianProcessPrior,in_spec=BoundedArray(name=inputs,shape=(2,),dtype=float32,minimum=-1.0,maximum=1.0),out_spec=Array(name=outputs,shape=(),dtype=float32),use_reward_as_observation=True)


As shown below, the Environment API follows that of `dm-env`/ `jumanji` and enables a conventional agent-environment loop for Reinforcement Learning:

In [3]:
from jit_env._core import TimeStep
from jit_env import Action


def run_experiment(_key, _env) -> list[(Action, TimeStep)]:
    state, step = _env.reset(_key)

    steps = [(jnp.zeros(2), step)]
    for _ in range(10):
        _key, key_policy = jax.random.split(_key)

        # Agent-Environment Interaction
        action = jax.random.uniform(key_policy, (2,), minval=-1, maxval=1)
        state, step = _env.step(state, action)

        steps.append((action, step))
        
    return steps
    
    
data = run_experiment(key, my_env)

print('action:', data[-1][0])
print('timestep:', jax.tree_map(float, data[-1][1]))

# >> action: [ 0.4381199  -0.10669184]
# >> timestep: TimeStep(step_type=DeviceArray(1, dtype=int8), reward=DeviceArray(-0.17889404, dtype=float32), 
#  discount=DeviceArray(1., dtype=float32), observation=DeviceArray(-0.17889404, dtype=float32), extras={})

action: [ 0.4381199  -0.10669184]
timestep: TimeStep(step_type=1.0, reward=0.10252925753593445, discount=1.0, observation=0.10252925753593445, extras={})


### Environment State

The environment state follows a simple layout that allows the branching of a Random Key, as required by the `Generic[State]` protocol from `jumanji.State` as well as tracking the base system time count and wrapper + base data. New Wrappers should not modify `EnvState` explicitly, but can implement their own internal state datastructure and registering that in the `data` attribute of `EnvState` through the class instance `hash` or `repr`.

In [4]:
print('Example Unwrapped Environment State:', repr(my_env))
state, step = my_env.reset(key)

for k, v in state.items():
    if k == 'data':
        print(k, type(v), end='\n\n')
        for wk, wv in v.items():
            print(wk, jax.tree_map(jnp.shape, wv), end='\n\n')
    else: 
        print(k, jax.tree_map(jnp.shape, v))

Example Unwrapped Environment State: Bandit(function=PartialGaussianProcessPrior,in_spec=BoundedArray(name=inputs,shape=(2,),dtype=float32,minimum=-1.0,maximum=1.0),out_spec=Array(name=outputs,shape=(),dtype=float32),use_reward_as_observation=True)
key (2,)
time ()
data <class 'dict'>

Bandit(function=PartialGaussianProcessPrior) FunctionState(params={'PartialGaussianProcessPrior': {'bases': (10,), 'shift': (10, 2)}}, state=None)



In [5]:
batched_env = env.delay.BatchDelay(my_env, batch_size=2)

print('Example wrapped Environment State:', repr(batched_env))
state, step = batched_env.reset(key)

for k, v in state.items():
    if k == 'data':
        print(k, type(v), end='\n\n')
        for wk, wv in v.items():
            print(wk, jax.tree_map(jnp.shape, wv), end='\n\n')
    else: 
        print(k, jax.tree_map(jnp.shape, v))

AttributeError: module 'bbox.env.delay' has no attribute 'BatchDelay'

As shown above, the BatchDelay wrapper registers a `DelayedState` data object within `state.data` under its `__repr__`. Even though the key is recursively defined, storing all data in a flat dictionary allows subsequent wrapper more direct access to the state data. It also ensures compatibility of unwrapped environments with the wrapped environment state, as they can simply ignore the unused entries in the `state.data` dictionary.

### Delay Wrappers

The implementation for the `Delay` Wrappers functions through wrapping any conventional `Environment` as a base and induce environment feedback asynchronicity through an `EventBuffer`. As shown before, the `BatchDelay` Wrapper stores the `DelayedState` datastructure with the `buffer` attribute. The core philosophy around this Wrapper Type is to restrict the agent from having direct access to the underlying optimized function and only allow it to communicate through the `EventBuffer` proxy. At any time `t` the agent can push an action to the wrapped environment, and the environment/ wrapper's dynamics will govern when the agent will receive feedback by postponing observations within the `EventBuffer`. 

Typically, the Wrapper directly evaluates the given query, and writes the results to `EventBuffer` before subsequently reading this tape at the current `state.time`. Although this resembles popping a Buffer in a FIFO manner, like a Queue, note that empty events are typically not time-skipped. Although, this is still possible with the `AwaitEvent` Wrapper, as we'll show later on.

##### Agent-Environment Dimensionality
To preserve static Array shapes within a Jax compatible Environment, the agent-environment dimensionality can change dependent on which Delay Type is instantiated. Typically, the observations of the original Environment are transformed into shape: `(buffer_size, *original_shape)` batches of observations.

##### TimeStep.Extras
The environment interface returns at every call to `step` or `reset` both an Environment `State` and a `TimeStep`. The `TimeStep` structure contains the observation/ feedback to the agent. On top of that, this struct contains an `extras` dictionary with meta-data from the underlying Environments. Typically this attribute is empty, and the agent does not have direct access to it. However, for logging, or subsequent transforms in the form of other (`Delay`) Wrappers, this meta-data can be helpful. For example, when updating a resource budget, or when monitorring time statistics, we might not want this information to be delayed. 

The Delay environments include two additional datastructures within `TimeStep.extras`. 
 1. The Unwrapped Environment's unmodified TimeStep.
 2. The explicit readout from the Delay's `EventBuffer` in the form of a `BufferInfo` structure.
 
The unmodified TimeStep can be utilized for monitorring purposes, whereas the `BufferInfo` can be utilized to exactly re-identify which delayed action has now received feedback and whether the given observation actually contains an event or not.

#### Batch Delays

In [None]:
batched_env = env.delay.BatchDelay(my_env, batch_size=2)
timed_batched_env = env.wrappers.TimeLimit(batched_env, max_episode_steps=10)

# The class __repr__ recursively tracks how base gets wrapped.
print("BatchEnv Hierarchy:", repr(batched_env))
print("TimedEnv Hierarchy:", repr(timed_batched_env))
print()

print('Canonical Observation shape:', my_env.observation_spec())
print('Batch Observation shape:', batched_env.observation_spec())

In [None]:
data = run_experiment(key, timed_batched_env)

In [None]:
# Investigate the new datastructures
act_ref, step_ref = data[-2]

# The action is unchanged
print('action:', act_ref)
# >> action: [ 0.4381199  -0.10669184]

# Remove the extras dictionary for readability.
extras = step_ref.__dict__.pop('extras')

# The observation is now a vector of batch-size = 2
print('timestep', step_ref)
# >> timestep TimeStep(step_type=DeviceArray(2, dtype=int8), reward=DeviceArray(0., dtype=float32), 
#  discount=DeviceArray(0., dtype=float32), observation=DeviceArray([0., 0.], dtype=float32), extras=None)

step_ref.__dict__['extras'] = extras

In [None]:
for k, v in extras.items():
    print(k)
    print('>', v)
    print()

#### BufferInfo

The `BufferInfo` object contains meta-data for the read-event within the Delay Wrapper. It tracks: the actual read-data, the Wrapper's internal time (which can deviate from `state.time`), the number of events that are contained within the buffer, and the ID of the system (default None; integer type in scheduled systems).

In [None]:
buffer_info = extras['/'.join((repr(batched_env), env.BufferInfo.__name__))]

readout: env.EventBuffer = buffer_info.__dict__.pop('readout')

print("BufferInfo Misc. Information:", buffer_info.__dict__)
print()

buffer_info.__dict__['readout'] = readout
print(readout)

#### Recursive Delays: IID Delayed Batches

The different types of Delay Environment Wrappers can be freely composed recursively, as long as your computer memory allows.

This requires some careful calculations! For example, the SensoryDelay wrapper maintains a buffer of `buffer_size * buffer_size * shapes`, the BatchDelay maintains a buffer of `batch_size * shapes`. In total this recursion induces an effective buffer of `buffer_size * buffer_size * batch_size * shapes`. Which in this example incurs a buffer size of `200 * shapes`.

In [None]:
null_delay_env = env.delay.SensoryDelay(
    batched_env, delay_process=lambda *_: 0, buffer_size=10, synchronize=False)

delay_env = env.delay.SensoryDelay(
    batched_env, delay_process=lambda *_: 1, buffer_size=10, synchronize=False)

print("DelayedEnv Hierarchy:", repr(null_delay_env))
print("DelayedEnv Hierarchy:", repr(delay_env))

batched_data = run_experiment(key, batched_env)
null_delay_data = run_experiment(key, null_delay_env)
delay_data = run_experiment(key, delay_env)

In [None]:
print("Reward Readout/ Event Times:")
print("\tBatched \tNo Delay\tDelay")

for i, (a, b, c) in enumerate(zip(batched_data, null_delay_data, delay_data)):
    print(i, f'\t{int(a[1].reward != 0)}\t\t{int(b[1].reward != 0)}\t\t{int(c[1].reward != 0)}'.replace('0', '-'))

#### Synchronized Delays

Some Delay Types allow for the keyword argument `synchronize`, this is default behaviour for `BatchDelay`. If a Delay Wrapper wraps another Delay Wrapper, this allows the outer wrapper to check whether its current TimeStep is delayed or not. If so, by setting `synchronize=True` it will not record the observation to the `EventBuffer` and it will not pop the latest event from the Buffer to read. This is useful for reducing the memory footprint of the buffer as it prevents writing empty events, as a result, one can use smaller buffer sizes.

For example, this is useful when wrapping: `BatchDelay(SensoryDelay(env))`. The BatchDelay will not update the batch-buffer, until the received observation can be considered not null. In the cells below we show how this affects the event-timings of the previous cells.

In [None]:
null_delay_env = env.delay.SensoryDelay(
    batched_env, delay_process=lambda *_: 0, buffer_size=10, synchronize=True)

delay_env = env.delay.SensoryDelay(
    batched_env, delay_process=lambda *_: 1, buffer_size=10, synchronize=True)

print("DelayedEnv Hierarchy:", repr(null_delay_env))
print("DelayedEnv Hierarchy:", repr(delay_env))

batched_data = run_experiment(key, batched_env)
null_delay_data = run_experiment(key, null_delay_env)
delay_data = run_experiment(key, delay_env)

In [None]:
print("Reward Readout/ Event Times:")
print("\tBatched \tNo Delay\tDelay")

for i, (a, b, c) in enumerate(zip(batched_data, null_delay_data, delay_data)):
    print(i, f'\t{int(a[1].reward != 0)}\t\t{int(b[1].reward != 0)}\t\t{int(c[1].reward != 0)}'.replace('0', '-'))

Since the delay of `delay_env` is always `1`, which is uneven, the synchronization prevents a read-event from occuring until the feedback from the underlying event is also non-empty. In the previous example with `synchronize=False` the events of `delay_env` occured at uneven time-steps (where the underlying BatchDelay Wrapper did not yield an event) whereas in the above example `synchronize=True` the events occur at even time-steps (synchronously with any non-null event from BatchDelay).

### Joint Optimization and System Scheduling

Multiple instances of the EventBuffer can be maintained by certain Wrappers, the agent must then also consider the *throughput* of its actions by scheduling its actions to some system of choice. This also allows systems to terminate asynchronously through a `TimeLimit` or `BudgetConstraint` wrapper, etc.

#### EventAwaiter

In [None]:
import jax
from jax import numpy as jnp

import matplotlib.pyplot as plt

from bbox import procgen as px
from bbox import env


key = jax.random.PRNGKey(123)


# Environment instantiation closely follows the core Function API.
rbf = lambda a, b: jnp.exp(-0.5 * jnp.sum(jnp.square(a - b)))
my_env = env.as_bandit(
    base=px.real.GaussianProcessPrior.partial(kernel=rbf, resolution=10),
    wrappers=None,
    dummy_x=jnp.zeros(2)
)

print(repr(my_env))

In [None]:
from bbox.env import ScheduledAction


delay_env = env.delay.SensoryDelay(my_env, delay_process=lambda *_: 1, buffer_size=5)
await_env = env.scheduling.FlushDelay(delay_env, max_steps=10)

print('Canonical Action:', delay_env.action_spec().generate_value())
print()
print('AwaitEvent Wrapper Action:', await_env.action_spec().generate_value())


state, step = await_env.reset(key)

In [None]:
state, step = await_env.reset(key)

for _ in range(4):
    state, step = await_env.step(state, ScheduledAction(action=jnp.ones(2), system_id=0))
    print(state.time, step.reward)
    
    state, step = await_env.step(state, ScheduledAction(action=jnp.ones(2), system_id=1))
    print(state.time, step.reward)

In [None]:
state, step = await_env.reset(key)

for _ in range(4):
    state, step = await_env.step(state, ScheduledAction(action=jnp.ones(2), system_id=0))
    
    print(state.time, step.reward)

### Time/ Budget Constraints

In the Batch Delay example, we already showed how to instantiate a TimeLimit on an Environment. Opposed to conventional implementations, our TimeLimit wrapper is sensitive to the ordering of the Wrapping, as it tracks the number of calls to the environment itself. This behaviour is optional and can be opted for using `track_self=True`, if one wants to use the environment `EnvState.time` for the time-budget then you should use `track_self=False`. This is useful for System Scheduling Environments, where the Environment can have partially terminated sub-environments without terminating itself. The cell below illustrates this:

In [None]:
for b in [True, False]:
    timed_env = env.wrappers.TimeLimit(my_env, max_episode_steps=2, track_self=b, terminate=True)
    
    delay_env = env.industrial.SensoryDelay(timed_env, delay_process=lambda b: 1, buffer_size=5)
    await_env = env.industrial.EventAwaiter(delay_env)
    
    state, step = await_env.reset(key)

    print(f"Timed track_self: {b}")
    for _ in range(3):
        state, step = await_env.step(state, ScheduledAction(action=jnp.ones(2), system_id=0))
        print(f't_outer={state.time}, t_wrapper={state.data[repr(timed_env)].cumulative} done={step.last()}, r={step.reward:.4f}')
        
        state, step = await_env.step(state, ScheduledAction(action=jnp.ones(2), system_id=1))
        print(f't_outer={state.time}, t_wrapper={state.data[repr(timed_env)].cumulative} done={step.last()}, r={step.reward:.4f}')
        
    print()

### Bakery: Combining all of the Above