In [2]:
from typing import Optional

import gym
import torch as th
import torch.nn
from tensordict import TensorDictBase, TensorDict
from torchrl.envs import EnvBase, step_mdp, check_env_specs
from torchrl.data import UnboundedContinuousTensorSpec, CompositeSpec, BoundedTensorSpec, BinaryDiscreteTensorSpec, DiscreteTensorSpec

"""
# Info for CartPole-v1 test
    | Num | Action                 |
    |-----|------------------------|
    | 0   | Push cart to the left  |
    | 1   | Push cart to the right |

    | Num | Observation           | Min                 | Max               |
    |-----|-----------------------|---------------------|-------------------|
    | 0   | Cart Position         | -4.8                | 4.8               |
    | 1   | Cart Velocity         | -Inf                | Inf               |
    | 2   | Pole Angle            | ~ -0.418 rad (-24°) | ~ 0.418 rad (24°) |
    | 3   | Pole Angular Velocity | -Inf                | Inf               |
"""


class CartPoleWrapper(EnvBase):
    def __init__(self, env):
        super().__init__(device="cpu")
        self.env = env  # The underlying environment
        # Define the observation and action specs according to the wrapped environment
        self.observation_spec = CompositeSpec({"observation": BoundedTensorSpec(
            low=th.tensor([-4.8, -th.inf, -0.418, -th.inf], dtype=th.float64),
            high=th.tensor([4.8, th.inf, 0.418, th.inf], dtype=th.float64))})
        self.action_spec = CompositeSpec({"action": BinaryDiscreteTensorSpec(1)})
        self.reward_spec = CompositeSpec({"reward": BoundedTensorSpec(th.tensor([0]), th.inf)})

    def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
        # Extract the action from the tensordict
        action = tensordict.get("action").detach().cpu().numpy()
        # Step the underlying environment with the extracted action
        obs, reward, term, trunc, info = self.env.step(int(action))
        done = term
        # print(f"obs is {obs}")
        # Create a TensorDict to return, including the new observation, reward, and done flag
        out = TensorDict({
            "done": th.tensor([done], dtype=th.bool),
            "observation": th.tensor(obs, dtype=th.float32),
            "reward": th.tensor([reward], dtype=th.float32),
        }, batch_size=[])
        return out

    def _reset(self, tensordict: Optional[TensorDictBase] = None, **kwargs) -> TensorDictBase:
        # Reset the underlying environment and get the initial observation
        obs = self.env.reset()[0]
        # Create a TensorDict for the initial state
        out = TensorDict({
            "observation": th.tensor(obs, dtype=th.float32),
            "action": th.tensor([0], dtype=th.float32),  # Placeholder action
            "done": th.tensor([False], dtype=th.bool),
        }, batch_size=[])
        return out

    def _set_seed(self, seed: Optional[int]):  # for reproduction of same results
        pass
        # self.env.seed(seed)  # Assuming the underlying env has a seed method

  register_pytree_node(


Testing the env

In [4]:
gym_env = gym.make('CartPole-v1')
env = CartPoleWrapper(gym_env)
print(f'action_spec: {env.action_spec}\n\n')

reset_td = env.reset()
print(reset_td.items())

step_td = env.step(reset_td)
# rollout_td = env.rollout(3)
# print(rollout_td)
# print('finished rollout')
# reset_with_action = env.rand_action(reset_td)
# print(reset_with_action["action"])
# 
# data = step_mdp(step_td)
# print(data)

print(check_env_specs(env))

  obs, reward, term, trunc, info = self.env.step(int(action))
2024-03-27 10:54:43,049 [torchrl][INFO] check_env_specs succeeded!


action_spec: BinaryDiscreteTensorSpec(
    shape=torch.Size([1]),
    space=DiscreteBox(n=2),
    device=cpu,
    dtype=torch.int8,
    domain=discrete)


dict_items([('observation', tensor([-0.0440,  0.0352, -0.0067,  0.0096])), ('action', tensor([0.])), ('done', tensor([False])), ('terminated', tensor([False]))])
None


Now the transformation and training

In [16]:
from torchrl.envs import StepCounter, TransformedEnv

transformed_env = TransformedEnv(env, StepCounter(max_steps=10))  # after 10 steps in truncated the env
rollout = transformed_env.rollout(max_steps=100)
print(rollout)

TensorDict(
    fields={
        action: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.int8, is_shared=False),
        done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: Tensor(shape=torch.Size([10, 4]), device=cpu, dtype=torch.float32, is_shared=False),
                reward: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                step_count: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.int64, is_shared=False),
                terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([10]),
            device=None,
          

  obs, reward, term, trunc, info = self.env.step(int(action))


In [17]:
print(rollout["next", "truncated"])

tensor([[False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [ True]])


### Now working with Modules

In [19]:
from tensordict.nn import TensorDictModule
from torchrl.envs.utils import ExplorationType, set_exploration_type
module = th.nn.LazyLinear(env.action_spec.shape[-1])
policy = TensorDictModule(
    module, 
    in_keys=["observation"],
    out_keys=["actions"],
)

# todo why do I get a number for 8 to 10 when i run this? does it mean that it loses at 8 steps?
rollout = env.rollout(max_steps=10, policy=policy) 
rollout

from torchrl.modules import Actor

"""
    Args:
        module (nn.Module): a :class:`~torch.nn.Module` used to map the input to
            the output parameter space.
        in_keys (iterable of str, optional): keys to be read from input
            tensordict and passed to the module. If it
            contains more than one element, the values will be passed in the
            order given by the in_keys iterable.
            Defaults to ``["observation"]``.
        out_keys (iterable of str): keys to be written to the input tensordict.
            The length of out_keys must match the
            number of tensors returned by the embedded module. Using ``"_"`` as a
            key avoid writing tensor to output.
            Defaults to ``["action"]``.
"""

policy = Actor(module)
rollout = env.rollout(max_steps=10, policy=policy)
print(rollout)
from torchrl.modules import MLP

module = MLP(
    out_features=env.action_spec.shape[-1],
    num_cells=[32, 64],
    activation_class=th.nn.Tanh,
)
policy = Actor(module)
rollout = env.rollout(max_steps=10, policy=policy)
rollout
from tensordict.nn.distributions import NormalParamExtractor
from torch.distributions import Normal
from torchrl.modules import ProbabilisticActor

backbone = MLP(in_features=4, out_features=2)
extractor = NormalParamExtractor()
module = th.nn.Sequential(backbone, extractor)
td_module = TensorDictModule(module, in_keys=["observation"], out_keys=["loc", "scale"])
policy = ProbabilisticActor(
    td_module,
    in_keys=["loc", "scale"],
    out_keys=["action"],
    distribution_class=Normal,
    return_log_prob=True
)

with set_exploration_type(ExplorationType.MEAN):
    # takes the mean as action
    rollout = env.rollout(max_steps=10, policy=policy)
rollout

TensorDict(
    fields={
        action: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: Tensor(shape=torch.Size([10, 4]), device=cpu, dtype=torch.float32, is_shared=False),
                reward: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([10]),
            device=None,
            is_shared=False),
        observation: Tensor(shape=torch.Size([10, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shar

  obs, reward, term, trunc, info = self.env.step(int(action))


TensorDict(
    fields={
        action: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        loc: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: Tensor(shape=torch.Size([10, 4]), device=cpu, dtype=torch.float32, is_shared=False),
                reward: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([10]),
            device=None,
            is_shared=False),
        observation: Tensor(shape=torch.Size([10, 4]), device=cpu, dtype=torch.float32, is_shared=F

# Exploration

In [27]:
from tensordict.nn import TensorDictSequential
from torchrl.modules import EGreedyModule

policy = Actor(MLP(4,1,num_cells=[32, 64]))
exploration_module = EGreedyModule(
    spec=env.action_spec, annealing_num_steps=1000, eps_init=0.5
)
exploration_policy = TensorDictSequential(policy, exploration_module)
with set_exploration_type(ExplorationType.MEAN):
    rollout = env.rollout(max_steps=10, policy=exploration_policy)

  obs, reward, term, trunc, info = self.env.step(int(action))


# Q-value net

In [28]:
from torchrl.modules import QValueModule

num_actions = 2
value_net = TensorDictModule(
    MLP(out_features=num_actions, num_cells=[32, 32]),
    in_keys=["observation"],
    out_keys=["action_value"],
)

# policy = TensorDictSequential(
#     value_net,
#     QValueModule(
#         action_space=env.action_spec
#     ),   # Reads the "action_value" entry by default
# )

rollout = env.rollout(max_steps=3, policy=policy)   # uses the one from prev cell
print(rollout)

TensorDict(
    fields={
        action: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
                reward: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([3]),
            device=None,
            is_shared=False),
        observation: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)

  obs, reward, term, trunc, info = self.env.step(int(action))


In [32]:
"""
CLASSIC TRAIN LOOP
for i in range(n_collections):
    data = get_next_batch(env, policy)
    for j in range(n_optim):
        loss = loss_fn(data)
        loss.backward()
        optim.step()"""

from torchrl.modules import Actor, MLP, ValueOperator
from torchrl.objectives import DDPGLoss

n_obs = env.observation_spec["observation"].shape[-1]
n_act = env.action_spec.shape[-1]
actor = Actor(MLP(in_features=n_obs, out_features=n_act, num_cells=[32, 32]))
value_net = ValueOperator(
    MLP(in_features=n_obs + n_act, out_features=1, num_cells=[32, 32]),
    in_keys=["observation", "action"],
)
ddpg_loss = DDPGLoss(actor_network=actor, value_network=value_net)

rollout = env.rollout(max_steps=100, policy=actor)
loss_vals = ddpg_loss(rollout)
print(loss_vals)

TensorDict(
    fields={
        loss_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        loss_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        pred_value: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False),
        pred_value_max: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        target_value: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False),
        target_value_max: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        td_error: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)


  obs, reward, term, trunc, info = self.env.step(int(action))


In [34]:
total_loss = 0
for key, val in loss_vals.items():
    if key.startswith("loss_"):
        total_loss += val
total_loss

tensor(1.0166, grad_fn=<AddBackward0>)

In [39]:
from torch.optim import Adam
optim = Adam(ddpg_loss.parameters())
total_loss.backward()
optim.step()
# optim.zero_grad()


RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

In [37]:
print(ddpg_loss.parameters().__doc__)

None


# Collectors

In [44]:
import torch

torch.manual_seed(0)

from torchrl.collectors import SyncDataCollector
from torchrl.envs.utils import RandomPolicy

env.set_seed(0)
policy = RandomPolicy(env.action_spec)
collector = SyncDataCollector(env, policy, frames_per_batch=200, total_frames=-1)

for data in collector:
    print(data.shape)
    break
print(data["collector", 'traj_ids'])

torch.Size([200])
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
        3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6,
        6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
        7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
        7, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9,
        9, 9, 9, 9, 9, 9, 9, 9])


  obs, reward, term, trunc, info = self.env.step(int(action))


## Replay Buffers
```
>>> for data in collector:
...     storage.store(data)
...     for i in range(n_optim):
...         sample = storage.sample()
...         loss_val = loss_fn(sample)
...         loss_val.backward()
...         optim.step() #
```

In [48]:
from torchrl.data.replay_buffers import LazyMemmapStorage, ReplayBuffer

buffer = ReplayBuffer(storage=LazyMemmapStorage(max_size=1000))
indices = buffer.extend(data)

sample = buffer.sample(batch_size=30)
sample

TensorDict(
    fields={
        action: Tensor(shape=torch.Size([30, 1]), device=cpu, dtype=torch.int8, is_shared=False),
        collector: TensorDict(
            fields={
                traj_ids: Tensor(shape=torch.Size([30]), device=cpu, dtype=torch.int64, is_shared=False)},
            batch_size=torch.Size([30]),
            device=cpu,
            is_shared=False),
        done: Tensor(shape=torch.Size([30, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([30, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: Tensor(shape=torch.Size([30, 4]), device=cpu, dtype=torch.float32, is_shared=False),
                reward: Tensor(shape=torch.Size([30, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                terminated: Tensor(shape=torch.Size([30, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([30]),
 

# Logging

In [49]:
from torchrl.record import CSVLogger
logger = CSVLogger(exp_name='csv_logger_exp')
logger.log_scalar("my_scalar", 0.4)