In [1]:
import os
import argparse
from typing import Optional, Tuple

from pettingzoo.mpe import simple_hmpe_v3


import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import rl_utils

from ppo import PPO

In [2]:
from torch.utils.tensorboard import SummaryWriter

# =========Logger===============
log_path = os.path.join(os.getcwd(), 'hmpe', 'ppo')
logger = SummaryWriter(log_path)

In [3]:
render_mode = 'rgb_array' # 'human' | 'rgb_array'

actor_lr = 1e-3
critic_lr = 1e-2
num_episodes = 100
hidden_dim = 128
max_cycles = 60
gamma = 0.98
lmbda = 0.95
epochs = 5
eps = 0.2
seed = 4
device = torch.device("cuda") if torch.cuda.is_available() else torch.device(
    "cpu")

env = simple_hmpe_v3.env(max_cycles=max_cycles, render_mode=render_mode)
env.reset(seed=seed)
np.random.seed(seed)
torch.manual_seed(seed)
# env._seed(seed)

state_dim = env.observation_space('agent_0').shape[0]
action_dim = env.action_space('agent_0').n
agent = PPO(state_dim, hidden_dim, action_dim, actor_lr, critic_lr, lmbda,
            epochs, eps, gamma, device)

In [4]:

loss, results = rl_utils.train_on_policy_agent(env, agent, num_episodes, logger)


  state = torch.tensor([state], dtype=torch.float).to(self.device)
Iteration 0: 100%|██████████| 10/10 [00:03<00:00,  3.15it/s, episode=10, return=-409.063]
Iteration 1: 100%|██████████| 10/10 [00:03<00:00,  3.32it/s, episode=20, return=-362.414]
Iteration 2: 100%|██████████| 10/10 [00:02<00:00,  3.38it/s, episode=30, return=-676.204]
Iteration 3: 100%|██████████| 10/10 [00:03<00:00,  3.28it/s, episode=40, return=-938.767]
Iteration 4: 100%|██████████| 10/10 [00:03<00:00,  3.25it/s, episode=50, return=-971.349]
Iteration 5: 100%|██████████| 10/10 [00:03<00:00,  3.27it/s, episode=60, return=-945.458]
Iteration 6: 100%|██████████| 10/10 [00:02<00:00,  3.43it/s, episode=70, return=-1042.098]
Iteration 7: 100%|██████████| 10/10 [00:02<00:00,  3.40it/s, episode=80, return=-1005.200]
Iteration 8: 100%|██████████| 10/10 [00:02<00:00,  3.41it/s, episode=90, return=-941.568]
Iteration 9: 100%|██████████| 10/10 [00:02<00:00,  3.38it/s, episode=100, return=-1039.585]


In [6]:
results['agent_0'][0]

{'obs': [array([ 0.        ,  0.        , -0.8       , -0.8       ,  0.7965882 ,
          0.78979176,  0.80036193,  1.5337317 ,  0.55989987,  0.35803384,
          0.8353392 ,  1.0258734 ,  1.5025713 ,  0.9312253 ,  0.42853343,
          1.4876395 ,  0.2       ,  0.        ,  0.4       ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  1.        ],
        dtype=float32),
  array([-0.        , -0.5       , -0.8       , -0.8       ,  0.7965882 ,
          0.78979176,  0.80036193,  1.5337317 ,  0.55989987,  0.35803384,
          0.8353392 ,  1.0258734 ,  1.5025713 ,  0.9312253 ,  0.42853343,
          1.4876395 ,  0.2       ,  0.        ,  0.4       ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  1.        ],
        dtype=float32),
  array([-0.        ,  0.125     , -0.8       , -0.85      ,  0.7965882 ,
          0.8397917 ,  0.80036193,  1.5837317 ,  0.55989987,  0.40803385,
          0.8353392 ,  1.0758735 ,  1.5025713 ,  0.9812

In [6]:
actor_model_save_path = os.path.join(log_path, 'actor_policy.pth')
critic_model_save_path = os.path.join(log_path, 'critic_policy.pth')


In [None]:
torch.save(agent.actor.state_dict(), actor_model_save_path)
torch.save(agent.critic.state_dict(), critic_model_save_path)

In [10]:
def test_policy(policy, test_env, test_episode=1):
    # policy.eval()
    result = []
    for i in range(test_episode):
        # print(1)
        test_env.reset(seed=4)
        for agent in env.agent_iter():
            # print(2)
            obs, rew, termination, truncation, info = env.last()
            # print(env.reward(agent, env.world))
            if termination or truncation:
                action = None
            else:
                # this is where you would insert your policy
                action = policy.take_action(obs)
            env.step(action)
            next_obs = test_env.observe(agent)
            result.append({'obs': obs, 'next_obs': next_obs, 'rew': rew})
    print(result)
    return result

In [20]:
test_env2 = simple_hmpe_v3.env(max_cycles=max_cycles, render_mode=render_mode)


In [21]:
test_env2.reset(seed=4)


In [27]:
test_env2.agent_selection = 'agent_0'
obs, rew, termination, truncation, info = test_env2.last()
action = agent.take_action(obs)

In [32]:
test_env2.scenario.reward(test_env2.world.agents[0], test_env2.world)

0.0

In [35]:
test_env2.step(action)

In [36]:
obs, rew, termination, truncation, info = test_env2.last()

In [39]:
test_env2.scenario.reward(test_env2.world.agents[0], test_env2.world)

0.0

In [11]:
result = test_policy(agent, test_env2, 1)

1
[]
