In [12]:
from typing import Optional

import gym
import torch as th
import torch.nn
from tensordict import TensorDictBase, TensorDict
from torchrl.envs import EnvBase, step_mdp, check_env_specs
from torchrl.data import UnboundedContinuousTensorSpec, CompositeSpec, BoundedTensorSpec, BinaryDiscreteTensorSpec

"""
# Info for CartPole-v1 test
    | Num | Action                 |
    |-----|------------------------|
    | 0   | Push cart to the left  |
    | 1   | Push cart to the right |

    | Num | Observation           | Min                 | Max               |
    |-----|-----------------------|---------------------|-------------------|
    | 0   | Cart Position         | -4.8                | 4.8               |
    | 1   | Cart Velocity         | -Inf                | Inf               |
    | 2   | Pole Angle            | ~ -0.418 rad (-24°) | ~ 0.418 rad (24°) |
    | 3   | Pole Angular Velocity | -Inf                | Inf               |
"""


class CartPoleWrapper(EnvBase):
    def __init__(self, env):
        super().__init__(device="cpu")
        self.env = env  # The underlying environment
        # Define the observation and action specs according to the wrapped environment
        self.observation_spec = CompositeSpec({"observation": BoundedTensorSpec(
            low=th.tensor([-4.8, -th.inf, -0.418, -th.inf], dtype=th.float64),
            high=th.tensor([4.8, th.inf, 0.418, th.inf], dtype=th.float64))})
        self.action_spec = CompositeSpec({"action": BinaryDiscreteTensorSpec(1)})
        self.reward_spec = CompositeSpec({"reward": BoundedTensorSpec(th.tensor([0]), th.inf)})

    def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
        # Extract the action from the tensordict
        action = tensordict.get("action").detach().cpu().numpy()
        # Step the underlying environment with the extracted action
        obs, reward, term, trunc, info = self.env.step(int(action))
        done = term
        # print(f"obs is {obs}")
        # Create a TensorDict to return, including the new observation, reward, and done flag
        out = TensorDict({
            "done": th.tensor([done], dtype=th.bool),
            "observation": th.tensor(obs, dtype=th.float32),
            "reward": th.tensor([reward], dtype=th.float32),
        }, batch_size=[])
        return out

    def _reset(self, tensordict: Optional[TensorDictBase] = None, **kwargs) -> TensorDictBase:
        # Reset the underlying environment and get the initial observation
        obs = self.env.reset()[0]
        # Create a TensorDict for the initial state
        out = TensorDict({
            "observation": th.tensor(obs, dtype=th.float32),
            "action": th.tensor([0], dtype=th.float32),  # Placeholder action
            "done": th.tensor([False], dtype=th.bool),
        }, batch_size=[])
        return out

    def _set_seed(self, seed: Optional[int]):  # for reproduction of same results
        pass
        # self.env.seed(seed)  # Assuming the underlying env has a seed method

Testing the env

In [13]:
gym_env = gym.make('CartPole-v1')
env = CartPoleWrapper(gym_env)

reset_td = env.reset()
print(reset_td.items())

step_td = env.step(reset_td)
# rollout_td = env.rollout(3)
# print(rollout_td)
# print('finished rollout')
# reset_with_action = env.rand_action(reset_td)
# print(reset_with_action["action"])
# 
# data = step_mdp(step_td)
# print(data)

check_env_specs(env)

  obs, reward, term, trunc, info = self.env.step(int(action))
2024-03-26 18:29:39,726 [torchrl][INFO] check_env_specs succeeded!


dict_items([('observation', tensor([-0.0322, -0.0097,  0.0372, -0.0105])), ('action', tensor([0.])), ('done', tensor([False])), ('terminated', tensor([False]))])


Now the transformation and training

In [17]:
from torchrl.envs import StepCounter, TransformedEnv

transformed_env = TransformedEnv(env, StepCounter(max_steps=10))  # after 10 steps in truncated the env
rollout = transformed_env.rollout(max_steps=100)
print(rollout)

TensorDict(
    fields={
        action: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.int8, is_shared=False),
        done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: Tensor(shape=torch.Size([10, 4]), device=cpu, dtype=torch.float32, is_shared=False),
                reward: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                step_count: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.int64, is_shared=False),
                terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([10]),
            device=None,
          

  obs, reward, term, trunc, info = self.env.step(int(action))


In [18]:
print(rollout["next", "truncated"])

tensor([[False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [ True]])


### Now working with Modules

In [30]:
from tensordict.nn import TensorDictModule
module = th.nn.LazyLinear(env.action_spec.shape[-1])
policy = TensorDictModule(
    module, 
    in_keys=["observation"],
    out_keys=["actions"],
)

# todo why do I get a number for 8 to 10 when i run this? does it mean that it loses at 8 steps?
rollout = env.rollout(max_steps=10, policy=policy) 
rollout

from torchrl.modules import Actor

"""
    Args:
        module (nn.Module): a :class:`~torch.nn.Module` used to map the input to
            the output parameter space.
        in_keys (iterable of str, optional): keys to be read from input
            tensordict and passed to the module. If it
            contains more than one element, the values will be passed in the
            order given by the in_keys iterable.
            Defaults to ``["observation"]``.
        out_keys (iterable of str): keys to be written to the input tensordict.
            The length of out_keys must match the
            number of tensors returned by the embedded module. Using ``"_"`` as a
            key avoid writing tensor to output.
            Defaults to ``["action"]``.
"""

policy = Actor(module)
rollout = env.rollout(max_steps=10, policy=policy)
print(rollout)
from torchrl.modules import MLP

module = MLP(
    out_features=env.action_spec.shape[-1],
    num_cells=[32, 64],
    activation_class=th.nn.Tanh,
)
policy = Actor(module)
rollout = env.rollout(max_steps=10, policy=policy)
rollout
from tensordict.nn.distributions import NormalParamExtractor
from torch.distributions import Normal
from torchrl.modules import ProbabilisticActor

backbone = MLP(in_features=4, out_features=2)
extractor = NormalParamExtractor()
module = th.nn.Sequential(backbone, extractor)
td_module = TensorDictModule(module, in_keys=["observation"], out_keys=["loc", "scale"])
policy = ProbabilisticActor(
    td_module,
    in_keys=["loc", "scale"],
    out_keys=["action"],
    distribution_class=Normal,
    return_log_prob=True
)
rollout = env.rollout(max_steps=10, policy=policy)
rollout

TensorDict(
    fields={
        action: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        actions: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: Tensor(shape=torch.Size([10, 4]), device=cpu, dtype=torch.float32, is_shared=False),
                reward: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([10]),
            device=None,
            is_shared=False),
        observation: Tensor(shape=torch.Size([10, 4]), device=cpu, dtype=torch.float32, is_shar

  obs, reward, term, trunc, info = self.env.step(int(action))


In [None]:
# Now doing Module testing


module = th.nn.LazyLinear(env.action_spec.shape[-1])
policy = TensorDictModule(
    module,
    in_keys=["observation"],
    out_keys=["actions"],
)

rollout = env.rollout(max_steps=10, policy=policy)
print(rollout)

policy = Actor(module)
rollout = env.rollout(max_steps=10, policy=policy)
print(rollout)

module = MLP(
    out_features=env.action_spec.shape[-1],
    num_cells=[32, 64],
    activation_class=th.nn.Tanh,
)
policy = Actor(module)
rollout = env.rollout(max_steps=10, policy=policy)
print(rollout)

backbone = MLP(in_features=4, out_features=2)
extractor = NormalParamExtractor()
module = th.nn.Sequential(backbone, extractor)
td_module = TensorDictModule(module, in_keys=["observation"], out_keys=["loc", "scale"])
policy = ProbabilisticActor(
    td_module,
    in_keys=["loc", "scale"],
    out_keys=["action"],
    distribution_class=Normal,
    return_log_prob=True
)
rollout = env.rollout(max_steps=10, policy=policy)
print(rollout)
