## Testing the built Classes and Environment
-----

In [21]:
import tensorflow as tf
assert len(tf.config.list_physical_devices('GPU')) > 0, "No GPU."

In [22]:
from importlib import reload

from src.data.states import get_body_state, get_astrotime_now

import src.components as cpts; cpts = reload(cpts)
import src.environment as envmt; envmt = reload(envmt)
import src.plot as cplot; cplot = reload(cplot)

In [23]:
time = get_astrotime_now()
walker_name = "mars"

walker_position, walker_velocity = get_body_state(walker_name, time).values()

#### Validating the `WalkerSystemEnv`

-----

In [24]:
from tf_agents.environments import validate_py_environment
from tf_agents.environments import tf_environment, tf_py_environment

In [25]:
system = cpts.SunSystem(["earth", "venus"], add_sun=True)
walker = cpts.Walker(walker_position, walker_velocity, mass=1., name=f"walker ({walker_name})")
solver = cpts.Solver()
target = cpts.OrbitTarget("jupiter", tilt_angle=45)
# target = cpts.FixedTarget([0, 0, 1])

In [26]:
class Env(
    envmt.WalkerSystemEnv,
    envmt.ContinuousAction,                   # src.env: ContinuousAction, DiscreteAction, OneDimDiscreteAction
    envmt.StateAndDiffObservation,            # src.env: StateAndDiffObservation, GravityObservation, AllPositionsObservation
    envmt.DistanceAndTargetReached
): ...

In [27]:
env = Env(
    walker, 
    system, 
    solver, 
    target, 
    max_iters=1000, 
    max_boost=1e-4
)
# validate_py_environment(env, 1)

Demonstration of the usage of `src.tfpy_env_wrapper`:

In [28]:
tf_env_ = tf_py_environment.TFPyEnvironment(env)
print(isinstance(tf_env_, tf_environment.TFEnvironment))
print(hasattr(tf_env_, "walker"))

True
False


In [29]:
tf_env = envmt.tfpy_env_wrapper(env)
print(isinstance(tf_env, tf_environment.TFEnvironment))
print(tf_env.walker)
print(tf_env.walker == walker)

True
<src.components.walkers.Walker object at 0x000001DD6B98B2B0>
True


Inspection of time and action specs:

In [30]:
print("TimeStep Specs:", tf_env.time_step_spec())
print("Action Specs:", tf_env.action_spec())

TimeStep Specs: TimeStep(
{'discount': BoundedTensorSpec(shape=(), dtype=tf.float32, name='discount', minimum=array(0., dtype=float32), maximum=array(1., dtype=float32)),
 'observation': {'diff-to-target': TensorSpec(shape=(3,), dtype=tf.float32, name='target'),
                 'walker-state': TensorSpec(shape=(6,), dtype=tf.float32, name='walker-state')},
 'reward': TensorSpec(shape=(), dtype=tf.float32, name='reward'),
 'step_type': TensorSpec(shape=(), dtype=tf.int32, name='step_type')})
Action Specs: BoundedTensorSpec(shape=(3,), dtype=tf.float32, name='boost', minimum=array(-1.e-04, dtype=float32), maximum=array(1.e-04, dtype=float32))


#### Executing a Policy in the `WalkerSystemEnv`

In [31]:
from tf_agents.policies.random_tf_policy import RandomTFPolicy
from tf_agents.policies.fixed_policy import FixedPolicy
from tf_agents.drivers.dynamic_episode_driver import DynamicEpisodeDriver
from tf_agents.replay_buffers.tf_uniform_replay_buffer import TFUniformReplayBuffer

In [32]:
time_step_spec = tf_env.time_step_spec()
action_spec    = tf_env.action_spec()
obs_specs = tf_env.observation_spec()

In [33]:
random_tf_policy = RandomTFPolicy(
    time_step_spec=time_step_spec,
    action_spec=action_spec
)

random_replay_buffer = TFUniformReplayBuffer(
    data_spec=random_tf_policy.collect_data_spec,
    batch_size=tf_env.batch_size,
    max_length=10000
)

random_driver = DynamicEpisodeDriver(
    tf_env,
    random_tf_policy,
    observers=[random_replay_buffer.add_batch],
    num_episodes=1
)

In [34]:
fixed_policy = FixedPolicy(
    actions=0,
    time_step_spec=time_step_spec,
    action_spec=action_spec
)

fixed_replay_buffer = TFUniformReplayBuffer(
    data_spec=fixed_policy.collect_data_spec,
    batch_size=tf_env.batch_size,
    max_length=10000
)

fixed_driver = DynamicEpisodeDriver(
    tf_env,
    fixed_policy,
    observers=[fixed_replay_buffer.add_batch],
    num_episodes=1
)

For a default episode:

In [35]:
initial_time_step = tf_env.reset()
if isinstance(env, envmt.ContinuousAction):
    _ = random_driver.run()
if isinstance(env, envmt.DiscreteAction):
    # _ = fixed_driver.run()
    _ = random_driver.run()

In [36]:
fixed_replay_buffer.clear()
random_replay_buffer.clear()

In [41]:
plotter = cplot.Plotter({'env': tf_env})
plotter.draw("3d", zrange=None)

For a given numer of steps:

In [38]:
time_step = tf_env.reset()
for k in range(1000):

    if isinstance(env, envmt.ContinuousAction):
        
        action = random_tf_policy.action(time_step).action

        # action = tf.reshape(tf.constant([0., 0., 3e-5], tf.float32), (1, -1))


    if isinstance(env, envmt.DiscreteAction):
        action = 2

        # if k % 2:
        #     action = 0
        # else:
        #     action = 1


    time_step = tf_env.step(action)

If a custom `action` (a `tf.constant` in this case) is used, it has to has
shape $(1\times 3)$. There must be a "batch"-dimension when using 
`TFPyEnvironment`s.

In [40]:
plotter = cplot.Plotter({'env': tf_env})
plotter.draw("3d", zrange=None)