In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import jax.numpy as jnp
import numpy as np
import time
import matplotlib.pyplot as plt
import jax
from jax import lax
from gym.envs.classic_control import PendulumEnv
from deluca import Agent
from deluca.envs import BalloonLung, DelayLung, LearnedLung
from deluca.agents import PID
from deluca.envs.core import Env
import pickle


In [3]:
def loop(context, x):
    env, agent = context
    agent_in, agent_out = agent
    error = env.observation['target'] - env.observation['measured']
    control_in = agent_in(error)
    control_out = agent_out(error)
    _, reward, _, _ = env.step((control_in, control_out))
    return (env, (agent_in, agent_out)), reward

In [5]:
# BalloonLung env
lung = BalloonLung(leak=False,
                   peep_valve=5.0,
                   PC=40.0,
                   P0=0.0,
                   C=10.0,
                   R=15.0,
                   dt=0.03,
                   waveform=None,
                   reward_fn=None)


In [112]:
# DelayLung env
lung = DelayLung(min_volume=1.5,
                 R_lung=10,
                 C_lung=6,
                 delay=25,
                 inertia=0.995,
                 control_gain=0.02,
                 dt=0.03,
                 waveform=None,
                 reward_fn=None)

In [8]:
# LearnedLung env
lung = LearnedLung.from_torch("learned_lung_C20_R20_PEEP10.pkl")

In [10]:
# for loop version
T = 10
xs = jnp.array(jnp.arange(T))
agent_in = PID([3.0, 4.0, 0.0])
agent_out = PID([3.0, 4.0, 0.0])
print(lung.reset())
reward = 0
for i in range(T):
    (lung, (agent_in, agent_out)), r = loop((lung, (agent_in, agent_out)), 0)
    reward += r
reward_forloop = reward

# scan version
agent_in = PID([3.0, 4.0, 0.0])
agent_out = PID([3.0, 4.0, 0.0])
print(lung.reset())
_,reward_scan = lax.scan(loop, (lung, (agent_in, agent_out)), xs)

# correctness test
print('reward_forloop = ' + str(reward_forloop))
print('reward_scan sum = ' + str(jnp.sum(reward_scan)))

{'measured': 0, 'target': DeviceArray(5., dtype=float64), 'dt': 0.03, 'phase': DeviceArray(1, dtype=int64)}
{'measured': 0, 'target': DeviceArray(5., dtype=float64), 'dt': 0.03, 'phase': DeviceArray(1, dtype=int64)}
reward_forloop = -127.64067316490946
reward_scan sum = -127.64067316490946
