In [None]:
# !apt-get update
# !apt-get install -y swig python3-dev

In [None]:
#!pip install -r requirements.txt

## PPO Training

In [None]:
import warnings
warnings.filterwarnings("ignore")
from torch import multiprocessing

from collections import defaultdict

# Information
from IPython.display import clear_output, display
import matplotlib.pyplot as plt
from tqdm import tqdm

# Torch
import torch
from torch import nn
from torch.distributions import OneHotCategorical

# TorchRL
from torchrl.envs.transforms import (
    TransformedEnv, Compose, ToTensorImage, ObservationNorm, StepCounter, DoubleToFloat, GrayScale, CatFrames, UnsqueezeTransform
    )
from torchrl.collectors import SyncDataCollector
from torchrl.data.replay_buffers import ReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.data.replay_buffers.storages import LazyTensorStorage
from torchrl.modules import ProbabilisticActor, SafeModule, ValueOperator
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value import GAE

# Environment
from torchrl.envs.libs.gym import GymEnv

# Other
import uuid
import os

### Hyper-parameters

In [None]:
is_fork = multiprocessing.get_start_method() == "fork"
device = (
    torch.device(0)
    if torch.cuda.is_available() and not is_fork
    else torch.device("cpu")
)

print(f"Using device: {device}")

# Collector hyper-parameters
frames_per_batch = 2048 # number of frames collected per batch
num_iterations = 4096 # number of batches

total_frames = num_iterations * frames_per_batch  # total number of frames to collect

# PPO hyper-parameters
sub_batch_size = 64  # cardinality of the sub-samples gathered from the current data in the inner loop
num_epochs = 8 # optimization steps per batch of data collected

# Checkpoint saving parameters
checkpoint_interval = 16
filename = ""


### Creating the environment

In [None]:
base_env = GymEnv("CarRacing-v3", continuous=False, render_mode="rgb_array", device=device)

# Compose them into a TransformedEnv
env = TransformedEnv(base_env,
    Compose(
        DoubleToFloat(),
        ToTensorImage(),
        GrayScale(),
        UnsqueezeTransform(-4),
        CatFrames(dim=-3, N=4),
        ObservationNorm(in_keys=["pixels"]),
        StepCounter()
    )
)

# Normalize observations
env.transform[-2].init_stats(num_iter=128, reduce_dim=0, cat_dim=0)

### Creating the model

In [None]:
class CarRacingBackbone(nn.Module):
    def __init__(self, n_actions: int, n_frames: int = 4, img_size: tuple = (96, 96)):
        super().__init__()
        self.conv1 = nn.Conv2d(n_frames, 32, kernel_size=8, stride=4, device=device)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2, device=device)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, device=device)

        # Dynamically compute flatten_size so we never hard‐code incorrectly:
        H, W = img_size
        with torch.no_grad():
            dummy = torch.zeros(1, n_frames, H, W, device=device)  # [1, 3, 96, 96]
            o = self.conv1(dummy)
            o = self.conv2(o)
            o = self.conv3(o)
            self.flatten_size = o.view(1, -1).shape[1]  # e.g. 4096

        self.fc1 = nn.Linear(self.flatten_size, 512, device=device)
        self.logits = nn.Linear(512, n_actions, device=device)

    def forward(self, obs: torch.Tensor):
        # If obs has shape [3, 96, 96], unsqueeze so it becomes [1, 3, 96, 96]:
        if obs.dim() == 3:                     # no batch dimension
            obs = obs.unsqueeze(0)             # now [1, 3, 96, 96]

        # By this point, obs must be [B, 3, 96, 96]:
        x = torch.relu(self.conv1(obs))        # → [B, 32, 23, 23]
        x = torch.relu(self.conv2(x))          # → [B, 64, 10, 10]
        x = torch.relu(self.conv3(x))          # → [B, 64,  8,  8]
        x = x.reshape(x.shape[0], -1)             # → [B, 4096]  (because flatten_size = 4096)
        x = torch.relu(self.fc1(x))            # → [B, 512]
        return self.logits(x)                  # → [B, n_actions]

backbone_net = SafeModule(
    module=CarRacingBackbone(env.action_spec.shape.numel()),
    in_keys=["pixels"],     # expects obs under key "pixels"
    out_keys=["logits"],    # produces a "logits" tensor
)

discrete_actor = ProbabilisticActor(
    module=backbone_net,
    spec=env.action_spec,               # DiscreteTensorSpec
    in_keys=["logits"],                 # read logits from that key
    distribution_class=OneHotCategorical,     # TorchRL’s one-hot categorical: samples a one-hot vector of size n_actions
    out_keys=["action"],                # writes a one-hot action into "action"
    return_log_prob=True,               # store "log_prob" in the tensordict
)

critic_net = SafeModule(
    module=CarRacingBackbone(1),
    in_keys=["pixels"],
    out_keys=["state_value"],
)
critic = ValueOperator(
    module=critic_net,
    in_keys=["pixels"],
    out_keys=["state_value"],
)

### Data Collector

In [None]:
replay_buffer = ReplayBuffer(
    storage=LazyTensorStorage(max_size=frames_per_batch, device=device),
    sampler=SamplerWithoutReplacement(),
    batch_size=sub_batch_size,
)

collector = SyncDataCollector(
    env,
    policy=discrete_actor,
    frames_per_batch=frames_per_batch,
    total_frames=total_frames,
    device=device,
    split_trajs=False,
)

gae_module = GAE(
    value_network=critic,
    gamma=0.99,
    lmbda=0.95,
)

def get_entropy_coef(iteration, total_iters):
    if iteration < 0.2 * total_iters:
        return 0.05
    elif iteration < 0.5 * total_iters:
        return 0.02
    else:
        return 0.01

ppo_loss = ClipPPOLoss(
    actor_network=discrete_actor,
    critic_network=critic,
    clip_epsilon=0.2,
    loss_critic_type="smooth_l1",
    
    #normalize_advantage=True,

    entropy_coef=get_entropy_coef(0, num_iterations),
)

optimizer = torch.optim.AdamW(ppo_loss.parameters(), lr=3e-4)


scheduler = torch.optim.lr_scheduler.ChainedScheduler([
    torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1e-2, total_iters=int(0.1 * num_iterations)), # First 10% of the training
    torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=int(0.9 * num_iterations), eta_min=1e-8), # Last 90% of the training
])

### Training

In [None]:
def plot(logs):
    # Update plot data
    clear_output(wait=True)

    # Rebuild the figure from scratch
    fig, axes = plt.subplots(1, 2, figsize=(6, 3))
    axes = axes.flatten()
    titles = [
        "Avg Reward (Train)", "Avg Reward (Eval)",
        #"Min Steps (Train)", "Max Steps (Train)",
    ]
    data = [
        (logs["train_reward"],   "blue"),
        (logs["eval_reward"],    "green"),
        #(logs["train_steps_min"],    "red"),
        #(logs["train_steps_max"],     "orange"),
    ]

    for ax, title, (y, color) in zip(axes, titles, data):
        ax.plot(y, color=color)
        ax.set_title(title)
        ax.relim()
        ax.autoscale_view()

    plt.tight_layout()

    # Display the new figure
    display(fig)
    plt.close(fig)


In [None]:
pbar = tqdm(total=num_iterations)
pbar.set_description("Training ")
logs = defaultdict(list)

runDir = f'./checkpoints/{uuid.uuid4()}'
os.mkdir(runDir)

for i, td in enumerate(collector):
    # Compute advantage + value targets
    td = gae_module(td)

    # Normalize Advantages
    adv = td["advantage"]
    adv = (adv - adv.mean()) / (adv.std() + 1e-8)
    td.set("advantage", adv)

    # Update entropy coefficient
    with torch.no_grad():
        ppo_loss.entropy_coef.copy_(torch.tensor(get_entropy_coef(i, num_iterations), device=ppo_loss.entropy_coef.device))

    # Sample minibatches
    replay_buffer.extend(td)

    for _ in range(num_epochs):
        sample_td = replay_buffer.sample()

        # Compute loss
        loss_vals = ppo_loss(sample_td)
        loss_value = (
            loss_vals["loss_objective"]
            + loss_vals["loss_critic"]
            + loss_vals["loss_entropy"]
        )

        # Backpropagate and optimize
        loss_value.backward()
        torch.nn.utils.clip_grad_norm_(ppo_loss.parameters(), 1.0)
        optimizer.step()
        optimizer.zero_grad()

    replay_buffer.empty()

    logs["train_reward"].append(td["next", "reward"].mean().item())

    if i % checkpoint_interval == 0:
        # Save the evaluation data
        with (set_exploration_type(ExplorationType.DETERMINISTIC)), torch.no_grad():
          # Run the env with the actors value
          eval_rollout = env.rollout(1024, discrete_actor)

          # Save the evaluation data
          logs["eval_reward"].append( eval_rollout["next", "reward"].mean().item() )
          #logs["eval_steps"].append( eval_rollout["step_count"].min().item() )
          del eval_rollout

          # Save a checkpoint
          filename = f'{runDir}/{i}.ch'
          checkpoint = {
              'model_state_dict': discrete_actor.module.state_dict(),
          }

          torch.save(checkpoint, filename)

    pbar.set_description("Training ")

    plot(logs)

    pbar.update(1)

    scheduler.step()
pbar.close()

# Final save of the model
filename = f'{runDir}/_final.ch'
checkpoint = {
    'model_state_dict': discrete_actor.module.state_dict(),
}
torch.save(checkpoint, filename)