# AgileRL Speaker-Listener with MATD3
https://docs.agilerl.com/en/latest/tutorials/pettingzoo/matd3.html

In [1]:
#!pip install --upgrade pip

In [2]:
#!python3 -m pip install playsound
#!python3 -m pip install PyObjC

In [3]:
#!pip install pettingzoo[mpe]
#!pip install agilerl
#!pip install imageio

In [4]:
#現在日時を取得
import datetime

dt_now = datetime.datetime.now()
str_dt_now = dt_now.strftime("%Y%m%d-%H%M")
print(str_dt_now)

20231226-0652


In [None]:
"""
This tutorial shows how to train an MATD3 agent on the simple speaker listener multi-particle environment.

Authors: Michael (https://github.com/mikepratt1), Nickua (https://github.com/nicku-a)
"""

import os
import pprint
import datetime

import numpy as np
import torch
from pettingzoo.mpe import simple_speaker_listener_v4
from tqdm import trange

from agilerl.components.multi_agent_replay_buffer import MultiAgentReplayBuffer
from agilerl.hpo.mutation import Mutations
from agilerl.hpo.tournament import TournamentSelection
from agilerl.utils.utils import initialPopulation

if __name__ == "__main__":

    #現在日時を取得
    dt_now = datetime.datetime.now()
    str_dt_now = dt_now.strftime("%Y%m%d-%H%M")
    print(str_dt_now)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("===== AgileRL Online Multi-Agent Demo =====")

    # Define the network configuration
    # ネットワークコンフィグレーションの定義
    NET_CONFIG = {
        "arch": "mlp",  # Network architecture
        "h_size": [32, 32],  # Actor hidden size
    }

    # Define the initial hyperparameters
    # 初期ハイパーパラメータの定義
    INIT_HP = {
        "POPULATION_SIZE": 4,
        "ALGO": "MATD3",  # Algorithm
        # Swap image channels dimension from last to first [H, W, C] -> [C, H, W]
        "CHANNELS_LAST": False,
        "BATCH_SIZE": 32,  # Batch size
        "LR": 0.01,  # Learning rate
        "GAMMA": 0.95,  # Discount factor
        "MEMORY_SIZE": 100000,  # Max memory buffer size
        "LEARN_STEP": 5,  # Learning frequency
        "TAU": 0.01,  # For soft update of target parameters
        "POLICY_FREQ": 2,  # Policy frequnecy
    }

    # Define the simple speaker listener environment as a parallel environment
    # シンプル・スピーカー・リスナー環境（並列）の定義
    env = simple_speaker_listener_v4.parallel_env(continuous_actions=True)
    env.reset()

    # Configure the multi-agent algo input arguments
    # マルチ・エージェントアルゴへの入力引数の設定
    try:
        state_dim = [env.observation_space(agent).n for agent in env.agents]
        one_hot = True
    except Exception:
        state_dim = [env.observation_space(agent).shape for agent in env.agents]
        one_hot = False
    try:
        action_dim = [env.action_space(agent).n for agent in env.agents]
        INIT_HP["DISCRETE_ACTIONS"] = True
        INIT_HP["MAX_ACTION"] = None
        INIT_HP["MIN_ACTION"] = None
    except Exception:
        action_dim = [env.action_space(agent).shape[0] for agent in env.agents]
        INIT_HP["DISCRETE_ACTIONS"] = False
        INIT_HP["MAX_ACTION"] = [env.action_space(agent).high for agent in env.agents]
        INIT_HP["MIN_ACTION"] = [env.action_space(agent).low for agent in env.agents]

    
    if False: # デバッグ出力
        print("state_dim", state_dim)
        print("one_hot", one_hot)
        print( 'INIT_HP["DISCRETE_ACTIONS"]', INIT_HP["DISCRETE_ACTIONS"])
        print( 'INIT_HP["MAX_ACTION"]', INIT_HP["MAX_ACTION"])
        print( 'INIT_HP["MIN_ACTION"]', INIT_HP["MIN_ACTION"])
    

    # Not applicable to MPE environments, used when images are used for observations (Atari environments)
    # MPE環境には適用されない。観測のために画像が使用される場合に使う（アタリ環境）
    if INIT_HP["CHANNELS_LAST"]:
        state_dim = [
            (state_dim[2], state_dim[0], state_dim[1]) for state_dim in state_dim
        ]

    # Append number of agents and agent IDs to the initial hyperparameter dictionary
    # エージェントとエージェントIDの数を、ハイパーパラメータディクショナリの初期化のために加える
    INIT_HP["N_AGENTS"] = env.num_agents
    INIT_HP["AGENT_IDS"] = env.agents

    
    if True: # デバッグ出力
        print('state_dim', state_dim)
        print('action_dim', action_dim)
        print('one_hot', one_hot)
        print('NET_CONFIG')
        pprint.pprint(NET_CONFIG)
        print('INIT_HP')
        pprint.pprint(INIT_HP)
        print('device', device)
    
    # Create a population ready for evolutionary hyper-parameter optimisation
    # 進化的なハイパーパラメータ最適化のための母集団を作成する
    population = initialPopulation(
        INIT_HP["ALGO"],
        state_dim,
        action_dim,
        one_hot,
        NET_CONFIG,
        INIT_HP,
        population_size=INIT_HP["POPULATION_SIZE"],
        device=device,
    )

    # Configure the multi-agent replay buffer
    # マルチ・エージェント・リプレイバッファを設定する
    field_names = ["state", "action", "reward", "next_state", "done"]
    memory = MultiAgentReplayBuffer(
        INIT_HP["MEMORY_SIZE"],
        field_names=field_names,
        agent_ids=INIT_HP["AGENT_IDS"],
        device=device,
    )

    # Instantiate a tournament selection object (used for HPO)
    # トーナメント選択オブジェクトのインスタンス化（HPOで使用）
    tournament = TournamentSelection(
        tournament_size=2,  # Tournament selection size
        elitism=True,  # Elitism in tournament selection
        population_size=INIT_HP["POPULATION_SIZE"],  # Population size
        evo_step=1,
    )  # Evaluate using last N fitness scores

    # Instantiate a mutations object (used for HPO)
    # ミューテーション・オブジェクトのインスタンス化（HPOで使用）
    mutations = Mutations(
        algo=INIT_HP["ALGO"],
        no_mutation=0.2,  # Probability of no mutation
        architecture=0.2,  # Probability of architecture mutation
        new_layer_prob=0.2,  # Probability of new layer mutation
        parameters=0.2,  # Probability of parameter mutation
        activation=0,  # Probability of activation function mutation
        rl_hp=0.2,  # Probability of RL hyperparameter mutation
        rl_hp_selection=[
            "lr",
            "learn_step",
            "batch_size",
        ],  # RL hyperparams selected for mutation
        mutation_sd=0.1,  # Mutation strength
        agent_ids=INIT_HP["AGENT_IDS"],
        arch=NET_CONFIG["arch"],
        rand_seed=2,#1,
        device=device,
    )

    # Define training loop parameters
    # 学習ループ・パラメータを定義
    max_episodes = 6000 #500  # Total episodes (default: 6000)
    max_steps = 100 #25  # Maximum steps to take in each episode
    epsilon = 1.0  # Starting epsilon value
    eps_end = 0.1  # Final epsilon value
    eps_decay = 0.995  # Epsilon decay
    evo_epochs = 20  # Evolution frequency
    evo_loop = 1  # Number of evaluation episodes
    elite = population[0]  # Assign a placeholder "elite" agent

    # Training loop
    # 学習ループ
    for idx_epi in trange(max_episodes):
        
        for agent in population:  # Loop through population
            
            state, info = env.reset()  # Reset environment at start of episode
            agent_reward = {agent_id: 0 for agent_id in env.agents}
            if INIT_HP["CHANNELS_LAST"]:
                state = {
                    agent_id: np.moveaxis(np.expand_dims(s, 0), [-1], [-3])
                    for agent_id, s in state.items()
                }

            for _ in range(max_steps):
                agent_mask = info["agent_mask"] if "agent_mask" in info.keys() else None
                env_defined_actions = (
                    info["env_defined_actions"]
                    if "env_defined_actions" in info.keys()
                    else None
                )

                # Get next action from agent
                # エージェントから次の行動を取得する
                cont_actions, discrete_action = agent.getAction(
                    state, epsilon, agent_mask, env_defined_actions
                )
                if agent.discrete_actions:
                    action = discrete_action
                else:
                    action = cont_actions

                next_state, reward, termination, truncation, info = env.step(
                    action
                )  # Act in environment

                # Image processing if necessary for the environment
                # 環境に応じた画像処理を行う
                if INIT_HP["CHANNELS_LAST"]:
                    state = {agent_id: np.squeeze(s) for agent_id, s in state.items()}
                    next_state = {
                        agent_id: np.moveaxis(ns, [-1], [-3])
                        for agent_id, ns in next_state.items()
                    }

                # Save experiences to replay buffer
                # 経験をリプレイバッファに保存する
                memory.save2memory(state, cont_actions, reward, next_state, termination)

                # Collect the reward
                # 報酬を受け取る
                for agent_id, r in reward.items():
                    agent_reward[agent_id] += r

                # Learn according to learning frequency
                # 学習周期に合わせて学習する
                if (memory.counter % agent.learn_step == 0) and (
                    len(memory) >= agent.batch_size
                ):
                    experiences = memory.sample(
                        agent.batch_size
                    )  # Sample replay buffer
                    agent.learn(experiences)  # Learn according to agent's RL algorithm

                # Update the state
                # 状態を更新する
                if INIT_HP["CHANNELS_LAST"]:
                    next_state = {
                        agent_id: np.expand_dims(ns, 0)
                        for agent_id, ns in next_state.items()
                    }
                state = next_state

                # Stop episode if any agents have terminated
                # いずれかのエージェントが終了したならば、エピソードを停止する
                if any(truncation.values()) or any(termination.values()):
                    break

            # Save the total episode reward
            # エピソードの合計報酬を保存する
            score = sum(agent_reward.values())
            agent.scores.append(score)

        # Update epsilon for exploration
        # 探索用のイプシロンを更新
        epsilon = max(eps_end, epsilon * eps_decay)

        # Now evolve population if necessary
        # 必要であれば、母集団を進化させる
        if (idx_epi + 1) % evo_epochs == 0:
            # Evaluate population
            fitnesses = [
                agent.test(
                    env,
                    swap_channels=INIT_HP["CHANNELS_LAST"],
                    max_steps=max_steps,
                    loop=evo_loop,
                )
                for agent in population
            ]

            print(f"Episode {idx_epi + 1}/{max_episodes}")
            print(f'Fitnesses: {["%.2f" % fitness for fitness in fitnesses]}')
            print(
                f'100 fitness avgs: {["%.2f" % np.mean(agent.fitness[-100:]) for agent in population]}'
            )

            # Tournament selection and population mutation
            # トーナメント選択と母集団の変異
            elite, population = tournament.select(population)
            population = mutations.mutation(population)

    # Save the trained algorithm
    # 学習アルゴリズムを保存する
    path = "./models/MATD3"
    #path = "./"+str_dt_now+"/models/MATD3"
    filename = "MATD3_trained_agent.pt"
    filename = f"MATD3_trained_agent_{str_dt_now}.pt"
    os.makedirs(path, exist_ok=True)
    save_path = os.path.join(path, filename)
    elite.saveCheckpoint(save_path)

20231226-0652
===== AgileRL Online Multi-Agent Demo =====
state_dim [(3,), (11,)]
action_dim [3, 5]
one_hot False
NET_CONFIG
{'arch': 'mlp', 'h_size': [32, 32]}
INIT_HP
{'AGENT_IDS': ['speaker_0', 'listener_0'],
 'ALGO': 'MATD3',
 'BATCH_SIZE': 32,
 'CHANNELS_LAST': False,
 'DISCRETE_ACTIONS': False,
 'GAMMA': 0.95,
 'LEARN_STEP': 5,
 'LR': 0.01,
 'MAX_ACTION': [array([1., 1., 1.], dtype=float32),
                array([1., 1., 1., 1., 1.], dtype=float32)],
 'MEMORY_SIZE': 100000,
 'MIN_ACTION': [array([0., 0., 0.], dtype=float32),
                array([0., 0., 0., 0., 0.], dtype=float32)],
 'N_AGENTS': 2,
 'POLICY_FREQ': 2,
 'POPULATION_SIZE': 4,
 'TAU': 0.01}
device cuda


  0%|▍                                                                                                                       | 19/6000 [00:10<54:25,  1.83it/s]

Episode 20/6000
Fitnesses: ['-12.26', '-483.91', '-371.99', '-252.56']
100 fitness avgs: ['-12.26', '-483.91', '-371.99', '-252.56']


  1%|▊                                                                                                                       | 39/6000 [00:21<55:46,  1.78it/s]

Episode 40/6000
Fitnesses: ['-485.64', '-191.07', '-127.47', '-408.97']
100 fitness avgs: ['-248.95', '-221.82', '-249.73', '-210.62']


  1%|█▏                                                                                                                      | 59/6000 [00:33<57:13,  1.73it/s]

Episode 60/6000
Fitnesses: ['-86.86', '-107.03', '-62.36', '-76.29']
100 fitness avgs: ['-195.44', '-176.09', '-161.20', '-165.84']


  1%|█▌                                                                                                                    | 79/6000 [00:45<1:01:32,  1.60it/s]

Episode 80/6000
Fitnesses: ['-2.06', '-44.94', '-6.49', '-1.50']
100 fitness avgs: ['-121.41', '-135.62', '-126.00', '-121.27']


  2%|█▉                                                                                                                    | 99/6000 [00:58<1:02:36,  1.57it/s]

Episode 100/6000
Fitnesses: ['-0.24', '-31.62', '-46.62', '-44.57']
100 fitness avgs: ['-97.06', '-103.34', '-106.45', '-117.41']


  2%|██▎                                                                                                                  | 119/6000 [01:12<1:09:52,  1.40it/s]

Episode 120/6000
Fitnesses: ['-85.59', '-194.45', '-297.91', '-212.42']
100 fitness avgs: ['-95.15', '-113.29', '-130.54', '-124.11']


  2%|██▋                                                                                                                  | 139/6000 [01:27<1:11:39,  1.36it/s]

Episode 140/6000
Fitnesses: ['-41.68', '-13.87', '-184.13', '-286.03']
100 fitness avgs: ['-87.51', '-99.09', '-107.86', '-152.75']


  3%|███                                                                                                                  | 159/6000 [01:42<1:10:02,  1.39it/s]

Episode 160/6000
Fitnesses: ['-2.02', '-314.99', '-600.70', '-207.60']
100 fitness avgs: ['-86.96', '-115.95', '-169.47', '-112.66']


  3%|███▌                                                                                                                   | 179/6000 [01:54<57:12,  1.70it/s]

Episode 180/6000
Fitnesses: ['-62.80', '-115.66', '-31.85', '-2.87']
100 fitness avgs: ['-84.27', '-90.15', '-80.83', '-77.61']


  3%|███▉                                                                                                                 | 199/6000 [02:07<1:00:16,  1.60it/s]

Episode 200/6000
Fitnesses: ['-13.88', '-46.35', '-4.56', '-27.00']
100 fitness avgs: ['-71.24', '-80.48', '-73.21', '-75.45']


  4%|████▎                                                                                                                | 219/6000 [02:22<1:10:55,  1.36it/s]

Episode 220/6000
Fitnesses: ['-107.45', '-5.80', '-74.92', '-46.98']
100 fitness avgs: ['-76.32', '-65.29', '-71.58', '-69.04']


  4%|████▋                                                                                                                | 239/6000 [02:38<1:15:42,  1.27it/s]

Episode 240/6000
Fitnesses: ['-29.59', '-45.41', '-51.60', '-17.65']
100 fitness avgs: ['-62.32', '-63.64', '-64.15', '-67.08']


  4%|█████                                                                                                                | 259/6000 [02:54<1:16:13,  1.26it/s]

Episode 260/6000
Fitnesses: ['-15.55', '-29.56', '-4.11', '-85.11']
100 fitness avgs: ['-63.12', '-64.20', '-57.84', '-64.07']


  5%|█████▍                                                                                                               | 279/6000 [03:10<1:16:06,  1.25it/s]

Episode 280/6000
Fitnesses: ['-45.62', '-92.23', '-16.05', '-38.45']
100 fitness avgs: ['-56.97', '-60.30', '-59.76', '-56.45']


  5%|█████▊                                                                                                               | 299/6000 [03:27<1:17:13,  1.23it/s]

Episode 300/6000
Fitnesses: ['-10.72', '-23.73', '-32.86', '-17.03']
100 fitness avgs: ['-56.49', '-57.35', '-54.88', '-56.91']


  5%|██████▏                                                                                                              | 319/6000 [03:43<1:17:19,  1.22it/s]

Episode 320/6000
Fitnesses: ['-23.80', '-86.51', '-23.22', '-63.66']
100 fitness avgs: ['-54.44', '-58.36', '-54.41', '-56.93']


  6%|██████▌                                                                                                              | 339/6000 [04:00<1:17:23,  1.22it/s]

Episode 340/6000
Fitnesses: ['-4.34', '-20.30', '-66.24', '-5.07']
100 fitness avgs: ['-51.46', '-54.78', '-55.10', '-51.50']


  6%|███████                                                                                                              | 359/6000 [04:17<1:17:26,  1.21it/s]

Episode 360/6000
Fitnesses: ['-31.05', '-95.75', '-73.58', '-25.95']
100 fitness avgs: ['-50.33', '-53.92', '-52.73', '-50.04']


  6%|███████▍                                                                                                             | 379/6000 [04:33<1:17:04,  1.22it/s]

Episode 380/6000
Fitnesses: ['-16.10', '-11.57', '-114.94', '-57.25']
100 fitness avgs: ['-48.26', '-50.56', '-53.46', '-50.42']


  7%|███████▊                                                                                                             | 399/6000 [04:50<1:17:29,  1.20it/s]

Episode 400/6000
Fitnesses: ['-41.73', '-19.92', '-17.93', '-7.20']
100 fitness avgs: ['-50.12', '-46.84', '-48.93', '-48.26']


  7%|████████▏                                                                                                            | 419/6000 [05:09<1:26:30,  1.08it/s]

Episode 420/6000
Fitnesses: ['-14.53', '-102.39', '-16.04', '-8.02']
100 fitness avgs: ['-46.66', '-51.48', '-46.73', '-44.99']


  7%|████████▌                                                                                                            | 439/6000 [05:28<1:25:36,  1.08it/s]

Episode 440/6000
Fitnesses: ['-261.26', '-46.16', '-140.81', '-72.44']
100 fitness avgs: ['-54.82', '-51.24', '-50.94', '-47.83']


  8%|████████▉                                                                                                            | 459/6000 [05:47<1:25:39,  1.08it/s]

Episode 460/6000
Fitnesses: ['-23.52', '-165.19', '-8.39', '-100.69']
100 fitness avgs: ['-50.03', '-52.93', '-49.37', '-53.39']


  8%|█████████▎                                                                                                           | 479/6000 [06:03<1:12:16,  1.27it/s]

Episode 480/6000
Fitnesses: ['-72.39', '-96.82', '-6.32', '-17.76']
100 fitness avgs: ['-50.33', '-55.20', '-47.58', '-48.06']


  8%|█████████▋                                                                                                           | 499/6000 [06:20<1:16:04,  1.21it/s]

Episode 500/6000
Fitnesses: ['-21.92', '-9.42', '-5.32', '-34.45']
100 fitness avgs: ['-46.55', '-46.51', '-46.35', '-47.05']


  9%|██████████                                                                                                           | 519/6000 [06:37<1:15:07,  1.22it/s]

Episode 520/6000
Fitnesses: ['-32.70', '-3.66', '-27.59', '-2.17']
100 fitness avgs: ['-45.82', '-44.90', '-45.82', '-44.81']


  9%|██████████▌                                                                                                          | 539/6000 [06:54<1:14:26,  1.22it/s]

Episode 540/6000
Fitnesses: ['-51.93', '-32.53', '-15.68', '-62.26']
100 fitness avgs: ['-45.07', '-44.35', '-43.73', '-46.43']


  9%|██████████▉                                                                                                          | 559/6000 [07:10<1:13:54,  1.23it/s]

Episode 560/6000
Fitnesses: ['-40.78', '-19.59', '-31.86', '-3.47']
100 fitness avgs: ['-43.62', '-43.47', '-43.90', '-42.89']


 10%|███████████▎                                                                                                         | 579/6000 [07:27<1:13:27,  1.23it/s]

Episode 580/6000
Fitnesses: ['-38.79', '-21.13', '-23.41', '-12.60']
100 fitness avgs: ['-42.75', '-42.70', '-42.77', '-42.82']


 10%|███████████▋                                                                                                         | 599/6000 [07:43<1:12:51,  1.24it/s]

Episode 600/6000
Fitnesses: ['-43.57', '-45.76', '-28.20', '-40.30']
100 fitness avgs: ['-42.85', '-42.87', '-42.34', '-42.74']


 10%|████████████                                                                                                         | 619/6000 [08:02<1:22:10,  1.09it/s]

Episode 620/6000
Fitnesses: ['-32.07', '-54.13', '-15.43', '-26.73']
100 fitness avgs: ['-42.01', '-43.11', '-41.47', '-42.22']


 11%|████████████▍                                                                                                        | 639/6000 [08:19<1:15:46,  1.18it/s]

Episode 640/6000
Fitnesses: ['-13.88', '-73.84', '-12.11', '-24.51']
100 fitness avgs: ['-40.61', '-43.00', '-40.55', '-41.67']


 11%|████████████▊                                                                                                        | 659/6000 [08:36<1:11:13,  1.25it/s]

Episode 660/6000
Fitnesses: ['-44.22', '-3.49', '-28.72', '-18.12']
100 fitness avgs: ['-40.66', '-39.43', '-40.19', '-39.87']


 11%|█████████████▏                                                                                                       | 679/6000 [08:52<1:10:09,  1.26it/s]

Episode 680/6000
Fitnesses: ['-54.30', '-42.80', '-54.97', '-24.79']
100 fitness avgs: ['-39.87', '-39.53', '-40.63', '-39.00']


 12%|█████████████▋                                                                                                       | 699/6000 [09:09<1:10:42,  1.25it/s]

Episode 700/6000
Fitnesses: ['-13.95', '-16.39', '-26.70', '-68.57']
100 fitness avgs: ['-38.28', '-38.35', '-39.16', '-39.84']


 12%|██████████████                                                                                                       | 719/6000 [09:25<1:10:11,  1.25it/s]

Episode 720/6000
Fitnesses: ['-29.15', '-30.47', '-36.35', '-16.67']
100 fitness avgs: ['-38.03', '-38.13', '-38.23', '-37.75']


 12%|██████████████▍                                                                                                      | 739/6000 [09:42<1:09:23,  1.26it/s]

Episode 740/6000
Fitnesses: ['-32.96', '-1.81', '-64.20', '-27.63']
100 fitness avgs: ['-37.62', '-37.05', '-38.46', '-37.48']


 13%|██████████████▊                                                                                                      | 759/6000 [09:58<1:05:43,  1.33it/s]

Episode 760/6000
Fitnesses: ['-4.78', '-10.55', '-61.30', '-18.46']
100 fitness avgs: ['-36.20', '-36.77', '-37.69', '-36.98']


 13%|███████████████▏                                                                                                     | 779/6000 [10:16<1:18:48,  1.10it/s]

Episode 780/6000
Fitnesses: ['-15.65', '-43.57', '-37.62', '-13.82']
100 fitness avgs: ['-35.67', '-37.14', '-36.99', '-35.63']


 13%|███████████████▌                                                                                                     | 799/6000 [10:34<1:17:15,  1.12it/s]

Episode 800/6000
Fitnesses: ['-72.88', '-2.35', '-18.86', '-28.71']
100 fitness avgs: ['-36.56', '-36.13', '-35.21', '-35.50']


 14%|███████████████▉                                                                                                     | 819/6000 [10:52<1:14:06,  1.17it/s]

Episode 820/6000
Fitnesses: ['-22.47', '-135.10', '-149.12', '-14.75']
100 fitness avgs: ['-35.79', '-38.54', '-38.88', '-34.99']


 14%|████████████████▎                                                                                                    | 839/6000 [11:13<1:31:21,  1.06s/it]

Episode 840/6000
Fitnesses: ['-35.77', '-15.87', '-0.80', '-93.99']
100 fitness avgs: ['-35.01', '-34.54', '-34.18', '-39.86']


 14%|████████████████▊                                                                                                    | 859/6000 [11:35<1:31:06,  1.06s/it]

Episode 860/6000
Fitnesses: ['-62.00', '-4.62', '-70.29', '-91.03']
100 fitness avgs: ['-34.83', '-33.49', '-35.83', '-41.05']


 15%|█████████████████▏                                                                                                   | 879/6000 [11:59<1:42:24,  1.20s/it]

Episode 880/6000
Fitnesses: ['-81.84', '-6.72', '-31.54', '-27.38']
100 fitness avgs: ['-34.59', '-32.88', '-33.45', '-34.66']


 15%|█████████████████▌                                                                                                   | 899/6000 [12:23<1:42:17,  1.20s/it]

Episode 900/6000
Fitnesses: ['-17.04', '-91.52', '-62.12', '-83.50']
100 fitness avgs: ['-32.53', '-34.19', '-33.53', '-34.56']


 15%|█████████████████▉                                                                                                   | 919/6000 [12:47<1:40:35,  1.19s/it]

Episode 920/6000
Fitnesses: ['-21.76', '-35.11', '-53.68', '-58.82']
100 fitness avgs: ['-32.30', '-32.59', '-34.61', '-33.10']


 16%|██████████████████▎                                                                                                  | 939/6000 [13:13<1:52:25,  1.33s/it]

Episode 940/6000
Fitnesses: ['-16.24', '-15.26', '-68.94', '-39.49']
100 fitness avgs: ['-31.96', '-32.22', '-33.36', '-32.73']


 16%|██████████████████▋                                                                                                  | 959/6000 [13:43<2:09:07,  1.54s/it]

Episode 960/6000
Fitnesses: ['-14.77', '-23.58', '-5.83', '-17.80']
100 fitness avgs: ['-31.86', '-32.54', '-31.67', '-31.92']


 16%|███████████████████                                                                                                  | 979/6000 [14:17<2:18:41,  1.66s/it]

Episode 980/6000
Fitnesses: ['-57.73', '-1.67', '-10.01', '-4.49']
100 fitness avgs: ['-32.20', '-31.30', '-31.41', '-31.30']


 17%|███████████████████▍                                                                                                 | 999/6000 [14:50<2:19:04,  1.67s/it]

Episode 1000/6000
Fitnesses: ['-87.41', '-19.53', '-6.71', '-24.07']
100 fitness avgs: ['-32.42', '-31.06', '-30.81', '-31.16']


 17%|███████████████████▋                                                                                                | 1019/6000 [15:24<2:22:17,  1.71s/it]

Episode 1020/6000
Fitnesses: ['-12.62', '-16.13', '-12.06', '-37.64']
100 fitness avgs: ['-30.45', '-30.52', '-30.78', '-31.19']


 17%|████████████████████                                                                                                | 1039/6000 [15:59<2:23:41,  1.74s/it]

Episode 1040/6000
Fitnesses: ['-12.29', '-98.34', '-7.99', '-27.54']
100 fitness avgs: ['-30.43', '-32.08', '-30.34', '-30.72']


 18%|████████████████████▍                                                                                               | 1059/6000 [16:33<2:22:48,  1.73s/it]

Episode 1060/6000
Fitnesses: ['-15.34', '-68.81', '-12.55', '-28.06']
100 fitness avgs: ['-30.06', '-31.07', '-30.01', '-30.67']


 18%|████████████████████▊                                                                                               | 1079/6000 [17:05<2:15:10,  1.65s/it]

Episode 1080/6000
Fitnesses: ['-6.95', '-9.54', '-54.70', '-7.60']
100 fitness avgs: ['-29.58', '-29.63', '-30.52', '-29.59']


 18%|█████████████████████▏                                                                                              | 1099/6000 [17:37<2:13:23,  1.63s/it]

Episode 1100/6000
Fitnesses: ['-16.79', '-14.18', '-25.06', '-29.77']
100 fitness avgs: ['-29.35', '-29.30', '-30.42', '-29.60']


 19%|█████████████████████▋                                                                                              | 1119/6000 [18:10<2:13:31,  1.64s/it]

Episode 1120/6000
Fitnesses: ['-24.88', '-76.30', '-87.65', '-43.93']
100 fitness avgs: ['-29.22', '-30.14', '-31.44', '-30.66']


 19%|██████████████████████                                                                                              | 1139/6000 [18:41<2:03:37,  1.53s/it]

Episode 1140/6000
Fitnesses: ['-46.62', '-25.81', '-13.43', '-70.88']
100 fitness avgs: ['-29.53', '-30.06', '-30.36', '-29.95']


 19%|██████████████████████▍                                                                                             | 1159/6000 [19:20<2:41:13,  2.00s/it]

Episode 1160/6000
Fitnesses: ['-18.11', '-60.28', '-40.04', '-36.03']
100 fitness avgs: ['-30.15', '-30.87', '-30.52', '-29.64']


 20%|██████████████████████▊                                                                                             | 1179/6000 [20:09<3:16:40,  2.45s/it]

Episode 1180/6000
Fitnesses: ['-71.76', '-12.21', '-5.56', '-11.76']
100 fitness avgs: ['-30.85', '-30.21', '-29.73', '-29.83']


 20%|███████████████████████▏                                                                                            | 1199/6000 [20:57<3:21:39,  2.52s/it]

Episode 1200/6000
Fitnesses: ['-10.47', '-16.45', '-8.86', '-5.73']
100 fitness avgs: ['-29.41', '-29.51', '-29.48', '-30.43']


 20%|███████████████████████▌                                                                                            | 1219/6000 [21:45<3:09:49,  2.38s/it]

Episode 1220/6000
Fitnesses: ['-34.35', '-9.45', '-4.34', '-57.00']
100 fitness avgs: ['-30.50', '-29.18', '-29.07', '-29.86']


In [None]:
import os

import imageio
import numpy as np
import torch
from pettingzoo.mpe import simple_speaker_listener_v4
from PIL import Image, ImageDraw

from agilerl.algorithms.matd3 import MATD3


# Define function to return image
def _label_with_episode_number(frame, episode_num):
    im = Image.fromarray(frame)

    drawer = ImageDraw.Draw(im)

    if np.mean(frame) < 128:
        text_color = (255, 255, 255)
    else:
        text_color = (0, 0, 0)
    drawer.text(
        (im.size[0] / 20, im.size[1] / 18), f"Episode: {episode_num+1}", fill=text_color
    )

    return im


if __name__ == "__main__":
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Configure the environment
    env = simple_speaker_listener_v4.parallel_env(
        continuous_actions=True, render_mode="rgb_array"
    )
    env.reset()
    try:
        state_dim = [env.observation_space(agent).n for agent in env.agents]
        one_hot = True
    except Exception:
        state_dim = [env.observation_space(agent).shape for agent in env.agents]
        one_hot = False
    try:
        action_dim = [env.action_space(agent).n for agent in env.agents]
        discrete_actions = True
        max_action = None
        min_action = None
    except Exception:
        action_dim = [env.action_space(agent).shape[0] for agent in env.agents]
        discrete_actions = False
        max_action = [env.action_space(agent).high for agent in env.agents]
        min_action = [env.action_space(agent).low for agent in env.agents]

    # Append number of agents and agent IDs to the initial hyperparameter dictionary
    n_agents = env.num_agents
    agent_ids = env.agents

    # Instantiate an MADDPG object
    matd3 = MATD3(
        state_dim,
        action_dim,
        one_hot,
        n_agents,
        agent_ids,
        max_action,
        min_action,
        discrete_actions,
        device=device,
    )

    # Load the saved algorithm into the MADDPG object
    #path = "./models/MATD3/MATD3_trained_agent.pt"
    #path = "./"+str_dt_now+"/models/MATD3/MATD3_trained_agent.pt"
    path = f"./models/MATD3/{filename}"
    matd3.loadCheckpoint(path)

    # Define test loop parameters
    episodes = 10  # Number of episodes to test agent on
    max_steps = 25  # Max number of steps to take in the environment in each episode

    rewards = []  # List to collect total episodic reward
    frames = []  # List to collect frames
    indi_agent_rewards = {
        agent_id: [] for agent_id in agent_ids
    }  # Dictionary to collect inidivdual agent rewards
    

    # Test loop for inference
    for ep in range(episodes):
        state, info = env.reset()
        agent_reward = {agent_id: 0 for agent_id in agent_ids}
        score = 0
        for _ in range(max_steps):
            agent_mask = info["agent_mask"] if "agent_mask" in info.keys() else None
            env_defined_actions = (
                info["env_defined_actions"]
                if "env_defined_actions" in info.keys()
                else None
            )

            # Get next action from agent
            cont_actions, discrete_action = matd3.getAction(
                state,
                epsilon=0,
                agent_mask=agent_mask,
                env_defined_actions=env_defined_actions,
            )
            if matd3.discrete_actions:
                action = discrete_action
            else:
                action = cont_actions

            # Save the frame for this step and append to frames list
            frame = env.render()
            frames.append(_label_with_episode_number(frame, episode_num=ep))

            # Take action in environment
            state, reward, termination, truncation, info = env.step(action)

            # Save agent's reward for this step in this episode
            for agent_id, r in reward.items():
                agent_reward[agent_id] += r

            # Determine total score for the episode and then append to rewards list
            score = sum(agent_reward.values())

            # Stop episode if any agents have terminated
            if any(truncation.values()) or any(termination.values()):
                break

        rewards.append(score)

        # Record agent specific episodic reward
        for agent_id in agent_ids:
            indi_agent_rewards[agent_id].append(agent_reward[agent_id])

        print("-" * 15, f"Episode: {ep}", "-" * 15)
        print("Episodic Reward: ", rewards[-1])
        for agent_id, reward_list in indi_agent_rewards.items():
            print(f"{agent_id} reward: {reward_list[-1]}")
    env.close()

    # Save the gif to specified path
    gif_path = "./videos/"
    print(os.path.join(gif_path, f"speaker_listener_{str_dt_now}.gif"))
    os.makedirs(gif_path, exist_ok=True)
    imageio.mimwrite(
        #os.path.join("./videos/", "speaker_listener.gif"), frames, duration=10
        os.path.join(gif_path, f"speaker_listener_{str_dt_now}.gif"), frames, duration=10
    )

In [2]:
import os
import imageio
import numpy as np
import torch
import matplotlib.pyplot as plt
from pettingzoo.mpe import simple_speaker_listener_v4
from PIL import Image, ImageDraw
from agilerl.algorithms.matd3 import MATD3

# Define function to return image with episode number label
def _label_with_episode_number(frame, episode_num):
    im = Image.fromarray(frame)
    drawer = ImageDraw.Draw(im)

    if np.mean(frame) < 128:
        text_color = (255, 255, 255)
    else:
        text_color = (0, 0, 0)
    drawer.text(
        (im.size[0] / 20, im.size[1] / 18), f"Episode: {episode_num+1}", fill=text_color
    )

    return im

if __name__ == "__main__":
    
    #device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    device = torch.device("mps")

    # Configure the environment
    env = simple_speaker_listener_v4.parallel_env(
        continuous_actions=True, render_mode="rgb_array"
    )
    env.reset()

    try:
        state_dim = [env.observation_space(agent).n for agent in env.agents]
        one_hot = True
    except Exception:
        state_dim = [env.observation_space(agent).shape for agent in env.agents]
        one_hot = False
    try:
        action_dim = [env.action_space(agent).n for agent in env.agents]
        discrete_actions = True
        max_action = None
        min_action = None
    except Exception:
        action_dim = [env.action_space(agent).shape[0] for agent in env.agents]
        discrete_actions = False
        max_action = [env.action_space(agent).high for agent in env.agents]
        min_action = [env.action_space(agent).low for agent in env.agents]

    # Append number of agents and agent IDs to the initial hyperparameter dictionary
    n_agents = env.num_agents
    agent_ids = env.agents

    # Instantiate an MADDPG object
    matd3 = MATD3(
        state_dim,
        action_dim,
        one_hot,
        n_agents,
        agent_ids,
        max_action,
        min_action,
        discrete_actions,
        device=device,
    )

    # Load the saved algorithm
    #path = f"./models/MATD3/{filename}"
    path = "./models/MATD3/MATD3_trained_agent_20231225-1958.pt"
    matd3.loadCheckpoint(path)

    episodes = 10  # Number of episodes for testing
    max_steps = 25  # Max steps per episode

    rewards = []  # Total episodic rewards
    frames = []  # Frames for visualization
    indi_agent_rewards = {agent_id: [] for agent_id in agent_ids}  # Individual agent rewards

    speaker_0_rewards = []  # Rewards for speaker_0
    listener_0_rewards = []  # Rewards for listener_0

    # Test loop
    for ep in range(episodes):
        state, info = env.reset()
        agent_reward = {agent_id: 0 for agent_id in agent_ids}
        score = 0

        for _ in range(max_steps):
            # ... [existing action selection and environment interaction code] ...
            agent_mask = info["agent_mask"] if "agent_mask" in info.keys() else None
            env_defined_actions = (
                info["env_defined_actions"]
                if "env_defined_actions" in info.keys()
                else None
            )

            # Get next action from agent
            cont_actions, discrete_action = matd3.getAction(
                state,
                epsilon=0,
                agent_mask=agent_mask,
                env_defined_actions=env_defined_actions,
            )
            if matd3.discrete_actions:
                action = discrete_action
            else:
                action = cont_actions

            frame = env.render()
            frames.append(_label_with_episode_number(frame, episode_num=ep))

            # ... [existing environment step and reward processing code] ...
            # Take action in environment
            state, reward, termination, truncation, info = env.step(action)

            # Save agent's reward for this step in this episode
            for agent_id, r in reward.items():
                agent_reward[agent_id] += r

            # Determine total score for the episode and then append to rewards list
            score = sum(agent_reward.values())

            # Stop episode if any agents have terminated
            if any(truncation.values()) or any(termination.values()):
                break

        rewards.append(score)
        for agent_id in agent_ids:
            indi_agent_rewards[agent_id].append(agent_reward[agent_id])
            
        # Store rewards for each agent
        Episodic_rewards.append(indi_agent_rewards['speaker_0'][-1])

        print("-" * 15, f"Episode: {ep}", "-" * 15)
        print("Episodic Reward: ", rewards[-1])
        for agent_id, reward_list in indi_agent_rewards.items():
            print(f"{agent_id} reward: {reward_list[-1]}")

    env.close()

    # Save the visualization
    gif_path = "./videos/"
    os.makedirs(gif_path, exist_ok=True)
    imageio.mimwrite(os.path.join(gif_path, f"speaker_listener_{str_dt_now}.gif"), frames, duration=10)

    # Plotting rewards
    plt.figure(figsize=(12, 6))
    plt.plot(Episodic_rewards, label='Episodic_rewards')
    plt.xlabel('Episode')
    plt.ylabel('Reward')
    plt.title('Episodic Rewards for Speaker and Listener Agents')
    plt.legend()
    plt.show()



RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.