# Reinforcement Learning with TorchRL (PPO example)

Main reason for me joining this competition is to further explore RL, even if it's not the best solution for the problem. In this notebook, I'll attempt to make the RL version of the problem approachable. I've opted to use TorchRL library, hoping to gain access to some common RL infra, such as replay buffers and implementations of most common algorithms.

Unfortunately, I haven't been able to install torchrl in kaggle environment, so I'm uploading notebook that was run on another machine.

I'm assuming there is an existing environment with PyTorch already installed.

## Utils for competition data

In [None]:
from ast import literal_eval
import numpy as np
import pandas as pd
from pathlib import Path
from pprint import pprint
from sympy.combinatorics import Permutation


In [None]:
DATA_DIR = Path("/kaggle/input/santa-2023")
PUZZLES = DATA_DIR / "puzzles.csv"
PUZZLES_INFO = DATA_DIR / "puzzle_info.csv"
SAMPLE_SUBMISSION = DATA_DIR / "sample_submission.csv"

In [None]:
def load_puzzles_info():
    puzzle_info = pd.read_csv(PUZZLES_INFO, index_col="puzzle_type")
    puzzle_info["allowed_moves"] = puzzle_info["allowed_moves"].apply(literal_eval)
    return puzzle_info


def load_puzzles():
    puzzles = pd.read_csv(PUZZLES, index_col="id")
    puzzles["id"] = puzzles.index
    # Parse color states
    puzzles = puzzles.assign(
        initial_state=lambda df: df["initial_state"].str.split(";"),
        solution_state=lambda df: df["solution_state"].str.split(";"),
    )
    puzzles.loc[:, "solution_state"] = puzzles["solution_state"].apply(
        lambda x: tuple(x)
    )
    puzzles.loc[:, "initial_state"] = puzzles["initial_state"].apply(lambda x: tuple(x))
    return puzzles


def load_sample_submission():
    sample_submission = pd.read_csv(SAMPLE_SUBMISSION)
    # sample_submission["id"] = sample_submission.index
    return sample_submission


def get_puzzles(
    puzzle_type="cube_2/2/2",
    puzzles=None,
    puzzle_info=None,
    sample_submission=None,
):
    puzzles = puzzles or load_puzzles()
    puzzle_info = puzzle_info or load_puzzles_info()
    sample_submission = sample_submission or load_sample_submission()
    if puzzle_type:
        puzzle_info = puzzle_info.loc[puzzle_type]
        puzzles = puzzles.query(f"puzzle_type == '{puzzle_type}'")
        sample_submission = sample_submission.loc[puzzles.index]
    return puzzle_info.allowed_moves, puzzles, sample_submission

In [None]:
moves, puzzles, sample_sub = get_puzzles("cube_2/2/2")

## RL environment

TorchRL has a concept of environments, similar to gymnasium (ex OpenAI Gym). To implement an environment, we extend from `EnvBase`. We have to define **`observation_spec`**, **`action_spec`** and **`reward_spec`** which describe the shape of tensors and kind of data these tensors hold. We also need to extend methods **`_reset`**, **`_step`** and 
**`_set_seed`**.

TorchRL dosn't use pure tensors for input/output, but introduces  **TensorDicts**. TensorDicts act like dicts of tensors, but they all share batch size, and many common operations can be taken on a tensordict and will affect all tensors within.

For the santa challenge, we will have independent instances of environment for every puzzle type. Note that for every `puzzle_type` in the data, there are several different puzzle complexities, for example with different number of individual colors. We treat those as different environments here.

Note that it is also possible to implement the environment in gymnasium, and apply a TorchRL wrapper `GymWrapper` for it to be available in TorchRL. Here we explore the TorchRL-native route.

In [None]:
from typing import Optional
import torch
from torchrl.envs import EnvBase
from tensordict import TensorDict
from torchrl.data import (
    OneHotDiscreteTensorSpec,
    DiscreteTensorSpec,
    CompositeSpec,
    UnboundedContinuousTensorSpec,
)

In [None]:
# Hardcoded initial and solution states, for simpliciy
INITIAL_STATE_COLORS = (
    "D", "E", "D", "A",
    "E", "B", "A", "B",
    "C", "A", "C", "A",
    "D", "C", "D", "F",
    "F", "F", "E", "E",
    "B", "F", "B", "C",
)
SOLUTION_STATE_COLORS = (
    "A", "A", "A", "A",
    "B", "B", "B", "B",
    "C", "C", "C", "C",
    "D", "D", "D", "D",
    "E", "E", "E", "E",
    "F", "F", "F", "F",
)

class PuzzleEnv(EnvBase):
    def __init__(
            self,
            state_size: int,
            colors: list,
            moves: dict,
            solution_state: list = SOLUTION_STATE_COLORS,
            **kwargs
    ):
        """
        Args:
            state_size: number of "squares" in the puzzle
            colors: all possible colors in this problem
            moves: moves dict from `puzzle_info.csv` for this puzzle type
            solution_state: final solution state
        """
        super().__init__(**kwargs)
        self.state_size = state_size
        # Consider all forward and backward moves
        self.move_names = list(moves.keys()) + [f"-{name}" for name in moves.keys()]
        moves_fwd = torch.tensor(list(moves.values()))
        moves_bwd = torch.argsort(moves_fwd)
        self.moves = torch.vstack((moves_fwd, moves_bwd))

        self.colors = colors
        self.colormap = {c: i for i, c in enumerate(colors)}
        self.n_colors = len(colors)

        self.solution_state = solution_state

        # TorchRL specs - needs to be initialized
        # actions from puzzle_info
        self.action_spec = OneHotDiscreteTensorSpec(
            2 * len(moves), shape=torch.Size([2 * len(moves)]), dtype=torch.int32
        )
        # Note that observation spec needs to be CompositeSpec.
        observation_spec = DiscreteTensorSpec(
            n=self.n_colors,
            shape=torch.Size([self.state_size]),
            dtype=torch.float,
        )
        self.observation_spec = CompositeSpec(observation=observation_spec, shape=None)
        # Rewards.
        self.reward_spec = UnboundedContinuousTensorSpec(shape=torch.Size([1]))

    def _reset(self, tensordict: Optional[TensorDict], **kwargs):
        """Basic implementation, sets hardcoded INITIAL_STATE.
        Also sets "num_wildcards", which we carry as the unofficial part of the state.
        """
        if tensordict is not None:
            return tensordict.clone()
        # initial_state = [self.colormap[color] for color in INITIAL_STATE_COLORS]
        initial_state = self._state_to_tensor(INITIAL_STATE_COLORS)
        out_tensordict = TensorDict({}, batch_size=torch.Size())
        self.state = initial_state
        out_tensordict.set("observation", self.state)
        out_tensordict.set("num_wildcards", torch.tensor(0))
        return out_tensordict

    def _step(self, tensordict):
        """
        Args
            tensordict: We expet fields "action" according to action_spec.
            We also expect "num_wildcards" here.

        """
        solution_state = self._state_to_tensor(self.solution_state)
        num_wildcards = tensordict["num_wildcards"]
        actions_onehot = tensordict["action"]
        actions_idx = actions_onehot.argmax(dim=-1).tolist()
        # Permutations to apply, one row per batch
        moves = self.moves[actions_idx]
        self.state = torch.gather(self.state, -1, moves)
        errors = torch.sum(~(self.state == solution_state), dim=-1)
        final = ~(torch.relu(errors - num_wildcards).bool())
        # rewards are 0 in final state, -1 otherwise
        rewards = final * 1.0 - 1.0
        out_tensordict = TensorDict(
            {
                "observation": self.state,
                "reward": rewards,
                "done": final,
            },
            batch_size=self.batch_size,
        )
        return out_tensordict

    def _set_seed(self, seed):
        """No need to do stuff here.

        super().seed() already sets the torch seed, and then calls this.
        """
        pass

    def _state_to_tensor(self, state):
        return torch.tensor([self.colormap[color] for color in state]).float()


Let's create an instance of our environment.

In [None]:
env = PuzzleEnv(
    # cube_2/2/2 has 24 "squares"
    state_size=24,
    colors=["A", "B", "C", "D", "E", "F"],
    moves=moves,
)

We can use `fake_tensordict()` to create a tensordict with all zeros, but valid according to all specs for this environment.

This is a good sanity check that our specs are OK.

In [None]:
env.fake_tensordict()

We can call `rollout` on the environment. If we don't provide a policy, it will use a random policy by default. 

In [None]:
env.rollout(3)

## Transforms

We can add transforms to our environment, to make it easier for training, or add some additional info. In gymnasium, that is achieved with wrappers, but TorchRL takes approach similar to other PyTorch libraries and offers `TransformedEnv` class that takes environment and applies transforms to it.

We can normalize observation inputs by applying `ObservationNorm` transform. Note that this transform will need to be initialized to work correctly, so we call `init_stats` which will run some iterations in the background for us.

We can also apply casting operations, such as `DoubleToFloat` if our model expects tensors of different type that the environment offers.

The StepCounter transform will be used to count the steps before the environment is terminated. We can use this measure as a supplementary measure of performance.

In [None]:
from torchrl.envs import (
    Compose,
    DoubleToFloat,
    ObservationNorm,
    StepCounter,
    TransformedEnv,
)

In [None]:
transformed_env = TransformedEnv(
    env,
    Compose(
        # normalize observations
        ObservationNorm(in_keys=["observation"]),
        # Our observations are already float.
        #DoubleToFloat(
        #    in_keys=["observation"],
        #),
        StepCounter(),
    ),
)

# Initialize ObservationNorm
transformed_env.transform[0].init_stats(num_iter=1000, reduce_dim=0, cat_dim=0)

## Models

PPO is an actor-critic algorithm so we need both the actor model that will directly output the policy action, and the critic  model that is used to guide the actor.

In [None]:
from tensordict.nn import TensorDictModule
from tensordict.nn.distributions import NormalParamExtractor
from torch import nn
from torchrl.modules import ProbabilisticActor, TanhNormal, ValueOperator

Here we'll add the actor model and TorchRL's `ProbabilisticActor`, which can be used as a policy within environments. `ProbabilisticActor` requires distribution of inputs, so for every action, our model will need to output a probability distribution of potential values. This can be achieved by the model outputing `2 * action_spec.shape`, followed by a `NormalParamExtractor` module. `NormalParamExtractor` will output a tensordict with keys "loc" and "scale", which are expected by `ProbabilisticActor`.

Note that our `nn.Module` is further wrapped into `TensorDictModule`, which has information about expected `in_keys` and `out_keys`.

In [None]:
actor_net = nn.Sequential(
    nn.Linear(24, 128), nn.Tanh(),
    nn.Linear(128, 128), nn.Tanh(),
    nn.Linear(128, 128), nn.Tanh(),
    nn.LazyLinear(2 * env.action_spec.shape[-1]),
    NormalParamExtractor(),  # extract "loc" and "scale"
)
policy_module = TensorDictModule(
    actor_net, in_keys=["observation"], out_keys=["loc", "scale"]
)

actor = ProbabilisticActor(
    policy_module,
    spec=env.action_spec,
    in_keys=["loc", "scale"],
    distribution_class=TanhNormal,
    distribution_kwargs={"min": 0.0, "max": 1.0},
    return_log_prob=True
)

Similarly for the Critic model, we will use TorchRL's `ValueOperator` which will output a scalar, our prediction for reward.

In [None]:
critic = ValueOperator(
    nn.Sequential(
        # alternatively use nn.LazyLinear
        nn.Linear(24, 128), nn.Tanh(),
        nn.Linear(128, 128), nn.Tanh(),
        nn.Linear(128, 128), nn.Tanh(),
        nn.Linear(128, 1),
    ),
    in_keys=["observation"],
)

## Data collector

TorchRL provides a set of `DataCollector` classes. They can be used to collect data using our environment and current iteration of the model. For simplicity, we will use `SyncDataCollector`.

We will need to set `total_frames` and `frames_per_batch`. Frame is a term that might be more appropriate in video game simulations, but it is effectively an observation that our agent is exposed to. The `total_frames` is a total number of observations that will be collected, and `frames_per_batch` is the amount of frames given to the agent to perform the next set of training updates.


In [None]:
from torchrl.collectors import SyncDataCollector

In [None]:
total_frames = 10000
frames_per_batch = 100

In [None]:
collector = SyncDataCollector(
    transformed_env,
    actor,
    frames_per_batch=frames_per_batch,
    total_frames=total_frames,
)

In [None]:
# We can test that our collector works with our env and actor like this:
collector.rollout()

## Replay buffers

TorchRL provides ready implementations of replay buffers, with a choice of storage mechanisms and sampling strategies.

For on-policy algorithm such as PPO, we can set the buffer's storage to have the `frames_per_batch` size. This will refill the buffer every time the new batch of data is collected.

In [None]:
from torchrl.data.replay_buffers import TensorDictReplayBuffer, \
    LazyTensorStorage, SamplerWithoutReplacement

In [None]:
#  In on-policy contexts, a replay buffer is refilled every time a batch of data is collected,
# and its data is repeatedly consumed for a certain number of epochs.
buffer = TensorDictReplayBuffer(
    storage=LazyTensorStorage(frames_per_batch),
    sampler=SamplerWithoutReplacement()
)

## Loss function

The PPO loss can be directly imported from torchrl for convenience using the `ClipPPOLoss` class.

PPO requires some “advantage estimation” to be computed. In short, an advantage is a value that reflects an expectancy over the return value while dealing with the bias / variance tradeoff. To compute the advantage, one just needs to (1) build the advantage module, which utilizes our value operator (critic), and (2) pass each batch of data through it before each epoch. For this advantage function, we will use the GAE module. It will update the input `TensorDict` with new `"advantage"` and `"value_target"` entries. The `"value_target"` is a gradient-free tensor that represents the empirical value that the value network should represent with the input observation. Both of these will be used by `ClipPPOLoss` to return the policy and value losses.

In [None]:
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value import GAE

In [None]:
max_grad_norm = 1.0

# PPO params
sub_batch_size = 25  # cardinality of the sub-samples gathered from the current data in the inner loop
num_epochs = 10  # optimisation steps per batch of data collected
clip_epsilon = (
    0.2  # clip value for PPO loss: see the equation in the intro for more context.
)
gamma = 0.99
lmbda = 0.95
entropy_eps = 1e-4

In [None]:
adv_fn = GAE(
    value_network=critic,
    gamma=gamma,
    lmbda=lmbda,
    average_gae=True,
)

loss_fn = ClipPPOLoss(
    actor,
    critic,
    gamma=gamma,
    clip_epsilon=clip_epsilon,
    entropy_bonus=bool(entropy_eps),
    entropy_coef=entropy_eps,
    # these keys match by default but we set this for completeness
    value_target_key=adv_fn.value_target_key,
    critic_coef=1.0,
    loss_critic_type="smooth_l1",
)

# Now also define optimization algorithm.
optim = torch.optim.Adam(loss_fn.parameters(), lr=2e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optim, total_frames // frames_per_batch, 0.0
)

## Training loop

We now have all the pieces needed to code our training loop.



In [None]:
from collections import defaultdict
from tqdm.notebook import tqdm

from torchrl.envs.utils import ExplorationType, set_exploration_type

In [None]:

logs = defaultdict(list)
pbar = tqdm(total=total_frames)
eval_str = ""

# We iterate over the collector until it reaches the total number of frames it was
# designed to collect:
for i, tensordict_data in enumerate(collector):
    # we now have a batch of data to work with. Let's learn something from it.
    for _ in range(num_epochs):
        # We'll need an "advantage" signal to make PPO work.
        # We re-compute it at each epoch as its value depends on the value
        # network which is updated in the inner loop.
        with torch.no_grad():
            # Updates tensordict in-place with "advantage" and "value_target" fields.
            adv_fn(tensordict_data)

        data_view = tensordict_data.reshape(-1)
        buffer.extend(data_view.cpu())

        for _ in range(frames_per_batch // sub_batch_size):
            sample = buffer.sample(sub_batch_size)  # mini-batch
            loss_vals = loss_fn(sample)
            loss_val = (
                loss_vals["loss_objective"]
                + loss_vals["loss_critic"]
                + loss_vals["loss_entropy"]
            )

            # Optimization: backward, grad clipping and optim step
            loss_val.backward()
            # this is not strictly mandatory but it's good practice to keep
            # your gradient norm bounded
            torch.nn.utils.clip_grad_norm_(loss_fn.parameters(), max_grad_norm)
            optim.step()
            optim.zero_grad()

    print(f"avg reward: {tensordict_data['next', 'reward'].mean().item(): 4.4f}")
    logs["reward"].append(tensordict_data["next", "reward"].mean().item())
    pbar.update(tensordict_data.numel())
    cum_reward_str = (
        f"average reward={logs['reward'][-1]: 4.4f} (init={logs['reward'][0]: 4.4f})"
    )
    # logs["step_count"].append(tensordict_data["step_count"].max().item())
    # stepcount_str = f"step count (max): {logs['step_count'][-1]}"
    logs["lr"].append(optim.param_groups[0]["lr"])
    lr_str = f"lr policy: {logs['lr'][-1]: 4.4f}"
    if i % 10 == 0:
        # We evaluate the policy once every 10 batches of data.
        # Evaluation is rather simple: execute the policy without exploration
        # (take the expected value of the action distribution) for a given
        # number of steps (1000, which is our env horizon).
        # The ``rollout`` method of the env can take a policy as argument:
        # it will then execute this policy at each step.
        with set_exploration_type(ExplorationType.MEAN), torch.no_grad():
            # execute a rollout with the trained policy
            eval_rollout = env.rollout(500, actor)
            logs["eval reward"].append(eval_rollout["next", "reward"].mean().item())
            logs["eval reward (sum)"].append(
                eval_rollout["next", "reward"].sum().item()
            )
            # logs["eval step_count"].append(eval_rollout["step_count"].max().item())
            eval_str = (
                f"eval cumulative reward: {logs['eval reward (sum)'][-1]: 4.4f} "
                f"(init: {logs['eval reward (sum)'][0]: 4.4f}), "
                # f"eval step-count: {logs['eval step_count'][-1]}"
            )
            del eval_rollout
    pbar.set_description(", ".join([eval_str, cum_reward_str, lr_str]))

    # We're also using a learning rate scheduler. Like the gradient clipping,
    # this is a nice-to-have but nothing necessary for PPO to work.
    scheduler.step()



## Acknowledgements

* This guide is heavily based on TorchRL's PPO guide: https://pytorch.org/rl/tutorials/coding_ppo.html
* Maxyme Szimanski's notebook: https://www.kaggle.com/code/maximeszymanski/ppo-deep-reinforcement-learning