In [1]:
"""
Learning Notes

I messed up a whole bunch of things in this code vs https://github.com/openai/spinningup/blob/master/spinup/examples/pytorch/pg_math/1_simple_pg.py:
1. Reward normalization
2. Didn't use "Categorical" distribution for the policy
3. Loss is something negative and huge...
4. Sampled from softmax using argmax - should have used torch.distributions.Categorical or jax.random.categorical -- which uses gumbel softmax trick to make it differentiable
5. Model was huge - i used 512x512x2 vs 32x2
6. LR was too small - 1e-4 vs 1e-2
7. Batch size was too small - 32 vs 5000 -- also samples are each time step, not each episode
8. INCORRECT OBSERVATIONS USED IN THE EXPERIENCE BUFFER - I used the newly sampled ones after running env.step
   but this was incorrect, since I needed use the previous ones...
 
"""

'\nLearning Notes\n\nI messed up a whole bunch of things in this code vs https://github.com/openai/spinningup/blob/master/spinup/examples/pytorch/pg_math/1_simple_pg.py:\n1. Reward normalization\n2. Didn\'t use "Categorical" distribution for the policy\n3. Loss is something negative and huge...\n4. Sampled from softmax using argmax - should have used torch.distributions.Categorical or jax.random.categorical -- which uses gumbel softmax trick to make it differentiable\n5. Model was huge - i used 512x512x2 vs 32x2\n6. LR was too small - 1e-4 vs 1e-2\n7. Batch size was too small - 32 vs 5000 -- also samples are each time step, not each episode\n8. INCORRECT OBSERVATIONS USED IN THE EXPERIENCE BUFFER - I used the newly sampled ones after running env.step\n   but this was incorrect, since I needed use the previous ones...\n \n'

In [2]:
import gymnasium as gym
import optax
import os
import numpy as np

from gymnasium import wrappers
from tensorflow_probability.substrates import jax as tfp

In [3]:
# # For inference:
# # The environment is described in the following link:
# # https://gymnasium.farama.org/environments/classic_control/cart_pole/
# env = gym.make("CartPole-v1", render_mode='rgb_array')

# # Record videos
# trigger = lambda _: True
# env = wrappers.RecordVideo(env, video_folder="./save_videos2", episode_trigger=trigger, disable_logger=True, video_length=1000)
# env = wrappers.RecordEpisodeStatistics(env)

In [28]:
import dataclasses
import enum


class PolicyGradientWeight(enum.Enum):
    NONE = 0
    TRAJECTORY_REWARDS = 1
    REWARDS_TO_GO = 2


class PolicyGradientWeightBaseline(enum.Enum):
    NONE = 0
    AVERAGE_TRAJECTORY_REWARDS = 1


@dataclasses.dataclass
class Params:
    learning_rate: float = 1e-2
    batch_size: int = 4096  # Use a large batch size for policy iteration!
    max_steps: int = 100
    obs_dim: int = 4

    # model related
    hidden_dims: list[int] = dataclasses.field(default_factory=lambda: [32])

    # Algorithm hypers
    policy_gradient_weight: PolicyGradientWeight = PolicyGradientWeight.TRAJECTORY_REWARDS
    policy_gradient_weight_baseline: PolicyGradientWeightBaseline = PolicyGradientWeightBaseline.NONE

    # initial exploration rate
    epsilon: int = 0.15
    # epsilon is decayed over time with a cosine schedule, starting at `epsilon`
    # and ending at `epsilon_min` over `max_steps` steps.
    epsilon_min: float = 0.1  # always explore at least 10% of the time

    def get_epsilon(self, step: int) -> float:
        if step > self.max_steps:
            raise ValueError("step must be less than max_steps")
        return self.epsilon_min + (self.epsilon - self.epsilon_min) * (1 + np.cos(np.pi * step / self.max_steps) ) / 2 


params = Params()

In [5]:
#@title Define the policy
import equinox as eqx
import jax

class MLP(eqx.Module):
    layers: list
    output_proj: eqx.nn.Linear

    def __init__(
        self, in_size: int, hidden_dims: list[int], out_size: int, key: jax.random.PRNGKey):
        self.layers = []
        prev_dim = in_size
        for dim in hidden_dims:
            key, layer_key = jax.random.split(key)
            self.layers.append(
                eqx.nn.Linear(
                    in_features=prev_dim,
                    out_features=dim,
                    use_bias=True,
                    key=layer_key
                )
            )
            prev_dim = dim
        self.output_proj = eqx.nn.Linear(
            in_features=prev_dim,
            out_features=out_size,
            use_bias=True,
            key=key
        )

    def __call__(self, x):
        for layer in self.layers:
            x = layer(x)
            x = jax.nn.relu(x)
        # log probs (log(softmax(x))
        x = self.output_proj(x)
        return x
        # return x - jax.scipy.special.logsumexp(x)

In [32]:
@eqx.filter_jit
def policy(model: MLP, obs: np.ndarray, epsilon: float, key: jax.random.PRNGKey):
    del epsilon  # unused
    # Find best action based on policy model
    logits = jax.vmap(model)(obs)
    # unnormalized log-probabilities are represented by the model output logits -- treat as categorical distribution
    model_actions = tfp.distributions.Categorical(logits=logits).sample(seed=key)

    # # maybe explore based on epsilon-greedy strategy
    # p_explore = jax.random.uniform(key1, (len(obs),))
    # random_actions = jax.random.randint(key2, (len(obs),), minval=0, maxval=2)

    # final_actions = jax.numpy.where(p_explore < epsilon, random_actions, model_actions)
    final_actions = model_actions
    return final_actions, model_actions

@dataclasses.dataclass
class SampleOutput:
    observations: np.ndarray = dataclasses.field(default_factory=lambda: np.array([[]]))
    actions: np.ndarray = dataclasses.field(default_factory=lambda: np.array([]))
    rewards: np.ndarray = dataclasses.field(default_factory=lambda: np.array([]))
    rewards_to_go: np.ndarray = dataclasses.field(default_factory=lambda: np.array([]))
    trajectory_rewards: np.ndarray = dataclasses.field(default_factory=lambda: np.array([]))
    valid: np.ndarray = dataclasses.field(default_factory=lambda: np.array([]))
    avg_expected_reward: float = 0.0
    avg_action: float = 0.0
    avg_episode_length: float = 0.0
    num_episodes: int = 0

    def num_samples(self):
        return len(self.observations)
    
    def __add__(self, other):
        def _weighted_sum(attr):
            return (getattr(self, attr) * self.num_samples() + getattr(other, attr) * other.num_samples()) / (self.num_samples() + other.num_samples())
        return SampleOutput(
            observations=np.concatenate([self.observations, other.observations]),
            actions=np.concatenate([self.actions, other.actions]),
            rewards=np.concatenate([self.rewards, other.rewards]),
            rewards_to_go=np.concatenate([self.rewards_to_go, other.rewards_to_go]),
            trajectory_rewards=np.concatenate([self.trajectory_rewards, other.trajectory_rewards]),
            valid=np.concatenate([self.valid, other.valid]),
            num_episodes=self.num_episodes + other.num_episodes,
            avg_expected_reward=_weighted_sum('avg_expected_reward'),
            avg_action=_weighted_sum('avg_action'),
            avg_episode_length=_weighted_sum('avg_episode_length')
        )

def sample_from_policy(envs: gym.vector.VectorEnv, model: MLP, epsilon: float, key: jax.random.PRNGKey):
    """Run all envs in parallel, return the trajectory of each env."""
    obs, _ = envs.reset()
    episode_over = np.array([False] * envs.num_envs)

    # Track the actions, rewards, and gradients for policy optimization
    trajectory_observations = []
    trajectory_rewards = []
    trajectory_actions = []
    trajectory_valid = []

    # iterate through the vectorized environments
    while np.any(~episode_over):
        trajectory_observations.append(obs)

        key_used, key = jax.random.split(key)
        actions, _ = policy(model, obs, epsilon, key_used)
        obs, rewards, terminated, truncated, _ = envs.step(np.array(actions))
        episode_over = episode_over | terminated | truncated

        # Mask out everything for episodes that are over
        mask = ~episode_over
        trajectory_rewards.append(rewards)
        trajectory_actions.append(actions)
        trajectory_valid.append(mask)

    # [B, T, D]
    trajectory_observations = np.einsum('Tbd->bTd', np.stack(trajectory_observations))
    # [B, T]
    trajectory_actions = np.stack(trajectory_actions).T
    trajectory_rewards = np.stack(trajectory_rewards).T
    trajectory_valid = np.stack(trajectory_valid).T

    # compute the rewards to go [B, T]
    rewards_to_go = np.where(trajectory_valid, trajectory_rewards, 0)[:, ::-1].cumsum(axis=-1)[:, ::-1]
    # Trajectory rewards [B, T]
    rewards_per_trajectory = rewards_to_go[:, :1].repeat(rewards_to_go.shape[-1], axis=-1)

    avg_expected_reward = np.mean(rewards_to_go[:, 0])  # average expected rewards for trajectories
    avg_action = np.mean(trajectory_actions)
    avg_episode_length = np.sum(trajectory_valid, axis=-1).mean()
    num_episodes = len(trajectory_observations)

    # Flatten all the (state, action, reward) tuples across episodes 
    # it doesn't matter which episode they came from (as long as we keep a track of the rewards!)
    trajectory_valid = trajectory_valid.ravel()  # (T*B, )

    return SampleOutput(
        # [B * T, D]
        observations=trajectory_observations.reshape(-1, trajectory_observations.shape[-1])[trajectory_valid],
        # [B * T,]
        actions=trajectory_actions.ravel()[trajectory_valid],
        rewards=trajectory_rewards.ravel()[trajectory_valid],
        rewards_to_go=rewards_to_go.ravel()[trajectory_valid],
        trajectory_rewards=rewards_per_trajectory.ravel()[trajectory_valid],
        valid=trajectory_valid,
        # []
        avg_expected_reward=avg_expected_reward,
        avg_action=avg_action,
        avg_episode_length=avg_episode_length,
        num_episodes=num_episodes
    )

In [20]:
def loss_fn(model, obs, actions, weights):
    logits = jax.vmap(model)(obs)
    log_probs = tfp.distributions.Categorical(logits=logits).log_prob(actions)

    # the loss is technically the log likelihood of the actions weighted by the rewards-to-go
    return -jax.numpy.sum(log_probs * weights) / len(obs)

@eqx.filter_jit
def update_model(model, obs, actions, weights, optimizer, opt_state):
    loss, grad = eqx.filter_value_and_grad(loss_fn)(model, obs, actions, weights)
    updates, opt_state = optimizer.update(grad, opt_state, eqx.filter(model, eqx.is_array))
    model = eqx.apply_updates(model, updates)
    return model, opt_state, loss

In [33]:
@dataclasses.dataclass
class SampleAndUpdatePolicyOutputs:
    model: MLP
    opt_state: optax.OptState
    loss: float
    avg_expected_reward: float
    avg_action: float
    avg_episode_length: float
    num_samples: int


def sample_batch_and_update_policy(
    envs: gym.vector.VectorEnv,
    model: MLP, optimizer: optax.GradientTransformation,
    opt_state: optax.OptState,
    params: Params,
    step: int,
    key: jax.random.PRNGKey):
    # Sample at least "batch_size" time steps (state,action,reward)s from the policy
    samples = None
    while True:
        key_used, key = jax.random.split(key)
        new_samples = sample_from_policy(envs, model, params.get_epsilon(step), key_used)
        if samples is None:
            samples = new_samples
        else:
            samples = samples + new_samples
        if samples.num_samples() >= params.batch_size:
            break
    # print(f"Sampled {samples.num_samples()} samples from {samples.num_episodes} episodes")

    # Compute the weights for the policy gradient
    if params.policy_gradient_weight == PolicyGradientWeight.REWARDS_TO_GO:
        weights = samples.rewards_to_go
    elif params.policy_gradient_weight == PolicyGradientWeight.TRAJECTORY_REWARDS:
        weights = samples.trajectory_rewards
    else:
        raise ValueError
    
    if params.policy_gradient_weight_baseline == PolicyGradientWeightBaseline.AVERAGE_TRAJECTORY_REWARDS:
        weights -= samples.avg_expected_reward

    # Compute the loss and the gradient
    model, opt_state, loss = update_model(
        model=model,
        obs=samples.observations,
        actions=samples.actions,
        weights=weights,
        optimizer=optimizer,
        opt_state=opt_state)
    return SampleAndUpdatePolicyOutputs(
        model=model, opt_state=opt_state,
        loss=loss,
        avg_expected_reward=samples.avg_expected_reward,
        avg_action=samples.avg_action,
        avg_episode_length=samples.avg_episode_length,
        num_samples=samples.num_samples(),
    )


def train(envs, model, optimizer, params, key: jax.random.PRNGKey):
    opt_state = optimizer.init(eqx.filter(model, eqx.is_array))

    for step in range(params.max_steps):
        key_used, key = jax.random.split(key)
        outputs = sample_batch_and_update_policy(
            envs, model, optimizer, opt_state, params, step, key_used)
        model, opt_state = outputs.model, outputs.opt_state
        if step % 1 == 0:
            print(f"Step {step}: Loss {outputs.loss} Avg Reward: {outputs.avg_expected_reward} Avg Episode Len: {outputs.avg_episode_length}")
            print(f"Batch size: {outputs.num_samples} Avg action: {outputs.avg_action}")
            # print(f"New epsilon: {params.get_epsilon(step)}")
    return model, opt_state

In [34]:
params = Params(
    policy_gradient_weight=PolicyGradientWeight.TRAJECTORY_REWARDS,
    policy_gradient_weight_baseline=PolicyGradientWeightBaseline.NONE,
)
# Vectorize the enviroment such that we can sample multiple trajectories concurrently
envs = gym.make_vec("CartPole-v1", num_envs=32, vectorization_mode='vector_entry_point')

optim  = optax.adam(params.learning_rate)
model = MLP(
    in_size=4,
    hidden_dims=params.hidden_dims + [2],
    out_size=2,
    key=jax.random.PRNGKey(12513)
)

model, opt_state = train(envs, model, optim, params, jax.random.PRNGKey(812342))

Step 0: Loss 16.650728225708008 Avg Reward: 19.61086665527455 Avg Episode Len: 19.61086665527455
Batch size: 4389 Avg action: 0.5666989693481507
Step 1: Loss 19.992151260375977 Avg Reward: 22.439758329422805 Avg Episode Len: 22.439758329422805
Batch size: 4262 Avg action: 0.5551792231824751
Step 2: Loss 22.345544815063477 Avg Reward: 21.881024368231046 Avg Episode Len: 21.881024368231046
Batch size: 4155 Avg action: 0.5419483606281343
Step 3: Loss 20.842893600463867 Avg Reward: 23.478734294368408 Avg Episode Len: 23.478734294368408
Batch size: 4457 Avg action: 0.5350342698876588
Step 4: Loss 23.622314453125 Avg Reward: 25.822790189412338 Avg Episode Len: 25.822790189412338
Batch size: 4118 Avg action: 0.5314920562482143
Step 5: Loss 26.265615463256836 Avg Reward: 27.89850113122172 Avg Episode Len: 27.89850113122172
Batch size: 4420 Avg action: 0.5288746647456954
Step 6: Loss 25.197097778320312 Avg Reward: 28.822456538377434 Avg Episode Len: 28.822456538377434
Batch size: 4573 Avg actio

In [30]:
params = Params(
    policy_gradient_weight=PolicyGradientWeight.REWARDS_TO_GO,
    policy_gradient_weight_baseline=PolicyGradientWeightBaseline.NONE,
)
# Vectorize the enviroment such that we can sample multiple trajectories concurrently
envs = gym.make_vec("CartPole-v1", num_envs=32, vectorization_mode='vector_entry_point')

optim  = optax.adam(params.learning_rate)
model = MLP(
    in_size=4,
    hidden_dims=params.hidden_dims + [2],
    out_size=2,
    key=jax.random.PRNGKey(12513)
)

model, opt_state = train(envs, model, optim, params, jax.random.PRNGKey(812342))

Step 0: Loss 9.154587745666504 Avg Reward: 20.24598471422242 Avg Episode Len: 20.24598471422242
Batch size: 4514 Avg action: 0.563946311968656
Step 1: Loss 11.030065536499023 Avg Reward: 23.486860236220473 Avg Episode Len: 23.486860236220473
Batch size: 4445 Avg action: 0.547225669338542
Step 2: Loss 9.436394691467285 Avg Reward: 21.035809798887463 Avg Episode Len: 21.035809798887463
Batch size: 4674 Avg action: 0.5385031539967431
Step 3: Loss 10.236491203308105 Avg Reward: 23.110586851766865 Avg Episode Len: 23.110586851766865
Batch size: 4358 Avg action: 0.5308460870968873
Step 4: Loss 10.814528465270996 Avg Reward: 24.17913465700065 Avg Episode Len: 24.17913465700065
Batch size: 4621 Avg action: 0.5243235797630953
Step 5: Loss 12.96762752532959 Avg Reward: 26.387987884936738 Avg Episode Len: 26.387987884936738
Batch size: 4189 Avg action: 0.5227539700891977
Step 6: Loss 13.924625396728516 Avg Reward: 29.52271029035013 Avg Episode Len: 29.52271029035013
Batch size: 4684 Avg action: 0

In [35]:
params = Params(
    policy_gradient_weight=PolicyGradientWeight.REWARDS_TO_GO,
    policy_gradient_weight_baseline=PolicyGradientWeightBaseline.AVERAGE_TRAJECTORY_REWARDS,
)
# Vectorize the enviroment such that we can sample multiple trajectories concurrently
envs = gym.make_vec("CartPole-v1", num_envs=32, vectorization_mode='vector_entry_point')

optim  = optax.adam(params.learning_rate)
model = MLP(
    in_size=4,
    hidden_dims=params.hidden_dims + [2],
    out_size=2,
    key=jax.random.PRNGKey(12513)
)

model, opt_state = train(envs, model, optim, params, jax.random.PRNGKey(812342))

Step 0: Loss -4.427526950836182 Avg Reward: 20.148605715871255 Avg Episode Len: 20.148605715871255
Batch size: 4505 Avg action: 0.5653456902638425
Step 1: Loss -5.111074447631836 Avg Reward: 22.36881723573433 Avg Episode Len: 22.36881723573433
Batch size: 4276 Avg action: 0.5536621040797353
Step 2: Loss -4.525051593780518 Avg Reward: 22.238398336086853 Avg Episode Len: 22.238398336086853
Batch size: 4237 Avg action: 0.5425636746869297
Step 3: Loss -4.53015661239624 Avg Reward: 23.206105899638338 Avg Episode Len: 23.206105899638338
Batch size: 4424 Avg action: 0.53840505332949
Step 4: Loss -5.857970714569092 Avg Reward: 26.782580835962143 Avg Episode Len: 26.782580835962143
Batch size: 5072 Avg action: 0.5284229532850959
Step 5: Loss -6.337006092071533 Avg Reward: 27.58201694139194 Avg Episode Len: 27.58201694139194
Batch size: 4368 Avg action: 0.527062237653042
Step 6: Loss -5.943201541900635 Avg Reward: 30.148056046125078 Avg Episode Len: 30.148056046125078
Batch size: 4813 Avg action

In [12]:
sorted(os.listdir("./save_videos2"))

['rl-video-episode-0.mp4',
 'rl-video-episode-1.mp4',
 'rl-video-episode-10.mp4',
 'rl-video-episode-11.mp4',
 'rl-video-episode-12.mp4',
 'rl-video-episode-13.mp4',
 'rl-video-episode-14.mp4',
 'rl-video-episode-2.mp4',
 'rl-video-episode-3.mp4',
 'rl-video-episode-4.mp4',
 'rl-video-episode-5.mp4',
 'rl-video-episode-6.mp4',
 'rl-video-episode-7.mp4',
 'rl-video-episode-8.mp4',
 'rl-video-episode-9.mp4']