# Training with RLlib PPO
[RLlib](https://docs.ray.io/en/latest/rllib/index.html) is a high-performance, distributed
reinforcement learning library. It is preferable to other RL libraries (e.g. Stable Baselines
3) for `bsk_rl` environments because it steps environments copies asynchronously; because 
of the variable step lengths, variable episode step counts, and long episode reset times, 
stepping each environment independently can increase step throughput by 2-5 times.

<div class="alert alert-warning">

**Warning:** RLlib had a bug that results in an undesirable timeout which stops
training. It has since been resolved: https://github.com/ray-project/ray/pull/45147

</div>


## Define the Environment
A nadir-scanning environment is created, to the one used in [this paper](https://hanspeterschaub.info/Papers/Stephenson2024.pdf). 
The satellite has to collect data while managing the data buffer level and battery level.

First, the satellite class is defined. A custom dynamics model is created that defines
a few additional properties to use in the state.

In [None]:
import numpy as np
from bsk_rl import act, data, obs, sats, scene
from bsk_rl.sim import dyn, fsw

class ScanningDownlinkDynModel(dyn.ContinuousImagingDynModel, dyn.GroundStationDynModel):
    # Define some custom properties to be accessed in the state
    @property
    def instrument_pointing_error(self) -> float:
        r_BN_P_unit = self.r_BN_P/np.linalg.norm(self.r_BN_P) 
        c_hat_P = self.satellite.fsw.c_hat_P
        return np.arccos(np.dot(-r_BN_P_unit, c_hat_P))
    
    @property
    def solar_pointing_error(self) -> float:
        a = self.world.gravFactory.spiceObject.planetStateOutMsgs[
            self.world.sun_index
        ].read().PositionVector
        a_hat_N = a / np.linalg.norm(a)
        nHat_B = self.satellite.sat_args["nHat_B"]
        NB = np.transpose(self.BN)
        nHat_N = NB @ nHat_B
        return np.arccos(np.dot(nHat_N, a_hat_N))

class ScanningSatellite(sats.AccessSatellite):
    observation_spec = [
        obs.SatProperties(
            dict(prop="storage_level_fraction"),
            dict(prop="battery_charge_fraction"),
            dict(prop="wheel_speeds_fraction"),
            dict(prop="instrument_pointing_error", norm=np.pi),
            dict(prop="solar_pointing_error", norm=np.pi)
        ),
        obs.OpportunityProperties(
            dict(prop="opportunity_open", norm=5700),
            dict(prop="opportunity_close", norm=5700),
            type="ground_station",
            n_ahead_observe=1,
        ),
        obs.Eclipse(norm=5700),
        obs.Time(),
    ]
    action_spec = [
        act.Scan(duration=180.0),
        act.Charge(duration=120.0),
        act.Downlink(duration=60.0),
        act.Desat(duration=60.0),
    ]
    dyn_type = ScanningDownlinkDynModel
    fsw_type = fsw.ContinuousImagingFSWModel

Next, parameters are set. Since this scenario is focused on maintaining acceptable data
and power levels, these are tuned to create a sufficiently interesting mission.

In [None]:
sat = ScanningSatellite(
    "Scanner-1",
    sat_args=dict(
        # Data
        dataStorageCapacity=5000 * 8e6,  # bits
        storageInit=lambda: np.random.uniform(0.0, 0.8) * 5000 * 8e6,
        instrumentBaudRate=0.5 * 8e6,
        transmitterBaudRate=-50 * 8e6,
        # Power
        batteryStorageCapacity=200 * 3600,  # W*s
        storedCharge_Init=lambda: np.random.uniform(0.3, 1.0) * 200 * 3600,
        basePowerDraw=-10.0,  # W
        instrumentPowerDraw=-30.0,  # W
        transmitterPowerDraw=-25.0,  # W
        thrusterPowerDraw=-80.0,  # W
        panelArea=0.25,
        # Attitude
        imageAttErrorRequirement=0.1,
        imageRateErrorRequirement=0.1,
        disturbance_vector=lambda: np.random.normal(scale=0.0001, size=3),  # N*m
        maxWheelSpeed=6000.0,  # RPM
        wheelSpeeds=lambda: np.random.uniform(-3000, 3000, 3),
        desatAttitude="nadir",
    )
)

Finally, the environment arguments are set. Stepping through this environment is 
demonstrated at the bottom of the page.

In [None]:
duration = 5 * 5700.0  # About 5 orbits
env_args = dict(
    satellite=sat,
    scenario=scene.UniformNadirScanning(value_per_second=1/duration),
    rewarder=data.ScanningTimeReward(),
    time_limit=duration,
    failure_penalty=-1.0,
    terminate_on_time_limit=True,
)

## Configure Ray and PPO

The `bsk_rl` package supplies a utility to make logging information at the end of episodes
easier. This is useful to see how an agent's policy is changing over time, using a
monitoring program such as [TensorBoard](https://www.tensorflow.org/tensorboard).

In [None]:
from bsk_rl.utils.rllib import EpisodeDataCallbacks

class CustomDataCallbacks(EpisodeDataCallbacks):
    def pull_env_metrics(self, env):
        reward = env.rewarder.cum_reward
        reward = sum(reward.values()) / len(reward)
        orbits = env.simulator.sim_time / (95 * 60)

        data = dict(
            reward=reward,
            # Are satellites dying, and how and when?
            alive=float(env.satellite.is_alive()),
            rw_status_valid=float(env.satellite.dynamics.rw_speeds_valid()),
            battery_status_valid=float(env.satellite.dynamics.battery_valid()),
            orbits_complete=orbits,
        )
        if orbits > 0:
            data["reward_per_orbit"] = reward / orbits
        if not env.satellite.is_alive():
            data["orbits_complete_partial_only"] = orbits
            
        return  data

Then, PPO (or some other algorithm) can be configured. Of particular importance
are setting `sample_timeout_s` and `metrics_episode_collection_timeout_s` to appropriately
high values for this environment.

In [None]:
from bsk_rl import SatelliteTasking
from bsk_rl.utils.rllib import unpack_config
from ray.rllib.algorithms.ppo import PPOConfig

N_CPUS = 3

training_args = dict(
    lr=0.00003,
    gamma=0.999,
    train_batch_size=250,  # usually a larger number, like 2500
    num_sgd_iter=10,
    model=dict(fcnet_hiddens=[512, 512], vf_share_layers=False),
    lambda_=0.95,
    use_kl_loss=False,
    clip_param=0.1,
    grad_clip=0.5,
)

config = (
    PPOConfig()
    .training(**training_args)
    .env_runners(num_env_runners=N_CPUS-1, sample_timeout_s=1000.0)
    .environment(
        env=unpack_config(SatelliteTasking),
        env_config=env_args,
    )
    .callbacks(CustomDataCallbacks)
    .reporting(
        metrics_num_episodes_for_smoothing=1,
        metrics_episode_collection_timeout_s=180,
    )
    .checkpointing(export_native_model_files=True)
    .framework(framework="tf2")
    # Uncomment these lines if using the new RLlib API stack
    # .framework(framework="torch")
    # .api_stack(
    #     enable_rl_module_and_learner=True,
    #     enable_env_runner_and_connector_v2=True,
    # )

)

Once the PPO configuration has been set, `ray` can be started and the agent can be
trained.

Training on a reasonably modern machine, we can achieve 5M steps over 32 processors in 6
to 18 hours, depending on specific environment configurations.

In [None]:
import ray
from ray import tune

ray.init(
    ignore_reinit_error=True,
    num_cpus=N_CPUS,
    object_store_memory=2_000_000_000,  # 2 GB
)

# Run the training
tune.run(
    "PPO",
    config=config.to_dict(),
    stop={"training_iteration": 10},  # Adjust the number of iterations as needed
    checkpoint_freq=10,
    checkpoint_at_end=True
)

# Shutdown Ray
ray.shutdown()

## Loading the Policy Network

The policy network can be found in the `model` subdirectory of the checkpoint output.

## Stepping Through the Environment

The environment is stepped through with random actions to give a sense of how it acts.

In [None]:
env = SatelliteTasking(**env_args, log_level="INFO")
env.reset()
terminated = False
while not terminated:
    action = env.action_space.sample()
    observation, reward, terminated, truncated, info = env.step(action)