In [1]:
import numpy as np
import pandas as pd

# ML
import tensorflow as tf
from tensorflow import keras

import torch

import flax
from flax import linen  # keras but for flax essentially. More literally, torch.nn but for flax
import jax
import jaxlib
import optax

import sklearn

# RL
import gymnasium as gym  # for environment interaction: states/observations and actions, as well as their associated spaces and the interaction of actions 
import stable_baselines3
import skrl
#import tf_agents

# Statistics
#import tensorflow_probability as tfp

# Data Visualization
import matplotlib.pyplot as plt
import seaborn as sns

In [2]:
# Check versions (in case of conflicts)
# Not always required, but in this case since I imported a million different things with dependencies yeah it's good for debugging
print(f"numpy {np.__version__}")
print(f"pandas {pd.__version__}")
print(f"tensorflow {tf.__version__}")
print(f"torch {torch.__version__}")
print(f"flax {flax.__version__}")
print(f"JAX {jax.__version__}")
print(f"optax {optax.__version__}")
print(f"sklearn {sklearn.__version__}")
print(f"gymnasium {gym.__version__}")
print(f"stable baselines3 {stable_baselines3.__version__}")
print(f"skrl {skrl.__version__}")
print(f"seaborn {sns.__version__}")

numpy 1.25.2
pandas 1.5.3
tensorflow 2.10.0
torch 2.0.1
flax 0.7.2
JAX 0.4.14
optax 0.1.7
sklearn 1.3.0
gymnasium 0.28.1
stable baselines3 2.0.0
skrl 1.0.0rc2
seaborn 0.12.2


In [3]:
# Let's play with JAX (mostly just numpy but jaxified)
import jax.numpy as jnp

jnp.array([1, 2, 3, 4])

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Array([1, 2, 3, 4], dtype=int32)

# Taxi Example

### Goal
The taxi will start at a random square. Given the location of a passenger and their destination, its goal is to get to move itself to pick up a passenger (each movement = separate action), and then move to its desired destination to drop off the passenger. Picking up and dropping off are also separate actions. It wants to achieve the goal, and as fast as possible (it's a Taxi for crying out loud!)

### Reward Function
The reward function:
- (-1) per step if no reward; encourages finding the fastest route
- (+20) for delivering the passenger correctly
- (-10) for doing things incorrectly. That is, executing "pick up" and "drop off" actions illegally 

### Episodic Task
Ends when terminated or truncated
- Terminated if: task successfully completed (dropped off the passenger correctly)
- Truncated if: using the `time_limit` wrapper and length of the episode is 200 

Link: https://gymnasium.farama.org/environments/toy_text/taxi/

In [4]:
# State Space Dicts
# (or Observation Space, but in this case since we know all that is going on in the environment, we can use the term state)
PASSENGER_LOCATIONS = ["Red", "Green", "Yellow", "Blue", "In Taxi"]
PASSENGER_DESTINATIONS = ["Red", "Green", "Yellow", "Blue"]
POSSIBLE_TAXI_POSITIONS = set(range(25))  # unnecessary line, but added for clarification that there are 25 possible taxi positions

# Action Space Dict
ACTIONS = ["DOWN", "UP", "RIGHT", "LEFT", "PICK UP", "DROP OFF"]

In [5]:
# we are impatient and want things fast, so let's use JAX (not recommended for beginners, but the code is mostly the same)
from skrl.envs.wrappers.jax import wrap_env as wrap_env_jax


# Taxi Environment
gymnasium_env = gym.make("Taxi-v3", render_mode="human")
env = wrap_env_jax(gymnasium_env, wrapper="gymnasium")  # JAXification!

env.__dict__

[38;20m[skrl:INFO] Environment class: gymnasium.core.Wrapper, gymnasium.utils.record_constructor.RecordConstructorArgs[0m
[38;20m[skrl:INFO] Environment wrapper: gymnasium[0m


{'_jax': False,
 '_env': <TimeLimit<OrderEnforcing<PassiveEnvChecker<TaxiEnv<Taxi-v3>>>>>,
 'device': CpuDevice(id=0),
 '_vectorized': False}

In [6]:
# It might look like the wrapped env doesn't have the attributes and methods the non-wrapped env has, but it does have access to it!
# The point of the wrapped environment is to work with the skrl framework, which provides interopability between multiple environmental frameworks such as Gymnasium, DeepMind, NVIDIA, and others
dir(env)

['__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattr__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_env',
 '_jax',
 '_observation_to_tensor',
 '_tensor_to_action',
 '_vectorized',
 'action_space',
 'close',
 'device',
 'num_agents',
 'num_envs',
 'observation_space',
 'render',
 'reset',
 'state_space',
 'step']

In [7]:
env.observation_space

Discrete(500)

In [8]:
env.action_space

Discrete(6)

In [31]:
import flax.linen as nn
from skrl.models.jax.base import Model as Model_jax
from skrl.models.jax.categorical import CategoricalMixin as CategoricalMixin_jax
from skrl.models.jax.deterministic import DeterministicMixin as DeterministicMixin_jax


# Let's get the policy prototype(s) straight (creating a skeleton for the policy; this is essentially its mechanical intellectual capacity)
# Policy is the function mapping situations to actions

# The prefix 'Categorical' refers to the OUTPUT of the policy
# The output space is discrete, since the action space is also discrete!
# We will be using one that outputs using a Categorical Distribution (or Discrete Probability Distribution). This should make exploration more robust, but the other reason is that skrl is only supporting this when using JAX with discrete spaces
# This'll be quite the wild ride so strap on in!
# This class will essentially be a prototype for the categorical models/policies used by the agent 
class CategoricalPolicyModel(CategoricalMixin_jax, Model_jax):
    def __init__(self, observation_space, action_space, device=None, unnormalized_log_prob=True, **kwargs):
        Model_jax.__init__(self, observation_space, action_space, device, **kwargs)
        CategoricalMixin_jax.__init__(self, unnormalized_log_prob)
    
    @nn.compact  # marks the given module method allowing inlined submodules; basically this means I can define nn.Dense(), etc. within this function
    def __call__(self, inputs, role):
        #print(inputs)
        # Maps state/observation to action
        x = nn.elu(nn.Dense(32)(inputs["states"]))
        x = nn.elu(nn.Dense(16)(x))
        x = nn.Dense(self.num_actions)(x)  # 6 neurons in this case
        
        return x, {}

# Deterministic and Continuous OUTPUT which is for the expected long-term value estimation
# This model will be used for expected long-term value estimation from the Critic to guide the Actor's exploration'
# Actor's exploration will be guided by the Critic's value, creating a hybrid of policy-based and value-based methods
# Exploration from the Actor will start out as it just doing random things and seeing what sticks/leads it to the most rewards.
# But as exploration goes on, the Critic's long-term value estimation will become increasingly accurate
# And now the Actor's exploration won't be as random anymore; it'll start to explore in ways the Critic values highly
# This creates a more focused/targeted/productive way to explore that doesn't boil down to pure stochasticity, and helps the Actor learn its policy by learning the long-term importances of its actions in different situations, thereby helping to deal with the Credit Assignment Problem.
# This differs from curriculum learning because its a characteristic to the model, not the environment/game it plays. This allows for more robustness. Although, combining both techniques isn't a bad idea by any means.
class DeterministicValueModel(DeterministicMixin_jax, Model_jax):
    def __init__(self, observation_space, action_space, device=None, clip_actions=False, **kwargs):
        Model_jax.__init__(self, observation_space, action_space, device, **kwargs)
        DeterministicMixin_jax.__init__(self, clip_actions)
        
    @nn.compact # marks the given module method allowing inlined submodules; basically this means I can define nn.Dense(), etc. within this function
    def __call__(self, inputs, role):
        # MLP: Nonlinear Regression
        # This Critic will evaluate the actions taken based on the state/observation
        
        print(inputs)
        
        # Flatten them; since the states and actions are 1D anyway this is OK
        states_flattened = jnp.ravel(inputs["states"])
        #taken_actions_flattened = jnp.ravel(inputs["taken_actions"])
        
        x = states_flattened
        #x = jnp.concatenate([states_flattened, taken_actions_flattened], axis=-1)
        
        #try:
        #    taken_actions_flattened = jnp.ravel(inputs["taken_actions"])
        
        #    x = jnp.concatenate([states_flattened, taken_actions_flattened], axis=-1)
        #except:
        #    x = states_flattened
        
        
        x = jnp.reshape(x, (-1, 1))
        
        hl1 = nn.relu(nn.Dense(64)(x))
        hl2 = nn.relu(nn.Dense(32)(hl1))
        
        out = nn.Dense(1)(hl2)
        
        return out, {}

In [32]:
from skrl.agents.jax.ppo import PPO as PPO_jax
from skrl.agents.jax.ppo import PPO_DEFAULT_CONFIG as PPO_DEFAULT_CONFIG_JAX
from skrl.memories.jax import RandomMemory as RandomMemory_jax

# Choosing an agent/model/policy/action architecture
# Let's be a basic bitch and use PPO

# models that the agent will use
ppo_models = {}
ppo_models["policy"] = CategoricalPolicyModel(env.observation_space, env.action_space)  # Actor: Policy
ppo_models["value"] = DeterministicValueModel(env.observation_space, env.action_space)  # Critic: Expected Cumulative Value Estimation; only required during training

# instantiate models' state dict
# Why can't this be done in the PPO constructor?
for (role, model) in ppo_models.items():
    model.init_state_dict(role)

# agent = policy/model + method of interaction/learning algorithm
# wants to optimize for long-term reward
# memory is only required during training of PPO
agent = PPO_jax(ppo_models,
                memory=RandomMemory_jax(memory_size=1024, num_envs=env.num_envs, device=env.device, replacement=True),  # sampling with replacement
                observation_space=env.observation_space, action_space=env.action_space, 
                cfg=PPO_DEFAULT_CONFIG_JAX.copy())  # configuration dict: hyperparameters, preprocessors, learning rate schedulers, etc.

{'states': array([[90]], dtype=int64), 'taken_actions': 2}


In [33]:
# Training the Agent
from skrl.trainers.jax import SequentialTrainer as SequentialTrainer_jax
from skrl.trainers.jax import ManualTrainer as ManualTrainer_jax


# https://skrl.readthedocs.io/en/latest/intro/getting_started.html#trainers

trainer = SequentialTrainer_jax(env=env, agents=[agent])  # default cfg

# train the agent
trainer.train()

# evaluate the agent
trainer.eval()

  0%|                                                                                                                                       | 1/100000 [00:00<15:01:04,  1.85it/s]

{'states': Traced<ShapedArray(int32[1,1])>with<DynamicJaxprTrace(level=1/0)>}


  0%|                                                                                                                                      | 15/100000 [00:07<13:54:41,  2.00it/s]

{'states': Traced<ShapedArray(float32[512,1])>with<DynamicJaxprTrace(level=4/0)>}


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100000/100000 [13:56:44<00:00,  1.99it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100000/100000 [13:56:53<00:00,  1.99it/s]


In [34]:
DEFAULT_MAX_TIMESTEPS = 50_000

def play_skrl_episode(env: gym.Env, agent):
    (observation, info) = env.reset()  # doesn't make sense for terminated or truncated to be here yet!
    #print(observation)
    #print(info)
    
    t = 0  # timestep
    (terminated, truncated) = (False, False)  # Episode End = terminated or truncated
    while not (terminated or truncated):
        (action, value_estimate, act_info) = agent.act(observation, t, DEFAULT_MAX_TIMESTEPS)  # KEY POINT OF RL #
        
        print(f"{observation} => {action}: {value_estimate}; ...{act_info}")
        
        (observation, reward, terminated, truncated, info) = env.step(action)
        
        #print(observation)
        #print(reward)
        #print(terminated)
        #print(truncated)
        #print(info)
        
        env.render()
        #frame = env.render()
        #print(frame)
        #plt.imshow(frame)
        
        t += 1

In [35]:
play_skrl_episode(env, agent)

[[221]] => [[0]]: [[nan]]; ...{'net_output': Array([[nan, nan, nan, nan, nan, nan]], dtype=float32), 'stddev': Array([[nan]], dtype=float32)}
[[321]] => [[0]]: [[nan]]; ...{'net_output': Array([[nan, nan, nan, nan, nan, nan]], dtype=float32), 'stddev': Array([[nan]], dtype=float32)}
[[421]] => [[0]]: [[nan]]; ...{'net_output': Array([[nan, nan, nan, nan, nan, nan]], dtype=float32), 'stddev': Array([[nan]], dtype=float32)}
[[421]] => [[0]]: [[nan]]; ...{'net_output': Array([[nan, nan, nan, nan, nan, nan]], dtype=float32), 'stddev': Array([[nan]], dtype=float32)}
[[421]] => [[0]]: [[nan]]; ...{'net_output': Array([[nan, nan, nan, nan, nan, nan]], dtype=float32), 'stddev': Array([[nan]], dtype=float32)}
[[421]] => [[0]]: [[nan]]; ...{'net_output': Array([[nan, nan, nan, nan, nan, nan]], dtype=float32), 'stddev': Array([[nan]], dtype=float32)}
[[421]] => [[0]]: [[nan]]; ...{'net_output': Array([[nan, nan, nan, nan, nan, nan]], dtype=float32), 'stddev': Array([[nan]], dtype=float32)}
[[421]

KeyboardInterrupt: 

In [21]:
# Recommended practice is to save the agents rather than the models
runs = [
    "23-08-17_22-42-00-923229_PPO",
    "23-08-17_23-17-18-255038_PPO",
    "23-08-18_02-31-14-419606_PPO"
] 

agent.load(f"./runs/{runs[0]}/checkpoints/best_agent.pickle")