In [12]:
from torchrl.envs import ExplorationType

In [13]:
import torch
import time

from torchrl.envs import GymEnv, StepCounter, TransformedEnv
from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq
from torchrl.modules import EGreedyModule, MLP, QValueModule
from torchrl.collectors import SyncDataCollector
from torchrl.data import LazyTensorStorage, ReplayBuffer
from torch.optim import Adam
from torchrl.objectives import DQNLoss, SoftUpdate
from torchrl._utils import logger as torchrl_logger
from torchrl.record import CSVLogger, VideoRecorder

torch.manual_seed(0)


<torch._C.Generator at 0x7fce9e3300b0>

In [14]:
# Define the environment
env = TransformedEnv(GymEnv("CartPole-v1"), StepCounter())
env.set_seed(0)



795726461

In [15]:
env.action_spec

OneHotDiscreteTensorSpec(
    shape=torch.Size([2]),
    space=DiscreteBox(n=2),
    device=cpu,
    dtype=torch.int64,
    domain=discrete)

In [16]:
env.specs["input_spec", "full_action_spec", "action"].space

DiscreteBox(n=2)

In [17]:
input_shape = env.observation_spec["observation"].shape
env_specs = env.specs
num_outputs = env_specs["input_spec", "full_action_spec", "action"].space.n
action_spec = env_specs["input_spec", "full_action_spec", "action"]



In [18]:
# Define the network for DQN values
value_mlp = MLP(out_features=env.action_spec.shape[-1], 
                num_cells=[64, 64])
value_net = Mod(value_mlp, 
                in_keys=["observation"], 
                out_keys=["action_value"])

# Define the policy. QValueModule adds the argmax step to the Q-values
policy = Seq(value_net, 
             QValueModule(spec=env.action_spec))

# Define the exploration step (e-greedy policy)
exploration_module = EGreedyModule(
    env.action_spec, 
    annealing_num_steps=100_000, 
    eps_init=0.5
)
policy_explore = Seq(policy, 
                     exploration_module)



In [23]:
# Define how to collect the data (experiences)
init_rand_steps = 5000 # warm-up steps
frames_per_batch = 10
optim_steps = 10
replay_capacity = 100_000

# NOTE: collector will gather rollouts continously
# If the current trajectory ends, it will start a new one
# NOTE: the rollout gotten from the collector is a dictionary
# that defines the sate and next state as a tensor with a batch dimension in the begining
# for example a rollout of 10 steps will have a tensor of observation of 10 in the batch dimension
# and the next will also have 10 which are all the tensors of the next state
# Practically, next is as you will shift the tensor of observation by one step
collector = SyncDataCollector(
    env,
    policy,
    frames_per_batch=frames_per_batch,
    total_frames=500_100,
    init_random_frames=init_rand_steps,
)
rb = ReplayBuffer(storage=LazyTensorStorage(replay_capacity))

In [20]:
loss = DQNLoss(value_network=policy, 
               action_space=env.action_spec, 
               delay_value=True) # delay_value=True means we will use a target network
optim = Adam(loss.parameters(), lr=0.02)

# eps: will be used to update the target network as 
# \theta_t = \theta_{t-1} * \epsilon + \theta_t * (1-\epsilon)
# where eps = 1 is hard update
updater = SoftUpdate(loss, eps=0.99)

In [21]:
# Define the recording and logging
path = "./training_loop"
logger = CSVLogger(exp_name="dqn", log_dir=path, video_format="mp4")
video_recorder = VideoRecorder(logger, tag="video")
record_env = TransformedEnv(
    GymEnv("CartPole-v1", from_pixels=True, pixels_only=False), video_recorder
)



In [28]:
500_100 // 10

50010

In [None]:
collector = SyncDataCollector(
    env,
    policy,
    frames_per_batch=frames_per_batch,
    total_frames=500_100,
    init_random_frames=init_rand_steps,
)

In [32]:
collector = SyncDataCollector(
    create_env_fn=env,
    policy=policy_explore,
    frames_per_batch=10,
    total_frames=5_100,
    device="cpu",
    storing_device="cpu",
    max_frames_per_traj=-1,
    init_random_frames=10_0,
)

In [33]:
for i, data in enumerate(collector):
    print(data.numel())

print(i)

10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
1

In [21]:
total_count = 0
total_episodes = 0
t0 = time.time()
for i, data in enumerate(collector):
    # Write data in replay buffer
    rb.extend(data)
    max_length = rb[:]["next", "step_count"].max() # From all the next steps get the max step count
    if len(rb) > init_rand_steps: # wam-up steps
        # Optim loop (we do several optim steps
        # per batch collected for efficiency)
        for _ in range(optim_steps):
            sample = rb.sample(128) # sample a batch of 128 (repetition is allowed)
            loss_vals = loss(sample)
            loss_vals["loss"].backward()
            optim.step()
            optim.zero_grad()
            # Update exploration factor
            # NOTE: Why I am updating the exploration factor here? 
            # I'm considering practically that I did 100 (or n) iteractions in the environment time optim_steps
            exploration_module.step(data.numel()) # data.numel() returns the number of elements in the data
            # Update target params each optimisation step
            updater.step()
            if i % 10:
                torchrl_logger.info(f"Max num steps: {max_length}, rb length {len(rb)}")
            total_count += data.numel()
            total_episodes += data["next", "done"].sum() # sum the number of done episodes
    if max_length > 200:
        break

t1 = time.time()

torchrl_logger.info(
    f"solved after {total_count} steps, {total_episodes} episodes and in {t1-t0}s."
)

2024-07-17 17:48:41,856 [torchrl][INFO] solved after 1000 steps, 10 episodes and in 2365.736939430237s.


In [8]:
record_env.rollout(max_steps=1000, policy=policy)
video_recorder.dump()