In [1]:
#!pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128
#!pip install -r requirements.txt

## PPO Training

In [2]:
import warnings
warnings.filterwarnings("ignore")
from torch import multiprocessing

from collections import defaultdict

import matplotlib.pyplot as plt
import torch
from tensordict.nn import TensorDictModule
from torch import nn
from torchrl.collectors import SyncDataCollector
from tensordict import TensorDict
from torchrl.data.replay_buffers import ReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.data.replay_buffers.storages import LazyTensorStorage
from torchrl.envs import ParallelEnv, TransformedEnv, Compose, StepCounter 
from torchrl.envs.transforms import Transform, RemoveEmptySpecs
from torchrl.envs.utils import check_env_specs, ExplorationType, set_exploration_type
from torchrl.modules import ProbabilisticActor, ValueOperator

from torch.distributions import Bernoulli

from torchrl.objectives import ClipPPOLoss, ValueEstimators
from torchrl.objectives.value import GAE
from tqdm import tqdm

from Environment.MazeEnv import TorchRLMazeEnv

import uuid
import os
import pickle
import json

In [3]:
is_fork = multiprocessing.get_start_method() == "fork"
device = (
    torch.device(0)
    if torch.cuda.is_available() and not is_fork
    else torch.device("cpu")
)
num_cells = 256  # number of cells in each layer i.e. output dim.
lr = 3e-4
max_grad_norm = 1.0

batches = 1
size = 4

frames_per_batch = batches * size
# For a complete training, bring the number of frames up to 1M
total_frames = frames_per_batch * 256

sub_batch_size = 32  # cardinality of the sub-samples gathered from the current data in the inner loop
num_epochs = 10# optimization steps per batch of data collected
clip_epsilon = (
    0.2  # clip value for PPO loss: see the equation in the intro for more context.
)
gamma = 0.99
lmbda = 0.95
entropy_eps = 1e-4

# Environment hyper-parameters
env_width = 128
env_height = 72

filename = ""

runID = uuid.uuid4()

runDir = f'./checkpoints/{runID}'

os.mkdir(runDir)

### Creating the environment

In [4]:
env = TorchRLMazeEnv(width=env_width, height=env_height, batches=batches, size=size, device=device, show_windows=True)

# Add transforms in the correct order
env = TransformedEnv(env, Compose( RemoveEmptySpecs(), StepCounter() ))

# check_env_specs(env)

### Creating the model

In [5]:
class ActorNet(nn.Module):
    def __init__(self, num_actions, num_cells=256):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=8, stride=4, device=device),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, device=device),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, device=device),
            nn.ReLU(),
            nn.Flatten(start_dim=1),
        )
        
        # Calculate the flattened size properly
        with torch.no_grad():
            dummy = torch.zeros(1, 3, env_height, env_width, device=device)  # Add batch dimension
            n_flat = self.features(dummy).shape[1]  # Get the flattened size

        self.head = nn.Sequential(
            nn.Linear(n_flat, num_cells, device=device),
            nn.ReLU(),
            nn.Linear(num_cells, num_actions, device=device),
        )

    def forward(self, x):
        if len(x.shape) == 3:
            x = x.unsqueeze(0)
        x = self.features(x)
        probs = self.head(x)
        if len(probs.shape) == 3:
            probs = probs.squeeze(0)
        return probs

# Create the actor and value operators as before
actor = TensorDictModule(
    module=ActorNet(6),
    in_keys=["observation"],
    out_keys=["logits"],
)

actor = ProbabilisticActor(
    module=actor,
    spec=env.action_spec,
    in_keys=["logits"],
    distribution_class=Bernoulli,
    distribution_kwargs={},
    return_log_prob=True,
    out_keys=["action"],
)

value = ValueOperator(
    module=ActorNet(1),
    in_keys=["observation"],
)

In [6]:
collector = SyncDataCollector(
    env,
    actor,
    frames_per_batch=128,
    total_frames=total_frames,
    device=device,
)

replay_buffer = ReplayBuffer(
    storage=LazyTensorStorage(max_size=frames_per_batch, device=device),
    sampler=SamplerWithoutReplacement(),
    batch_size=torch.Size([frames_per_batch]),
)

advantage_module = GAE(
    gamma=gamma, lmbda=lmbda, value_network=value, average_gae=True, device=device,
)

loss_module = ClipPPOLoss(
    actor_network=actor,
    critic_network=value,
    clip_epsilon=clip_epsilon,
    entropy_bonus=bool(entropy_eps),
    entropy_coef=entropy_eps,
    critic_coef=1.0,
    loss_critic_type="smooth_l1",
    device=device,
)

optim = torch.optim.Adam(loss_module.parameters(), lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optim, total_frames // frames_per_batch, 0.0
)

collector.rollout()

TensorDict(
    fields={
        action: Tensor(shape=torch.Size([4, 32, 6]), device=cuda:0, dtype=torch.float32, is_shared=True),
        collector: TensorDict(
            fields={
                traj_ids: Tensor(shape=torch.Size([4, 32]), device=cuda:0, dtype=torch.int64, is_shared=True)},
            batch_size=torch.Size([4, 32]),
            device=cuda:0,
            is_shared=True),
        done: Tensor(shape=torch.Size([4, 32, 1]), device=cuda:0, dtype=torch.bool, is_shared=True),
        logits: Tensor(shape=torch.Size([4, 32, 6]), device=cuda:0, dtype=torch.float32, is_shared=True),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([4, 32, 1]), device=cuda:0, dtype=torch.bool, is_shared=True),
                observation: Tensor(shape=torch.Size([4, 32, 3, 72, 128]), device=cuda:0, dtype=torch.float32, is_shared=True),
                reward: Tensor(shape=torch.Size([4, 32, 1]), device=cuda:0, dtype=torch.float32, is_shared=True),


### Training

In [7]:
logs = defaultdict(list)
pbar = tqdm(total=total_frames)
eval_str = ""

for i, tensordict_data in enumerate(collector):

    tensordict_data = tensordict_data.squeeze(1) # Remove extra added dimension

    # we now have a batch of data to work with. Let's learn something from it.
    for _ in range(num_epochs):
        # We'll need an "advantage" signal to make PPO work.
        # We re-compute it at each epoch as its value depends on the value
        # network which is updated in the inner loop.
        advantage_module(tensordict_data)

        flat_bs = tensordict_data.batch_size[0]                                 # e.g. 64

        # 2) squeeze away any extra dims on obs first
        #    (you said your obs was already [batch, C, H, W] here)
        obs = tensordict_data["observation"]             # [T, N, C, H, W]
        obs = obs.reshape(flat_bs, *obs.shape[1:])       # [flat_bs, C, H, W]

        # 3) grab event-shapes for the others
        action_shape    = tensordict_data["action"].shape[1:]          # e.g. [6]
        logprob_shape   = tensordict_data["sample_log_prob"].shape[1:] # e.g. [6]
        # state_value, value_target, advantage are scalar => shape []

        # 4) rebuild
        flat_td = TensorDict(
            {
            "observation":    obs,
            "action":         tensordict_data["action"]
                                    .reshape(flat_bs, *action_shape),
            "sample_log_prob": tensordict_data["sample_log_prob"]
                                    .reshape(flat_bs, *logprob_shape),
            "state_value":    tensordict_data["state_value"]
                                    .reshape(flat_bs, 1),
            "value_target":   tensordict_data["value_target"]
                                    .reshape(flat_bs, 1),
            "advantage":      tensordict_data["advantage"]
                                    .reshape(flat_bs),      # scalar event
            },
            batch_size=torch.Size([flat_bs]),
            device=tensordict_data.device,
        )
        replay_buffer.extend(flat_td.cpu())

        for _ in range(frames_per_batch // sub_batch_size):
            subdata = replay_buffer.sample(sub_batch_size).to(device)
            loss_vals = loss_module(subdata)
            loss_value = (
                loss_vals["loss_objective"]
                + loss_vals["loss_critic"]
                + loss_vals["loss_entropy"]
            )

            # Optimization: backward, grad clipping and optimization step
            loss_value.backward()
            # this is not strictly mandatory but it's good practice to keep
            # your gradient norm bounded
            torch.nn.utils.clip_grad_norm_(loss_module.parameters(), max_grad_norm)
            optim.step()
            optim.zero_grad()

    logs["reward"].append(tensordict_data["next", "reward"].mean().item())
    pbar.update(tensordict_data.numel())
    cum_reward_str = (
        f"average reward={logs['reward'][-1]: 4.4f} (init={logs['reward'][0]: 4.4f})"
    )
    logs["step_count"].append(tensordict_data["step_count"].max().item())
    stepcount_str = f"step count (max): {logs['step_count'][-1]}"
    logs["lr"].append(optim.param_groups[0]["lr"])
    lr_str = f"lr policy: {logs['lr'][-1]: 4.4f}"
    if i % 10 == 0:
        # We evaluate the policy once every 10 batches of data.
        # Evaluation is rather simple: execute the policy without exploration
        # (take the expected value of the action distribution) for a given
        # number of steps (1000, which is our ``env`` horizon).
        # The ``rollout`` method of the ``env`` can take a policy as argument:
        # it will then execute this policy at each step.
        with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
            # execute a rollout with the trained policy
            eval_rollout = env.rollout(256, actor)
            logs["eval reward"].append(eval_rollout["next", "reward"].mean().item())
            logs["eval reward (sum)"].append(
                eval_rollout["next", "reward"].sum().item()
            )
            logs["eval step_count"].append(eval_rollout["step_count"].max().item())
            eval_str = (
                f"eval cumulative reward: {logs['eval reward (sum)'][-1]: 4.4f} "
                f"(init: {logs['eval reward (sum)'][0]: 4.4f}), "
                f"eval step-count: {logs['eval step_count'][-1]}"
            )
            del eval_rollout
    pbar.set_description(", ".join([eval_str, cum_reward_str, stepcount_str, lr_str]))

    if i % (frames_per_batch * 8) == 0:
        filename = f'{runDir}/{i}.ch'
        checkpoint = {
            'model_state_dict': actor.module.state_dict(),
            'reward': logs["eval reward"][-1],  # Assuming you have an optimizer
            'epoch': i,  # Assuming you have an epoch counter
        }

        torch.save(checkpoint, filename)
        
The above exception was the direct cause of the following exception:

    # We're also using a learning rate scheduler. Like the gradient clipping,
    # this is a nice-to-have but nothing necessary for PPO to work.
    scheduler.step()

filename = f'{runDir}/_final.ch'
checkpoint = {
    'model_state_dict': actor.module.state_dict(),
    'reward': logs["eval reward"][-1],  # Assuming you have an optimizer
    'epoch': i,  # Assuming you have an epoch counter
}

torch.save(checkpoint, filename)

  0%|          | 0/1024 [00:00<?, ?it/s]

RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [4, 32, 3, 72, 128]

In [None]:
with open(f'{runDir}/logs.json', 'w') as f:
    f.write(json.dumps(logs))

plt.figure(figsize=(10, 10))
plt.subplot(2, 2, 1)
plt.plot(logs["reward"])
plt.title("training rewards (average)")
plt.subplot(2, 2, 2)
plt.plot(logs["step_count"])
plt.title("Max step count (training)")
plt.subplot(2, 2, 3)
plt.plot(logs["eval reward (sum)"])
plt.title("Return (test)")
plt.subplot(2, 2, 4)
plt.plot(logs["eval step_count"])
plt.title("Max step count (test)")
plt.show()


In [None]:
# Load model for testing
if filename == "":
    filename = "checkpoints/47d6adb5-3abd-4640-8b31-50a1b174142d/_final.ch"

l_checkpoint = torch.load(filename)

actor.module.state_dict(l_checkpoint["model_state_dict"])

env = TorchRLMazeEnv(width=env_width, height=env_height, batches=1, size=1, device=device, show_windows=True)
env = TransformedEnv(env, Compose( RemoveEmptySpecs(), StepCounter() ))

env.rollout(512, actor)