<a href="https://colab.research.google.com/github/AiProcess/CartPole_RL/blob/main/cart_pole.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torchrl
!pip install tensordict

In [None]:
!pip install av
!apt-get update
!apt-get install -y ffmpeg

In [3]:
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Default device is {device}")

Default device is cuda


# Environment

In [4]:
from torchrl.envs import GymEnv
from torchrl.envs.transforms import StepCounter, TransformedEnv
from torchrl.record import VideoRecorder, CSVLogger

## Main Env

In [None]:
env = TransformedEnv(
    env=GymEnv(env_name="CartPole-v1"),
    transform=StepCounter(),
).to(device)

## Test Env

In [None]:
csv_logger = CSVLogger(exp_name="test", log_dir="cart_pole", video_format="mp4")
video_recorder = VideoRecorder(logger=csv_logger, tag="video")

test_env = TransformedEnv(
    env=GymEnv(
        env_name="CartPole-v1",
        from_pixels=True,
        pixels_only=False
    ),
    transform=video_recorder
).to(device)

In [None]:
# env.rollout(max_steps=5)

# Agent

In [8]:
from torch import nn
import torch.nn.functional as F
from tensordict.nn import TensorDictModule, TensorDictSequential
from torchrl.modules import EGreedyModule, QValueModule

## Policy

In [74]:
class MLP_Model(nn.Module):
    def __init__(self, in_features, out_features) -> None:
        super().__init__()
        self.layer1 = nn.Linear(in_features=in_features, out_features=32)
        self.layer2 = nn.Linear(in_features=32, out_features=64)
        self.layer3 = nn.Linear(in_features=64, out_features=64)
        self.layer4 = nn.Linear(in_features=64, out_features=64)
        self.layer5 = nn.Linear(in_features=64, out_features=32)
        self.layer6 = nn.Linear(in_features=32, out_features=out_features)

    def forward(self, x):
        y = F.relu(self.layer1(x))
        y = F.relu(self.layer2(y))
        y = F.relu(self.layer3(y))
        y = F.relu(self.layer4(y))
        y = F.relu(self.layer5(y))
        y = F.relu(self.layer6(y))

        return y

In [75]:
value_mlp = MLP_Model(
    in_features=env.observation_spec['observation'].shape[0],
    out_features=env.action_spec.shape[0]
)

value_net = TensorDictModule(
    module=value_mlp,
    in_keys=["observation"],
    out_keys=["action_value"]
)

policy = TensorDictSequential(value_net, QValueModule(spec=env.action_spec))

exploration_module = EGreedyModule(
    env.action_spec, annealing_num_steps=100_000, eps_init=0.5
)

policy_explore = TensorDictSequential(policy,exploration_module).to(device)

## Test with random policy

In [76]:
test_env.rollout(max_steps=1000, policy=policy_explore)
video_recorder.dump()

## Data Collection and Replay Buffer

In [77]:
from torchrl.collectors import SyncDataCollector
from torchrl.data import LazyTensorStorage, ReplayBuffer

In [78]:
init_rand_steps = 5000

In [79]:
collector = SyncDataCollector(
    create_env_fn=env,
    policy=policy_explore,
    frames_per_batch=128,
    total_frames=-1,
    init_random_frames=init_rand_steps
)

In [80]:
collector

SyncDataCollector(
    env=TransformedEnv(
        env=GymEnv(env=CartPole-v1, batch_size=torch.Size([]), device=cuda:0),
        transform=StepCounter(keys=[])),
    policy=TensorDictSequential(
        module=ModuleList(
          (0): TensorDictSequential(
              module=ModuleList(
                (0): TensorDictModule(
                    module=MLP_Model(
                      (layer1): Linear(in_features=4, out_features=32, bias=True)
                      (layer2): Linear(in_features=32, out_features=64, bias=True)
                      (layer3): Linear(in_features=64, out_features=64, bias=True)
                      (layer4): Linear(in_features=64, out_features=64, bias=True)
                      (layer5): Linear(in_features=64, out_features=32, bias=True)
                      (layer6): Linear(in_features=32, out_features=2, bias=True)
                    ),
                    device=cuda:0,
                    in_keys=['observation'],
                    out_keys=['

In [81]:
rb = ReplayBuffer(storage=LazyTensorStorage(max_size=100_000, device=device))

In [82]:
rb

ReplayBuffer(
    storage=LazyTensorStorage(
        data=<empty>, 
        shape=None, 
        len=0, 
        max_size=100000), 
    sampler=RandomSampler(), 
    writer=RoundRobinWriter(cursor=0, full_storage=False), 
    batch_size=None, 
    collate_fn=<function _collate_id at 0x7c38e9f02f80>)

## Loss module and Optimizer

In [83]:
from torch.optim import Adam
from torchrl.objectives import DQNLoss, SoftUpdate

In [84]:
dqn_loss = DQNLoss(
    value_network=policy_explore,
    action_space=env.action_spec,
    delay_value=True
)

In [85]:
optim = Adam(params=dqn_loss.parameters(), lr=0.02)

In [86]:
updater = SoftUpdate(dqn_loss, eps=0.99)

# Training Loop

In [87]:
import time
from torchrl._utils import logger as torchrl_logger

In [88]:
optim_steps = 10

In [89]:
total_count = 0
total_episodes = 0
t0 = time.time()
for i, data in enumerate(collector):
    # Write data in replay buffer
    rb.extend(data)
    max_length = rb[:]["next", "step_count"].max()
    if len(rb) > init_rand_steps:
        # Optim loop (we do several optim steps
        # per batch collected for efficiency)
        for _ in range(optim_steps):
            sample = rb.sample(128)
            loss_vals = dqn_loss(sample)
            loss_vals["loss"].backward()
            optim.step()
            optim.zero_grad()
            # Update exploration factor
            exploration_module.step(data.numel())
            # Update target params
            updater.step()
            if i % 10:
                torchrl_logger.info(f"Max num steps: {max_length}, rb length {len(rb)}")
            total_count += data.numel()
            total_episodes += data["next", "done"].sum()
    if max_length > 400:
        break

t1 = time.time()

torchrl_logger.info(
    f"solved after {total_count} steps, {total_episodes} episodes and in {t1-t0}s."
)

2024-11-16 21:04:12,183 [torchrl][INFO] Max num steps: 123, rb length 5120
2024-11-16 21:04:12,202 [torchrl][INFO] Max num steps: 123, rb length 5120
2024-11-16 21:04:12,221 [torchrl][INFO] Max num steps: 123, rb length 5120
2024-11-16 21:04:12,239 [torchrl][INFO] Max num steps: 123, rb length 5120
2024-11-16 21:04:12,257 [torchrl][INFO] Max num steps: 123, rb length 5120
2024-11-16 21:04:12,275 [torchrl][INFO] Max num steps: 123, rb length 5120
2024-11-16 21:04:12,293 [torchrl][INFO] Max num steps: 123, rb length 5120
2024-11-16 21:04:12,311 [torchrl][INFO] Max num steps: 123, rb length 5120
2024-11-16 21:04:12,329 [torchrl][INFO] Max num steps: 123, rb length 5120
2024-11-16 21:04:12,349 [torchrl][INFO] Max num steps: 123, rb length 5120
2024-11-16 21:04:12,997 [torchrl][INFO] Max num steps: 123, rb length 5376
2024-11-16 21:04:13,017 [torchrl][INFO] Max num steps: 123, rb length 5376
2024-11-16 21:04:13,035 [torchrl][INFO] Max num steps: 123, rb length 5376
2024-11-16 21:04:13,053 [

# Results

In [90]:
test_env.rollout(max_steps=1000, policy=policy_explore)
video_recorder.dump()