In [None]:
!pip install -r train_requirements.txt

## PPO Training

In [None]:
import warnings
warnings.filterwarnings("ignore")
from torch import multiprocessing

from collections import defaultdict

# Information
from IPython.display import clear_output, display
import matplotlib.pyplot as plt
from tqdm import tqdm

# Torch
import torch
from tensordict.nn import TensorDictModule
from torch import nn
from torch.distributions import Categorical

# TorchRL
from torchrl.envs import (Compose, DoubleToFloat, ObservationNorm, StepCounter, TransformedEnv)
from torchrl.collectors import SyncDataCollector
from torchrl.data.replay_buffers import ReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.data.replay_buffers.storages import LazyTensorStorage
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.modules import ProbabilisticActor, SafeModule, ValueOperator, TanhNormal, ActorValueOperator, NormalParamExtractor
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value import GAE

# Environment
from torchrl.envs.libs.gym import GymEnv

# Other
import os
import json

In [None]:
is_fork = multiprocessing.get_start_method() == "fork"
device = (
    torch.device(0)
    if torch.cuda.is_available() and not is_fork
    else torch.device("cpu")
)
# Model hyper-parameters
lr = 3e-4
max_grad_norm = 1.0

# Collector hyper-parameters
frames_per_batch = 128 # 1024 # number of frames collected per batch
total_frames = 65536  # total number of frames to collect

# PPO hyper-parameters
sub_batch_size = 64  # cardinality of the sub-samples gathered from the current data in the inner loop
num_epochs = 8 # optimization steps per batch of data collected
clip_epsilon = ( 0.2 )  # clip value for PPO loss
gamma = 0.99
lmbda = 0.95
entropy_eps = 1e-4

# Checkpoint saving parameters
checkpoint_interval = 8
filename = ""


### Creating the environment

In [None]:
env = GymEnv("CarRacing-v3", continuous=False, render_mode="rgb_array", device=device)

print("observation_spec:", env.observation_spec)
print("reward_spec:", env.reward_spec)
print("input_spec:", env.input_spec)
print("action_spec (as defined by input_spec):", env.action_spec)

### Creating the model

In [None]:
class RaceCNN(nn.Module):
    def __init__(self, num_actions):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=8, stride=4, device=device),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, device=device),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, device=device),
            nn.ReLU(),
            nn.Flatten(),
        )

        # Calculate flattened size
        with torch.no_grad():
            dummy = torch.zeros((1, 3, 96, 96), device=device)
            n_flat = self.features(dummy).shape[1]

        self.head = nn.Sequential(
            nn.Linear(n_flat, 512, device=device),
            nn.ReLU(),
            nn.Linear(512, num_actions, device=device),
        )

    def forward(self, x):
        if x.dtype == torch.uint8:
            x = x.float() / 255.0
        # If input is NHWC, convert to NCHW
        if x.ndim == 4 and x.shape[-1] == 3:
            x = x.permute(0, 3, 1, 2)  # (B, H, W, C) -> (B, C, H, W)
        elif x.ndim == 3 and x.shape[0] == 3:
            pass  # Already (C, H, W)
        elif x.ndim == 3 and x.shape[-1] == 3:
            x = x.permute(2, 0, 1)  # (H, W, C) -> (C, H, W)
        x = self.features(x)
        return self.head(x)

# Transforms observations to logits using module
actor = TensorDictModule(
    module=RaceCNN(env.action_spec.shape.numel()),
    in_keys=["pixels"],
    out_keys=["logits"],
)

# 
actor = ProbabilisticActor(
    module=actor,
    spec=env.action_spec,
    in_keys=["logits"],
    distribution_class=Categorical,
    return_log_prob=True,
    out_keys=["action"],
)

value = ValueOperator(
    module=RaceCNN(1),
    in_keys=["pixels"],
)

collector = SyncDataCollector(
    env=env,
    policy=actor,
    frames_per_batch=frames_per_batch,
    device=device,
    replace=False,
)

replay_buffer = ReplayBuffer(
    storage=LazyTensorStorage(), sampler=SamplerWithoutReplacement
)

advantage_module = GAE(
    gamma=gamma, lmbda=lmbda, value_network=value, average_gae=True, device=device,
)

loss_module = ClipPPOLoss(
    actor_network=actor,
    critic_network=value,
    clip_epsilon=clip_epsilon,
    entropy_bonus=bool(entropy_eps),
    entropy_coef=entropy_eps,
    # these keys match by default but we set this for completeness
    critic_coef=1.0,
    loss_critic_type="smooth_l1",
)

optim = torch.optim.Adam(loss_module.parameters(), lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optim, total_frames // frames_per_batch, 0.0
)

### Training

In [None]:
pbar = tqdm(total=total_frames)
pbar.set_description("Training ")

logs = defaultdict(list)

# # Save model
# num_runs = len(os.listdir("checkpoints"))
# runDir = f'./checkpoints/{num_runs}'
# os.mkdir(runDir)

# We get a batch of data from the collector
for i, tensordict_data in enumerate(collector):

    print(tensordict_data)

    for _ in range(num_epochs):
        advantage_module(tensordict_data)

        data_view = tensordict_data.reshape(-1)

        # Add the data to the replay buffer
        replay_buffer.extend(data_view.cpu())

        # Sample sub-batches from the replay buffer
        for _ in range(frames_per_batch // sub_batch_size):
            subdata = replay_buffer.sample(sub_batch_size)

            # Compute the loss
            loss_vals = loss_module(subdata.to(device))
            loss_value = (
                loss_vals["loss_objective"]
                + loss_vals["loss_critic"]
                + loss_vals["loss_entropy"]
            )
            
            # Backpropagate and optimize
            loss_value.backward()
            torch.nn.utils.clip_grad_norm_(loss_module.parameters(), max_grad_norm)
            optim.step()
            optim.zero_grad()

    logs["train_reward"].append( tensordict_data["next", "reward"].mean().item() )
    #logs["train_steps_min"].append( tensordict_data["step_count"].min().item() )
    #logs["train_steps_max"].append( tensordict_data["step_count"].max().item() )

    # Execute the policy without exploration
    if i % checkpoint_interval == 0:
        pbar.set_description("Evaluation ")
        with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
            # Run the env with the actors value
            eval_rollout = env.rollout(512, actor)

            # Save the evaluation data
            logs["eval_reward"].append( eval_rollout["next", "reward"].mean().item() )
            #logs["eval_steps"].append( eval_rollout["step_count"].min().item() )
            del eval_rollout

            # # Save a checkpoint
            # filename = f'{runDir}/{i}.ch'
            # checkpoint = {
            #     'model_state_dict': actor.module.state_dict(),
            # }

            # torch.save(checkpoint, filename)

    pbar.set_description("Training ")

    # Update plot data
    clear_output(wait=True)

    # Rebuild the figure from scratch
    fig, axes = plt.subplots(1, 2, figsize=(6, 3))
    axes = axes.flatten()
    titles = [
        "Avg Reward (Train)", "Avg Reward (Eval)",
        #"Min Steps (Train)", "Max Steps (Train)",
    ]
    data = [
        (logs["train_reward"],   "blue"),
        (logs["eval_reward"],    "green"),
        #(logs["train_steps_min"],    "red"),
        #(logs["train_steps_max"],     "orange"),
    ]

    for ax, title, (y, color) in zip(axes, titles, data):
        ax.plot(y, color=color)
        ax.set_title(title)
        ax.relim()
        ax.autoscale_view()

    plt.tight_layout()

    # Display the new figure
    display(fig)
    plt.close(fig)

    # Learning rate scheduling
    scheduler.step()
    
    # Update the progress bar
    pbar.update(tensordict_data.numel())

# # Final save of the model
# filename = f'{runDir}/_final.ch'
# checkpoint = {
#     'model_state_dict': actor.module.state_dict(),
# }

# torch.save(checkpoint, filename)

In [None]:
# Load model for testing
if filename == "":
    num_runs = len(os.listdir("checkpoints"))
    filename = f'./checkpoints/{num_runs-1}/_final.ch'
    runDir = f'./checkpoints/{num_runs-1}'
else:
    with open(f'{runDir}/logs.json', 'w') as f:
        f.write(json.dumps(logs))

l_checkpoint = torch.load(filename)

actor.module.state_dict(l_checkpoint["model_state_dict"])

env.rollout(512, actor)

In [None]:
with open(f'{runDir}/logs.json', 'r') as f:
    logs = json.load(f)

fig, axes = plt.subplots(1, 2, figsize=(6, 3))
axes = axes.flatten()
titles = [
    "Avg Reward (Train)", "Avg Reward (Eval)",
    #"Min Steps (Train)", "Max Steps (Train)",
]
data = [
    (logs["train_reward"],   "blue"),
    (logs["eval_reward"],    "green"),
    #(logs["train_steps_min"],    "red"),
    #(logs["train_steps_max"],     "orange"),
]

for ax, title, (y, color) in zip(axes, titles, data):
    ax.plot(y, color=color)
    ax.set_title(title)
    ax.relim()
    ax.autoscale_view()

plt.tight_layout()