# Working with unbatchable environments

While all the environments in RL4CO are batched, it is not common to find CO problems that are hard (or impossible) to be written in a *batchable* way.

To overcome this issue, one might want to code the environment in an *unbatched* way, meaning that the logic coded in the environment is responsible to `reset`, `step`, compute the reward and render a **single** problem instance at time.
Then, one would like to run this unbatched environment in parallel, in order to speed up the data collection.

Fortunately, this can be easily done using TorchRL's features, and this tutorial will show you how.

## Unbatched TSP
To simplify the understanding, we decide to use a simple environment like the Travelling Salesman Problem (TSP).

We start by importing the needed packages. This step is similar to the one for batched environments.

In [1]:
from typing import Optional

import torch

from tensordict.tensordict import TensorDict
from torchrl.data import (
    BoundedTensorSpec,
    CompositeSpec,
    UnboundedContinuousTensorSpec,
    UnboundedDiscreteTensorSpec,
)

from rl4co.envs.common.base import RL4COEnvBase
from rl4co.utils.ops import gather_by_index, get_tour_length
from rl4co.utils.pylogger import get_pylogger

from rl4co.envs.routing.tsp.generator import TSPGenerator
from rl4co.envs.routing.tsp.render import render

log = get_pylogger(__name__)

Now let's define the environment class, that inherits from the `RL4COEnvBase`. The definition of the environment is very similar to the batched one, but all the tensors have batch size 1.

Since we want to parallelize the environment using TorchRL, it is very important to define the environment specs.

Another important thing to notice is that, since the `reset` method of the env calls the `generator` object with `self.batch_size`, we need to set this attribute to `[1]`.

In [2]:
class UnbatchedTSPEnv(RL4COEnvBase):
    """Traveling Salesman Problem (TSP) environment
    At each step, the agent chooses a city to visit. The reward is 0 unless the agent visits all the cities.
    In that case, the reward is (-)length of the path: maximizing the reward is equivalent to minimizing the path length.

    Observations:
        - locations of each customer.
        - the current location of the vehicle.

    Constrains:
        - the tour must return to the starting customer.
        - each customer must be visited exactly once.

    Finish condition:
        - the agent has visited all customers and returned to the starting customer.

    Reward:
        - (minus) the negative length of the path.

    Args:
        generator: TSPGenerator instance as the data generator
        generator_params: parameters for the generator
    """

    name = "tsp"

    def __init__(
        self,
        generator: TSPGenerator = None,
        generator_params: dict = {},
        **kwargs,
    ):
        super().__init__(**kwargs)
        if generator is None:
            generator = TSPGenerator(**generator_params)
        self.generator = generator
        self.batch_size = [1] # needed for the reset method, that calls the generator using the self.batch_size
        self._make_spec(self.generator)

    @staticmethod
    def _step(td: TensorDict) -> TensorDict:
        current_node = td["action"]
        first_node = current_node if td["i"].all() == 0 else td["first_node"]

        # # Set not visited to 0 (i.e., we visited the node)
        available = td["action_mask"].scatter(
            -1, current_node.unsqueeze(-1).expand_as(td["action_mask"]), 0
        )

        # We are done there are no unvisited locations
        done = torch.sum(available, dim=-1) == 0

        # The reward is calculated outside via get_reward for efficiency, so we set it to 0 here
        reward = torch.zeros_like(done)

        td.update(
            {
                "first_node": first_node,
                "current_node": current_node,
                "i": td["i"] + 1,
                "action_mask": available,
                "reward": reward,
                "done": done,
            },
        )
        return td

    def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict:
        # Initialize locations
        device = td.device
        init_locs = td["locs"]

        # We do not enforce loading from self for flexibility
        num_loc = init_locs.shape[-2]

        # Other variables
        current_node = torch.zeros((1,), dtype=torch.int64, device=device)
        available = torch.ones(
            (1, num_loc), dtype=torch.bool, device=device
        )  # 1 means not visited, i.e. action is allowed
        i = torch.zeros((1, 1), dtype=torch.int64, device=device)

        return TensorDict(
            {
                "locs": init_locs,
                "first_node": current_node,
                "current_node": current_node,
                "i": i,
                "action_mask": available,
                "reward": torch.zeros((1, 1), dtype=torch.float32),
            },
            batch_size=1,
        )

    def _make_spec(self, generator: TSPGenerator):
        self.observation_spec = CompositeSpec(
            locs=BoundedTensorSpec(
                low=generator.min_loc,
                high=generator.max_loc,
                shape=(1, generator.num_loc, 2),
                dtype=torch.float32,
            ),
            first_node=UnboundedDiscreteTensorSpec(
                shape=(1,),
                dtype=torch.int64,
            ),
            current_node=UnboundedDiscreteTensorSpec(
                shape=(1,),
                dtype=torch.int64,
            ),
            i=UnboundedDiscreteTensorSpec(
                shape=(1,),
                dtype=torch.int64,
            ),
            action_mask=UnboundedDiscreteTensorSpec(
                shape=(1, generator.num_loc),
                dtype=torch.bool,
            ),
            shape=(1,),
        )
        self.action_spec = BoundedTensorSpec(
            shape=(1,),
            dtype=torch.int64,
            low=0,
            high=generator.num_loc,
        )
        self.reward_spec = UnboundedContinuousTensorSpec(shape=(1))
        self.done_spec = UnboundedDiscreteTensorSpec(shape=(1), dtype=torch.bool)

    def _get_reward(self, td, actions) -> TensorDict:
        if self.check_solution:
            self.check_solution_validity(td, actions)

        # Gather locations in order of tour and return distance between them (i.e., -reward)
        locs_ordered = gather_by_index(td["locs"], actions)
        return -get_tour_length(locs_ordered)

    @staticmethod
    def check_solution_validity(td: TensorDict, actions: torch.Tensor):
        """Check that solution is valid: nodes are visited exactly once"""
        assert (
            torch.arange(actions.size(1), out=actions.data.new())
            .view(1, -1)
            .expand_as(actions)
            == actions.data.sort(1)[0]
        ).all(), "Invalid tour"

    @staticmethod
    def render(td: TensorDict, actions: torch.Tensor=None, ax = None):
        return render(td, actions, ax)

Let's check that our environment works properly by testing the `reset` and `step` methods.

In [3]:
tsp_env = UnbatchedTSPEnv()
td = tsp_env.reset()

print("Reset td:\n", td)

td["action"] = torch.tensor([0])
td = tsp_env.step(td)
print("\nStep td:\n", td)

Reset td:
 TensorDict(
    fields={
        action_mask: Tensor(shape=torch.Size([1, 20]), device=cpu, dtype=torch.bool, is_shared=False),
        current_node: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, is_shared=False),
        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        first_node: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, is_shared=False),
        i: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.int64, is_shared=False),
        locs: Tensor(shape=torch.Size([1, 20, 2]), device=cpu, dtype=torch.float32, is_shared=False),
        reward: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([1]),
    device=None,
    is_shared=False)

Step td:
 {'next': TensorDict(
    fields={
        action: Tensor(shape=torch.Size([1]), device=cpu, dt

Let's also check that the action was correctly performed by checking the `action_mask`.

In [None]:
print("Action mask:\n", td["next"]["action_mask"])

As we can see, the first node was correctly masked out. As a final check, let's see what happens if we visit all the nodes.

In [None]:
td = td["next"]
num_locs = td["locs"].shape[-2]
for i in range(1, num_locs): # we have already visited the first node
    td["action"] = torch.tensor([i])
    td = tsp_env.step(td)
    td = td["next"]

print("Done:\n", td["done"])

We can see that the episode reached status `done` as we would expect. Let's check if the reward can be computed.

In [None]:
actions = torch.arange(num_locs, device=td.device).unsqueeze(0)
reward = tsp_env.get_reward(td, actions)
print(reward)

Ok, the environment is working properly. Let's make it parallel.

## Parallelizing the environment

To parallelize the environment, we will use the `ParallelEnv` from TorchRL.

In [None]:
from torchrl.envs.batched_envs import ParallelEnv

First of all, following TorchRL best practices, we check if we defined the environment properly.

In [7]:
from torchrl.envs import check_env_specs
check_env_specs(tsp_env)

AttributeError: 'dict' object has no attribute '_get_str'

For clarity, we decide to only use 2 parallel envs, but you can set a higher number if your hardware is happy about it.

The `ParallelEnv` class requires 2 arguments: the number of parallel environments and a callable that returns an environment. We will provide a *lambda function* that creates it.

In [None]:
envs = ParallelEnv(2, lambda: UnbatchedTSPEnv())

Let's perform the same checks we did on the unbatched environment. We start checking if the `reset` and `step` methods work properly.

In [None]:
td = envs.reset()
print("Reset td:\n", td)

td["action"] = torch.zeros((2, 1), dtype=torch.int64)
td = envs.step(td)
print("\nStep td:\n", td)

Notice that now the batch size is 2 and that all the shapes are correct.
We proceed with the action mask check.

In [None]:
print("Action mask:\n", td["next"]["action_mask"])

Ok, in both environments the first entry in the mask is correctly set.
Finally, let's see if we can run an entire episode and compute the rewards.

In [None]:
td = td["next"]
num_locs = td["locs"].shape[-2]
for i in range(1, num_locs): # we have already visited the first node
    td["action"] = torch.tensor([[i, i]])
    td = tsp_env.step(td)
    td = td["next"]

print("Done:\n", td["done"])

In [None]:
actions = torch.arange(num_locs, device=td.device).repeat(2, 1)
reward = tsp_env.get_reward(td, actions)
print(reward)