In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque
from torch.distributions import Categorical, Normal
from tensordict.nn import TensorDictModule
from torchrl.envs.libs.vmas import VmasEnv
from torchrl.envs import RewardSum, TransformedEnv
from torchrl.envs.utils import check_env_specs
from tensordict import TensorDict  # Add this import at the top of your script


class ActorCritic(nn.Module):
    def __init__(self, input_dim, action_dim):
        super(ActorCritic, self).__init__()
        self.actor = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, action_dim)
        )
        self.critic = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )

    def forward(self, x):
        value = self.critic(x)
        policy_dist = torch.tanh(self.actor(x))
        rounded_policy_dist = self.round_actions(policy_dist, decimal_places=2)
        return rounded_policy_dist, value

    def round_actions(self, actions, decimal_places=3):
        scale_factor = 10 ** decimal_places
        return torch.round(actions * scale_factor) / scale_factor

class ProblemSolver:
    def __init__(self, scenario_name, n_agents, max_steps, frames_per_batch, state_dim, action_dim, device, alpha=0.1, gamma=0.99, epsilon=0.2, K_epochs=4, batch_size=64, communication_weight=0.5):
        self.alpha = alpha  # Learning rate
        self.gamma = gamma  # Discount factor
        self.epsilon = epsilon  # Exploration rate
        self.K_epoch = K_epochs
        self.batch_size = batch_size
        self.q_table = {}
        self.device = device
        self.communication_weight = communication_weight  # Weight parameter for incorporating messages

        # Setup environment
        self.env = self.setup_environment(scenario_name, n_agents, max_steps, frames_per_batch, device)

        # Initialize actor-critic networks
        self.policy = ActorCritic(state_dim, action_dim).to(device)
        self.optimizer = optim.Adam(self.policy.parameters(), lr=alpha)
        self.policy_old = ActorCritic(state_dim, action_dim).to(device)
        self.policy_old.load_state_dict(self.policy.state_dict())
        self.memory = deque(maxlen=10000)

    def setup_environment(self, scenario_name, n_agents, max_steps, frames_per_batch, device):
        num_vmas_envs = frames_per_batch // max_steps
        print(f"number of environments: {num_vmas_envs}")
        env = VmasEnv(
            scenario=scenario_name,
            num_envs=num_vmas_envs,
            continuous_actions=True,
            max_steps=max_steps,
            device=device,
            n_agents=n_agents,
        )
        env = TransformedEnv(
            env,
            RewardSum(in_keys=[env.reward_key], out_keys=[("agents", "episode_reward")])
        )
        check_env_specs(env)
        return env

    def get_action_discrete(self, agent, env, agent_id, agent_obs):
        # Implementation here...
        pass

    def update_action_discrete(self, agent, env, agent_id, obs, action, reward, next_obs):
        # Implementation here...
        pass

    def get_action_continuous(self, agent_obs):
        if isinstance(agent_obs, TensorDict):  
            # Extract the 'observation' tensor for the specific agent
            state = agent_obs.get(("agents", "observation"))  
        else:
            state = torch.tensor(agent_obs[:6], dtype=torch.float32) if not isinstance(agent_obs[:6], torch.Tensor) else agent_obs[:6].float()

        state = state.to(self.device)
        
        with torch.no_grad():
            policy_dist, _ = self.policy_old(state)
        action = policy_dist.cpu().numpy()
        return action





    def update_action_continuous(self, agent_obs, action, reward, next_obs, done):
        agent_obs = torch.tensor(agent_obs[:6], dtype=torch.float32).to(self.device)
        next_obs = torch.tensor(next_obs[:6], dtype=torch.float32).to(self.device)
        action = torch.tensor(action, dtype=torch.float32).to(self.device)

        # Calculate advantage
        _, value = self.policy(agent_obs)
        _, next_value = self.policy(next_obs)
        target_value = reward + self.gamma * next_value * (~done)
        advantage = (target_value - value).detach()

        # Calculate the log probability of the action under the current policy
        mu, sigma = self.policy(agent_obs)
        sigma = torch.exp(sigma)
        dist = Normal(mu, sigma)
        log_prob = dist.log_prob(action).sum()

        # Calculate the log probability of the action under the old policy
        with torch.no_grad():
            mu_old, sigma_old = self.policy_old(agent_obs)
            sigma_old = torch.exp(sigma_old)
            dist_old = Normal(mu_old, sigma_old)
            old_log_prob = dist_old.log_prob(action).sum()

        # Policy ratio
        ratio = torch.exp(log_prob - old_log_prob)

        # PPO objective with clipping
        surrogate1 = ratio * advantage
        surrogate2 = torch.clamp(ratio, 1.0 - self.epsilon, 1.0 + self.epsilon) * advantage
        policy_loss = -torch.min(surrogate1, surrogate2).mean()

        # Value function loss
        value_loss = nn.MSELoss()(value, target_value.detach())

        # Total loss
        loss = policy_loss + 0.5 * value_loss

        # Update the actor-critic model
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def print_model_parameters(self):
        print(f"Model parameters for Agent:")
        for name, param in self.policy.named_parameters():
            print(f"  {name}: {param.data.numpy()}")
        print(f"End of model parameters\n")

def main():
    # Parameters
    scenario_name = "navigation"
    n_agents = 2
    max_steps = 100
    frames_per_batch = 1000
    state_dim = 6  # Example state dimension
    action_dim = 2  # Example action dimension
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Initialize the PPO problem solver
    problem_solver = ProblemSolver(scenario_name, n_agents, max_steps, frames_per_batch, state_dim, action_dim, device)

    # Example loop
    for episode in range(10):  # Example for 10 episodes
        agent_obs = problem_solver.env.reset()
        done = False

        while not done:
            actions = []
            for agent_id in range(n_agents):
                action = problem_solver.get_action_continuous(agent_obs[agent_id])
                actions.append(action)

            next_obs, rewards, done, _ = problem_solver.env.step(actions)

            for agent_id in range(n_agents):
                problem_solver.update_action_continuous(agent_obs[agent_id], actions[agent_id], rewards[agent_id], next_obs[agent_id], done)
            
            agent_obs = next_obs

        # Optionally print model parameters
        problem_solver.print_model_parameters()

if __name__ == "__main__":
    main()


number of environments: 10


2024-08-16 19:00:47,794 [torchrl][INFO] check_env_specs succeeded!


donesnya: tensor([False, False, False, False, False, False, False, False, False, False],
       device='cuda:0')
shape of donesnya: torch.Size([10])
shape of self step nya: torch.Size([10])
observation detected in environment: tensor([[ 0.0413, -0.5019,  0.0000,  0.0000,  0.8860,  0.4495,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [-0.8318,  0.3951,  0.0000,  0.0000, -1.6072, -0.2036,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.4802, -0.7557,  0.0000,  0.0000, -0.3290, -1.6683,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.9596, -0.3084,  0.0000,  0.0000,  0.3680,  0.4978,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [-0.6

RuntimeError: mat1 and mat2 shapes cannot be multiplied (2x18 and 6x128)