In [1]:
import torch
from tensordict.nn import TensorDictModule
from tensordict.nn.distributions import NormalParamExtractor
from torch import multiprocessing
from torchrl.collectors import SyncDataCollector
from torchrl.data.replay_buffers import ReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.data.replay_buffers.storages import LazyTensorStorage
from torchrl.envs import RewardSum, TransformedEnv
from torchrl.envs.libs.vmas import VmasEnv
from torchrl.envs.utils import check_env_specs
from torchrl.modules import MultiAgentMLP, ProbabilisticActor, TanhNormal
from torchrl.objectives import ClipPPOLoss, ValueEstimators
from matplotlib import pyplot as plt
from tqdm import tqdm
import io
import numpy as np
from IPython.display import display, Image
import os
import PIL.Image as PILImage

# Set device based on CUDA availability
torch.manual_seed(0)
is_fork = multiprocessing.get_start_method() == "fork"
device = torch.device(0) if torch.cuda.is_available() and not is_fork else torch.device("cpu")
vmas_device = device

def setup_environment(scenario_name, n_agents, max_steps, frames_per_batch, vmas_device):
    num_vmas_envs = frames_per_batch // max_steps
    
    env = VmasEnv(
        scenario=scenario_name,
        num_envs=num_vmas_envs,
        continuous_actions=True,
        max_steps=max_steps,
        device=vmas_device,
        n_agents=n_agents,
    )
    
    env = TransformedEnv(
        env,
        RewardSum(in_keys=[env.reward_key], out_keys=[("agents", "episode_reward")])
    )
    
    check_env_specs(env)
    
    return env

def setup_policy_network(env, device, share_parameters_policy=True):
    policy_net = torch.nn.Sequential(
        MultiAgentMLP(
            n_agent_inputs=env.observation_spec["agents", "observation"].shape[-1],
            n_agent_outputs=2 * env.action_spec.shape[-1],
            n_agents=env.n_agents,
            centralised=False,
            share_params=share_parameters_policy,
            device=device,
            depth=2,
            num_cells=256,
            activation_class=torch.nn.Tanh,
        ),
        NormalParamExtractor(),
    )

    policy_module = TensorDictModule(
        policy_net,
        in_keys=[("agents", "observation")],
        out_keys=[("agents", "loc"), ("agents", "scale")],
    )
    
    policy = ProbabilisticActor(
        module=policy_module,
        spec=env.unbatched_action_spec,
        in_keys=[("agents", "loc"), ("agents", "scale")],
        out_keys=[env.action_key],
        distribution_class=TanhNormal,
        distribution_kwargs={
            "min": env.unbatched_action_spec[env.action_key].space.low,
            "max": env.unbatched_action_spec[env.action_key].space.high,
        },
        return_log_prob=True,
        log_prob_key=("agents", "sample_log_prob"),
    )
    
    return policy

def setup_value_network(env, device, mappo=True, share_parameters_critic=True):
    critic_net = MultiAgentMLP(
        n_agent_inputs=env.observation_spec["agents", "observation"].shape[-1],
        n_agent_outputs=1,
        n_agents=env.n_agents,
        centralised=mappo,
        share_params=share_parameters_critic,
        device=device,
        depth=2,
        num_cells=256,
        activation_class=torch.nn.Tanh,
    )

    critic = TensorDictModule(
        module=critic_net,
        in_keys=[("agents", "observation")],
        out_keys=[("agents", "state_value")],
    )
    
    return critic

def setup_ppo_loss(policy, critic, env, clip_epsilon=0.2, entropy_eps=1e-4, gamma=0.9, lmbda=0.9):
    loss_module = ClipPPOLoss(
        actor_network=policy,
        critic_network=critic,
        clip_epsilon=clip_epsilon,
        entropy_coef=entropy_eps,
        normalize_advantage=False,
    )
    
    loss_module.set_keys(
        reward=env.reward_key,
        action=env.action_key,
        sample_log_prob=("agents", "sample_log_prob"),
        value=("agents", "state_value"),
        done=("agents", "done"),
        terminated=("agents", "terminated"),
    )
    
    loss_module.make_value_estimator(ValueEstimators.GAE, gamma=gamma, lmbda=lmbda)
    
    return loss_module

def train_ppo(env, policy, critic, loss_module, n_iters, frames_per_batch, minibatch_size, num_epochs, max_grad_norm, lr, device):
    collector = SyncDataCollector(
        env,
        policy,
        device=vmas_device,
        storing_device=device,
        frames_per_batch=frames_per_batch,
        total_frames=frames_per_batch * n_iters,
    )

    replay_buffer = ReplayBuffer(
        storage=LazyTensorStorage(frames_per_batch, device=device),
        sampler=SamplerWithoutReplacement(),
        batch_size=minibatch_size,
    )
    
    optim = torch.optim.Adam(loss_module.parameters(), lr)
    
    pbar = tqdm(total=n_iters, desc="episode_reward_mean = 0")
    episode_reward_mean_list = []
    
    for tensordict_data in collector:
        tensordict_data.set(
            ("next", "agents", "done"),
            tensordict_data.get(("next", "done"))
            .unsqueeze(-1)
            .expand(tensordict_data.get_item_shape(("next", env.reward_key))),
        )
        tensordict_data.set(
            ("next", "agents", "terminated"),
            tensordict_data.get(("next", "terminated"))
            .unsqueeze(-1)
            .expand(tensordict_data.get_item_shape(("next", env.reward_key))),
        )

        with torch.no_grad():
            loss_module.value_estimator(
                tensordict_data,
                params=loss_module.critic_network_params,
                target_params=loss_module.target_critic_network_params,
            )

        data_view = tensordict_data.reshape(-1)
        replay_buffer.extend(data_view)

        for _ in range(num_epochs):
            for _ in range(frames_per_batch // minibatch_size):
                subdata = replay_buffer.sample()
                loss_vals = loss_module(subdata)

                loss_value = (
                    loss_vals["loss_objective"]
                    + loss_vals["loss_critic"]
                    + loss_vals["loss_entropy"]
                )

                loss_value.backward()
                torch.nn.utils.clip_grad_norm_(loss_module.parameters(), max_grad_norm)
                optim.step()
                optim.zero_grad()

        collector.update_policy_weights_()

        # Logging
        done = tensordict_data.get(("next", "agents", "done"))
        episode_reward_mean = (
            tensordict_data.get(("next", "agents", "episode_reward"))[done].mean().item()
        )
        episode_reward_mean_list.append(episode_reward_mean)
        pbar.set_description(f"episode_reward_mean = {episode_reward_mean}", refresh=False)
        pbar.update()
    
    plt.plot(episode_reward_mean_list)
    plt.xlabel("Training iterations")
    plt.ylabel("Reward")
    plt.title("Episode reward mean")
    plt.show()

def evaluate(env, policy, max_steps, scenario_name):
    env.frames = []
    
    def rendering_callback(env, td):
        frame = env.render(mode="rgb_array")
        env.frames.append(frame)

    with torch.no_grad():
        env.rollout(
            max_steps=max_steps,
            policy=policy,
            callback=rendering_callback,
            auto_cast_to_device=True,
            break_when_any_done=False,
        )

    env.frames = [PILImage.fromarray(frame) if isinstance(frame, np.ndarray) else frame for frame in env.frames]
    
    gif_path = f"{scenario_name}.gif"
    frames = env.frames
    frames[0].save(
        gif_path,
        save_all=True,
        append_images=frames[1:],
        duration=300,
        loop=0,
    )
    display(Image(gif_path))

# In the main function, pass the scenario_name when calling evaluate:
def main():
    # Configurations
    scenario_name = "navigation"
    n_agents = 3
    max_steps = 100
    frames_per_batch = 6_000
    n_iters = 10
    minibatch_size = 400
    num_epochs = 1
    lr = 0.0003
    max_grad_norm = 1.0
    
    # Environment setup
    env = setup_environment(scenario_name, n_agents, max_steps, frames_per_batch, vmas_device)
    
    # Policy and Value Networks
    policy = setup_policy_network(env, device)
    critic = setup_value_network(env, device)
    
    # PPO Loss Setup
    loss_module = setup_ppo_loss(policy, critic, env)
    
    # Train PPO
    train_ppo(env, policy, critic, loss_module, n_iters, frames_per_batch, minibatch_size, num_epochs, max_grad_norm, lr, device)
    
    # Evaluation
    evaluate(env, policy, max_steps, scenario_name)

if __name__ == "__main__":
    main()



  comms.append(torch.tensor(other_agent.state.c).unsqueeze(0).unsqueeze(0))  # Ensure it's 3D


RuntimeError: Tensors must have same number of dimensions: got 4 and 5