In [3]:
import gymnasium as gym
import functools
import random
from copy import copy, deepcopy

import numpy as np
from gymnasium.spaces import Discrete, MultiDiscrete, Box, Dict

from pettingzoo import ParallelEnv


## torch学習部分
from ray.rllib.models import ModelCatalog
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
import torch
from torch import nn


## ray部分
import ray
from ray.rllib.algorithms.impala import ImpalaConfig, Impala
from ray import air
from ray import tune

from ray.air.integrations.wandb import WandbLoggerCallback
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.tune.logger import CSVLoggerCallback
from ray.tune.registry import register_env, get_trainable_cls
from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv
from ray.rllib.policy.policy import Policy

In [4]:
PROJ_NAME = "test_PvP_cartpole_zero_cum_rew"

# 実行設定
# IS_COLAB = True
USE_WANDB = True
NUM_GPUS = 1 if torch.cuda.is_available() else 0
CHECKPOINT_DIR = "/home/s2430014/research/pvp-cartpole/log"

# ディレクトリ関係
COMMON_CONFIG_PATH = "/home/s2430014/research/common_config.ini"
CALLBACK_PATH = "/home/s2430014/research/pvp-cartpole/callback"
MODEL_SAVE_FREQ = 100

# 環境関連
BASE_ALIVE_TIME= 25

# モデル関連
HIDDEN_DIM = 128
HIDDEN_DEPTH = 8
TRAINING_ITER = 500
LR = 1e-3

# self-play関連
WIN_RATE_THRESHOLD = 0.58
ALGO = "IMPALA" #???
FRAMEWORK = "torch" #???
NUM_ENV_RUNNERS = 2
STOP_TIMESTEPS = 2000
STOP_ITERS = 10000000

# INITIAL_EPSILON = 1.0
# FINAL_EPSILON = 0.1

#モデルを訓練するかどうか
TRAIN = True

In [5]:
# 時間情報を取得
from datetime import datetime, timezone, timedelta
from zoneinfo import ZoneInfo

now = datetime.now(ZoneInfo("Asia/Tokyo"))
time_code = now.strftime("%Y%m%d_%H:%M:%S")
print(time_code)

# wandbにlogin
import configparser

config_ini = configparser.ConfigParser()
config_ini.read(COMMON_CONFIG_PATH, encoding='utf-8')
api_key = config_ini['WANDB']['api_key']

20241005_12:05:18


In [32]:
class CustomTorchModel(TorchModelV2, nn.Module):
    def __init__(self, obs_space, action_space, num_outputs, model_config, name):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name)
        nn.Module.__init__(self)

        # print(f'num_outputs = {num_outputs}')
        # assert num_outputs == 2, f'Assertion failed: num_outputs must be 2 but num_outputs={num_outputs}'
        self.num_outputs = num_outputs
        self._num_objects = obs_space.shape[0]
        self._num_actions = num_outputs

        self.hidden_depth = HIDDEN_DEPTH

        layers = [nn.Linear(self._num_objects, HIDDEN_DIM), nn.ReLU()]
        for i in range(HIDDEN_DEPTH):
            layers.append(nn.Linear(HIDDEN_DIM, HIDDEN_DIM))
            layers.append(nn.ReLU())
        self.layers = nn.Sequential(*layers)

        # value function
        self.vf_head = nn.Linear(HIDDEN_DIM, 1)

        # action logits（ソフトマックスは適用しない）
        self.ac_head = nn.Linear(HIDDEN_DIM, num_outputs)

    def forward(self, input_dict, state, seq_lens):
        x = input_dict["obs"]
        # assert not isinstance(x, collections.OrderedDict) , f'input is orderdict {x}'
        x = self.layers(x)
        logits = self.ac_head(x)
        self.value_out = self.vf_head(x)
        return logits, []

    def value_function(self):
        return torch.reshape(self.value_out, (-1,))


In [91]:
checkpoint_dir = "main_v380.pt"

In [92]:
env = gym.make('CartPole-v1')

In [93]:
ge = gym.make('CartPole-v1')
action_space = ge.action_space
observation_space = Box(
            high = np.concatenate([ge.observation_space.high, [1.0], [1.0]]),
            low = np.concatenate([ge.observation_space.low, [0.0], [0.0]]),
            shape = (6,),
            dtype = np.float32
)

In [94]:
def load_model(checkpoint_path):
    model = CustomTorchModel(
        obs_space = observation_space,
        action_space = action_space,
        num_outputs = 2,
        model_config = {},
        name = "temp"
    )
    model.load_state_dict(torch.load(checkpoint_dir))
    model.eval()  # 評価モードに設定
    return model

In [95]:
m = load_model(checkpoint_dir)

In [96]:
def select_action(model, obs, role):
    # observationにroleを追加する部分はselect actionで吸収
    obs = np.concatenate([obs, [role, 1]]) # 次が自分のターンであることを表す1

    with torch.no_grad():

        obs_tensor = torch.tensor(obs, dtype=torch.float32)
        input_dict = {"obs": obs_tensor}
        action = model(input_dict, None, None)
        return torch.argmax(action[0]).item()

In [100]:
m = load_model(checkpoint_dir)

def experiment(role):
  li = []
  action_list = []
  for j in range(100):
    obs, info = env.reset()

    act_cnt  = 0
    for i in range(50):

      res = select_action(m, obs, role)
      action_list.append(res)
        
      obs, v, terminated, _, _ = env.step(res)
      act_cnt += 1
      if terminated:
        break
    li.append(act_cnt)
  return li, action_list


def random_experiment():
  li = []
  action_list = []
  for j in range(100):
    obs, info = env.reset()

    act_cnt  = 0
    for i in range(50):
      rng = np.random.default_rng()

      res = np.argmax(rng.random(2))
      action_list.append(res)
        
      obs, v, terminated, _, _ = env.step(res)
      act_cnt += 1
      if terminated:
        break
    li.append(act_cnt)
  return li, action_list

def all1_experiment():
  li = []
  action_list = []
  
  for j in range(100):
    obs, info = env.reset()

    act_cnt  = 0
    for i in range(50):
      res = 1
      obs, v, terminated, _, _ = env.step(res)
      action_list.append(res)
      act_cnt += 1
      if terminated:
        break
    li.append(act_cnt)
  return li, action_list

In [101]:
li_0, a0 = experiment(0)
li_1, a1 = experiment(1)
li_r, ar = random_experiment()
li_all1, aall1 = all1_experiment()

In [107]:
print(f"act as stabilizer..... mean game len {np.array(li_0).mean()}")
print(f"act as disturber...... mean game len {np.array(li_1).mean()}")
print(f"act randomly.......... mean game len {np.array(li_r).mean()}")
print(f"choose 1 every time... mean game len {np.array(li_1).mean()}")

act as stabilizer..... mean game len 44.62
act as disturber...... mean game len 10.05
act randomly.......... mean game len 23.97
choose 1 every time... mean game len 10.05
