# RL4CO Decoding Strategies Notebook

This notebook demonstrates how to utilize the different decoding strategies available in rl4co/models/nn/dec_strategies.py during the different phases of model development. We will also demonstrate how to evaluate the model for different decoding strategies on the test dataset. 

<a href="https://colab.research.google.com/github/ai4co/rl4co/blob/main/notebooks/tutorials/7-decoding-strategies.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a>


### Installation

In [None]:
## Uncomment the following line to install the package from PyPI
## You may need to restart the runtime in Colab after this
## Remember to choose a GPU runtime for faster training!

# !pip install rl4co

In [None]:
import torch

from rl4co.envs import TSPEnv
from rl4co.models.zoo import AttentionModel, AttentionModelPolicy
from rl4co.utils.trainer import RL4COTrainer

### Setup Policy and Environment

In [None]:
%%capture
# RL4CO env based on TorchRL
env = TSPEnv(num_loc=10) 

# Policy: neural network, in this case with encoder-decoder architecture
policy = AttentionModelPolicy(env.name, 
                              embedding_dim=128,
                              num_encoder_layers=3,
                              num_heads=8,
                            )

# Model: default is AM with REINFORCE and greedy rollout baseline
model = AttentionModel(env, 
                       baseline="rollout",
                       batch_size = 128,
                       val_batch_size = 512,
                       test_batch_size = 512,
                       train_data_size=1_00,
                       val_data_size=1_000,
                       test_data_size=1_000,
                       optimizer_kwargs={"lr": 1e-4},
                       policy_kwargs={  # we can specify the decode types using the policy_kwargs
                           "train_decode_type": "sampling",
                           "val_decode_type": "greedy",
                           "test_decode_type": "beam_search",
                       }
                       ) 

### Setup Trainer and train model

In [None]:
trainer = RL4COTrainer(
    max_epochs=2,
    logger=None,
)

trainer.fit(model)

### Test the model using Trainer class

In [None]:
# here we evaluate the model on the test set using the beam search decoding strategy as declared in the model constructor
trainer.test(model=model)

In [None]:
# we can simply change the decoding type of the current model instance
model.policy.test_decode_type = "greedy"
trainer.test(model=model)

### Manual Test Loop

Let's compare beam search with a greedy decoding strategy by manually looping over our test dataset:

In [None]:
bs_rewards = []
for batch in model.test_dataloader():
    td = env.reset(batch)
    with torch.no_grad():
        # in a manual loop we can dynamically specify the decode type
        out = model(td, decode_type="beam_search", beam_width=10)
    bs_rewards.append(out["reward"])
print("Average reward is %s" % torch.cat(bs_rewards).mean())

In [None]:
bs_rewards = []
for batch in model.test_dataloader():
    td = env.reset(batch)
    with torch.no_grad():
        out = model(td, decode_type="greedy")
    bs_rewards.append(out["reward"])
print("Average reward is %s" % torch.cat(bs_rewards).mean())

In [None]:
bs_rewards = []
for batch in model.test_dataloader():
    td = env.reset(batch)
    bs = batch.batch_size[0]
    with torch.no_grad():
        out = model(td, decode_type="multistart_greedy", num_starts=10, return_actions=True)
        rewards = torch.stack(out["reward"].split(bs), 1).max(1).values
    bs_rewards.append(rewards)
print("Average reward is %s" % torch.cat(bs_rewards).mean())

We can see that beam search finds a better solution than the greedy decoder

### Digging deeper into beam search solutions

We can also analyze the different solutions obtained via beam search when passing "select_best=False" to the forward pass of the policy. The solutions in this case are sorted per instance-wise, that is:

- instance1_solution1
- instance2_solution1
- instance3_solution1
- instance1_solution2
- instance2_solution2
- instance3_solution2

In [None]:
td = env.reset(batch)

In [None]:
bs = batch.batch_size[0]

In [None]:
out = model(td, decode_type="beam_search", beam_width=5, select_best=False, return_actions=True)

In [None]:
# we split the sequence ofter every "batch_size" instances, then stack the different solutions obtained for each minibatch instance by the beam search together.
actions_stacked = torch.stack(out["actions"].split(bs), 1)
rewards_stacked = torch.stack(out["reward"].split(bs), 1)

In [None]:
import matplotlib.pyplot as plt
batch_instance = 0
for i, actions in enumerate(actions_stacked[batch_instance].cpu()):
    reward = rewards_stacked[batch_instance, i]
    _, ax = plt.subplots()
    
    env.render(td[0], actions, ax=ax)
    ax.set_title("Reward: %s" % reward.item())