In [12]:
from torchrl.envs import ExplorationType

In [13]:
import torch
import time

from torchrl.envs import GymEnv, StepCounter, TransformedEnv
from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq
from torchrl.modules import EGreedyModule, MLP, QValueModule
from torchrl.collectors import SyncDataCollector
from torchrl.data import LazyTensorStorage, ReplayBuffer
from torch.optim import Adam
from torchrl.objectives import DQNLoss, SoftUpdate
from torchrl._utils import logger as torchrl_logger
from torchrl.record import CSVLogger, VideoRecorder

torch.manual_seed(0)


<torch._C.Generator at 0x7fce9e3300b0>

In [14]:
# Define the environment
env = TransformedEnv(GymEnv("CartPole-v1"), StepCounter())
env.set_seed(0)



795726461

In [15]:
env.action_spec

OneHotDiscreteTensorSpec(
    shape=torch.Size([2]),
    space=DiscreteBox(n=2),
    device=cpu,
    dtype=torch.int64,
    domain=discrete)

In [16]:
env.specs["input_spec", "full_action_spec", "action"].space

DiscreteBox(n=2)

In [17]:
input_shape = env.observation_spec["observation"].shape
env_specs = env.specs
num_outputs = env_specs["input_spec", "full_action_spec", "action"].space.n
action_spec = env_specs["input_spec", "full_action_spec", "action"]



In [18]:
# Define the network for DQN values
value_mlp = MLP(out_features=env.action_spec.shape[-1], 
                num_cells=[64, 64])
value_net = Mod(value_mlp, 
                in_keys=["observation"], 
                out_keys=["action_value"])

# Define the policy. QValueModule adds the argmax step to the Q-values
policy = Seq(value_net, 
             QValueModule(spec=env.action_spec))

# Define the exploration step (e-greedy policy)
exploration_module = EGreedyModule(
    env.action_spec, 
    annealing_num_steps=100_000, 
    eps_init=0.5
)
policy_explore = Seq(policy, 
                     exploration_module)



In [23]:
# Define how to collect the data (experiences)
init_rand_steps = 5000 # warm-up steps
frames_per_batch = 10
optim_steps = 10
replay_capacity = 100_000

# NOTE: collector will gather rollouts continously
# If the current trajectory ends, it will start a new one
# NOTE: the rollout gotten from the collector is a dictionary
# that defines the sate and next state as a tensor with a batch dimension in the begining
# for example a rollout of 10 steps will have a tensor of observation of 10 in the batch dimension
# and the next will also have 10 which are all the tensors of the next state
# Practically, next is as you will shift the tensor of observation by one step
collector = SyncDataCollector(
    env,
    policy,
    frames_per_batch=frames_per_batch,
    total_frames=500_100,
    init_random_frames=init_rand_steps,
)
rb = ReplayBuffer(storage=LazyTensorStorage(replay_capacity))

In [20]:
loss = DQNLoss(value_network=policy, 
               action_space=env.action_spec, 
               delay_value=True) # delay_value=True means we will use a target network
optim = Adam(loss.parameters(), lr=0.02)

# eps: will be used to update the target network as 
# \theta_t = \theta_{t-1} * \epsilon + \theta_t * (1-\epsilon)
# where eps = 1 is hard update
updater = SoftUpdate(loss, eps=0.99)

In [21]:
# Define the recording and logging
path = "./training_loop"
logger = CSVLogger(exp_name="dqn", log_dir=path, video_format="mp4")
video_recorder = VideoRecorder(logger, tag="video")
record_env = TransformedEnv(
    GymEnv("CartPole-v1", from_pixels=True, pixels_only=False), video_recorder
)



In [7]:
test_interval = 30
frames_per_batch = 10
current_frames = 0
for i in range(150):
    current_frames += frames_per_batch
    prev_test_frame = ((i - 1) * frames_per_batch) // test_interval
    cur_test_frame = (i * frames_per_batch) // test_interval

    print("first", prev_test_frame, cur_test_frame)
    print("second", current_frames, current_frames % test_interval)


first -1 0
second 10 10
first 0 0
second 20 20
first 0 0
second 30 0
first 0 1
second 40 10
first 1 1
second 50 20
first 1 1
second 60 0
first 1 2
second 70 10
first 2 2
second 80 20
first 2 2
second 90 0
first 2 3
second 100 10
first 3 3
second 110 20
first 3 3
second 120 0
first 3 4
second 130 10
first 4 4
second 140 20
first 4 4
second 150 0
first 4 5
second 160 10
first 5 5
second 170 20
first 5 5
second 180 0
first 5 6
second 190 10
first 6 6
second 200 20
first 6 6
second 210 0
first 6 7
second 220 10
first 7 7
second 230 20
first 7 7
second 240 0
first 7 8
second 250 10
first 8 8
second 260 20
first 8 8
second 270 0
first 8 9
second 280 10
first 9 9
second 290 20
first 9 9
second 300 0
first 9 10
second 310 10
first 10 10
second 320 20
first 10 10
second 330 0
first 10 11
second 340 10
first 11 11
second 350 20
first 11 11
second 360 0
first 11 12
second 370 10
first 12 12
second 380 20
first 12 12
second 390 0
first 12 13
second 400 10
first 13 13
second 410 20
first 13 13
seco

In [None]:
collector = SyncDataCollector(
    env,
    policy,
    frames_per_batch=frames_per_batch,
    total_frames=500_100,
    init_random_frames=init_rand_steps,
)

In [32]:
collector = SyncDataCollector(
    create_env_fn=env,
    policy=policy_explore,
    frames_per_batch=10,
    total_frames=5_100,
    device="cpu",
    storing_device="cpu",
    max_frames_per_traj=-1,
    init_random_frames=10_0,
)

In [4]:
import datetime

current_date = datetime.datetime.now()
date_str = current_date.strftime("%Y_%m_%d-%H_%M_%S")  # Includes date and time
date_str

'2024_07_18-18_37_34'

In [21]:
total_count = 0
total_episodes = 0
t0 = time.time()
for i, data in enumerate(collector):
    # Write data in replay buffer
    rb.extend(data)
    max_length = rb[:]["next", "step_count"].max() # From all the next steps get the max step count
    if len(rb) > init_rand_steps: # wam-up steps
        # Optim loop (we do several optim steps
        # per batch collected for efficiency)
        for _ in range(optim_steps):
            sample = rb.sample(128) # sample a batch of 128 (repetition is allowed)
            loss_vals = loss(sample)
            loss_vals["loss"].backward()
            optim.step()
            optim.zero_grad()
            # Update exploration factor
            # NOTE: Why I am updating the exploration factor here? 
            # I'm considering practically that I did 100 (or n) iteractions in the environment time optim_steps
            exploration_module.step(data.numel()) # data.numel() returns the number of elements in the data
            # Update target params each optimisation step
            updater.step()
            if i % 10:
                torchrl_logger.info(f"Max num steps: {max_length}, rb length {len(rb)}")
            total_count += data.numel()
            total_episodes += data["next", "done"].sum() # sum the number of done episodes
    if max_length > 200:
        break

t1 = time.time()

torchrl_logger.info(
    f"solved after {total_count} steps, {total_episodes} episodes and in {t1-t0}s."
)

2024-07-17 17:48:41,856 [torchrl][INFO] solved after 1000 steps, 10 episodes and in 2365.736939430237s.


In [8]:
record_env.rollout(max_steps=1000, policy=policy)
video_recorder.dump()

In [5]:
import random

# Generate and print 10 random seeds
random_seeds = [random.randint(0, 1000000) for _ in range(10)]
print(random_seeds)

[118398, 676190, 786456, 171936, 887739, 919409, 711872, 442081, 189061, 117840]
