<a href="https://colab.research.google.com/github/Josh-Slagle/Intro_to_MARL_TorchRL/blob/main/Introduction_to_MARL_TorchRL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Dependencies

In [1]:
%%capture
!pip install \
  antlr4-python3-runtime==4.9.3 \
  certifi==2024.7.4 \
  charset-normalizer==3.3.2 \
  click==8.1.7 \
  cloudpickle==3.0.0 \
  docker-pycreds==0.4.0 \
  filelock==3.15.4 \
  fsspec==2024.6.1 \
  gitdb==4.0.11 \
  GitPython==3.1.43 \
  gym==0.26.2 \
  gym-notices==0.0.8 \
  hydra-core==1.3.2 \
  idna==3.7 \
  Jinja2==3.1.4 \
  MarkupSafe==2.1.5 \
  mpmath==1.3.0 \
  networkx==3.3 \
  numpy==2.0.0 \
  nvidia-cublas-cu12==12.1.3.1 \
  nvidia-cuda-cupti-cu12==12.1.105 \
  nvidia-cuda-nvrtc-cu12==12.1.105 \
  nvidia-cuda-runtime-cu12==12.1.105 \
  nvidia-cudnn-cu12==8.9.2.26 \
  nvidia-cufft-cu12==11.0.2.54 \
  nvidia-curand-cu12==10.3.2.106 \
  nvidia-cusolver-cu12==11.4.5.107 \
  nvidia-cusparse-cu12==12.1.0.106 \
  nvidia-nccl-cu12==2.20.5 \
  nvidia-nvjitlink-cu12==12.5.82 \
  nvidia-nvtx-cu12==12.1.105 \
  omegaconf==2.3.0 \
  packaging==24.1 \
  platformdirs==4.2.2 \
  protobuf==5.27.2 \
  psutil==6.0.0 \
  pyglet==1.5.27 \
  PyYAML==6.0.1 \
  requests==2.32.3 \
  sentry-sdk==2.10.0 \
  setproctitle==1.3.3 \
  six==1.16.0 \
  smmap==5.0.1 \
  sympy==1.13.0 \
  tensordict==0.4.0 \
  torch==2.3.1 \
  torchrl==0.4.0 \
  tqdm==4.66.4 \
  triton==2.3.1 \
  typing_extensions==4.12.2 \
  urllib3==2.2.2 \
  vmas==1.4.2  \
  wandb==0.17.4

# Itinerary

## Learning Objective
Goal of this lightning talk is to simultaneously introduce the topic of Multi-Agent Reinforcement Learning and give you a quick overview of some of the more modern tools being used to develop in this field.


# Material

## Super Quick Overview of RL

Reinforcement Learning is a field of algorithms that learn from exploration of a defined environment (trial and error).
The key components of RL are:
* Environment
* Agent

The goal in RL is to select a *policy* which maximizes *expected return* when the agent acts according to it. [OpenAI Spinning Up](https://spinningup.openai.com/en/latest/spinningup/rl_intro.html)



Lightning_Talk_MARL_Intro_Diagrams.drawio.svg


## What is MARL & Why is it interesting?


Multi-Agent Reinforcement Learning (MARL) extends traditional Reinforcement Learning (RL) to environments where multiple agents interact and learn simultaneously. Each agent in MARL aims to optimize its own policy while considering the presence and potential actions of other agents.
* Cooperative, competitive, or mixed interactions
* Shared & Individual Policy/Value nets (IPPO vs MAPPO)

In [MAPPO](https://arxiv.org/abs/2103.01955) the critic is **centralised** (global or concatenation of all agents' observations).
In [IPPO](https://arxiv.org/abs/2011.09533) the critic network is mapped to a single agent (akin to the policy network). This enables *decentralised* training using only local information.
[TorchRL's Documentation](https://github.com/pytorch/rl/tree/v0.5.0)



## Tools used in this talk


* [Weights & Biases](https://wandb.ai/site) - Logging & Model Management
* [Hydra](https://hydra.cc/) - Config File Management
* [TorchRL](https://github.com/pytorch/rl/tree/v0.5.0) - PyTorch based RL Library
* [VMAS](https://vmas.readthedocs.io/en/latest/) - Vectorized Multi-Agent Simulation Library

## TorchRL's view on RL


* TorchRL was designed to be highly modular for exploration.
* Data is passed around from component to component using *Tensordicts*.


This example's modular structure:

Lightning_Talk_MARL_Intro_Diagrams-TorchRL Modules.drawio.svg

TensorDict Evolution step-by-step

In [2]:
from torchrl.envs.libs.vmas import VmasEnv
env = VmasEnv(
    scenario="discovery",
    num_envs=1,
    continuous_actions=True,  # VMAS supports both continuous and discrete actions
    max_steps=10,
    n_agents=2,  # These are custom kwargs that change for each VMAS scenario, see the VMAS repo to know more.
)



#### Environment Reset

In [3]:
reset = env.reset()
print(reset)

  and should_run_async(code)


TensorDict(
    fields={
        agents: TensorDict(
            fields={
                info: TensorDict(
                    fields={
                        collision_rew: Tensor(shape=torch.Size([1, 2, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                        covering_reward: Tensor(shape=torch.Size([1, 2, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                        targets_covered: Tensor(shape=torch.Size([1, 2, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
                    batch_size=torch.Size([1, 2]),
                    device=cpu,
                    is_shared=False),
                observation: Tensor(shape=torch.Size([1, 2, 21]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([1, 2]),
            device=cpu,
            is_shared=False),
        done: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        terminated: Tensor(shape=torch.Size([1, 1

#### Random Action
Note the additional 'action' field!

In [4]:
reset_with_action = env.rand_action(reset)
print(reset_with_action)

TensorDict(
    fields={
        agents: TensorDict(
            fields={
                action: Tensor(shape=torch.Size([1, 2, 2]), device=cpu, dtype=torch.float32, is_shared=False),
                info: TensorDict(
                    fields={
                        collision_rew: Tensor(shape=torch.Size([1, 2, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                        covering_reward: Tensor(shape=torch.Size([1, 2, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                        targets_covered: Tensor(shape=torch.Size([1, 2, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
                    batch_size=torch.Size([1, 2]),
                    device=cpu,
                    is_shared=False),
                observation: Tensor(shape=torch.Size([1, 2, 21]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([1, 2]),
            device=cpu,
            is_shared=False),
        done: Tensor(shape=torch.

A look into the action taken:

(Note: In single-agent RL, the 'agents' key would not be needed)

In [5]:
print(reset_with_action["agents", "action"])

tensor([[[-0.1222, -0.8262],
         [ 0.5191, -0.0434]]])


#### Take an Action within the Environment
Note the 'next' field with 'reward' and 'observation's for each agent.
In this environment, 'done' and 'terminated' are for the environment as a whole (global).

In [6]:
stepped_data = env.step(reset_with_action)
print(stepped_data)

TensorDict(
    fields={
        agents: TensorDict(
            fields={
                action: Tensor(shape=torch.Size([1, 2, 2]), device=cpu, dtype=torch.float32, is_shared=False),
                info: TensorDict(
                    fields={
                        collision_rew: Tensor(shape=torch.Size([1, 2, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                        covering_reward: Tensor(shape=torch.Size([1, 2, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                        targets_covered: Tensor(shape=torch.Size([1, 2, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
                    batch_size=torch.Size([1, 2]),
                    device=cpu,
                    is_shared=False),
                observation: Tensor(shape=torch.Size([1, 2, 21]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([1, 2]),
            device=cpu,
            is_shared=False),
        done: Tensor(shape=torch.

#### Rollout
Note the shape of the returned tensordict.

In [7]:
n_rollout_steps = 5
rollout = env.rollout(n_rollout_steps)
print(rollout)

TensorDict(
    fields={
        agents: TensorDict(
            fields={
                action: Tensor(shape=torch.Size([1, 5, 2, 2]), device=cpu, dtype=torch.float32, is_shared=False),
                info: TensorDict(
                    fields={
                        collision_rew: Tensor(shape=torch.Size([1, 5, 2, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                        covering_reward: Tensor(shape=torch.Size([1, 5, 2, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                        targets_covered: Tensor(shape=torch.Size([1, 5, 2, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
                    batch_size=torch.Size([1, 5, 2]),
                    device=cpu,
                    is_shared=False),
                observation: Tensor(shape=torch.Size([1, 5, 2, 21]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([1, 5, 2]),
            device=cpu,
            is_shared=False),
        done

Lightning_Talk_MARL_Intro_Diagrams-TorchRL Intro - Stepping Through.drawio.svg

#### A look into a field's Tensor

In [8]:
print(rollout['agents','info','collision_rew'])

tensor([[[[0.],
          [0.]],

         [[0.],
          [0.]],

         [[0.],
          [0.]],

         [[0.],
          [0.]],

         [[0.],
          [0.]]]])


# Example Project

## Environment Description

We'll be trying to solve the 'navigation' environment from VMAS (Vectorized Multi-Agent Simulator).

The navigation environment **rewards** agents for moving towards their respective colored goal and penalizes them for bumping into each other.

Each agent **observes** their vector from the goal, position, velocity, and how far their sensors have been intruded upon.

The environment is considered 'done' or terminated if all agents distances from their respective goals is less than their radius.

[VMAS Navigation Implementation](https://github.com/proroklab/VectorizedMultiAgentSimulator/blob/main/vmas/scenarios/navigation.py)



It is possible to set the configuration of this environment to change the observation space (all goals are seen by each agent) but this project focuses on a 1-1 mapping of agents to goals.

## Algorithm Description

This project uses MAPPO (Multi-Agent Proximal Policy Optimization) although TorchRL makes IPPO as easy as changing a configuration parameter. (Try it!)

[PPO Spinning Up Link](https://spinningup.openai.com/en/latest/algorithms/ppo.html)

## Hydra Config

 **Q**: What is Hydra?

 **A**: Hydra-core simplifies managing and configuring complex applications by allowing dynamic and hierarchical composition of configuration files, command-line arguments, and environment variables.

It is YAML based and allows ubiquitous usage of the configuration file thoroughout your application. Often it is used in combination with [OmegaConf](https://omegaconf.readthedocs.io/en/2.3_branch/), but is not in this example project.

In [9]:
%%writefile config.yaml
device: "cpu"
alg:
  seed: 3437
  # Learning Parameters
  lr: 3e-4
  lr_scheduler: False
  max_grad_norm: 1.0
  # Neural Net
  neural_network:
    representation: 'MLP' # Options: {MLP, KAN}
    mlp:
      width: 256 # number of cells in each layer i.e. output dim.
      num_layers: 5
      activation_fn: Tanh
    kan:
      width: 2
      num_layers: 5
      grid: 3
      k: 3
  # Continual RL Alg Additional Parameters
  l2_init: False
  crelu: False

  # PPO Parameters
  sub_batch_size: 200  # cardinality of the sub-samples gathered from the current data in the inner loop
  num_epochs: 5  # optimisation steps per batch of data collected
  clip_epsilon: 0.2  # clip value for PPO loss: see the equation in the intro for more context.
  gamma: 0.99
  lmbda: 0.95
  entropy_eps: 0.0
  ent_coef: 0.0
  vf_coef: 0.5

  load_pretrained: False

  marl:
    share_policy_params: True
    share_critic_params: True
    mappo: True # IPPO if 'False' (Aka centralized critic)
    centralized_policy: False

env:
  type: 'vmas'
  #name: "discovery"
  name: "navigation"
  max_episode_steps: 1_000
  num_envs: 8
  num_agents: 5

eval:
  evaluation_frequency: 10 # Number of batches until evaluated
  evaluation_horizon_steps: 1_000

data_collection:
  frames_per_batch: 2_000
  num_iters: 10 # 100
  total_frames: 20_000 # 800_000
  split_trajs: False

logging:
  todo: ${hydra:runtime.cwd}/logs

model_paths:
  base_path: ${hydra:runtime.cwd}/models/
  policy_model: "policy.pt"
  value_model: "value.pt"
  initial_policy_model: "init_policy.pt"
  initial_value_model: "init_value.pt"

wab_config:
  use_wab: False
  user: 'TBD-USER'
  org: 'TBD-ORG'
  project: 'marl-discovery'
  initial_policy_model:
    name: 'initial_policy_model'
    load_artifact_tag: 'latest'
  initial_value_model:
    name: 'initial_value_model'
    load_artifact_tag: 'latest'
  policy_model:
    name: 'policy_model'
    load_artifact_tag: 'latest'
  value_model:
    name: 'value_model'
    load_artifact_tag: 'latest'


Writing config.yaml


In [10]:
!mkdir models

In [11]:
!ls

config.yaml  models  sample_data


In [12]:
import os
print(os.getcwd())

/content


## Imports

In [13]:
# Imports
from collections import defaultdict
import time

## Torch
import torch
from torch import nn

# Tensordict
from tensordict.nn import TensorDictModule
from tensordict.nn.distributions import NormalParamExtractor

# Data Collection
from torchrl.collectors import SyncDataCollector
from torchrl.data.replay_buffers import ReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.data.replay_buffers.storages import LazyTensorStorage

# Env
from torchrl.envs import RewardSum, TransformedEnv
from torchrl.envs.libs.vmas import VmasEnv
from torchrl.envs.utils import (
    check_env_specs,
    ExplorationType,
    set_exploration_type,
)

# Multi Agent
from torchrl.modules import MultiAgentMLP, ProbabilisticActor, TanhNormal

# Loss
from torchrl.objectives import ClipPPOLoss, ValueEstimators
from torchrl.objectives.value import GAE

import vmas
from tqdm import tqdm

# For Configuration Management
import hydra

# For Logging and Saving Models
import wandb


## Utils

These utility functions help with mapping configuration file strings to python objects.

In [14]:
from torch import nn
def getActivationFunction(name, **kwargs):
    try:
        activation_class = getattr(nn, name)
        return activation_class(**kwargs)
    except AttributeError:
        raise ValueError(f"Activation function '{name}' is not found in torch.nn.")

def getActivationClass(name):
    try:
        activation_class = getattr(nn, name)
        return activation_class
    except AttributeError:
        raise ValueError(f"Activation function '{name}' is not found in torch.nn.")



  and should_run_async(code)


The filter_argv function is an artifact from using hydra-core within google CoLab and will not be needed in other, more formal environments

In [15]:
import sys
def filter_argv():
  print(sys.argv)
  # Keep only arguments before the '-f' argument which is added by Colab
  if '-f' in sys.argv:
    sys.argv = sys.argv[:sys.argv.index('-f')]

# Device Selection
is_fork = torch.multiprocessing.get_start_method() == "fork"
device = (
    torch.device(0)
    if torch.cuda.is_available() and not is_fork
    else torch.device("cpu")
)
device = (torch.device("cpu"))
vmas_device = device

## Environment

Here we just create the environment. A good note to make is the necessity to map keys in multi-agent environments, since out of the box TorchRL modules will assume single agent, and therefore not recognize the 'agent' field's layer.

For example here we sum the rewards up over the entire episode and specify the transformation to write *within* the agent.episode_reward field.

[More information about multi-agent environments](https://pytorch.org/rl/stable/reference/envs.html#multi-agent-environments)

In [16]:
# Environment
def createEnvironment(cfg):
    if cfg.env.type == 'vmas':
        env = VmasEnv(
            scenario=cfg.env.name,
            num_envs=cfg.env.num_envs,
            continuous_actions=True,
            max_steps=cfg.env.max_episode_steps,
            device=vmas_device,
            n_agents = cfg.env.num_agents
        )
    else:
        print("Error in createEnvironment: Unsupported Env Type.")
        quit()


    env = TransformedEnv(
        env,
        RewardSum(in_keys=[env.reward_key], out_keys=[("agents", "episode_reward")]),
    )

    print("Setting Seed for Env.")
    env.set_seed(cfg.alg.seed) # Set Env Seed
    check_env_specs(env)

    return env


## Actor/Policy Network Module

Without going into detail much, MAPPO is a version of PPO, which uses two networks to learn how to make 'good' decisions. The *policy* (also called 'actor') network defines the behavior of each agent.

In [17]:
# Policy
def createPolicy(cfg, env, load_pretrained=False):
    activation_cls = getActivationClass(cfg.alg.neural_network.mlp.activation_fn)
    policy_net = torch.nn.Sequential(
        MultiAgentMLP(
            n_agent_inputs=env.observation_spec["agents", "observation"].shape[
                -1
            ],  # n_obs_per_agent
            n_agent_outputs=2 * env.action_spec.shape[-1],  # 2 * n_actions_per_agents
            n_agents=cfg.env.num_agents,
            centralised=cfg.alg.marl.centralized_policy,  # the policies are decentralised (ie each agent will act from its observation)
            share_params=cfg.alg.marl.share_policy_params,
            device=device,
            depth=cfg.alg.neural_network.mlp.num_layers,
            num_cells=cfg.alg.neural_network.mlp.width,
            activation_class=activation_cls,
        ),
        NormalParamExtractor(),  # this will just separate the last dimension into two outputs: a loc and a non-negative scale
    )

    policy_module = TensorDictModule(
        policy_net,
        in_keys=[("agents", "observation")],
        out_keys=[("agents", "loc"), ("agents", "scale")],
    )

    policy_module = ProbabilisticActor(
        module=policy_module,
        spec=env.unbatched_action_spec,
        in_keys=[("agents", "loc"), ("agents", "scale")],
        out_keys=[env.action_key],
        distribution_class=TanhNormal,
        distribution_kwargs={
            "min": env.unbatched_action_spec[env.action_key].space.low,
            "max": env.unbatched_action_spec[env.action_key].space.high,
        },
        return_log_prob=True,
        log_prob_key=("agents", "sample_log_prob"),
    )  # we'll need the log-prob for the PPO loss

    print("Running policy:", policy_module(env.reset()))
    if load_pretrained and cfg.wab_config.use_wab:
        policy_artifact = wandb.run.use_artifact("/".join([cfg.wab_config.org, cfg.wab_config.project, cfg.wab_config.policy_model.name]) + f":{cfg.wab_config.policy_model.load_artifact_tag}" , type='model')
        policy_artifact_dir = policy_artifact.download()
        policy_module.load_state_dict(torch.load(policy_artifact_dir + "/" + cfg.model_paths.policy_model))
    return policy_module


## Critic/Value Network Module

The second of the two networks PPO uses to learn how to make 'good' decisions is the *value* network. This network approximates how 'good' the actions the agent is taking are by trying to learn the 'value' of the states it finds itself in. Because of this 'critiquing', this network is also often refered to as the 'critic'.

In [18]:
# Critic
def createValue(cfg, env, load_pretrained=False):
    #activation_cls = utils.getActivationClass(cfg.alg.neural_network.mlp.activation_fn)
    activation_cls = getActivationClass(cfg.alg.neural_network.mlp.activation_fn)
    critic_net = MultiAgentMLP(
        n_agent_inputs=env.observation_spec["agents", "observation"].shape[-1],
        n_agent_outputs=1,  # 1 value per agent TODO get from env?
        #n_agents=env.n_agents,
        n_agents=cfg.env.num_agents,
        centralised=cfg.alg.marl.mappo,
        share_params=cfg.alg.marl.share_critic_params,
        device=device,
        depth=cfg.alg.neural_network.mlp.num_layers,
        num_cells=cfg.alg.neural_network.mlp.width,
        activation_class=activation_cls,
    )

    value_module = TensorDictModule(
        module=critic_net,
        in_keys=[("agents", "observation")],
        out_keys=[("agents", "state_value")],
    )

    print("Running value:", value_module(env.reset())) # TODO - if not needed, remove passed in env
    if load_pretrained and cfg.wab_config.use_wab:
        value_artifact = wandb.run.use_artifact("/".join([cfg.wab_config.org, cfg.wab_config.project, cfg.wab_config.value_model.name]) + f":{cfg.wab_config.value_model.load_artifact_tag}" , type='model')
        value_artifact_dir = value_artifact.download()
        value_module.load_state_dict(torch.load(value_artifact_dir + "/" + cfg.model_paths.value_model))
    return value_module


## Data Collection / Storage

This code snippet is one way how to implement the collector and replay buffer we talked about earlier at the top of this notebook.

Note that because PPO is *on policy*, a replay buffer is not strictly needed; however here it is implemented for the sake of learning.

In [19]:
# Collector
def createCollector(cfg, env, policy_module):
    collector = SyncDataCollector(
        env,
        policy_module,
        device=vmas_device,
        storing_device=device,
        frames_per_batch=cfg.data_collection.frames_per_batch,
        total_frames=cfg.data_collection.total_frames,
        split_trajs=cfg.data_collection.split_trajs,
    )

    return collector

# Replay Buffer
def createReplayBuffer(cfg ):
    replay_buffer = ReplayBuffer(
        storage=LazyTensorStorage(max_size=cfg.data_collection.frames_per_batch, device=device),
        sampler=SamplerWithoutReplacement(),
        batch_size=cfg.alg.sub_batch_size
    )

    return replay_buffer


## Loss Function

Of course, because we're learning, we need an evaluation metric of how the networks are performing w.r.t its experiences. Just like in supervised ML where we try to steer predictions towards the labeled values, the loss module attempts to steer actions toward those *expected* to produce higher cumulative rewards.

In [20]:
# Advantage Module
def createAdvantageModule(cfg, loss_module):
    loss_module.make_value_estimator(
        ValueEstimators.GAE,
        gamma=cfg.alg.gamma,
        lmbda=cfg.alg.lmbda
    )  # We build GAE
    advantage_module = loss_module.value_estimator

    return advantage_module

# Loss Module
def createLossModule(cfg, policy_module, value_module, env):
    loss_module = ClipPPOLoss(
        actor_network=policy_module,
        critic_network=value_module,
        clip_epsilon=cfg.alg.clip_epsilon,
        entropy_bonus=bool(cfg.alg.entropy_eps),
        entropy_coef=cfg.alg.entropy_eps,
        # these keys match by default but we set this for completeness
        critic_coef=cfg.alg.vf_coef,
        loss_critic_type="smooth_l1",
        normalize_advantage=False, # IMPORTANT for MARL: We do not want to be normalizing across the agent dimensions
    )

    loss_module.set_keys(  # We have to tell the loss where to find the keys
        reward=env.reward_key,
        action=env.action_key,
        sample_log_prob=("agents", "sample_log_prob"),
        value=("agents", "state_value"),
        # These last 2 keys will be expanded to match the reward shape
        done=("agents", "done"),
        terminated=("agents", "terminated"),
    )

    optim = torch.optim.Adam(loss_module.parameters(), cfg.alg.lr)

    if cfg.alg.lr_scheduler:
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optim, cfg.data_collection.total_frames // cfg.data_collection.frames_per_batch, 0.0
        )
    else:
        scheduler = None

    return loss_module, optim, scheduler

## Visualization

In [None]:
%%capture
!apt-get update
!apt-get install -y x11-utils
!apt-get install -y xvfb
!pip install pyvirtualdisplay
!pip install pyopengl
!apt install python3-openg

In [None]:
%%capture
!apt-get install -y python-opengl
!apt-get install -y libglu1-mesa libgl1-mesa-glx libosmesa6
!pip install PyOpenGL PyOpenGL_accelerate


In [None]:
@hydra.main(config_path="/content", config_name="config.yaml", version_base=None)
def renderEnvIndependent(cfg, policy=None):
  import pyvirtualdisplay
  display = pyvirtualdisplay.Display(visible=False, size=(1400, 900))
  display.start()
  from PIL import Image

  print("Creating Env...")
  env = createEnvironment(cfg)

  print("Rendering Callback...")
  def rendering_callback(env, td):
      env.frames.append(Image.fromarray(env.render(mode="rgb_array")))

  env.frames = []
  print("Rolling Out...")
  with torch.no_grad():
    if policy is None:
      print("Policy Omitted")
      filename="before_training.gif"
      env.reset()
      env.rollout(
          #max_steps=cfg.env.max_episode_steps,
          max_steps=100,
          callback=rendering_callback,
          auto_cast_to_device=True,
          break_when_any_done=False,
      )
    else:
      print("Policy Provided")
      filename="after_training.gif"
      env.rollout(
          max_steps=100,
          policy=policy,
          callback=rendering_callback,
          auto_cast_to_device=True,
          break_when_any_done=False,
      )

  print("Saving...")
  env.frames[0].save(
      filename,
      save_all=True,
      append_images=env.frames[1:],
     duration=3,
     loop=0,
  )

def renderEnvTraining(env, policy):
  import pyvirtualdisplay
  display = pyvirtualdisplay.Display(visible=False, size=(1400, 900))
  display.start()
  from PIL import Image

  print("Rendering Callback...")
  def rendering_callback(env, td):
      env.frames.append(Image.fromarray(env.render(mode="rgb_array")))

  env.frames = []
  print("Rolling Out...")
  with torch.no_grad():
      env.rollout(
          max_steps=100,
          policy=policy,
          callback=rendering_callback,
          auto_cast_to_device=True,
          break_when_any_done=False,
      )

  print("Saving...")
  filename="after_training.gif"
  env.frames[0].save(
      filename,
      save_all=True,
      append_images=env.frames[1:],
     duration=3,
     loop=0,
  )



In [None]:
if __name__ == "__main__":
    filter_argv()
    renderEnvIndependent()
    print("Rendering")

In [None]:
from IPython.display import Image
Image(open(f"before_training.gif", "rb").read())


## Training

The below is the main training loop where the individual components we created with purpose earlier get to make themselves useful.

The first thing we do is save off our initial models without any training (I like to do this even when setting seeds to compare later).

During the training loop itself, we:
* Collect trajectories from within our collector (policies acting in the environment - e.g., rollouts)
* Modify the structure of the received tensordicts to fit what is expected.
* This batch is placed into the replay buffer.
* We train for a number of epochs where:
  * Sub-batches are sampled from the replay buffer.
  * The loss module computes the loss function components.
  * The loss function is optimized over using the typical pytorch methods.


Everything following the optimization is for logging, artifact store creation, and visualization only.

In [None]:
@hydra.main(config_path="/content", config_name="config.yaml", version_base=None)
def train(cfg):
    print(cfg)
    print("Setting Seed for all but Env (done in env module)")
    torch.manual_seed(cfg.alg.seed)
    torch.cuda.manual_seed_all(cfg.alg.seed)

    if cfg.wab_config.use_wab:
        wandb.init(
            # set the wandb project where this run will be logged
            project=cfg.wab_config.project,
            # track hyperparameters and run metadata
            config=dict(cfg)
        )
    env = createEnvironment(cfg)
    policy_module = createPolicy(cfg, env, load_pretrained=cfg.alg.load_pretrained)
    value_module = createValue(cfg=cfg, env=env, load_pretrained=cfg.alg.load_pretrained)
    collector = createCollector(cfg=cfg, env=env, policy_module=policy_module)
    replay_buffer = createReplayBuffer(cfg)
    loss_module, optim, scheduler = createLossModule(cfg=cfg, policy_module=policy_module, value_module=value_module, env=env)
    advantage_module = createAdvantageModule(cfg=cfg, loss_module=loss_module)

    # Set up training data collection
    logs = defaultdict(list)
    pbar = tqdm(total=cfg.data_collection.total_frames)
    eval_str = ""

    # Save models before training
    with set_exploration_type(ExplorationType.MEAN), torch.no_grad():
        policy_module.eval()

        if cfg.wab_config.use_wab:
            # Save Actor / Policy Network
            ## Save Model Locally
            torch.save(policy_module.state_dict(), cfg.model_paths.base_path + cfg.model_paths.initial_policy_model)
            ## Push to WandB
            policy_model_path = cfg.model_paths.base_path + cfg.model_paths.initial_policy_model
            policy_model_artifact = wandb.Artifact(cfg.wab_config.initial_policy_model.name, type='model')
            policy_model_artifact.add_file(local_path=policy_model_path)
            wandb.run.log_artifact(policy_model_artifact)

            # Save Critic / Value Network
            ## Save Model Locally
            torch.save(value_module.state_dict(), cfg.model_paths.base_path + cfg.model_paths.initial_value_model)
            ## Push to WandB
            value_model_path = cfg.model_paths.base_path + cfg.model_paths.initial_value_model
            value_model_artifact = wandb.Artifact(cfg.wab_config.initial_value_model.name, type='model')
            value_model_artifact.add_file(local_path=value_model_path)
            wandb.run.log_artifact(value_model_artifact)



    # Iterate over collector until the total number of frames it was created with is reached:
    for i, tensordict_data in enumerate(collector):
        training_wandb_log = defaultdict(list)
        eval_wandb_log = defaultdict(list)
        # Reformatting collected data for shared parameters to per-agent basis
        tensordict_data.set(
            ("next", "agents", "done"),
            tensordict_data.get(("next", "done"))
            .unsqueeze(-1)
            .expand(tensordict_data.get_item_shape(("next", env.reward_key))),
        )
        tensordict_data.set(
            ("next", "agents", "terminated"),
            tensordict_data.get(("next", "terminated"))
            .unsqueeze(-1)
            .expand(tensordict_data.get_item_shape(("next", env.reward_key))),
        )

        # TRAINING SECTION
        # For each batch (designated at collector creation), learn:
        for _ in range(cfg.alg.num_epochs):
            # We recompute the advantage signal for PPO at each epoch since its value depends on the
            #   value network, which is constantly updated.
            with torch.no_grad():
                advantage_module(
                    tensordict_data,
                    params=loss_module.critic_network_params,
                    target_params=loss_module.target_critic_network_params,
                )  # Compute GAE and add it to the dat
            data_view = tensordict_data.reshape(-1) # Flatten the tensor

            replay_buffer.extend(data_view.cpu()) # Push this flattened batch of data into the replay buffer
            # Iterate over the number of sub-batches for PPO within each batch
            for _ in range(cfg.data_collection.frames_per_batch // cfg.alg.sub_batch_size):
                subdata = replay_buffer.sample(cfg.alg.sub_batch_size)
                loss_values = loss_module(subdata.to(device))
                loss_value = (
                    loss_values["loss_objective"]
                    + loss_values["loss_critic"]
                    #+ loss_values["loss_entropy"]
                )

                # Optimization Steps
                ## Compute the gradient
                loss_value.backward()
                ## Clip the gradient just calculated
                torch.nn.utils.clip_grad_norm_(loss_module.parameters(), cfg.alg.max_grad_norm)
                ## Perform Optimization
                optim.step()
                ## Zero out the gradient for next iteration
                optim.zero_grad()

        # Now we're out of the batch training loop, and we can update our progress
        ## Put the mean erward of this batch into the logs
        done = tensordict_data.get(("next", "agents", "done"))
        logs["reward"].append(tensordict_data["next", "agents", "episode_reward"][done].mean().item())
        ## Update the progress bar with the number of elements (numel) in this batch's tensordict
        pbar.update(tensordict_data.numel())
        # Update the cumulative reward string that is used for displaying progress.
        #   pbar is updated after evaluation section with relevant information.
        cum_reward_str = (
            f"average reward={logs['reward'][-1]: 4.4f} (init={logs['reward'][0]: 4.4f})"
        )
        # Add the learning rate (from the optimizers params) to the log's step count tracker
        logs["lr"].append(optim.param_groups[0]["lr"])
        # Update the stepcount string that is used for displaying progress.
        #   pbar is updated after evaluation section with relevant information.
        lr_str = f"lr policy: {logs['lr'][-1]: 4.4f}"

        training_wandb_log["average_reward"] = logs['reward'][-1]
        training_wandb_log["lr"] = logs['lr'][-1]


        # MID-TRAINING EVALUATION SECTION
        ## Evaluate
        ### i is the number of batches of data
        if i % cfg.eval.evaluation_frequency == 0:
            with set_exploration_type(ExplorationType.MEAN), torch.no_grad():
                # Execute rollout with current, trained policy
                evaluation_rollout = env.rollout(cfg.eval.evaluation_horizon_steps, policy_module)
                logs["eval reward"].append(evaluation_rollout["next", "agents", "episode_reward"].mean().item())
                logs["eval reward (sum)"].append(evaluation_rollout["next", "agents", "episode_reward"].sum().item())

                # From these, construct the evaluation string used for progress bar updates
                eval_str = (
                    f"eval cumulative reward: {logs['eval reward (sum)'][-1]: 4.4f} "
                    f"(init: {logs['eval reward (sum)'][0]: 4.4f}), "
                )
                eval_wandb_log['eval reward'] = logs["eval reward"][-1]
                eval_wandb_log['eval reward (sum)'] = logs["eval reward (sum)"][-1]

                del evaluation_rollout # Clear the evaluated rollout - we no longer need it.

        if cfg.wab_config.use_wab:
            wandb.log({
                "batch": i,
                "training": training_wandb_log,
                "eval": eval_wandb_log
            })

        del training_wandb_log
        del eval_wandb_log

        pbar.set_description(", ".join([eval_str, cum_reward_str, lr_str]))

        # Take a step in the learning rate scheduler.
        # Not required in PPO but can often lead to better performance.
        if cfg.alg.lr_scheduler:
            scheduler.step()

    # Save model off to Open Neural Network Exchange format
    with set_exploration_type(ExplorationType.MEAN), torch.no_grad():
        policy_module.eval()

        if cfg.wab_config.use_wab:
            # Save Actor / Policy Network
            ## Save Model Locally
            torch.save(policy_module.state_dict(), cfg.model_paths.base_path + cfg.model_paths.policy_model)
            ## Push to WandB
            policy_model_path = cfg.model_paths.base_path + cfg.model_paths.policy_model
            policy_model_artifact = wandb.Artifact(cfg.wab_config.policy_model.name, type='model')
            policy_model_artifact.add_file(local_path=policy_model_path)
            wandb.run.log_artifact(policy_model_artifact)

            # Save Critic / Value Network
            ## Save Model Locally
            torch.save(value_module.state_dict(), cfg.model_paths.base_path + cfg.model_paths.value_model)
            ## Push to WandB
            value_model_path = cfg.model_paths.base_path + cfg.model_paths.value_model
            value_model_artifact = wandb.Artifact(cfg.wab_config.value_model.name, type='model')
            value_model_artifact.add_file(local_path=value_model_path)
            wandb.run.log_artifact(value_model_artifact)

    # Render after Training
    renderEnvTraining(env, policy_module)

if __name__ == "__main__":
    print("Training...")
    start_time = time.perf_counter()
    filter_argv()
    train()
    end_time = time.perf_counter()


## After Training Visualization

In [None]:
Image(open(f"after_training.gif", "rb").read())