In [8]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [9]:
import jax.numpy as jnp
import matplotlib.pyplot as plt
import jax
from jax import lax
from deluca.envs import BalloonLung
from deluca.agents import PID

In [10]:
def loop(context, x):
    env, agent = context
    agent_in, agent_out = agent
    error = env.observation['target'] - env.observation['measured']
    control_in = agent_in(error)
    control_out = agent_out(error)
    _, reward, _, _ = env.step((control_in, control_out))
    return (env, (agent_in, agent_out)), reward

In [11]:
# BalloonLung env
lung = BalloonLung(leak=False,
                   peep_valve=5.0,
                   PC=40.0,
                   P0=0.0,
                   C=10.0,
                   R=15.0,
                   dt=0.03,
                   waveform=None,
                   reward_fn=None)


In [12]:
# for loop version
T = 10
xs = jnp.array(jnp.arange(T))
agent_in = PID([3.0, 4.0, 0.0])
agent_out = PID([3.0, 4.0, 0.0])
print(lung.reset())
reward = 0
for i in range(T):
    (lung, (agent_in, agent_out)), r = loop((lung, (agent_in, agent_out)), 0)
    reward += r
reward_forloop = reward

# scan version
agent_in = PID([3.0, 4.0, 0.0])
agent_out = PID([3.0, 4.0, 0.0])
print(lung.reset())
_,reward_scan = lax.scan(loop, (lung, (agent_in, agent_out)), xs)

# correctness test
print('reward_forloop = ' + str(reward_forloop))
print('reward_scan sum = ' + str(jnp.sum(reward_scan)))

{'measured': DeviceArray(0.05487419, dtype=float64), 'target': DeviceArray(5., dtype=float64), 'dt': 0.03, 'phase': DeviceArray(1, dtype=int64)}
{'measured': DeviceArray(0.05487419, dtype=float64), 'target': DeviceArray(5., dtype=float64), 'dt': 0.03, 'phase': DeviceArray(1, dtype=int64)}
reward_forloop = -170.7004493227582
reward_scan sum = -170.7004493227583
