## Creating an environment

In essence, TorchRL does not directly provide environments, but instead offers wrappers for other libraries that encapsulate the simulators. The `torchrl.envs` module can be viewed as a provider for a generic environment API, as well as a central hub for simulation backends like gym, which is what we'll use. Creating your environment is typically as straightforward as the underlying backend API allows.

In [1]:
from torchrl.envs import GymEnv, set_gym_backend
with set_gym_backend("gym"):
    env = GymEnv("Pendulum-v1")

## Running an environment

Environments in TorchRL have two crucial methods: `torchrl.envs.EnvBase.reset`, which initiates
an episode, and `torchrl.envs.EnvBase.step`, which executes an action selected by the actor.
 
In TorchRL, environment methods read and write `tensordict.TensorDict` instances. Essentially, `tensordict.TensorDict` is a generic key-based data carrier for tensors.

The benefit of using TensorDicts is that they enable us to handle simple and complex data structures interchangeably. As our function signatures are very generic, they eliminate the challenge of accommodating different data formats. In simpler terms, they allow us to operate on both simple and highly complex environments, since their user-facing API is identical and simple!

Let's put the environment into action and see what a tensordict instance looks like:

In [13]:
reset = env.reset()
print(reset)

TensorDict(
    fields={
        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        observation: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)


### Now let's take a random action in the action space. First, sample the action:

In [14]:
reset_with_action = env.rand_action(reset)
print(reset_with_action)

TensorDict(
    fields={
        action: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        observation: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)


### This tensordict has the same structure as the one obtained from `torchrl.envs.EnvBase` with an additional `"action"` entry. You can access the action easily, like you would do with a regular dictionary:

In [15]:
print(reset_with_action["action"])

tensor([1.7638])


### We now need to pass this action to the environment. We'll be passing the entire tensordict to the ``step`` method, since there might be more than one tensor to be read in more advanced cases like Multi-Agent RL or stateless environments:

In [16]:
stepped_data = env.step(reset_with_action)
print(stepped_data)

TensorDict(
    fields={
        action: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
                reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
                terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
                truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False),
        observation: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        terminat

### The last bit of information you need to run a rollout in the environment is how to bring that ``"next"`` entry at the root to perform the next step. TorchRL provides a dedicated `torchrl.envs.utils.step_mdp` function that does just that: it filters out the information you won't need and delivers a data structure corresponding to your observation after a step in the Markov Decision Process, or MDP

In [17]:
from torchrl.envs import step_mdp

data = step_mdp(stepped_data)
print(data)

TensorDict(
    fields={
        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        observation: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)


### Writing down those three steps (computing an action, making a step, moving in the MDP) can be a bit tedious and repetitive. Fortunately, TorchRL provides a nice `torchrl.envs.EnvBase.rollout` function that allows you to run them in a closed loop at will:

In [18]:
rollout = env.rollout(max_steps=10)
print(rollout)

TensorDict(
    fields={
        action: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                reward: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([10]),
            device=None,
            is_shared=False),
        observation: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float32, 

### This data looks pretty much like the ``stepped_data`` above with the exception of its batch-size, which now equates the number of steps we provided through the ``max_steps`` argument. The magic of tensordict doesn't end there: if you're interested in a single transition of this environment, you can index the tensordict like you would index a tensor:

In [19]:
transition = rollout[3]
print(transition)

TensorDict(
    fields={
        action: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
                reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
                terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
                truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False),
        observation: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        terminat

### Most of the time, you'll want to modify the output of the environment to better suit your requirements. For example, you might want to monitor the number of steps executed since the last reset, resize images, or stack consecutive observations together. In this section, we'll examine a simple transform, the `torchrl.envs.transforms.StepCounter` transform. The transform is integrated with the environment through a `torchrl.envs.transforms.TransformedEnv`:

In [20]:
from torchrl.envs import StepCounter, TransformedEnv

transformed_env = TransformedEnv(env, StepCounter(max_steps=10))
rollout = transformed_env.rollout(max_steps=100)
print(rollout)

TensorDict(
    fields={
        action: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                reward: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                step_count: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.int64, is_shared=False),
                terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([10]),
            device=None,
       

### As you can see, our environment now has one more entry, `"step_count"` that tracks the number of steps since the last reset. Given that we passed the optional argument `max_steps=10` to the transform constructor, we also truncated the trajectory after 10 steps (not completing a full rollout of 100 steps like we asked with the `rollout` call). We can see that the trajectory was truncated by looking at the truncated entry:

In [21]:
print(rollout["next", "truncated"])

tensor([[False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [ True]])
