In [None]:
%pip install stable-baselines3 numpy torch supersuit pettingzoo pymunk scipy gymnasium matplotlib einops tensorboard wandb imageio 
from __future__ import annotations
import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from stable_baselines3 import SAC
import glob
import os
import time
from datetime import datetime
import supersuit as ss

import os
import random
import time
from distutils.util import strtobool
from typing import Dict
import matplotlib.pyplot as plt
import imageio
import einops
import gymnasium as gym
from pettingzoo import ParallelEnv
from pettingzoo.mpe import simple_spread_v3
from pettingzoo.mpe import simple_v3
from pettingzoo.mpe import simple_adversary_v3

from pettingzoo.butterfly import knights_archers_zombies_v10

from pettingzoo.utils.env import AgentID, ObsType
from torch.utils.tensorboard import SummaryWriter


import collections


from MASAC.masac.utils import extract_agent_id


# from MASAC.masac.ma_buffer import MAReplayBuffer, Experience
# from MASAC.masac.masac import concat_id
from argparse import Namespace


In [None]:

total_timesteps = 10000
time_now = datetime.now()
exp_name = 'MASAC_QMIX_Circle_simple_adversary'
os.makedirs('output/'+exp_name, exist_ok=True)
# Definicja obiektu args z wymaganymi parametrami
args = Namespace(
    exp_name=exp_name,                      # Nazwa eksperymentu
    seed=1,                                # Seed dla losowości
    torch_deterministic=True,               # Czy używać deterministycznych operacji w PyTorch
    cuda=True,                              # Czy używać CUDA (GPU)
    track=False,                            # Czy śledzić eksperyment (np. za pomocą wandb)
    wandb_project_name="Project_" + exp_name + str(time_now.hour) + ":" + str(time_now.minute),        # Nazwa projektu w wandb
    wandb_entity='Entity_' + exp_name,               # Entity w wandb
    capture_video=False,                    # Czy przechwytywać wideo
    total_timesteps= total_timesteps,                 # Całkowita liczba kroków treningowych
    buffer_size=1000000,                    # Rozmiar bufora replay
    gamma=0.98,                             # Discount factor
    tau=0.01,                              # Współczynnik tau do aktualizacji sieci docelowych
    batch_size=128,                         # Rozmiar batcha
    learning_starts=total_timesteps/10,                   # Krok, po którym zaczyna się nauka
    policy_lr=5e-4,                         # Learning rate dla polityki
    q_lr=5e-4,                              # Learning rate dla Q-function
    policy_frequency=1,                     # Częstotliwość aktualizacji polityki
    target_network_frequency=1,             # Częstotliwość aktualizacji sieci docelowych
    alpha=0.2,                              # Waga entropii
    autotune=True,                           # Czy automatycznie dostrajać alpha
    save_frequency = total_timesteps/10,                  # Częstotliwość zapisywania modeli
    actor_path='output/'+exp_name+"/",                     # Ścieżka do zapisu modelu aktora
)

# env setup
# env = simple_v3.parallel_env(render_mode=None,  continuous_actions=True)
env = simple_spread_v3.parallel_env(render_mode=None, N=2, local_ratio=0.5, max_cycles=40, continuous_actions=True)
# env = simple_adversary_v3.parallel_env(render_mode=None, N=2, max_cycles=20, continuous_actions=True)
env.reset(seed=args.seed)

In [11]:
class SoftQNetwork(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.fc1 = nn.Linear(state_dim + action_dim, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 1)

    def forward(self, x, a):
        x = torch.cat([x, a], dim=1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


In [12]:
LOG_STD_MAX = 2
LOG_STD_MIN = -5
def concat_id(observation, agent_id):
    agent_id_encoded = np.array([int(agent_id.split('_')[-1])], dtype=np.float32)  # Extract numerical ID
    return np.concatenate((observation, agent_id_encoded))

class Actor(nn.Module):
    def __init__(self, observation_space_shape, action_space_shape):
        super().__init__()
        self.input_dim = np.prod(observation_space_shape) + 1

        self.action_dim = np.prod(action_space_shape)

        self.fc1 = nn.Linear(self.input_dim, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc_mean = nn.Linear(256, self.action_dim)
        self.fc_logstd = nn.Linear(256, self.action_dim)

        # Action rescaling
        self.register_buffer("action_scale", torch.tensor(1.0))  # Placeholder, updated later
        self.register_buffer("action_bias", torch.tensor(0.0))

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        mean = self.fc_mean(x)
        log_std = self.fc_logstd(x)
        log_std = torch.tanh(log_std)
        return mean, log_std


    def get_action(self, x):
        mean, log_std = self(x)
        std = log_std.exp()
        normal = torch.distributions.Normal(mean, std)
        x_t = normal.rsample()
        y_t = torch.tanh(x_t)
        action = y_t * self.action_scale + self.action_bias
        log_prob = normal.log_prob(x_t)
        log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + 1e-6)
        log_prob = log_prob.sum(1, keepdim=True)
        return action, log_prob, mean



Replay Buffer

In [13]:
Experience = collections.namedtuple(
    "Experience",
    field_names=[
        "global_obs", 
        "local_obs", 
        "joint_actions", 
        "rewards", 
        "next_global_obs", 
        "next_local_obs", 
        "terminateds", 
        "q_values",  # Nowe: wartości Q
        "log_probs"   # Nowe: logarytmy prawdopodobieństw akcji
    ]
)

class MAReplayBuffer:
    def __init__(
        self,
        global_obs_shape,
        local_obs_shapes: Dict[str, tuple],
        action_dim,
        num_agents=1,
        max_size=100000,
        obs_dtype=np.float32,
        action_dtype=np.float32,
    ):
        self.max_size = max_size
        self.ptr, self.size = 0, 0
        self.obs_type = obs_dtype
        self.action_type = action_dtype
        self.global_obs = np.zeros((max_size,) + global_obs_shape, dtype=obs_dtype)
        self.local_obs = {
            agent_id: np.zeros((max_size,) + shape, dtype=obs_dtype)
            for agent_id, shape in local_obs_shapes.items()
        }
        self.next_global_obs = np.zeros((max_size,) + global_obs_shape, dtype=obs_dtype)
        self.next_local_obs = {
            agent_id: np.zeros((max_size,) + shape, dtype=obs_dtype)
            for agent_id, shape in local_obs_shapes.items()
        }
        self.joint_actions = np.zeros((max_size, action_dim), dtype=action_dtype)

        self.rewards = np.zeros((max_size,), dtype=np.float32)
        self.terminateds = np.zeros((max_size, 1), dtype=np.float32)
        self.q_values = np.zeros((max_size, num_agents), dtype=np.float32)
        self.log_probs = np.zeros((max_size, num_agents), dtype=np.float32)

    def add(
        self,
        global_obs: np.ndarray,
        local_obs: Dict[str, np.ndarray],
        joint_actions: np.ndarray,
        reward: float,
        next_global_obs: np.ndarray,
        next_local_obs: Dict[str, np.ndarray],
        terminated: bool,
        q_values: np.ndarray,
        log_probs: np.ndarray,
    ):
        self.global_obs[self.ptr] = np.array(global_obs, dtype=self.obs_type).copy()
        self.next_global_obs[self.ptr] = np.array(next_global_obs, self.obs_type).copy()
        for agent_id, obs in local_obs.items():
            self.local_obs[agent_id][self.ptr] = np.array(obs, dtype=self.obs_type).copy()
        for agent_id, obs in next_local_obs.items():
            self.next_local_obs[agent_id][self.ptr] = np.array(obs, dtype=self.obs_type).copy()
        self.joint_actions[self.ptr] = np.array(joint_actions, dtype=self.action_type).reshape(-1).copy()
        self.rewards[self.ptr] = np.array(reward, dtype=np.float32).copy()
        self.terminateds[self.ptr] = np.array(terminated).copy()
        self.q_values[self.ptr] = np.array(q_values, dtype=np.float32).copy()
        self.log_probs[self.ptr] = np.array(log_probs, dtype=np.float32).copy()
        self.ptr = (self.ptr + 1) % self.max_size
        self.size = min(self.size + 1, self.max_size)

    def sample(self, batch_size, replace=True, use_cer=False, to_tensor=False, add_id_to_local_obs=False, device=None):
        """Sample a batch of experiences from the buffer.

        Args:
            batch_size: Batch size
            replace: Whether to sample with replacement
            use_cer: Whether to use CER
            to_tensor: Whether to convert the data to PyTorch tensors
            add_id_to_local_obs: Whether to add the agent id to the local observations
            device: Device to use

        Returns:
            An experience tuple:
                global_obs: Global observations (batch_size, global_obs_shape)
                local_obs:
                    Local observations of each agent (batch_size, num_agents, local_obs_shape)
                    (!) the dict is flattened into a vector
                    If add_id_to_local_obs is True, the local observations vectors are concatenated with the agent id
                joint_actions: Actions of all agents (batch_size, num_agents * action_dim)
                rewards: Rewards (batch_size,)
                next_global_obs: Next global observations (batch_size, global_obs_shape)
                next_local_obs: Next local observations of each agent (batch_size, num_agents, local_obs_shape) (!) the dict is flattened into a vector
                terminateds: Whether the episode is terminated or not (batch_size, 1)

        """
        inds = np.random.choice(self.size, batch_size, replace=replace)
        if use_cer:
            inds[0] = self.ptr - 1  # always use last experience

        def flatten_local_obss(local_obs_dict, inds, to_tensor=False, device=None, add_id_to_local_obs=False):
            batch_local_obs = {agent_id: [] for agent_id in local_obs_dict.keys()}
            for agent_id, obs_array in local_obs_dict.items():
                if to_tensor:
                    batch_local_obs[agent_id] = torch.tensor(obs_array[inds], dtype=torch.float32).to(device)
                else:
                    batch_local_obs[agent_id] = obs_array[inds]
            return batch_local_obs



        # print(f"Indices: {inds}")
        # for agent_id, obs in local_obs.items():
        #     print(f"Agent {agent_id}, Observations shape: {obs.shape}, Selected indices shape: {obs[inds].shape}")


        if to_tensor:
            return Experience(
                global_obs=torch.tensor(self.global_obs[inds]).to(device),
                local_obs=flatten_local_obss(self.local_obs, inds, to_tensor=True, device=device, add_id_to_local_obs=add_id_to_local_obs),
                joint_actions=torch.tensor(self.joint_actions[inds]).to(device),
                rewards=torch.tensor(self.rewards[inds]).to(device),
                next_global_obs=torch.tensor(self.next_global_obs[inds]).to(device),
                next_local_obs=flatten_local_obss(self.next_local_obs, inds, to_tensor=True, device=device, add_id_to_local_obs=add_id_to_local_obs),
                terminateds=torch.tensor(self.terminateds[inds]).to(device),
                q_values=torch.tensor(self.q_values[inds]).to(device),
                log_probs=torch.tensor(self.log_probs[inds]).to(device),
            )
        else:
            return Experience(
                global_obs=self.global_obs[inds],
                local_obs=flatten_local_obss(self.local_obs, inds, to_tensor=False, device=device, add_id_to_local_obs=add_id_to_local_obs),
                joint_actions=self.joint_actions[inds],
                rewards=self.rewards[inds],
                next_global_obs=self.next_global_obs[inds],
                next_local_obs=flatten_local_obss(self.next_local_obs, inds, to_tensor=False, device=device, add_id_to_local_obs=add_id_to_local_obs),
                terminateds=self.terminateds[inds],
                q_values=self.q_values[inds],
                log_probs=self.log_probs[inds],
            )



    def __len__(self):
        """Get the size of the buffer."""
        return self.size

QMIX

In [14]:
class MixingNetwork(nn.Module):
    """Mixing Network to compute Q_tot from Q_a."""
    def __init__(self, num_agents, state_dim, mixing_hidden_dim=32):
        super(MixingNetwork, self).__init__()
        self.num_agents = num_agents

        # Hypernetwork for weights and biases
        self.hyper_w1 = nn.Sequential(
            nn.Linear(state_dim, mixing_hidden_dim),
            nn.ReLU(),
            nn.Linear(mixing_hidden_dim, num_agents * mixing_hidden_dim)
        )
        self.hyper_b1 = nn.Linear(state_dim, mixing_hidden_dim)

        self.hyper_w2 = nn.Sequential(
            nn.Linear(state_dim, mixing_hidden_dim),
            nn.ReLU(),
            nn.Linear(mixing_hidden_dim, mixing_hidden_dim)
        )
        self.hyper_b2 = nn.Sequential(
            nn.Linear(state_dim, mixing_hidden_dim),
            nn.ReLU(),
            nn.Linear(mixing_hidden_dim, 1)
        )

        # Nonlinear layer to mix agent Q-values
        self.non_linear = nn.ReLU()

    def forward(self, q_values, state):
        """Forward pass for computing Q_tot."""
        batch_size = q_values.size(0)

        # Compute first layer weights and biases
        w1 = self.hyper_w1(state).view(batch_size, self.num_agents, -1)
        b1 = self.hyper_b1(state).view(batch_size, 1, -1)

        # First mixing layer
        hidden = self.non_linear(torch.bmm(q_values.unsqueeze(1), w1) + b1)

        # Compute second layer weights and biases
        w2 = self.hyper_w2(state).view(batch_size, -1, 1)
        b2 = self.hyper_b2(state).view(batch_size, 1, 1)

        # Second mixing layer
        q_tot = torch.bmm(hidden, w2) + b2
        return q_tot.squeeze(-1)


In [15]:
# Klasa do zbierania i wyświetlania zwrotów z epizodów
class Plotter:
    def __init__(self):
        self.returns_trained = []
        self.episodes_trained = []
        self.returns_random = []
        self.episodes_random = []

    def add_return(self, episode, return_value, agent_type='trained'):
        if agent_type == 'trained':
            self.episodes_trained.append(episode)
            self.returns_trained.append(return_value)
        elif agent_type == 'random':
            self.episodes_random.append(episode)
            self.returns_random.append(return_value)

    def plot_returns(self):
        plt.figure(figsize=(12,6))
        plt.plot(self.episodes_trained, self.returns_trained, label='Wytrenowany Agent', marker='o')
        plt.plot(self.episodes_random, self.returns_random, label='Agent Losowy', marker='x')
        plt.xlabel('Epizod')
        plt.ylabel('Zwrot')
        plt.title('Porównanie Zwrotów: Wytrenowany Agent vs Agent Losowy')
        plt.legend()
        plt.grid(True)
        plt.show()

# Klasa do wizualizacji działania agenta
class AgentVisualizer:
    def __init__(self, actor_path, device='cpu'):
        """
        Inicjalizuje wizualizatora agenta.

        Args:
            actor_path (str): Ścieżka do pliku z zapisanym modelem aktora (actor.pth).
            device (str, optional): Urządzenie do obliczeń ('cpu' lub 'cuda'). Domyślnie 'cpu'.
        """
        self.device = torch.device(device if torch.cuda.is_available() and device == 'cuda' else 'cpu')
        # Używamy oryginalnego środowiska PettingZoo z renderowaniem 'rgb_array'

        self.env = env
        # self.env = simple_spread_v3.parallel_env(
        #     N=3, 
        #     local_ratio=0.5, 
        #     max_cycles=25, 
        #     continuous_actions=True,
        #     render_mode="rgb_array"  # Poprawne ustawienie render_mode
        # )
        self.actor = Actor(self.env)
        self.actor.load_state_dict(torch.load(actor_path, map_location=self.device))
        self.actor.to(self.device)
        self.actor.eval()
        self.plotter = Plotter()

        # Sprawdzenie dostępnych trybów renderowania
        available_render_modes = self.env.metadata.get('render_modes', None)
        print("Dostępne tryby renderowania:", available_render_modes)

    def run_episodes(self, num_episodes=10, agent_type='trained', render=False, save_gif=False, gif_path='agent_demo.gif'):
        """
        Uruchamia określoną liczbę epizodów z wytrenowanym lub losowym agentem.

        Args:
            num_episodes (int, optional): Liczba epizodów do uruchomienia. Domyślnie 10.
            agent_type (str, optional): Typ agenta ('trained' lub 'random'). Domyślnie 'trained'.
            render (bool, optional): Czy renderować epizody. Domyślnie False.
            save_gif (bool, optional): Czy zapisać renderowane epizody jako GIF. Domyślnie False.
            gif_path (str, optional): Ścieżka do zapisu GIF. Domyślnie 'agent_demo.gif'.
        """
        frames = []
        for episode in range(1, num_episodes + 1):
            obs, info = self.env.reset(seed=42 + episode)
            global_return = 0.0
            done = False
            while not done:
                actions = {}
                if agent_type == 'trained':
                    with torch.no_grad():
                        # for agent_id in self.env.possible_agents:
                        #     # Przygotowanie obserwacji z ID agenta
                        #     obs_with_id = concat_id(obs[agent_id], agent_id)
                        #     obs_tensor = torch.Tensor(obs_with_id).to(self.device)
                        #     # Dodanie wymiaru batch (1, ...)
                        #     obs_tensor = obs_tensor.unsqueeze(0)
                        #     action, _, _ = self.actor.get_action(obs_tensor)
                        #     # Przekonwertowanie akcji na numpy
                        #     actions[agent_id] = action.cpu().numpy().flatten()
                        for agent_id in self.env.possible_agents:
                            obs_tensor = torch.Tensor(obs[agent_id]).to(self.device).unsqueeze(0)
                            action, _, _ = actors[agent_id].get_action(obs_tensor)
                            actions[agent_id] = action.cpu().numpy().flatten()

                elif agent_type == 'random':
                    for agent_id in self.env.possible_agents:
                        # Przykładowe akcje losowe zgodne z przestrzenią akcji
                        actions[agent_id] = self.env.action_space(agent_id).sample()
                else:
                    raise ValueError("Nieznany typ agenta. Użyj 'trained' lub 'random'.")

                # Wykonanie akcji w środowisku
                next_obs, rewards, terminations, truncations, infos = self.env.step(actions)
                done = any(terminations.values()) or any(truncations.values())

                if render:
                    try:
                        # Renderowanie środowiska w trybie 'rgb_array'
                        frame = self.env.render()
                        if frame is not None:
                            frames.append(frame)
                        else:
                            print("Renderowanie zwróciło None.")
                    except TypeError as e:
                        print(f"Nieudane renderowanie z mode='rgb_array': {e}")
                        print("Spróbuj wywołać render bez argumentów lub z innym trybem.")
                        try:
                            frame = self.env.render()
                            if frame is not None:
                                frames.append(frame)
                        except Exception as e:
                            print(f"Nieudane renderowanie bez trybu: {e}")

                # Sumowanie zwrotów
                global_return += sum(rewards.values())

                # Aktualizacja obserwacji
                obs = next_obs

            # Dodanie zwrotu do Plotter
            self.plotter.add_return(episode, global_return, agent_type=agent_type)
            print(f"Epizod {episode} ({agent_type}): Zwrot = {global_return}")

        if save_gif and frames:
            try:
                imageio.mimsave(gif_path, frames, fps=10)
                print(f"Zapisano wideo jako '{gif_path}'")
            except Exception as e:
                print(f"Nie udało się zapisać GIF: {e}")

        if render and frames:
            # Wyświetlenie kilku pierwszych klatek jako przykładu
            num_frames_to_show = min(5, len(frames))
            for i in range(num_frames_to_show):
                plt.figure(figsize=(5,5))
                plt.imshow(frames[i])
                plt.axis('off')
                plt.title(f'Klatka {i+1}')
                plt.show()

    def plot_returns(self):
        """
        Wyświetla wykres zwrotów z epizodów dla obu agentów.
        """
        self.plotter.plot_returns()

In [None]:
run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}"
if args.track:
    import wandb

    wandb.init(
        project=args.wandb_project_name,
        entity=args.wandb_entity,
        sync_tensorboard=True,
        config=vars(args),
        name=run_name,
        monitor_gym=False,
        save_code=True,
    )
writer = SummaryWriter(f"runs/{run_name}")
writer.add_text(
    "hyperparameters",
    "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
)

# TRY NOT TO MODIFY: seeding
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = args.torch_deterministic

device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
# device = torch.device("mps") if torch.backends.mps.is_available() else device


env.reset(seed=args.seed)
single_action_space = env.action_space(env.unwrapped.agents[0])
single_observation_space = env.observation_space(env.unwrapped.agents[0])
assert isinstance(single_action_space, gym.spaces.Box), "only continuous action space is supported"

max_action = float(single_action_space.high[0])

actors = {}
for agent_id in env.possible_agents:
    obs_shape = env.observation_space(agent_id).shape
    act_shape = env.action_space(agent_id).shape
    actors[agent_id] = Actor(observation_space_shape=obs_shape, action_space_shape=act_shape).to(device)

total_action_dim = sum([np.prod(env.action_space(agent).shape) for agent in env.possible_agents])
state_dim = np.prod(env.state().shape)
action_dim = total_action_dim

qf1 = SoftQNetwork(state_dim=state_dim, action_dim=action_dim).to(device)
qf2 = SoftQNetwork(state_dim=state_dim, action_dim=action_dim).to(device)

qf1_target = SoftQNetwork(state_dim=state_dim, action_dim=action_dim).to(device)
qf2_target = SoftQNetwork(state_dim=state_dim, action_dim=action_dim).to(device)

mixing_network = MixingNetwork(
    num_agents=env.max_num_agents,
    state_dim=np.prod(env.state().shape)
).to(device)  # NEW: Instantiate the Mixing Network

# Ładowanie starych wag
# qf1_target.load_state_dict(qf1.state_dict())
# qf2_target.load_state_dict(qf2.state_dict())
q_optimizer = optim.Adam(list(qf1.parameters()) + list(qf2.parameters()), lr=args.q_lr)
actor_optimizers = {
    agent_id: optim.Adam(actors[agent_id].parameters(), lr=args.policy_lr)
    for agent_id in env.possible_agents
}

mixing_optimizer = optim.Adam(mixing_network.parameters(), lr=args.q_lr)  # NEW: Optimizer for Mixing Network

# Automatic entropy tuning
if args.autotune:
    target_entropy = -torch.prod(torch.Tensor(single_action_space.shape).to(device)).item()
    log_alpha = torch.zeros(1, requires_grad=True, device=device)
    alpha = log_alpha.exp().item()
    a_optimizer = optim.Adam([log_alpha], lr=args.q_lr)
else:
    alpha = args.alpha

single_observation_space.dtype = np.float32

local_obs_shapes = {
    agent_id: env.observation_space(agent_id).shape
    for agent_id in env.possible_agents
}

for agent_id in env.possible_agents:
    print(f"Agent: {agent_id}, Action space shape: {env.action_space(agent_id).shape}")
print(f"Total action dimension: {action_dim}")


rb = MAReplayBuffer(
    global_obs_shape=env.state().shape,
    local_obs_shapes=local_obs_shapes,
    action_dim=sum([np.prod(env.action_space(agent).shape) for agent in env.possible_agents]),
    num_agents=len(env.possible_agents),
)

start_time = time.time()

# TRY NOT TO MODIFY: start the game
obs, info = env.reset(seed=args.seed)
global_return = 0.0
global_obs: np.ndarray = env.state()
for global_step in range(args.total_timesteps + 1):
    # ALGO LOGIC: put action logic here
    if global_step < args.learning_starts:
        actions: Dict[str, np.ndarray] = {agent: env.action_space(agent).sample() for agent in env.possible_agents}
    else:
        actions: Dict[str, np.ndarray] = {}
        with torch.no_grad():
            for agent_id in env.possible_agents:
                obs_with_id = concat_id(obs[agent_id], agent_id)  # Concatenate observation and agent ID
                obs_tensor = torch.Tensor(obs_with_id).to(device).unsqueeze(0)
                act, _, _ = actors[agent_id].get_action(obs_tensor)
                actions[agent_id] = act.cpu().numpy().flatten()



    # TRY NOT TO MODIFY: execute the game and log data.
    next_obs: Dict[str, ObsType]
    rewards: Dict[str, float]
    next_obs, rewards, terminateds, truncateds, infos = env.step(actions)

    terminated: bool = any(terminateds.values())
    truncated: bool = any(truncateds.values())

    # TRY NOT TO MODIFY: save data to replay buffer; handle `final_observation`
    real_next_obs = next_obs
    # TODO PZ doesn't have that yet
    # if truncated:
    #     real_next_obs = infos["final_observation"].copy()
    q_values = []
    log_probs = []
    # all_actions = np.concatenate([actions[agent] for agent in env.possible_agents], axis=-1)
    # all_actions_tensor = torch.tensor(all_actions, dtype=torch.float32).unsqueeze(0).to(device)

    all_actions = np.concatenate([actions[agent] for agent in env.possible_agents], axis=-1)
    all_actions_tensor = torch.tensor(all_actions, dtype=torch.float32).unsqueeze(0).to(device)
    q_value = torch.min(
        qf1(torch.Tensor(global_obs).to(device).unsqueeze(0), all_actions_tensor),
        qf2(torch.Tensor(global_obs).to(device).unsqueeze(0), all_actions_tensor),
    )



    with torch.no_grad():
        for agent_id in env.possible_agents:
            obs_with_id = torch.Tensor(concat_id(obs[agent_id], agent_id)).to(device)
            act, log_prob, _ = actors[agent_id].get_action(obs_with_id.unsqueeze(0))  # Access the specific actor instance
            q_value = torch.min(
                qf1(
                    torch.Tensor(global_obs).to(device).unsqueeze(0), 
                    all_actions_tensor  
                ),
                qf2(
                    torch.Tensor(global_obs).to(device).unsqueeze(0),
                    all_actions_tensor
                )
            )

            # Min(Q1, Q2)
            q_values.append(q_value.item())
            log_probs.append(log_prob.item())


    rb.add(
        global_obs=global_obs,
        local_obs=obs,
        joint_actions=np.array(list(actions.values())).flatten(),
        reward=np.array(list(rewards.values())).sum(),
        next_global_obs=env.state(),
        next_local_obs=real_next_obs,
        terminated=terminated,
        q_values=np.array(q_values),  # Dodano q_values
        log_probs=np.array(log_probs),  # Dodano log_probs
    )

    # TRY NOT TO MODIFY: CRUCIAL step easy to overlook
    obs = next_obs
    global_return += sum(rewards.values())
    global_obs = env.state()


    # ALGO LOGIC: training.
    if global_step > args.learning_starts:
        data: Experience = rb.sample(args.batch_size, to_tensor=True, device=device, add_id_to_local_obs=True)
        with torch.no_grad():
            # Determine the maximum shape across all agent observations
            max_shape = np.array([obs.shape for obs in data.next_local_obs.values()]).max(axis=0)

            # Pad all observations to match the maximum shape
            padded_next_local_obs = []
            for obs in data.next_local_obs.values():
                # Calculate padding for each dimension
                padding = [(0, max_size - obs.size(dim)) for dim, max_size in enumerate(max_shape)]
                # Flatten padding list
                flat_padding = [p for pair in reversed(padding) for p in pair]  # Reverse for torch padding order
                padded_obs = torch.nn.functional.pad(torch.Tensor(obs).to(device), flat_padding)
                padded_next_local_obs.append(padded_obs)

            # Concatenate padded observations
            flattened_next_local_obs = torch.cat(padded_next_local_obs, dim=0)

            # Reshape for batch processing
            flattened_next_local_obs = flattened_next_local_obs.reshape(
                (args.batch_size * len(env.possible_agents), -1)  # Adjust shape dynamically
            )

            # Forward pass to get next actions and log probabilities
            next_state_actions, next_state_log_pi, _ = actors.get_action(flattened_next_local_obs)
            next_joint_actions = next_state_actions.reshape(
                (args.batch_size, np.prod(single_action_space.shape) * len(env.possible_agents))
            )

            # Sums the log probabilities of the actions in the agent dimension to get the joint log probability
            next_state_log_pi = einops.reduce(
                next_state_log_pi.reshape((args.batch_size, len(env.possible_agents))), "b a -> b ()", "sum"
            )





            # SAC Bellman equation
            qf1_next_target = qf1_target(data.next_global_obs, next_joint_actions)
            qf2_next_target = qf2_target(data.next_global_obs, next_joint_actions)
            min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) - alpha * next_state_log_pi
            next_q_value = data.rewards.flatten() + (1 - data.terminateds.flatten()) * args.gamma * (
                min_qf_next_target
            ).view(-1)

        # Computes q loss
        qf1_a_values = qf1(data.global_obs, data.joint_actions).view(-1)
        qf2_a_values = qf2(data.global_obs, data.joint_actions).view(-1)
        qf1_loss = F.mse_loss(qf1_a_values, next_q_value)
        qf2_loss = F.mse_loss(qf2_a_values, next_q_value)
        
        qf_loss = qf1_loss + qf2_loss

        # NEW: Dodanie Mixing Network
        q_tot = mixing_network(q_values=data.q_values, state=data.global_obs)
        mixing_loss_term = F.mse_loss(q_tot, next_q_value.unsqueeze(1))
        qf_loss = qf1_loss + qf2_loss + mixing_loss_term

        q_optimizer.zero_grad()
        qf_loss.backward()
        q_optimizer.step()

        # NEW: Compute Q_tot using Mixing Network
        q_tot = mixing_network(q_values=data.q_values, state=data.global_obs)
        mixing_loss = F.mse_loss(q_tot, next_q_value.unsqueeze(1))  # Match Q_tot with target Q

        mixing_optimizer.zero_grad()
        mixing_loss.backward()
        mixing_optimizer.step()

        if global_step % args.policy_frequency == 0:  # TD 3 Delayed update support
            for _ in range(
                args.policy_frequency
            ):  # compensate for the delay by doing 'actor_update_interval' instead of 1
                # flatten data.local_obs to forward for all agents at once
                flattened_local_obs = data.local_obs.reshape(
                    (args.batch_size * env.unwrapped.max_num_agents, np.prod(single_observation_space.shape) + 1)
                )
                # forward pass to get next actions and log probs
                pi, log_pi, _ = actors.get_action(flattened_local_obs)
                next_joint_actions = pi.reshape(
                    (args.batch_size, np.prod(single_action_space.shape) * env.unwrapped.max_num_agents)
                )
                # Sums the log probs of the actions in the agent dimension to get the joint log prob
                log_pi = einops.reduce(
                    log_pi.reshape((args.batch_size, env.unwrapped.max_num_agents)), "b a -> b ()", "sum"
                )

                # SAC pi update
                qf1_pi = qf1(data.global_obs, next_joint_actions)
                qf2_pi = qf2(data.global_obs, next_joint_actions)
                min_qf_pi = torch.min(qf1_pi, qf2_pi).view(-1)
                actor_loss = ((alpha * log_pi) - min_qf_pi).mean()

                actor_optimizer.zero_grad()
                actor_loss.backward()
                actor_optimizer.step()

                if args.autotune:
                    with torch.no_grad():
                        # Oblicz log_probs dla bieżących akcji
                        _, log_pi, _ = actors.get_action(flattened_local_obs)
                        log_pi = einops.reduce(
                            log_pi.reshape((args.batch_size, env.unwrapped.max_num_agents)), "b a -> b ()", "sum"
                        )

                        # NEW: Pobierz wartość Q_tot z Mixing Network
                        q_tot = mixing_network(q_values=data.q_values, state=data.global_obs)

                    # NEW: Modyfikacja straty entropii z użyciem Q_tot
                    alpha_loss = (-log_alpha * (log_pi + target_entropy) + q_tot.mean()).mean()

                    a_optimizer.zero_grad()
                    alpha_loss.backward()
                    a_optimizer.step()
                    alpha = log_alpha.exp().item()


        # update the target networks
        if global_step % args.target_network_frequency == 0:
            for param, target_param in zip(qf1.parameters(), qf1_target.parameters()):
                target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data)
            for param, target_param in zip(qf2.parameters(), qf2_target.parameters()):
                target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data)

        if global_step % 100 == 0:
            writer.add_scalar("losses/qf1_values", qf1_a_values.mean().item(), global_step)
            writer.add_scalar("losses/qf2_values", qf2_a_values.mean().item(), global_step)
            writer.add_scalar("losses/qf1_loss", qf1_loss.item(), global_step)
            writer.add_scalar("losses/qf2_loss", qf2_loss.item(), global_step)
            writer.add_scalar("losses/qf_loss", qf_loss.item() / 2.0, global_step)
            writer.add_scalar("losses/actor_loss", actor_loss.item(), global_step)
            writer.add_scalar("losses/alpha", alpha, global_step)
            writer.add_scalar("losses/mixing_loss", mixing_loss.item(), global_step)  # NEW: Log mixing loss
            writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
            if args.autotune:
                writer.add_scalar("losses/alpha_loss", alpha_loss.item(), global_step)

    if global_step % args.save_frequency == 0:
        for agent_id, actor in actors.items():
            torch.save(actor.state_dict(), f"{args.actor_path}/actor_{agent_id}_latest.pth")

        torch.save(qf1.state_dict(), f"{args.actor_path}/qf1_latest.pth")
        torch.save(qf2.state_dict(), f"{args.actor_path}/qf2_latest.pth")
        if args.autotune:
            torch.save(log_alpha, f"{args.actor_path}/log_alpha_latest.pth")
        print(f"Saved latest models at step {global_step}")

       
    if terminated or truncated:
        obs, info = env.reset()
        writer.add_scalar("charts/return", global_return, global_step)
        global_return = 0.0
        global_obs = env.state()

save_point = f"{args.actor_path}{global_step}.pth"

# Saves the trained actor for execution
torch.save(actors.state_dict(), save_point)
print(f"Training finished. Model saved at {save_point}")
env.close()
writer.close()



In [None]:
# %%
# Inicjalizacja wizualizatora agenta
print(f"Loading model from: {save_point}")
render = True
# env = simple_v3.parallel_env(render_mode="rgb_array" if render else None,  continuous_actions=True)
# env = simple_spread_v3.parallel_env(render_mode="rgb_array" if render else None, N=2, local_ratio=0.5, max_cycles=20, continuous_actions=True)
env = simple_adversary_v3.parallel_env(render_mode="rgb_array" if render else None, N=2, max_cycles=20, continuous_actions=True)
env.reset(seed=args.seed)  # Ensure the environment is reset
visualizer = AgentVisualizer(actor_path=save_point, device='cuda')  # lub 'cpu' jeśli nie masz GPU

# %%
# Sprawdzenie dostępnych trybów renderowania
print("Available render modes:", visualizer.env.metadata.get('render_modes', 'No render modes available'))

num_episodess=20

# %%
# Uruchomienie epizodów z wytrenowanym agentem
print("\n--- Uruchomienie epizodów z wytrenowanym agentem ---")
visualizer.run_episodes(
    num_episodes=num_episodess, 
    agent_type='trained', 
    render=render, 
    save_gif=True, 
    gif_path=f"{args.actor_path}{global_step}_Trained.gif"
)

# %%
# Uruchomienie epizodów z agentem losowym
print("\n--- Uruchomienie epizodów z agentem losowym ---")
visualizer.run_episodes(
    num_episodes=num_episodess, 
    agent_type='random', 
    render=render, 
    save_gif=True, 
    gif_path=f"{args.actor_path}{global_step}_Random.gif"
)

# %%
# Wyświetlenie wykresu zwrotów z obu agentów
print("\n--- Wykres porównujący zwroty wytrenowanego agenta z agentem losowym ---")
visualizer.plot_returns()

