In [1]:
import torch

# Tensordict modules
from tensordict.nn import TensorDictModule
from tensordict.nn.distributions import NormalParamExtractor
from torch import multiprocessing

# Data collection for training
from torchrl.collectors import SyncDataCollector
from torchrl.data.replay_buffers import ReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.data.replay_buffers.storages import LazyTensorStorage

# Environment
from torchrl.envs import RewardSum, TransformedEnv
from torchrl.envs.libs.vmas import VmasEnv
from torchrl.envs.utils import check_env_specs

# Multi-Agent Network
from torchrl.modules import MultiAgentMLP, ProbabilisticActor, TanhNormal

# Objectives
from torchrl.objectives import ClipPPOLoss, ValueEstimators

# Utils
from matplotlib import pyplot as plt
from rich.console import Console
from rich.progress import Progress

In [2]:
console = Console()

In [3]:
# Set the seed
torch.manual_seed(0)

# Define Hyperparameters
is_fork = multiprocessing.get_start_method() == "fork"
device = (
    torch.device(0)
    if torch.cuda.is_available() and not is_fork
    else torch.device("cpu")
)
vmas_device = device

In [4]:
# Sampling config
frames_per_batch = 1_000
n_iters = 10
total_frames = frames_per_batch * n_iters

# Training config
n_epochs = 30
minibatch_size = 400
lr = 1e-4  # Learning rate
max_grad_norm = 1.0  # Max norm for gradients

# PPO config
clip_epsilon = 0.2
gamma = 0.99
lmbda = 0.9
entropy_eps = 1e-4

In [5]:
# Create VMAS Navigation env
max_steps = 100
n_vmas_envs = frames_per_batch // max_steps
scenario_name = "navigation"
n_agents = 3

env = VmasEnv(
    scenario=scenario_name,
    num_envs=n_vmas_envs,
    max_steps=max_steps,
    continuous_actions=True,
    device=vmas_device,
    # Custom args for navigation env
    n_agents=n_agents,
)

In [6]:
console.print("action_spec:", env.full_action_spec)
console.print("reward_spec:", env.full_reward_spec)
console.print("done_spec:", env.full_done_spec)
console.print("observation_spec:", env.observation_spec)

In [7]:
console.print(f"Action keys: {env.action_keys}")
console.print(f"Reward keys: {env.reward_keys}")
console.print(f"Done key: {env.done_keys}")

In [8]:
env = TransformedEnv(
    env=env,
    transform=RewardSum(
        in_keys=[env.reward_key], out_keys=[("agents", "episode_reward")]
    )
)

In [9]:
check_env_specs(env=env)

2024-11-17 13:07:24,799 [torchrl][INFO] check_env_specs succeeded!


In [10]:
n_steps = 5
rollout = env.rollout(max_steps=n_steps)

console.print(f"Rollout of {n_steps} steps: {rollout}")
console.print(f"Shape of rollout TensorDict = {rollout.shape}")

In [11]:
share_params = True
policy_net = torch.nn.Sequential(
    MultiAgentMLP(
            n_agent_inputs=env.observation_spec["agents", "observation"].shape[-1],
            n_agent_outputs=env.n_agents,
            n_agents=env.n_agents,
            centralized=False,
            share_params=share_params,
            device=device,
            depth=2,
            num_cells=256,
            activation_class=torch.nn.Tanh,
    ),
    NormalParamExtractor()
)

In [12]:
policy_module = TensorDictModule(
    policy_net,
    in_keys=[("agents", "observation")],
    out_keys=[("agents", "loc"), ("agents", "scale")]
)

In [13]:
policy = ProbabilisticActor(
    module=policy_module,
    spec=env.unbatched_action_spec,
    in_keys=[("agents", "loc"), ("agents", "scale")],
    out_keys=[env.action_key],
    distribution_class=TanhNormal,
    distribution_kwargs={
        "low": env.unbatched_action_spec[env.action_key].space.low,
        "high": env.unbatched_action_spec[env.action_key].space.high,
    },
    return_log_prob=True,
    log_prob_key=("agents", "sample_log_prob")  # Log proba required for PPO loss
)

In [24]:
share_critic_params = True
mappo = True

critic_net = MultiAgentMLP(

    n_agent_inputs=env.observation_spec["agents", "observation"].shape[-1],
    n_agent_outputs=1,
    n_agents=env.n_agents,
    centralized=mappo,
    share_params=share_critic_params,
    device=device,
    depth=2,
    num_cells=256,
    activation_class=torch.nn.Tanh
)

critic = TensorDictModule(
    module=critic_net,
    in_keys=[("agents", "observation")],
    out_keys=[("agents", "state_value")],
)

In [25]:
console.print(f"Running Policy: {policy(env.reset())}")
console.print(f"Running Critic: {critic(env.reset())}")

In [26]:
collector = SyncDataCollector(
    env,
    policy,
    device=vmas_device,
    storing_device=device,
    frames_per_batch=frames_per_batch,
    total_frames=total_frames
)

In [27]:
replay_buffer = ReplayBuffer(
    storage=LazyTensorStorage(
        frames_per_batch,
        device=device
    ),
    sampler=SamplerWithoutReplacement(),
    batch_size=minibatch_size
)

In [28]:
loss_module = ClipPPOLoss(
    actor_network=policy,
    critic_network=critic,
    clip_epsilon=clip_epsilon,
    entropy_coef=entropy_eps,
    normalize_advantage=False
)
loss_module.set_keys(
    reward=env.reward_key,
    action=env.action_key,
    sample_log_prob=("agents", "state_value"),
    done=("agents", "done"),
    terminated=("agents", "terminated"),
)

loss_module.make_value_estimator(
    ValueEstimators.GAE, gamma=gamma, lmbda=lmbda
)
GAE = loss_module.value_estimator
optim = torch.optim.Adam(loss_module.parameters(), lr)

KeyError: "value key 'state_value' not found in value network out_keys [('agents', 'state_value')]"