In [None]:
#!pip install -r requirements.txt

## PPO Training

In [None]:
import warnings
warnings.filterwarnings("ignore")
from torch import multiprocessing

from collections import defaultdict

import matplotlib.pyplot as plt
import torch
from tensordict.nn import TensorDictModule
from tensordict.nn.distributions import NormalParamExtractor
from torch import nn
from torchrl.collectors import SyncDataCollector
from torchrl.data.replay_buffers import ReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.data.replay_buffers.storages import LazyTensorStorage
from torchrl.envs import (Compose, DoubleToFloat, ObservationNorm, StepCounter, TransformedEnv)
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.utils import check_env_specs, ExplorationType, set_exploration_type
from torchrl.envs.transforms import UnsqueezeTransform
from torchrl.modules import ProbabilisticActor, TanhNormal, ValueOperator
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value import GAE
from tqdm import tqdm

In [None]:
is_fork = multiprocessing.get_start_method() == "fork"
device = (
    torch.device(0)
    if torch.cuda.is_available() and not is_fork
    else torch.device("cpu")
)
num_cells = 256  # number of cells in each layer i.e. output dim.
lr = 3e-4
max_grad_norm = 1.0

frames_per_batch = 32
# For a complete training, bring the number of frames up to 1M
total_frames = 50_000

sub_batch_size = 64  # cardinality of the sub-samples gathered from the current data in the inner loop
num_epochs = 10  # optimization steps per batch of data collected
clip_epsilon = (
    0.2  # clip value for PPO loss: see the equation in the intro for more context.
)
gamma = 0.99
lmbda = 0.95
entropy_eps = 1e-4

# Environment hyper-parameters
env_width = 64
env_height = 64

### Creating the environment

In [None]:
from Environment.MazeEnv import TorchRLMazeEnv

env = TorchRLMazeEnv(width=env_width, height=env_height, device=device)

env = TransformedEnv(
    env,
    Compose(
        StepCounter(),
    ),
)

check_env_specs(env)

In [None]:
rollout = env.rollout(3)
print("rollout of three steps:", rollout)

### Creating the model

In [None]:
# Actor network - CNN based policy network
actor_net = nn.Sequential(
    # CNN layers
    nn.Conv2d(3, 32, kernel_size=8, stride=4, device=device),
    nn.ReLU(),
    nn.Conv2d(32, 64, kernel_size=4, stride=2, device=device),
    nn.ReLU(),
    nn.Conv2d(64, 64, kernel_size=3, stride=1, device=device),
    nn.ReLU(),
    nn.Flatten(),

    # Fully connected layers
    nn.LazyLinear(512, 512, device=device),
    nn.ReLU(),
    nn.Linear(512, 6, device=device),  # Output 6 logits for binary actions
    nn.Sigmoid(),
)

# Create the actor
actor = TensorDictModule(
    actor_net,
    in_keys=["observation"],
    out_keys=["action"],
)

# Value network - Similar architecture but outputs a single value
value_net = nn.Sequential(
    # CNN layers
    nn.Conv2d(3, 32, kernel_size=8, stride=4, device=device),
    nn.ReLU(),
    nn.Conv2d(3, 64, kernel_size=4, stride=2, device=device),
    nn.ReLU(),
    nn.Conv2d(3, 64, kernel_size=3, stride=1, device=device),
    nn.ReLU(),
    nn.Flatten(),

    # Fully connected layers
    nn.Linear(512, 512, device=device),
    nn.ReLU(),
    nn.Linear(512, 1, device=device),  # Output a single value
)

# Create the value operator
value = ValueOperator(
    value_net,
    in_keys=["observation"],
)

In [None]:
collector = SyncDataCollector(
    env,
    actor,
    frames_per_batch=frames_per_batch,
    total_frames=total_frames,
    split_trajs=False,
    device=device,
)

replay_buffer = ReplayBuffer(
    storage=LazyTensorStorage(max_size=frames_per_batch),
    sampler=SamplerWithoutReplacement(),
)

advantage_module = GAE(
    gamma=gamma, lmbda=lmbda, value_network=value, average_gae=True, device=device,
)

loss_module = ClipPPOLoss(
    actor_network=actor,
    critic_network=value,
    clip_epsilon=clip_epsilon,
    entropy_bonus=bool(entropy_eps),
    entropy_coef=entropy_eps,
    # these keys match by default but we set this for completeness
    critic_coef=1.0,
    loss_critic_type="smooth_l1",
)

optim = torch.optim.Adam(loss_module.parameters(), lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optim, total_frames // frames_per_batch, 0.0
)

In [None]:
# Train with collector

for tensordict_data in collector:
    print("tensordict_data", tensordict_data)

    # Compute the advantages
    tensordict_data = advantage_module(tensordict_data)

    # Add the data to the replay buffer
    replay_buffer.extend(tensordict_data)

    # Sample a batch from the replay buffer
    batch = replay_buffer.sample(sub_batch_size)

    # Optimize the model
    for _ in range(num_epochs):
        loss = loss_module(batch)
        optim.zero_grad(set_to_none=True)
        loss.backward()
        nn.utils.clip_grad_norm_(loss_module.parameters(), max_grad_norm)
        optim.step()
        scheduler.step()

    print(f"Loss: {loss.item()}")  # Print the loss for monitoring

In [None]:
plt.figure(figsize=(10, 10))
plt.subplot(2, 2, 1)
plt.plot(logs["reward"])
plt.title("training rewards (average)")
plt.subplot(2, 2, 2)
plt.plot(logs["step_count"])
plt.title("Max step count (training)")
plt.subplot(2, 2, 3)
plt.plot(logs["eval reward (sum)"])
plt.title("Return (test)")
plt.subplot(2, 2, 4)
plt.plot(logs["eval step_count"])
plt.title("Max step count (test)")
plt.show()