In [19]:
%pip install stable-baselines3 numpy torch supersuit pettingzoo pymunk scipy gymnasium matplotlib einops tensorboard wandb imageio 

^C
Note: you may need to restart the kernel to use updated packages.


In [10]:
import datetime
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import random
from collections import deque
from typing import List, Dict, NamedTuple
import gymnasium as gym
from pettingzoo.mpe import simple_reference_v3
import os
from torch.utils.tensorboard import SummaryWriter
from argparse import Namespace
import time
import sys
import json


class WritterUtil:
    def __init__(self, writter: SummaryWriter, args):
        self.writter = writter
        self.log_every = args.log_every
        self.scalars = {}
        self.step = 0
        self.log_step = 0

    def set_step(self, step):
        self.step += 1
        self.log_step = step

    def WriteScalar(self, tag, value):
        
        step = self.step
        log_step = self.log_step
        self.writter.add_scalar('raw/'+tag, value, log_step)

        if tag not in self.scalars:
            self.scalars[tag] = [0, step-1, torch.zeros(self.log_every)]
        
        n, last_step, buffer = self.scalars[tag]
        
        buffer[n] = value
        n += 1
        if step - last_step >= self.log_every:
            mean = torch.mean(buffer[0:n])
            self.writter.add_scalar('mean/'+tag, mean, log_step)
            self.writter.add_histogram(tag, buffer[0:n], log_step)
            
            last_step = step
            n=0


        self.scalars[tag] = [n,last_step,buffer]

class Controller:
    def __init__(self, sys_agent, args):
        self.n_agents = args.n_agents
        self.n_actions = args.n_actions        
        self.agent_id = args.agent_id
        self.last_action = args.last_action
        self.device = args.device    
        self.sys_agent_src = sys_agent
        self.sys_agent = type(sys_agent)(args)        
        self.sys_agent.requires_grad_(False)        
        self.episode = 0

    def new_episode(self):
        state_dict = self.sys_agent_src.state_dict()
        self.sys_agent.load_state_dict(state_dict)        
        self.hiddens = self.sys_agent.init_hiddens(1)
        self.last_actions = torch.zeros(1, self.n_agents, self.n_actions)      
        self.episode += 1

    def get_actions(self, states, avail_actions, explore=False):
        # Explicit tensor conversion with device and dtype
        states = torch.tensor(states, dtype=torch.float32, device=self.device)
        if states.dim() == 2:
            states = states.unsqueeze(0)
                
        avail_actions = torch.tensor(avail_actions, dtype=torch.float32, device=self.device)
        if avail_actions.dim() == 2:
            avail_actions = avail_actions.unsqueeze(0)
            
        if self.agent_id:
            agent_ids = torch.eye(self.n_agents, device=self.device)
            agent_ids = agent_ids.unsqueeze(0).expand(states.size(0), -1, -1)
            states = torch.cat([states, agent_ids], dim=-1)
        
        if self.last_action:            
            states = torch.cat([states, self.last_actions], -1)
        
        with torch.no_grad():
            ps, hs_next = self.sys_agent.forward(states, avail_actions, self.hiddens)
        self.hiddens = hs_next
        
        # Action selection with validation
        if explore:
            while True:
                actions = torch.multinomial(ps[0], 1).squeeze(-1)
                selected_avails = avail_actions[0][torch.arange(self.n_agents), actions]
                if not (selected_avails == 0).sum().item():
                    break
        else:
            actions = torch.argmax(ps[0], -1)
        
        self.last_actions = self.one_hot(actions, self.n_actions).unsqueeze(0)
        return actions.cpu().numpy()

    def one_hot(self, tensor, n_classes):
        return F.one_hot(tensor.to(dtype=torch.int64), n_classes).to(dtype=torch.float32)



class Actor(nn.Module):
    def __init__(self, args):
        super(Actor, self).__init__()
        self.args = args

        self.input_dim = args.input_dim        
        self.n_actions = args.n_actions
        self.n_agents = args.n_agents
        self.rnn_hidden_dim = args.rnn_hidden_dim

        self.fc1 = nn.Linear(self.input_dim, self.rnn_hidden_dim)
        self.rnn = nn.GRUCell(self.rnn_hidden_dim, self.rnn_hidden_dim)
        self.fc2 = nn.Linear(self.rnn_hidden_dim, self.n_actions)

    def init_hidden(self, n_batch):
        # make hidden states on same device as model
        return self.fc1.weight.new_zeros(n_batch, self.rnn_hidden_dim)

    def forward(self, inputs, avail_actions, h_last):
        
        h = F.relu(self.fc1(inputs))        
        h = self.rnn(h, h_last)
        y = self.fc2(h)
        y[avail_actions == 0] = -float('inf')
        y = torch.softmax(y,-1)

        return y,h
    
class Actors(nn.Module):
    def __init__(self, args):
        super(Actors, self).__init__()
        self.n_agents = args.n_agents
        self.n_actions = args.n_actions
        self.agent = Actor(args)

    def init_hiddens(self, n_batch):
        hiddens = self.agent.init_hidden(n_batch*self.n_agents)
        hiddens = hiddens.reshape(n_batch,self.n_agents,-1)
        return hiddens

    def forward(self, states, avail_actions, hiddens):
        n_batch = states.shape[0]
        states = states.reshape(-1,states.shape[-1])
        hiddens = hiddens.reshape(-1,hiddens.shape[-1])
        avail_actions = avail_actions.reshape(-1,avail_actions.shape[-1])
        ys, hs_next = self.agent.forward(states, avail_actions, hiddens)
        ys = ys.reshape(n_batch,self.n_agents,-1)
        hs_next = hs_next.reshape(n_batch,self.n_agents,-1)
        
        return ys, hs_next

class Critic(nn.Module):
    def __init__(self, args):
        super(Critic, self).__init__()
        self.args = args

        self.input_dim = args.input_dim
        
        self.n_actions = args.n_actions
        self.n_agents = args.n_agents
        self.rnn_hidden_dim = args.rnn_hidden_dim

        self.fc1 = nn.Linear(self.input_dim, self.rnn_hidden_dim)
        self.rnn = nn.GRUCell(self.rnn_hidden_dim, self.rnn_hidden_dim)
        self.fc2 = nn.Linear(self.rnn_hidden_dim, self.n_actions)

    def init_hidden(self, n_batch):
        # make hidden states on same device as model
        return self.fc1.weight.new_zeros(n_batch, self.rnn_hidden_dim)

    def forward(self, inputs, avail_actions, h_last):
        # x = torch.cat([inputs, action],1)
        h = F.relu(self.fc1(inputs))
        h = self.rnn(h, h_last)
        y = self.fc2(h)
        y[avail_actions == 0] = -1e38
        return y, h

class Critics(nn.Module):
    def __init__(self, args):
        super(Critics, self).__init__()
        self.n_agents = args.n_agents
        self.n_actions = args.n_actions
        self.agent = Critic(args)



    def init_hiddens(self, n_batch):
        hiddens = self.agent.init_hidden(n_batch*self.n_agents)
        hiddens = hiddens.reshape(n_batch,self.n_agents,-1)
        return hiddens

    def forward(self, states, avail_actions, hiddens):

        n_batch = states.shape[0]
        states = states.reshape(-1,states.shape[-1])
        hiddens = hiddens.reshape(-1,hiddens.shape[-1])
        avail_actions = avail_actions.reshape(-1,avail_actions.shape[-1])
        ys, hs_next = self.agent.forward(states, avail_actions, hiddens)
        ys = ys.reshape(n_batch,self.n_agents,-1)
        hs_next = hs_next.reshape(n_batch,self.n_agents,-1)
        
        return ys, hs_next



class EpisodeBuffer:
    def __init__(self, scheme, args):
        self.scheme = scheme.copy()
        self.buffer_size = args.buffer_size
        self.max_seq_length = args.episode_limit
        # self.device = args.device
        self.device = 'cpu'
        self._setup_data()

    def _setup_data(self):
        self.index_st = 0
        self.n_sample = 0

        self.data = {}
        for k, v in self.scheme.items():
            shape = (self.buffer_size, self.max_seq_length) + v['shape']
            self.data[k] = torch.zeros(shape, dtype=v['dtype'],device=self.device)

    def sample(self, batch_size, top_n = 0):
        if self.n_sample < batch_size:
            return None

        def trans_ids(ep_id):
            ep_id += self.index_st
            ep_id %= self.buffer_size
            return ep_id

        ep_ids = np.random.choice(self.n_sample - top_n, batch_size - top_n, replace=False)
        ep_ids = list(ep_ids) + [self.n_sample-i-1 for i in range(top_n)]# [self.n_sample -1]
        ep_ids = list(map(trans_ids, ep_ids))

        ret = {}
        for k, v in self.data.items():
            ret[k] = v[ep_ids]

        return ret
    

    def add_episode(self, data):

        #len_ep = len
        if self.n_sample < self.buffer_size:
            self.n_sample += 1
        else:
            self.index_st += 1
            self.index_st %= self.buffer_size

        index_ep = (self.index_st + self.n_sample - 1) % self.buffer_size
        
        for k, v in data.items():

            ep_len = len(v)
            dtype = self.scheme[k]['dtype']
            self.data[k][index_ep].zero_()
            self.data[k][index_ep, 0:ep_len] = torch.as_tensor(v, dtype=dtype, device=self.device)

    def clear(self):
        self.index_st = 0
        self.n_sample = 0
        for item in self.data.values():
            item.zero_()
        

    def state_dict(self):
        buffer_state = {'index_st': self.index_st, 'n_sample': self.n_sample, 'data': self.data}
        return buffer_state
    
    def load_state_dict(self, state_dict):
        self.index_st = state_dict['index_st']
        self.n_sample = state_dict['n_sample']    
        self.data = state_dict['data']

class QMixNet(nn.Module):
    def __init__(self, args):
        super(QMixNet, self).__init__()
        self.args = args
        # 因为生成的hyper_w1需要是一个矩阵，而pytorch神经网络只能输出一个向量，
        # 所以就先输出长度为需要的 矩阵行*矩阵列 的向量，然后再转化成矩阵

        # args.n_agents是使用hyper_w1作为参数的网络的输入维度，args.qmix_hidden_dim是网络隐藏层参数个数
        # 从而经过hyper_w1得到(经验条数，args.n_agents * args.qmix_hidden_dim)的矩阵
        
        input_dim = args.state_dim
        
        self.input_dim = input_dim
        self.hyper_w1 = nn.Linear(input_dim, args.n_agents * args.qmix_hidden_dim)
        # 经过hyper_w2得到(经验条数, 1)的矩阵
        self.hyper_w2 = nn.Linear(input_dim, args.qmix_hidden_dim * 1)

        # hyper_w1得到的(经验条数，args.qmix_hidden_dim)矩阵需要同样维度的hyper_b1
        self.hyper_b1 = nn.Linear(input_dim, args.qmix_hidden_dim)
        # hyper_w2得到的(经验条数，1)的矩阵需要同样维度的hyper_b1
        self.hyper_b2 =nn.Sequential(nn.Linear(input_dim, args.qmix_hidden_dim),
                                     nn.ReLU(),
                                     nn.Linear(args.qmix_hidden_dim, 1)
                                     )

    def forward(self, q_values, states):  # states的shape为(episode_num, max_episode_len， state_dim)
        # 传入的q_values是三维的，shape为(episode_num, max_episode_len， n_agents)
        episode_num = q_values.size(0)
        q_values = q_values.view(-1, 1, self.args.n_agents)  # (episode_num * max_episode_len, 1, n_agents) = (1920,1,5)
        states = states.reshape(-1, self.input_dim)  # (episode_num * max_episode_len, state_dim)

        w1 = torch.abs(self.hyper_w1(states))  # (1920, 160)
        b1 = self.hyper_b1(states)  # (1920, 32)

        w1 = w1.view(-1, self.args.n_agents, self.args.qmix_hidden_dim)  # (1920, 5, 32)
        b1 = b1.view(-1, 1, self.args.qmix_hidden_dim)  # (1920, 1, 32)

        hidden = F.elu(torch.bmm(q_values, w1) + b1)  # (1920, 1, 32)

        w2 = torch.abs(self.hyper_w2(states))  # (1920, 32)
        b2 = self.hyper_b2(states)  # (1920, 1)

        w2 = w2.view(-1, self.args.qmix_hidden_dim, 1)  # (1920, 32, 1)
        b2 = b2.view(-1, 1, 1)  # (1920, 1， 1)

        q_total = torch.bmm(hidden, w2) + b2  # (1920, 1, 1)
        q_total = q_total.view(episode_num, -1, 1)  # (32, 60, 1)
        return q_total

class Learner:
    def __init__(self, w_util, args):
        
        self.w_util = w_util
        self.device = args.device
        self.critic_attn = args.critic_attn

        # if self.critic_attn:                        
        #     from .attn_critics import Critics
        # else:            
        #     from .critics import Critics

        self.sys_actor = Actors(args)
        self.sys_critic1 = Critics(args)
        self.sys_critic2 = Critics(args)
        self.sys_critic1_tar = Critics(args)
        self.sys_critic2_tar = Critics(args)
        #self.sys_critic1.train()
        #self.sys_critic2.train()
    
        self.mix_net1 = QMixNet(args)        
        self.mix_net2 = QMixNet(args)        
        self.mix_net1_tar = QMixNet(args)
        self.mix_net2_tar = QMixNet(args)
        
        if self.device != 'cpu':
            self.sys_actor.cuda(self.device)
            self.sys_critic1.cuda(self.device)
            self.sys_critic2.cuda(self.device)
            self.sys_critic1_tar.cuda(self.device)
            self.sys_critic2_tar.cuda(self.device)
            self.mix_net1.cuda(self.device)
            self.mix_net2.cuda(self.device)
            self.mix_net1_tar.cuda(self.device)
            self.mix_net2_tar.cuda(self.device)
        
        self.sys_critic1_tar.requires_grad_(False)
        self.sys_critic2_tar.requires_grad_(False)
        self.mix_net1_tar.requires_grad_(False)
        self.mix_net2_tar.requires_grad_(False)



        self._sync_target()

        
        self.n_agents = args.n_agents
        self.n_actions = args.n_actions
        self.gamma = args.gamma
        self.entropy_tar = args.entropy_tar
        self.lr = args.lr
        self.lr_actor = args.lr_actor
        self.lr_alpha = args.lr_alpha
        self.l2 = args.l2
        self.target_update = args.target_update
        self.step = 0                
        self.agent_id = args.agent_id
        self.last_action = args.last_action  
        self.log_alpha_st = args.log_alpha_st
        self.shared_alpha = args.shared_alpha
        self.maximum_entropy = args.maximum_entropy     
        self.args = args

        
        if self.shared_alpha:
            self.log_alpha = torch.tensor(args.log_alpha_st, dtype=torch.float32, requires_grad=True, device = self.device)
        else:
            self.log_alpha = torch.tensor([args.log_alpha_st]*self.n_agents, dtype=torch.float32, requires_grad=True, device = self.device)
        
        params_critic = list(self.sys_critic1.parameters()) + list(self.sys_critic2.parameters()) + list(self.mix_net1.parameters()) + list(self.mix_net2.parameters())
        
        self.params_critic = params_critic
        self.params_actor = list(self.sys_actor.parameters())

    
        self.optim_actor = torch.optim.Adam(self.sys_actor.parameters(), lr = self.lr_actor, weight_decay=self.l2)
        self.optim_critic = torch.optim.Adam(params_critic, lr = self.lr, weight_decay=self.l2)        
        self.optim_alpha = torch.optim.Adam([self.log_alpha], lr=self.lr_alpha)

    def train(self, data):

        self.step += 1
        w_util = self.w_util

        state = data['state'].to(device=self.device, non_blocking=True)
        obs = data['obs'].to(device=self.device, non_blocking=True)
        actions = data['actions'].to(device=self.device, non_blocking=True)
        reward = data['reward'].to(device=self.device, non_blocking=True)
        valid = data['valid'].to(device=self.device, non_blocking=True)
        avail_actions = data['avail_actions'].to(device=self.device, non_blocking=True)
        actions_onehot = self.one_hot(actions,self.n_actions)
 
        n_batch = obs.shape[0]
        T = obs.shape[1]
        alpha = torch.exp(self.log_alpha.detach())
        valid_rate = torch.mean(valid.float())

        if self.agent_id:
            agent_ids = torch.eye(self.n_agents,device=obs.device)
            agent_ids = agent_ids.reshape((1,)*(obs.ndim-2)+agent_ids.shape)            
            agent_ids = agent_ids.expand(obs.shape[:-2]+(-1,-1))
            obs = torch.cat([obs,agent_ids],-1)
        if self.last_action:
            last_actions = torch.zeros_like(actions_onehot)
            last_actions[:,1:] = actions_onehot[:,:-1]
            obs = torch.cat([obs,last_actions],-1)

        hiddens1 = self.sys_critic1.init_hiddens(n_batch)
        hiddens1_tar = self.sys_critic1_tar.init_hiddens(n_batch)
        hiddens2 = self.sys_critic2.init_hiddens(n_batch)
        hiddens2_tar = self.sys_critic2_tar.init_hiddens(n_batch)
        hiddens_actor = self.sys_actor.init_hiddens(n_batch)
        

        Q1s = []
        Q2s = []
        Q1s_tar = []
        Q2s_tar = []        
        ps = []
        
        

        for i in range(T):
            
            Q1, hiddens1 = self.sys_critic1.forward(obs[:,i], avail_actions[:,i], hiddens1)
            Q2, hiddens2 = self.sys_critic2.forward(obs[:,i], avail_actions[:,i], hiddens2)
            p, hiddens_actor = self.sys_actor.forward(obs[:,i], avail_actions[:,i], hiddens_actor)
            
            #with torch.no_grad():
                
            Q1_tar, hiddens1_tar = self.sys_critic1_tar.forward(obs[:,i], avail_actions[:,i], hiddens1_tar)
            Q2_tar, hiddens2_tar = self.sys_critic2_tar.forward(obs[:,i], avail_actions[:,i], hiddens2_tar)
            
            Q1s.append(Q1)
            Q2s.append(Q2)
            Q1s_tar.append(Q1_tar)
            Q2s_tar.append(Q2_tar)
            ps.append(p)

        Q1s = torch.stack(Q1s,1)
        Q2s = torch.stack(Q2s,1)
        Q1s_tar = torch.stack(Q1s_tar,1)
        Q2s_tar = torch.stack(Q2s_tar,1)

        ps = torch.stack(ps,1)
        ps[valid == 0] = 0
        
        log_ps = torch.log(ps + 1e-38)
        log_ps[valid == 0] = 0
        #log_ps[avail_actions == 0] = 0
        entropy = -torch.sum(ps*log_ps, -1)
        

        Q1s[valid == 0] = 0
        Q2s[valid == 0] = 0
        Q1s_tar[valid == 0] = 0
        Q2s_tar[valid == 0] = 0

        q1s = self.gather_end(Q1s,actions)
        q2s = self.gather_end(Q2s,actions)

        q1s_tot = self.mix_net1(q1s, state)
        q2s_tot = self.mix_net2(q2s, state)
        
        q1s_tot[valid == 0] = 0
        q2s_tot[valid == 0] = 0
        
        V1s = torch.sum(ps * Q1s.detach(), -1)
        V2s = torch.sum(ps * Q2s.detach(), -1)
        
        V1s_tar = torch.sum(ps.detach() * Q1s_tar, -1)
        V2s_tar = torch.sum(ps.detach() * Q2s_tar, -1)
        
        self.mix_net1.requires_grad_(False)
        self.mix_net2.requires_grad_(False)
        
        V1s_tot = self.mix_net1.forward(V1s, state)
        V2s_tot = self.mix_net2.forward(V2s, state)

        self.mix_net1.requires_grad_(True)
        self.mix_net2.requires_grad_(True)

        V1s_tot_tar = self.mix_net1_tar.forward(V1s_tar, state)
        V2s_tot_tar = self.mix_net2_tar.forward(V2s_tar, state)
        
        Vs_tot = torch.min(torch.stack([V1s_tot,V2s_tot],-1), -1)[0]
        Vs_tot_tar = torch.min(torch.stack([V1s_tot_tar,V2s_tot_tar],-1), -1)[0]

        Vs_tot[valid == 0] = 0
        Vs_tot_tar[valid == 0] = 0

        

        alpha_entropy = torch.sum(alpha * entropy, -1, keepdim=True)     
        Ves_tot = Vs_tot + alpha_entropy
        Ves_tot_tar = Vs_tot_tar + alpha_entropy.detach()
        
        # train actor        
        loss_actor = - torch.mean(Ves_tot)/valid_rate
        self.optim_actor.zero_grad()
        loss_actor.backward()
        torch.nn.utils.clip_grad_norm_(self.params_actor, self.args.grad_norm_clip)
        self.optim_actor.step()
               
        
                        
        qs_star = torch.zeros_like(q1s_tot)
        qs_star += reward.unsqueeze(-1)

        if self.maximum_entropy:
            sys_v_tar = Ves_tot_tar
        else:
            sys_v_tar = Vs_tot_tar

        qs_star[:,:-1] += self.gamma * (sys_v_tar[:,1:])
        
        loss1 = F.mse_loss(q1s_tot,qs_star)/valid_rate
        loss2 = F.mse_loss(q2s_tot,qs_star)/valid_rate        
        loss = loss1+loss2
        self.optim_critic.zero_grad()        
        loss.backward()  
        torch.nn.utils.clip_grad_norm_(self.params_critic, self.args.grad_norm_clip)      
        self.optim_critic.step()        
        self.update_target()
        
        loss_alpha = self.log_alpha*(entropy.detach()-self.entropy_tar)
        loss_alpha[valid ==0] = 0
        loss_alpha = torch.mean(loss_alpha)/valid_rate

        self.optim_alpha.zero_grad()
        loss_alpha.backward()
        self.optim_alpha.step()
        
        m_loss = loss/2
        m_alpha = alpha.mean()    
        m_v = torch.mean(V1s.detach())/valid_rate        
        m_v_total = torch.mean(Vs_tot.detach())/valid_rate
        m_entropy = torch.mean(entropy.detach())/valid_rate        
        m_max_p = torch.mean(torch.max(ps.detach(),-1)[0])/valid_rate
        
        w_util.WriteScalar('train/loss', m_loss.item())
        w_util.WriteScalar('train/v', m_v.item())
        w_util.WriteScalar('train/v_total', m_v_total.item())
        w_util.WriteScalar('train/entropy', m_entropy.item())
        w_util.WriteScalar('train/alpha', m_alpha.item())        
        w_util.WriteScalar('train/max_p', m_max_p.item())
        
        return m_loss.item()

    def state_dict(self):
        state_dict = {}
        state_dict['critic1'] = self.sys_critic1.state_dict()
        state_dict['critic2'] = self.sys_critic2.state_dict()
        state_dict['critic1_tar'] = self.sys_critic1_tar.state_dict()
        state_dict['critic2_tar'] = self.sys_critic2_tar.state_dict()
        state_dict['mix_net1'] = self.mix_net1.state_dict()
        state_dict['mix_net2'] = self.mix_net2.state_dict()
        state_dict['mix_net1_tar'] = self.mix_net1_tar.state_dict()
        state_dict['mix_net2_tar'] = self.mix_net2_tar.state_dict()
        state_dict['actor'] = self.sys_actor.state_dict()
        state_dict['log_alpha'] = self.log_alpha.detach()
        state_dict['optim_critic'] = self.optim_critic.state_dict()
        state_dict['optim_actor'] = self.optim_actor.state_dict()
        state_dict['optim_alpha'] = self.optim_alpha.state_dict()
        return state_dict        

    def load_state_dict(self, state_dict):
        self.sys_critic1.load_state_dict(state_dict['critic1'])
        self.sys_critic2.load_state_dict(state_dict['critic2'])
        self.sys_critic1_tar.load_state_dict(state_dict['critic1_tar'])
        self.sys_critic2_tar.load_state_dict(state_dict['critic2_tar'])
        self.mix_net1.load_state_dict(state_dict['mix_net1'])
        self.mix_net2.load_state_dict(state_dict['mix_net2'])
        self.mix_net1_tar.load_state_dict(state_dict['mix_net1_tar'])
        self.mix_net2_tar.load_state_dict(state_dict['mix_net2_tar'])
        self.sys_actor.load_state_dict(state_dict['actor'])
        with torch.no_grad():
            self.log_alpha.copy_(state_dict['log_alpha'])
        self.optim_critic.load_state_dict(state_dict['optim_critic'])
        self.optim_actor.load_state_dict(state_dict['optim_actor'])
        self.optim_alpha.load_state_dict(state_dict['optim_alpha'])

        
    def gather_end(self, input, index):
        index = torch.unsqueeze(index,-1).to(dtype=torch.int64)
        return torch.gather(input, input.ndim -1, index).squeeze(-1)

    def one_hot(self, tensor, n_classes):
        return F.one_hot(tensor.to(dtype=torch.int64), n_classes).to(dtype=torch.float32)

    def update_target(self):
        if self.target_update >=1:
            if self.step % self.target_update == 0:
                self._sync_target()
        else:
            def soft_update(src,tar):
                cur_state = src.state_dict()
                tar_state = tar.state_dict()
                for key in tar_state:                
                    v_tar = tar_state[key]
                    v_cur = cur_state[key]
                    v_tar += self.target_update*(v_cur-v_tar).detach()
            soft_update(self.sys_critic1,self.sys_critic1_tar)
            soft_update(self.sys_critic2,self.sys_critic2_tar)
            soft_update(self.mix_net1,self.mix_net1_tar)
            soft_update(self.mix_net2,self.mix_net2_tar)
    
    def _sync_target(self):
        self.sys_critic1_tar.load_state_dict(self.sys_critic1.state_dict())
        self.sys_critic2_tar.load_state_dict(self.sys_critic2.state_dict())
        self.mix_net1_tar.load_state_dict(self.mix_net1.state_dict())
        self.mix_net2_tar.load_state_dict(self.mix_net2.state_dict())
 
    def _valid_mask(self, input, valid):
        input *= valid.view(list(valid.shape) + [1] * (input.ndim - valid.ndim))
        return input
        
class Experiment:
    def __init__(self, args):        
        self.args = args


    def save(self):

        run_state = {}                
        run_state['episode'] = self.e
        run_state['step'] = self.step
        run_state['best_win_rate'] = self.best_win_rate
        run_state['learner'] = self.learner.state_dict()
        run_state['buffer'] = self.buffer.state_dict()
        torch.save(run_state, self.path_checkpt)        

    def load(self):
        run_state = torch.load(self.path_checkpt)
        self.e = run_state['episode']
        self.step = run_state['step']
        self.best_win_rate = run_state['best_win_rate']
        self.learner.load_state_dict(run_state['learner'])                
        self.buffer.load_state_dict(run_state['buffer'])
        self.result = np.load(self.path_result)

    def start(self):
        args = self.args

        path_checkpt = 'checkpoints'
        path_result = 'results'
        path_model = 'models'
        if not os.path.exists(path_checkpt):
            os.mkdir(path_checkpt)
        if not os.path.exists(path_result):
            os.mkdir(path_result)
        if not os.path.exists(path_model):
            os.mkdir(path_model)
        path_checkpt = os.path.join(path_checkpt, args.run_name + '.tar')
        path_result = os.path.join(path_result, args.run_name + '.npy')
        path_model = os.path.join(path_model, args.run_name + '.tar')
            
        # env = StarCraft2Env(map_name=args.map_name, window_size_x=640, window_size_y=480)
        env = simple_reference_v3.parallel_env( local_ratio=args.local_ratio, max_cycles=args.max_cycles, continuous_actions=args.continuous_actions)
        print(env)
        env.reset()

        env_info = {
            "n_agents": args.N,
            "n_actions": 5,  # no_action, move_left, move_right, move_down, move_up
            "obs_shape": 21,
            "state_shape": 42,
            "episode_limit": args.max_cycles
        }
        # env_info = env.get_env_info()

        args.n_agents = env_info["n_agents"]
        args.n_actions = env_info["n_actions"]
        args.obs_dim = env_info['obs_shape']
        args.input_dim = args.obs_dim
        if args.agent_id:
            args.input_dim += args.n_agents
        if args.last_action:        
            args.input_dim += args.n_actions        
        args.state_dim = env_info['state_shape']
        args.episode_limit = env_info['episode_limit']
        print(f"Calculated input_dim: {args.input_dim}")

        
        writter = SummaryWriter('runs/'+ args.run_name + '/' + datetime.datetime.now().strftime('%Y-%m-%d,%H%M%S'))
        w_util = WritterUtil(writter,args)
        learner = Learner(w_util,args)               
        ctrler = Controller(learner.sys_actor,args)
        runner = Runner(env,ctrler,args)
                
        scheme = {}
        n_agents = args.n_agents
        scheme['obs'] = {'shape':(n_agents, args.obs_dim), 'dtype': torch.float32}
        scheme['valid'] = {'shape':(), 'dtype': torch.int32}
        scheme['actions'] = {'shape':(n_agents,), 'dtype': torch.int32}
        scheme['avail_actions'] = {'shape':(n_agents, args.n_actions), 'dtype': torch.int32}
        scheme['reward'] = {'shape':(), 'dtype': torch.float32}
        scheme['state'] = {'shape':(args.state_dim,), 'dtype': torch.float32}

        buffer = EpisodeBuffer(scheme, args)
        result = np.zeros((3, args.n_steps // args.test_every_step))

        self.e = 0
        self.step = 0
        self.best_win_rate = 0                   
        self.path_checkpt = path_checkpt
        self.path_result = path_result
        self.path_model = path_model      
        self.env = env        
        self.writter = writter
        self.w_util = w_util
        self.ctrler = ctrler
        self.runner = runner
        self.learner = learner
        self.buffer = buffer
        self.result = result

        if args.continue_run:
            self.load()
                    
        
    def run(self):
        
        args = self.args
        buffer = self.buffer        
        runner = self.runner
        learner = self.learner
        w_util = self.w_util
        
        while self.step < args.n_steps:
        #for self.e in range(self.e, args.n_episodes+1):
            self.e += 1            
            data, episode_reward, win_tag, step =  runner.run()
            old_step = self.step
            self.step += step                   
            w_util.set_step(self.step)            
            w_util.WriteScalar('train/reward', episode_reward)
            print("Episode {}, step {}, win = {}, reward = {}".format(self.e, self.step, win_tag, episode_reward))
            buffer.add_episode(data)
            data = buffer.sample(args.n_batch, args.top_n)
            if data:
                loss = learner.train(data)                        
                        
            if self.step // args.test_every_step != old_step // args.test_every_step:
                self.test_model()            

            if args.save_every and self.e % args.save_every == 0:                
                self.save()

        self.env.close()

    def test_model(self):
        args = self.args
        w_util = self.w_util
        runner = self.runner
        result = self.result

        win_count = 0
        reward_avg = 0
        for i in range(args.test_count):
            _, episode_reward, win_tag, _ =  runner.run(test_mode=True)
            if win_tag:
                win_count += 1
            reward_avg += episode_reward                
        win_rate = win_count/args.test_count        
        reward_avg /= args.test_count
        w_util.WriteScalar('test/reward', reward_avg)
        w_util.WriteScalar('test/win_rate', win_rate)
        result[:, self.step // args.test_every_step - 1] = [self.e, self.step, win_rate]
        np.save(self.path_result, self.result)
        if win_rate >= self.best_win_rate:
            self.best_win_rate = win_rate
            torch.save(self.learner.sys_actor.state_dict(), self.path_model)            
        print('Test reward = {}, win_rate = {}'.format(reward_avg, win_rate))
        
class Runner:
    def __init__(self, env, controller, args):
        """
        Initialize the Runner.
        Args:
            env: The PettingZoo environment.
            controller: The controller that decides agent actions.
            args: Additional arguments, including max_cycles and number of agents (N).
        """
        self.controller = controller
        self.env = env
        self.args = args

    def run(self, test_mode=False):
        """
        Run a single episode in the PettingZoo environment.

        Args:
            test_mode (bool): If True, disables exploration. Defaults to False.

        Returns:
            tuple: (data, episode_reward, win_tag, steps)
        """
        print("Resetting environment...")
        observations, infos = self.env.reset()

        self.controller.new_episode()
        data = {
            'state': [],
            'obs': [],
            'valid': [],
            'actions': [],
            'avail_actions': [],
            'reward': []
        }
        episode_reward = 0
        steps = 0

        while steps < self.args.max_cycles:
            steps += 1
            
            # Ensure environment state is valid
            assert hasattr(self.env, 'state'), "Environment does not have a 'state()' method."
            state = self.env.state()

            # Validate observations dictionary
            assert isinstance(observations, dict), "Observations must be a dictionary."
            assert len(observations) == len(self.env.agents), "Mismatch between agents and observations."

            # Convert observations dict to list in consistent order
            obs_list = [observations[agent_id] for agent_id in self.env.agents]

            # Validate available actions
            avail_actions_list = [[1] * 5 for _ in range(self.args.N)]  # Example: 5 actions per agent
            assert len(avail_actions_list) == self.args.N, "Available actions list size mismatch."

            explore = not test_mode

            # Get actions from controller
            actions_array = self.controller.get_actions(obs_list, avail_actions_list, explore)
            assert len(actions_array) == len(self.env.agents), "Mismatch between actions and agents."

            # Convert actions array to dict for PettingZoo environment
            actions_dict = {agent_id: action for agent_id, action in zip(self.env.agents, actions_array)}

            # Step the environment
            try:
                observations, rewards, terminations, truncations, infos = self.env.step(actions_dict)
            except Exception as e:
                print(f"Error during environment step: {e}")
                break

            # Validate rewards dictionary
            assert isinstance(rewards, dict), "Rewards must be a dictionary."

            # Calculate total reward for this step
            step_reward = sum(rewards.values())
            episode_reward += step_reward

            # Store transition data
            data['state'].append(state)
            data['obs'].append(obs_list)
            data['valid'].append(1)  # Assume all steps are valid
            data['actions'].append(actions_array)
            data['avail_actions'].append(avail_actions_list)
            data['reward'].append(step_reward)

            # Check if all agents are done
            if all(terminations.values()) or all(truncations.values()):
                break

        # Define win condition placeholder
        win_tag = episode_reward > 0  # Modify based on environment-specific criteria

        return data, episode_reward, win_tag, steps

class Args:
    pass


if __name__ == "__main__":

    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    config_path='config_reference.json'
        
    with open(config_path, 'r') as f:
        config = json.load(f)
    args = Args()
    args.__dict__.update(config)

    args.continue_run = False
    argv = sys.argv
    if len(argv)>1:
        if argv[1] == '-c':  ##continue
            args.continue_run = True

    experiment = Experiment(args)
    experiment.start()
    experiment.run()
    



simple_reference_v3
Calculated input_dim: 23
Resetting environment...
Episode 1, step 75, win = False, reward = -288.01432849850306
Resetting environment...
Episode 2, step 150, win = False, reward = -142.1535379728103
Resetting environment...
Episode 3, step 225, win = False, reward = -135.1270850043955
Resetting environment...
Episode 4, step 300, win = False, reward = -145.24494312781715
Resetting environment...
Episode 5, step 375, win = False, reward = -190.4201828004041
Resetting environment...
Episode 6, step 450, win = False, reward = -302.21583289685685
Resetting environment...
Episode 7, step 525, win = False, reward = -218.045845509648
Resetting environment...
Episode 8, step 600, win = False, reward = -206.38626844102592
Resetting environment...
Episode 9, step 675, win = False, reward = -180.96033648925334
Resetting environment...
Episode 10, step 750, win = False, reward = -214.77717370651763
Resetting environment...
Episode 11, step 825, win = False, reward = -249.814848

In [None]:
import torch
import imageio
import numpy as np
from pettingzoo.mpe import simple_spread_v3

# Model Actor
class Actor(torch.nn.Module):
    def __init__(self):
        super(Actor, self).__init__()
        self.input_dim = 23
        self.hidden_dim = 64
        self.n_actions = 5
        self.fc1 = torch.nn.Linear(self.input_dim, self.hidden_dim)
        self.rnn = torch.nn.GRUCell(self.hidden_dim, self.hidden_dim)
        self.fc2 = torch.nn.Linear(self.hidden_dim, self.n_actions)

    def init_hidden(self, batch_size):
        return torch.zeros(batch_size, self.hidden_dim)

    def forward(self, inputs, avail_actions, h_last):
        x = torch.relu(self.fc1(inputs))
        h = self.rnn(x, h_last)
        q_values = self.fc2(h)
        q_values[avail_actions == 0] = -float('inf')
        probabilities = torch.nn.functional.softmax(q_values, dim=-1)
        return probabilities, h

# Funkcja rozszerzająca obserwacje
def augment_observation(obs, agent_id, n_agents):
    agent_id_one_hot = np.eye(n_agents)[agent_id]  # One-hot encoding agenta
    return np.concatenate([obs, agent_id_one_hot])

# Funkcja usuwająca prefiks
def remove_prefix_from_state_dict(state_dict, prefix="agent."):
    return {key[len(prefix):] if key.startswith(prefix) else key: value for key, value in state_dict.items()}

# Wizualizator
class AgentVisualizer:
    def __init__(self, actor, env, device='cpu'):
        self.device = torch.device(device)
        self.env = env
        self.actor = actor.to(self.device)
        self.actor.eval()
        self.n_agents = len(env.possible_agents)

    def run_episode(self, save_gif=True, gif_path='trained_agent_demo.gif'):
        frames = []
        obs, infos = self.env.reset(seed=42)
        hidden_state = self.actor.init_hidden(1).to(self.device)
        done = False

        while not done:
            actions = {}
            with torch.no_grad():
                for idx, agent_id in enumerate(self.env.agents):
                    obs_tensor = augment_observation(
                        obs[agent_id], agent_id=idx, n_agents=self.n_agents
                    )
                    obs_tensor = torch.tensor(obs_tensor, dtype=torch.float32).unsqueeze(0).to(self.device)
                    avail_actions = torch.ones((1, self.actor.n_actions), device=self.device)
                    probabilities, hidden_state = self.actor.forward(obs_tensor, avail_actions, hidden_state)
                    actions[agent_id] = torch.argmax(probabilities, dim=-1).item()

            next_obs, rewards, terminations, truncations, infos = self.env.step(actions)
            done = all(terminations.values()) or all(truncations.values())
            frames.append(self.env.render())

            obs = next_obs

        if save_gif:
            imageio.mimsave(gif_path, frames, fps=60)
            print(f"Zapisano epizod jako GIF: {gif_path}")

# Załaduj model
actor_path = "results/A_Final_simple_referenece_3e-4.tar"
state_dict = torch.load(actor_path, map_location='cpu')
state_dict = remove_prefix_from_state_dict(state_dict)

actor = Actor()
actor.load_state_dict(state_dict)

# Środowisko
env = simple_reference_v3.parallel_env(
    local_ratio=0.5,
    max_cycles=75,
    continuous_actions=False,
    render_mode="rgb_array"
)

# Wizualizacja
visualizer = AgentVisualizer(actor=actor, env=env, device='cuda' if torch.cuda.is_available() else 'cpu')
visualizer.run_episode(save_gif=True, gif_path='trained_agent_reference.gif')


UnpicklingError: unpickling stack underflow