In [13]:
from til_environment import gridworld

env = gridworld.env()

In [14]:
# Torch
import torch
import torch.nn as nn

# Tensordict modules
from tensordict.nn import set_composite_lp_aggregate, TensorDictModule, TensorDictSequential
from tensordict import  TensorDictBase
from torch import multiprocessing

# Data collection
from torchrl.collectors import SyncDataCollector
from torch.distributions import Categorical
from torchrl.data.replay_buffers import ReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.data.replay_buffers.storages import LazyTensorStorage

#Env
from torchrl.envs import RewardSum, TransformedEnv, PettingZooWrapper, Compose, DoubleToFloat, StepCounter, ParallelEnv, EnvCreator, ExplorationType, set_exploration_type

# Utils
from torchrl.envs.utils import check_env_specs

# Multi-agent network
from torchrl.modules import MultiAgentMLP, ProbabilisticActor, TanhNormal

# Loss
from torchrl.objectives import ClipPPOLoss, ValueEstimators

# Utils
torch.manual_seed(0)
from matplotlib import pyplot as plt
from tqdm import tqdm

In [15]:
from torchrl.envs import PettingZooWrapper

group_map = {
    "scout": ["player_0"],
    "guards": ["player_1", "player_2", "player_3"]
}

raw_env = gridworld.env()
env = PettingZooWrapper(raw_env, use_mask=True)

In [16]:
env.rollout(5)

TensorDict(
    fields={
        done: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                player_0: TensorDict(
                    fields={
                        done: Tensor(shape=torch.Size([5, 1, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                        mask: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                        observation: Tensor(shape=torch.Size([5, 1, 576]), device=cpu, dtype=torch.int64, is_shared=False),
                        reward: Tensor(shape=torch.Size([5, 1, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                        terminated: Tensor(shape=torch.Size([5, 1, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                        truncated: Tensor(shape=torch.Size([5, 1, 1]), devi

In [5]:
is_fork = multiprocessing.get_start_method() == "fork"
device = (
    torch.device(0)
    if torch.cuda.is_available() and not is_fork
    else torch.device("cpu")
)
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


#Parameters for Env
n_parallel_envs = 2  # Number of parallel environments

# Sampling
frames_per_batch = 2_000  # Number of team frames collected per training iteration
total_frames = 200_000

# Training
num_epochs = 5  # Number of optimization steps per training iteration
minibatch_size = 400  # Size of the mini-batches in each optimization step
lr = 3e-4  # Learning rate
max_grad_norm = 1.0  # Maximum norm for the gradients

# PPO
clip_epsilon = 0.2  # clip value for PPO loss
gamma = 0.99  # discount factor
lmbda = 0.9  # lambda for generalised advantage estimation
entropy_eps = 1e-4  # coefficient of the entropy term in the PPO loss

# disable log-prob aggregation
set_composite_lp_aggregate(False).set()

In [12]:
agent_names = ["player_0", "player_1", "player_2", "player_3"]

# Create RewardSum transforms for each agent
reward_transforms = [
    RewardSum(
        in_keys=[(agent, "reward")],
        out_keys=[(agent, "episode_reward")]
    ) for agent in agent_names
]

# Compose all transforms
env_transforms = Compose(
    *reward_transforms,
    DoubleToFloat(),
    StepCounter()
)

# Set up the environment creation function
make_env = EnvCreator(lambda: TransformedEnv(
    PettingZooWrapper(raw_env, use_mask=True),  # call raw_env() if it's a function
    Compose(
    *reward_transforms,
    DoubleToFloat(),
    StepCounter()
)
))

# Create parallel environments
env = ParallelEnv(n_parallel_envs, make_env, serial_for_single=True)

KeyError: 'Player_0'

In [9]:
print("action_keys:", env.action_keys)
print("reward_keys:", env.reward_keys)
print("done_keys:", env.done_keys)

print("Action Spec:", env.action_spec)
print("Observation Spec:", env.observation_spec)
print("Reward Spec:", env.reward_spec)
print("Done Spec:", env.done_spec)

check_env_specs(env)

action_keys: [('player_0', 'action'), ('player_1', 'action'), ('player_2', 'action'), ('player_3', 'action')]
reward_keys: [('player_0', 'reward'), ('player_1', 'reward'), ('player_2', 'reward'), ('player_3', 'reward')]
done_keys: ['done', 'terminated', 'truncated', ('player_0', 'done'), ('player_0', 'terminated'), ('player_0', 'truncated'), ('player_1', 'done'), ('player_1', 'terminated'), ('player_1', 'truncated'), ('player_2', 'done'), ('player_2', 'terminated'), ('player_2', 'truncated'), ('player_3', 'done'), ('player_3', 'terminated'), ('player_3', 'truncated')]
Action Spec: Composite(
    player_0: Composite(
        action: Categorical(
            shape=torch.Size([2, 1]),
            space=CategoricalBox(n=5),
            device=cpu,
            dtype=torch.int64,
            domain=discrete),
        device=cpu,
        shape=torch.Size([2, 1])),
    player_1: Composite(
        action: Categorical(
            shape=torch.Size([2, 1]),
            space=CategoricalBox(n=5),

  from pkg_resources import resource_stream, resource_exists
Process _ProcessNoWarn-1:
Traceback (most recent call last):
  File "/usr/lib/python3.12/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/home/ubuntu/Desktop/real_learning/venv/lib/python3.12/site-packages/torchrl/_utils.py", line 734, in run
    return mp.Process.run(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/home/ubuntu/Desktop/real_learning/venv/lib/python3.12/site-packages/torchrl/envs/batched_envs.py", line 2163, in _run_worker_pipe_shared_mem
    env = env_fun(**env_fun_kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/Desktop/real_learning/venv/lib/python3.12/site-packages/torchrl/envs/env_creator.py", line 203, in __call__
    env = self.create_env_fn(**kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/

EOFError: 