In [3]:
import torch

torch.manual_seed(0)

import time

from torchrl.envs import GymEnv, StepCounter, TransformedEnv

env = TransformedEnv(GymEnv("CartPole-v1"), StepCounter())
env.set_seed(0)

from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq

# Policy

from torchrl.modules import EGreedyModule, MLP, QValueModule

value_mlp = MLP(out_features=env.action_spec.shape[-1], num_cells=[64, 64])
value_net = Mod(value_mlp, in_keys=["observation"], out_keys=["action_value"])
policy = Seq(value_net, QValueModule(env.action_spec))
exploration_module = EGreedyModule(env.action_spec, annealing_num_steps=100_000, eps_init=0.5)
policy_explore = Seq(policy, exploration_module)

# Data collector
from torchrl.collectors import SyncDataCollector
from torchrl.data import LazyTensorStorage, ReplayBuffer

init_rand_steps = 5000
frames_per_batch = 100
optim_steps = 10
collector = SyncDataCollector(
    env,
    policy,
    frames_per_batch=frames_per_batch,
    total_frames=-1,
    init_random_frames=init_rand_steps
)
rb = ReplayBuffer(storage=LazyTensorStorage(100_000))

from torch.optim import Adam

# Optimization & Loss Module

from torchrl.objectives import DQNLoss, SoftUpdate

loss = DQNLoss(value_network=policy, action_space=env.action_spec, delay_value=True)
optim = Adam(loss.parameters(), lr=0.02)
updater = SoftUpdate(loss, eps=0.99)

# Logger

from torchrl._utils import logger as torchrl_logger
from torchrl.record import CSVLogger, VideoRecorder

path = "./training_loop"
logger = CSVLogger(exp_name="dqn", log_dir=path, video_format="mp4")
video_recorder = VideoRecorder(logger, tag="video")
record_env = TransformedEnv(GymEnv("CartPole-v1", from_pixels=True, pixels_only=False), video_recorder)

# The main loop
total_count = 0
total_episodes = 0
t0 = time.time()
print('starting loop')
for i, data in enumerate(collector):
    
    # Write data in rb
    rb.extend(data)
    max_length = rb[:]['next', 'step_count'].max()
    print(total_count)
    if len(rb) > init_rand_steps:
        print(total_count)
        # Optim loop (we do several optim steps per batch collected for efficiency
        for _ in range(optim_steps):
            sample = rb.sample(128)
            loss_vals = loss(sample)
            loss_vals["loss"].backward()
            optim.step()
            optim.zero_grad()
            # Update the exploration factor
            exploration_module.step(data.numel())
            # Update target params
            updater.step()
            if i % 10:
                torchrl_logger.info(f"Max num steps: {max_length}, rb length {len(rb)}")
            total_count += data.numel()
            total_episodes += data["next", "done"].sum()
    if max_length > 200:
        break # truncate it

t1 = time.time()

torchrl_logger.info(
    f"solved after {total_count} steps, {total_episodes} episodes and in {t1-t0}s."
)

# Save reconding
record_env.rollout(max_steps=1000, policy=policy)
# video_recorder.dump()

  logger.warn(
  logger.warn(
  logger.warn(
  logger.warn(
Actual - Expected keys={('collector', 'traj_ids')}.
2024-03-27 12:45:40,627 [torchrl][INFO] Max num steps: 100, rb length 5200
2024-03-27 12:45:40,632 [torchrl][INFO] Max num steps: 100, rb length 5200
2024-03-27 12:45:40,638 [torchrl][INFO] Max num steps: 100, rb length 5200
2024-03-27 12:45:40,644 [torchrl][INFO] Max num steps: 100, rb length 5200
2024-03-27 12:45:40,649 [torchrl][INFO] Max num steps: 100, rb length 5200
2024-03-27 12:45:40,656 [torchrl][INFO] Max num steps: 100, rb length 5200
2024-03-27 12:45:40,662 [torchrl][INFO] Max num steps: 100, rb length 5200
2024-03-27 12:45:40,668 [torchrl][INFO] Max num steps: 100, rb length 5200
2024-03-27 12:45:40,674 [torchrl][INFO] Max num steps: 100, rb length 5200
2024-03-27 12:45:40,680 [torchrl][INFO] Max num steps: 100, rb length 5200
2024-03-27 12:45:40,728 [torchrl][INFO] Max num steps: 100, rb length 5300
2024-03-27 12:45:40,734 [torchrl][INFO] Max num steps: 100, rb 