In [1]:
import os
import uuid

import torch
from torch import nn
from torchrl.collectors import MultiaSyncDataCollector, SyncDataCollector
from torchrl.data import LazyMemmapStorage, MultiStep, TensorDictReplayBuffer
from torchrl.envs import (
    EnvCreator,
    ExplorationType,
    ParallelEnv,
    RewardScaling,
    StepCounter,
)
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.transforms import (
    CatFrames,
    Compose,
    GrayScale,
    ObservationNorm,
    Resize,
    ToTensorImage,
    TransformedEnv,
)
from torchrl.modules import DuelingCnnDQNet, EGreedyModule, QValueActor
from tensordict.nn import TensorDictSequential


from torchrl.objectives import DQNLoss, SoftUpdate
from torchrl.record.loggers.csv import CSVLogger
from torchrl.trainers import (
    LogReward,
    Recorder,
    ReplayBufferTrainer,
    Trainer,
    UpdateWeights,
)


def is_notebook() -> bool:
    try:
        shell = get_ipython().__class__.__name__
        if shell == "ZMQInteractiveShell":
            return True  # Jupyter notebook or qtconsole
        elif shell == "TerminalInteractiveShell":
            return False  # Terminal running IPython
        else:
            return False  # Other type (?)
    except NameError:
        return False  # Probably standard Python interpreter

In [2]:
def make_env(
    parallel=False,
    obs_norm_sd=None,
    num_workers=1,
):
    if obs_norm_sd is None:
        obs_norm_sd = {"standard_normal": True}
    if parallel:

        def maker():
            return GymEnv(
                "CartPole-v1",
                from_pixels=True,
                pixels_only=True,
                device=device,
            )

        base_env = ParallelEnv(
            num_workers,
            EnvCreator(maker),
            # Don't create a sub-process if we have only one worker
            serial_for_single=True,
        )
    else:
        base_env = GymEnv(
            "CartPole-v1",
            from_pixels=True,
            pixels_only=True,
            device=device,
        )

    env = TransformedEnv(
        base_env,
        Compose(
            StepCounter(),  # to count the steps of each trajectory
            ToTensorImage(),
            RewardScaling(loc=0.0, scale=0.1),
            GrayScale(),
            Resize(64, 64),
            CatFrames(4, in_keys=["pixels"], dim=-3),
            ObservationNorm(in_keys=["pixels"], **obs_norm_sd),
        ),
    )
    return env

In [3]:
def get_norm_stats():
    test_env = make_env()
    test_env.transform[-1].init_stats(
        num_iter=1000, cat_dim=0, reduce_dim=[-1, -2, -4], keep_dims=(-1, -2)
    )
    obs_norm_sd = test_env.transform[-1].state_dict()
    # let's check that normalizing constants have a size of ``[C, 1, 1]`` where
    # ``C=4`` (because of :class:`~torchrl.envs.CatFrames`).
    print("state dict of the observation norm:", obs_norm_sd)
    test_env.close()
    del test_env
    return obs_norm_sd

In [4]:
def make_model(dummy_env):
    cnn_kwargs = {
        "num_cells": [32, 64, 64],
        "kernel_sizes": [6, 4, 3],
        "strides": [2, 2, 1],
        "activation_class": nn.ELU,
        # This can be used to reduce the size of the last layer of the CNN
        # "squeeze_output": True,
        # "aggregator_class": nn.AdaptiveAvgPool2d,
        # "aggregator_kwargs": {"output_size": (1, 1)},
    }
    mlp_kwargs = {
        "depth": 2,
        "num_cells": [
            64,
            64,
        ],
        "activation_class": nn.ELU,
    }
    net = DuelingCnnDQNet(
        dummy_env.action_spec.shape[-1], 1, cnn_kwargs, mlp_kwargs
    ).to(device)
    net.value[-1].bias.data.fill_(init_bias)
    
    print("net is: ", net)

    actor = QValueActor(net, in_keys=["pixels"], spec=dummy_env.action_spec).to(device)

    print("QValueActor is: ", actor, "\n\n")

    # init actor: because the model is composed of lazy conv/linear layers,
    # we must pass a fake batch of data through it to instantiate them.
    tensordict = dummy_env.fake_tensordict()
    actor(tensordict)

    # we join our actor with an EGreedyModule for data collection
    exploration_module = EGreedyModule(
        spec=dummy_env.action_spec,
        annealing_num_steps=total_frames,
        eps_init=eps_greedy_val,
        eps_end=eps_greedy_val_env,
    )

    print("actor is: ", actor)
    actor_explore = TensorDictSequential(actor, exploration_module)

    return actor, actor_explore

In [5]:
def get_replay_buffer(buffer_size, n_optim, batch_size):
    replay_buffer = TensorDictReplayBuffer(
        batch_size=batch_size,
        storage=LazyMemmapStorage(buffer_size),
        prefetch=n_optim,
    )
    return replay_buffer

In [6]:
def get_collector(
    stats,
    num_collectors,
    actor_explore,
    frames_per_batch,
    total_frames,
    device,
):
    # We can't use nested child processes with mp_start_method="fork"
   
    cls = SyncDataCollector
    env_arg = make_env(parallel=True, obs_norm_sd=stats, num_workers=num_workers)

    data_collector = cls(
        env_arg,
        policy=actor_explore,
        frames_per_batch=frames_per_batch,
        total_frames=total_frames,
        # this is the default behaviour: the collector runs in ``"random"`` (or explorative) mode
        exploration_type=ExplorationType.RANDOM,
        # We set the all the devices to be identical. Below is an example of
        # heterogeneous devices
        device=device,
        storing_device=device,
        split_trajs=False,
        postproc=MultiStep(gamma=gamma, n_steps=5),
    )
    return data_collector

In [7]:
def get_loss_module(actor, gamma):
    loss_module = DQNLoss(actor, delay_value=True)
    loss_module.make_value_estimator(gamma=gamma)
    target_updater = SoftUpdate(loss_module, eps=0.995)
    return loss_module, target_updater

In [8]:
device = (
    torch.device(0)
    if torch.cuda.is_available() and not is_fork
    else torch.device("cpu")
)

In [9]:
# the learning rate of the optimizer
lr = 2e-3
# weight decay
wd = 1e-5
# the beta parameters of Adam
betas = (0.9, 0.999)
# Optimization steps per batch collected (aka UPD or updates per data)
n_optim = 8

In [10]:
gamma = 0.99
tau = 0.02
total_frames = 5_000  # 500000
init_random_frames = 100  # 1000
frames_per_batch = 32  # 128
batch_size = 32  # 256
buffer_size = min(total_frames, 100000)
num_workers = 2  # 8
num_collectors = 2  # 4
eps_greedy_val = 0.1
eps_greedy_val_env = 0.005
init_bias = 2.0


In [11]:
#stats = get_norm_stats()
test_env = make_env(parallel=False)
# Get model
actor, actor_explore = make_model(test_env)
loss_module, target_net_updater = get_loss_module(actor, gamma)

  logger.warn(
  logger.warn(


net is:  DuelingCnnDQNet(
  (features): ConvNet(
    (0): LazyConv2d(0, 32, kernel_size=(6, 6), stride=(2, 2))
    (1): ELU(alpha=1.0)
    (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
    (3): ELU(alpha=1.0)
    (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
    (5): ELU(alpha=1.0)
    (6): SquashDims()
  )
  (advantage): MLP(
    (0): LazyLinear(in_features=0, out_features=64, bias=True)
    (1): ELU(alpha=1.0)
    (2): Linear(in_features=64, out_features=64, bias=True)
    (3): ELU(alpha=1.0)
    (4): Linear(in_features=64, out_features=2, bias=True)
  )
  (value): MLP(
    (0): LazyLinear(in_features=0, out_features=64, bias=True)
    (1): ELU(alpha=1.0)
    (2): Linear(in_features=64, out_features=64, bias=True)
    (3): ELU(alpha=1.0)
    (4): Linear(in_features=64, out_features=1, bias=True)
  )
)
QValueActor is:  QValueActor(
    module=ModuleList(
      (0): TensorDictModule(
          module=DuelingCnnDQNet(
            (features): ConvNet(
              (0):



In [12]:
actor

QValueActor(
    module=ModuleList(
      (0): TensorDictModule(
          module=DuelingCnnDQNet(
            (features): ConvNet(
              (0): Conv2d(4, 32, kernel_size=(6, 6), stride=(2, 2))
              (1): ELU(alpha=1.0)
              (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
              (3): ELU(alpha=1.0)
              (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
              (5): ELU(alpha=1.0)
              (6): SquashDims()
            )
            (advantage): MLP(
              (0): Linear(in_features=9216, out_features=64, bias=True)
              (1): ELU(alpha=1.0)
              (2): Linear(in_features=64, out_features=64, bias=True)
              (3): ELU(alpha=1.0)
              (4): Linear(in_features=64, out_features=2, bias=True)
            )
            (value): MLP(
              (0): Linear(in_features=9216, out_features=64, bias=True)
              (1): ELU(alpha=1.0)
              (2): Linear(in_features=64, out_features=64, 

In [13]:
actor_explore

TensorDictSequential(
    module=ModuleList(
      (0): QValueActor(
          module=ModuleList(
            (0): TensorDictModule(
                module=DuelingCnnDQNet(
                  (features): ConvNet(
                    (0): Conv2d(4, 32, kernel_size=(6, 6), stride=(2, 2))
                    (1): ELU(alpha=1.0)
                    (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
                    (3): ELU(alpha=1.0)
                    (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
                    (5): ELU(alpha=1.0)
                    (6): SquashDims()
                  )
                  (advantage): MLP(
                    (0): Linear(in_features=9216, out_features=64, bias=True)
                    (1): ELU(alpha=1.0)
                    (2): Linear(in_features=64, out_features=64, bias=True)
                    (3): ELU(alpha=1.0)
                    (4): Linear(in_features=64, out_features=2, bias=True)
                  )
                  (va

In [14]:
loss_module

DQNLoss(
  (value_network_params): TensorDictParams(params=TensorDict(
      fields={
          module: TensorDict(
              fields={
                  0: TensorDict(
                      fields={
                          module: TensorDict(
                              fields={
                                  advantage: TensorDict(
                                      fields={
                                          0: TensorDict(
                                              fields={
                                                  bias: Parameter(shape=torch.Size([64]), device=cpu, dtype=torch.float32, is_shared=False),
                                                  weight: Parameter(shape=torch.Size([64, 9216]), device=cpu, dtype=torch.float32, is_shared=False)},
                                              batch_size=torch.Size([]),
                                              device=None,
                                              is_shared=False),
         

In [15]:
collector = get_collector(
    stats=test_env,
    num_collectors=num_collectors,
    actor_explore=actor_explore,
    frames_per_batch=frames_per_batch,
    total_frames=total_frames,
    device=device,
)
optimizer = torch.optim.Adam(
    loss_module.parameters(), lr=lr, weight_decay=wd, betas=betas
)
exp_name = f"dqn_exp_{uuid.uuid1()}"
tmpdir = tempfile.TemporaryDirectory()
logger = CSVLogger(exp_name=exp_name, log_dir=tmpdir.name)
warnings.warn(f"log dir: {logger.experiment.log_dir}")

  logger.warn(
  logger.warn(
  logger.warn(


TypeError: torchrl.envs.transforms.transforms.ObservationNorm() argument after ** must be a mapping, not TransformedEnv

In [None]:
!pip install gymnasium[classic-control]

Collecting pygame>=2.1.3 (from gymnasium[classic-control])
  Downloading pygame-2.6.0-cp311-cp311-win_amd64.whl.metadata (13 kB)
Downloading pygame-2.6.0-cp311-cp311-win_amd64.whl (10.8 MB)
   ---------------------------------------- 0.0/10.8 MB ? eta -:--:--
   ---------------------------------------- 0.0/10.8 MB ? eta -:--:--
   ---------------------------------------- 0.0/10.8 MB 487.6 kB/s eta 0:00:22
    --------------------------------------- 0.2/10.8 MB 1.4 MB/s eta 0:00:08
   -- ------------------------------------- 0.6/10.8 MB 3.6 MB/s eta 0:00:03
   --- ------------------------------------ 1.0/10.8 MB 4.3 MB/s eta 0:00:03
   ---- ----------------------------------- 1.3/10.8 MB 4.6 MB/s eta 0:00:03
   ----- ---------------------------------- 1.5/10.8 MB 4.9 MB/s eta 0:00:02
   ----- ---------------------------------- 1.5/10.8 MB 4.3 MB/s eta 0:00:03
   -------- ------------------------------- 2.2/10.8 MB 5.3 MB/s eta 0:00:02
   ---------- ----------------------------- 2.9/10.8