In [None]:
from pistonball_CleanRL import *
import numpy as np
import torch
import torch.nn as nn
from torch.distributions.categorical import Categorical
from supersuit import color_reduction_v0, frame_stack_v1, resize_v1
from pettingzoo.butterfly import pistonball_v6

In [None]:
class Agent(nn.Module):
    def __init__(self, num_actions):
        super().__init__()

        self.network = nn.Sequential(
            self._layer_init(nn.Conv2d(4, 32, 3, padding=1)),
            nn.MaxPool2d(2),
            nn.ReLU(),
            self._layer_init(nn.Conv2d(32, 64, 3, padding=1)),
            nn.MaxPool2d(2),
            nn.ReLU(),
            self._layer_init(nn.Conv2d(64, 128, 3, padding=1)),
            nn.MaxPool2d(2),
            nn.ReLU(),
            nn.Flatten(),
            self._layer_init(nn.Linear(128 * 8 * 8, 512)),
            nn.ReLU(),
        )
        self.actor = self._layer_init(nn.Linear(512, num_actions), std=0.01)
        self.critic = self._layer_init(nn.Linear(512, 1))

    def _layer_init(self, layer, std=np.sqrt(2), bias_const=0.0):
        torch.nn.init.orthogonal_(layer.weight, std)
        torch.nn.init.constant_(layer.bias, bias_const)
        return layer

    def get_value(self, x):
        return self.critic(self.network(x / 255.0))

    def forward(self, x, action=None):
        hidden = self.network(x / 255.0)
        logits = self.actor(hidden)
        probs = Categorical(logits=logits)
        if action is None:
            action = probs.sample()
        return action, probs.log_prob(action), probs.entropy(), self.critic(hidden)


In [None]:
class Agent(nn.Module):
    def __init__(self, num_actions, embedding_dim=64):
        super().__init__()
        
        self.network = nn.Sequential(
            self._layer_init(nn.Conv2d(4, 32, 3, padding=1)),
            nn.MaxPool2d(2),
            nn.ReLU(),
            self._layer_init(nn.Conv2d(32, 64, 3, padding=1)),
            nn.MaxPool2d(2),
            nn.ReLU(),
            self._layer_init(nn.Conv2d(64, 128, 3, padding=1)),
            nn.MaxPool2d(2),
            nn.ReLU(),
            nn.Flatten(),
            self._layer_init(nn.Linear(128 * 8 * 8, 512)),
            nn.ReLU(),
        )
        
        self.action_embedding1 = nn.Embedding(num_actions, embedding_dim)
        self.action_embedding2 = nn.Embedding(num_actions, embedding_dim)
        
        self.actor = self._layer_init(nn.Linear(512 + 2 * embedding_dim, num_actions))
        self.critic = self._layer_init(nn.Linear(512 + 2 * embedding_dim, 1))
    
    def _layer_init(self, layer, std=np.sqrt(2), bias_const=0.0):
        torch.nn.init.orthogonal_(layer.weight, std)
        torch.nn.init.constant_(layer.bias, bias_const)
        return layer
    
    def get_value(self, x):
        return self.critic(self.network(x / 255.0))
    
    def get_action_and_value(self, x, actions):
        action1, action2 = actions
        hidden = self.network(x / 255.0)
        action1_embedded = self.action_embedding1(action1).unsqueeze(0)
        action2_embedded = self.action_embedding2(action2).unsqueeze(0)
        hidden = torch.cat((hidden, action1_embedded, action2_embedded), dim=1)
        logits = self.actor(hidden)
        probs = Categorical(logits=logits)
        action = probs.sample()
        return action, probs.log_prob(action), probs.entropy(), self.critic(hidden)

    def forward(self, x):
        num_agents = x.shape[0]
        x = x.unsqueeze(1)
        actions = torch.ones(num_agents + 2, dtype=torch.int)
        log_probs = torch.zeros(num_agents)
        entropies = torch.zeros(num_agents)
        values = torch.zeros(num_agents)
        
        for ind in range(num_agents - 1, -1, -1):
            depend_actions = actions[ind:ind+2]
            actions[ind], log_probs[ind], entropies[ind], values[ind] = self.get_action_and_value(x[ind], depend_actions)
        return actions[:-2], log_probs, entropies, values
        


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
frame_size = (64, 64)
stack_size = 4

env = pistonball_v6.parallel_env(render_mode="rgb_array", continuous=False)
env = color_reduction_v0(env)
env = resize_v1(env, frame_size[0], frame_size[1])
env = frame_stack_v1(env, stack_size=stack_size)
next_obs, info = env.reset(seed=None)
obs: torch.Tensor = batchify_obs(next_obs, device)
print(next_obs["piston_0"].shape)
print(obs.shape)
print(env.unwrapped.screen_width, env.unwrapped.screen_height)
print(obs.unsqueeze(1).shape)

In [None]:
num_actions = env.action_space(env.possible_agents[0]).n
agent = Agent_ADG(num_actions=num_actions).to(device)

agent(obs)
# obs = obs[0].unsqueeze(0)
# action, log_prob, entropy, value = agent(obs) #(20, 128, 8, 8)

In [None]:
num_agents = 20
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
action_base = torch.ones(num_agents + 2, dtype=torch.int).to(device)
action = torch.zeros(num_agents, dtype=torch.int)
action_base[:-2] = action
action_base.device

In [None]:
import torch
import torch.nn as nn

# 定义线性层
linear = nn.Linear(3, 2)
x = torch.randn(4, 3)  # mini-batch 大小为 4
y = linear(x)

# 假设目标值
target = torch.randn(4, 2)
criterion = nn.MSELoss(reduction='sum')  # 默认求和
loss = criterion(y, target)
loss.backward()

print(linear.weight.grad)  # 梯度是累加的

In [None]:
import torch.nn as nn
import torch

action_embedding = nn.Embedding(3, 64)
out = action_embedding(torch.tensor([0, 1, 2]))
out