In [2]:
%pip install stable-baselines3 numpy torch supersuit pettingzoo pymunk scipy gymnasium matplotlib einops tensorboard wandb imageio 
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import random
import time
from collections import namedtuple, deque
from typing import Dict, Any

import gymnasium as gym
from pettingzoo.mpe import simple_reference_v3
from torch.utils.tensorboard import SummaryWriter

Note: you may need to restart the kernel to use updated packages.


In [5]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Normal

from pettingzoo.mpe import simple_spread_v3
# from pettingzoo.sisl import multiwalker_v9
# from pettingzoo.butterfly import pistonball_v6
# =====================================================================
# 1. PARAMETRY HIPER
# =====================================================================
TIMESTAMPS = 1000
SEED = 82
LEARNING_STARTS = TIMESTAMPS // 10
LR_ACTOR = 1e-3
LR_CRITIC = 1e-3
LR_MIXER = 1e-3
GAMMA = 0.99
BATCH_SIZE = 256
ALPHA = 0.2            # początkowe alpha (entropia w SAC) - może być uczone
TAU = 0.005            # do aktualizacji wag "soft update"
REPLAY_SIZE = 100000
MAX_EPISODES = 1000
MAX_STEPS = 50         # max liczba kroków na epizod
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# =====================================================================
# 2. SIEC AKTORA (SAC)
# =====================================================================
class Actor(nn.Module):
    def __init__(self, obs_dim, act_dim, hidden_dim=128):
        super(Actor, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        # Mu i log_sigma dla ciągłej polityki
        self.mu_layer = nn.Linear(hidden_dim, act_dim)
        self.log_sigma_layer = nn.Linear(hidden_dim, act_dim)

    def forward(self, obs):
        x = self.net(obs)
        mu = self.mu_layer(x)
        log_sigma = self.log_sigma_layer(x)
        log_sigma = torch.clamp(log_sigma, min=-20, max=2)  # ograniczenie zakresu
        return mu, log_sigma

    def sample(self, obs):
        """Zwraca akcję i log_prob dla SAC"""
        mu, log_sigma = self.forward(obs)
        sigma = log_sigma.exp()
        dist = Normal(mu, sigma)
        z = dist.rsample()  # reparametrization trick
        action = torch.tanh(z)  # ograniczamy akcję do (-1,1)
        log_prob = dist.log_prob(z) - torch.log(1 - action.pow(2) + 1e-7)
        return action, log_prob.sum(dim=-1, keepdim=True)


# =====================================================================
# 3. SIECI KRYTYKA (SAC) - TWIN Q
# =====================================================================
class QCritic(nn.Module):
    def __init__(self, obs_dim, act_dim, hidden_dim=128):
        super(QCritic, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_dim + act_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, obs, act):
        obs = obs.view(obs.size(0), -1)  # Flatten the observations
        x = torch.cat([obs, act], dim=-1)
        return self.net(x)


# =====================================================================
# 4. QMIX - Mixing Network
# =====================================================================
class MixingNetwork(nn.Module):
    """
    Łączy lokalne Q_a (dla każdego agenta) w Q_tot w sposób monotoniczny.
    Tutaj używamy prostej sieci i hiper-sieci generującej wagi i bias.
    """
    def __init__(self, n_agents, state_dim, hidden_dim=32):
        super(MixingNetwork, self).__init__()
        self.n_agents = n_agents
        self.state_dim = state_dim
        self.hidden_dim = hidden_dim

        # Hiper-sieci generujące wagi i biasy
        self.hyper_w_1 = nn.Sequential(
            nn.Linear(state_dim, hidden_dim * n_agents),
            nn.ReLU(),
            nn.Linear(hidden_dim * n_agents, n_agents * hidden_dim)
        )
        self.hyper_b_1 = nn.Linear(state_dim, hidden_dim)

        self.hyper_w_2 = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        self.hyper_b_2 = nn.Linear(state_dim, 1)

    def forward(self, q_values, state):
        """
        q_values: (batch_size, n_agents)
        state: (batch_size, state_dim)
        """
        bs = q_values.size(0)

        # layer 1
        w1 = self.hyper_w_1(state)  # (bs, n_agents * hidden_dim)
        b1 = self.hyper_b_1(state)  # (bs, hidden_dim)
        w1 = w1.view(bs, self.n_agents, self.hidden_dim)

        # (bs, n_agents) -> (bs, n_agents, 1)
        q_values = q_values.unsqueeze(-1)

        # Mnożenie i dodawanie bias (monotoniczność zależy m.in. od konstrukcji w1 >= 0)
        # Aby zagwarantować monotoniczność, można np. wymusić ReLU na w1:
        w1 = torch.abs(w1)  # prosty hack, by zachować >= 0
        
        hidden = torch.bmm(q_values.transpose(1,2), w1).squeeze(1) + b1  # (bs, hidden_dim)

        # layer 2
        w2 = self.hyper_w_2(state)  # (bs, hidden_dim)
        w2 = torch.abs(w2)          # wymuszenie >= 0
        b2 = self.hyper_b_2(state)  # (bs, 1)

        # (bs, hidden_dim) * (bs, hidden_dim) -> (bs, 1)
        q_tot = torch.sum(hidden * w2, dim=1, keepdim=True) + b2
        return q_tot


# =====================================================================
# 5. BUFOR REPLAY
# =====================================================================
class MultiAgentReplayBuffer:
    def __init__(self, max_size=REPLAY_SIZE, n_agents=1):
        self.max_size = max_size
        self.n_agents = n_agents
        self.ptr = 0
        self.size = 0
        self.states = []
        self.obs = []
        self.actions = []
        self.rewards = []
        self.next_states = []
        self.next_obs = []
        self.dones = []

    def add(self, state, obs, action, reward, next_state, next_obs, done):
        # Konwersja danych do formatu NumPy
        state = np.array(state)
        obs = [np.array(o) for o in obs]
        action = np.array(action)
        reward = np.array(reward)
        next_state = np.array(next_state)
        next_obs = [np.array(o) for o in next_obs]
        done = np.array(done)

        if self.size < self.max_size:
            self.states.append(state)
            self.obs.append(obs)
            self.actions.append(action)
            self.rewards.append(reward)
            self.next_states.append(next_state)
            self.next_obs.append(next_obs)
            self.dones.append(done)
            self.size += 1
        else:
            idx = self.ptr
            self.states[idx] = state
            self.obs[idx] = obs
            self.actions[idx] = action
            self.rewards[idx] = reward
            self.next_states[idx] = next_state
            self.next_obs[idx] = next_obs
            self.dones[idx] = done
        self.ptr = (self.ptr + 1) % self.max_size

    def sample(self, batch_size=BATCH_SIZE):
        idxs = np.random.randint(0, self.size, size=batch_size)
        
        # Konwersja i spłaszczanie danych w momencie próbkowania
        batch_state = torch.FloatTensor(np.array([self.states[i] for i in idxs])).to(DEVICE)
        batch_obs = torch.FloatTensor(
            np.array([np.stack(self.obs[i]).reshape(self.n_agents, -1) for i in idxs])
        ).to(DEVICE)

        batch_action = torch.FloatTensor(np.array([self.actions[i] for i in idxs])).to(DEVICE)
        batch_reward = torch.FloatTensor(np.array([self.rewards[i] for i in idxs])).to(DEVICE)
        batch_next_state = torch.FloatTensor(np.array([self.next_states[i] for i in idxs])).to(DEVICE)
        batch_next_obs = torch.FloatTensor(np.array([np.concatenate(self.next_obs[i]) for i in idxs])).to(DEVICE)
        batch_next_obs = torch.FloatTensor(
            np.array([np.stack(self.next_obs[i]).reshape(self.n_agents, -1) for i in idxs])
        ).to(DEVICE)
        batch_done = torch.FloatTensor(np.array([self.dones[i] for i in idxs])).to(DEVICE)
        
        return batch_state, batch_obs, batch_action, batch_reward, batch_next_state, batch_next_obs, batch_done



# =====================================================================
# 6. FUNKCJE POMOCNICZE (soft update)
# =====================================================================
def soft_update(net, target_net, tau=TAU):
    for param, target_param in zip(net.parameters(), target_net.parameters()):
        target_param.data.copy_(
            tau * param.data + (1 - tau) * target_param.data
        )


# =====================================================================
# 7. TRENER
# =====================================================================
class QMIX_SAC_Agents:
    def __init__(self, env):
        self.env = env
        self.agents = env.agents
        self.n_agents = len(self.agents)
        
        # Zakładamy, że obs_dim i act_dim takie same dla wszystkich agentów
        # example_obs, _ = env.reset()[self.agents[0]]
        obs_dict, _ = env.reset(seed=SEED)
        example_obs = obs_dict[self.agents[0]]
        print(f"Original observation shape: {example_obs.shape}")
        self.obs_dim = np.prod(example_obs.shape)  # Flattened observation dimension
        print(f"Calculated obs_dim: {self.obs_dim}")

        # obs_dim = np.prod(self.env.observation_space[self.agents[0]].shape)
        # self.obs_dim = len(example_obs)
        # self.obs_dim = np.prod(self.env.observation_space(self.agents[0]).shape)  # Flattened observation
        self.act_dim = self.env.action_space(self.agents[0]).shape[0]            # Action dimension

        # Dla ciągłych akcji: act_dim (np. 2D w MPE?)
        # Dla dyskretnych - trzeba zmodyfikować!
        # self.act_dim = env.action_space[self.agents[0]].shape[0]
        # self.act_dim = 1  # np. 1, 2, 5, itp., w zależności od środowiska
        
        # Tworzymy Actor i TwinQCritic dla każdego agenta
        self.actors = []
        self.critic1 = []
        self.critic2 = []
        self.target_critic1 = []
        self.target_critic2 = []

        self.actor_opt = []
        self.critic1_opt = []
        self.critic2_opt = []

        for _ in range(self.n_agents):
            actor = Actor(self.obs_dim, self.act_dim).to(DEVICE)
            critic1 = QCritic(self.obs_dim, self.act_dim).to(DEVICE)
            critic2 = QCritic(self.obs_dim, self.act_dim).to(DEVICE)
            t_critic1 = QCritic(self.obs_dim, self.act_dim).to(DEVICE)
            t_critic2 = QCritic(self.obs_dim, self.act_dim).to(DEVICE)
            t_critic1.load_state_dict(critic1.state_dict())
            t_critic2.load_state_dict(critic2.state_dict())

            self.actors.append(actor)
            self.critic1.append(critic1)
            self.critic2.append(critic2)
            self.target_critic1.append(t_critic1)
            self.target_critic2.append(t_critic2)

            self.actor_opt.append(optim.Adam(actor.parameters(), lr=LR_ACTOR))
            self.critic1_opt.append(optim.Adam(critic1.parameters(), lr=LR_CRITIC))
            self.critic2_opt.append(optim.Adam(critic2.parameters(), lr=LR_CRITIC))

        # Mixing Network
        # Globalny stan ma dimension: np. sumaryczna obserwacja (zależy od środowiska)
        # W "simple_spread" np. stan globalny to złożenie pozycji agentów, landmarków itp.
        # Tu dla uproszczenia załóżmy, że bierzemy po prostu concatenation lokalnych obserwacji (n_agents * obs_dim).
        self.global_state_dim = self.n_agents * self.obs_dim
        self.mixer = MixingNetwork(self.n_agents, self.global_state_dim).to(DEVICE)
        self.mixer_opt = optim.Adam(self.mixer.parameters(), lr=LR_MIXER)

        # Stała entropii w SAC (może być uczona, tu uproszczenie)
        self.alpha = ALPHA

        # self.replay_buffer = MultiAgentReplayBuffer(REPLAY_SIZE)
        self.replay_buffer = MultiAgentReplayBuffer(max_size=REPLAY_SIZE, n_agents=self.n_agents)
        self.target_mixer = MixingNetwork(self.n_agents, self.global_state_dim).to(DEVICE)
        self.target_mixer.load_state_dict(self.mixer.state_dict())
    

    def select_actions(self, obs_n, evaluate=False):
        """
        obs_n: lista/tuple obserwacji (po jednej dla każdego agenta),
               lub dict {agent_name: obs}
        Zwraca listę akcji (torch niekoniecznie, bo tu step w env).
        """
        actions = []
        for i, obs in enumerate(obs_n):
            obs_t = torch.FloatTensor(obs.flatten()).unsqueeze(0).to(DEVICE)
            # obs_t = torch.FloatTensor(obs).unsqueeze(0).to(DEVICE)
            with torch.no_grad():
                if not evaluate:
                    action, _ = self.actors[i].sample(obs_t)
                else:
                    # deterministycznie (mu zamiast sample)
                    mu, _ = self.actors[i](obs_t)
                    # action = torch.sigmoid(mu) # dla akcji w zakresie (0,1)
                    action = torch.tanh(mu)  # dla akcji w zakresie (-1,1)
                    
            # action = action.clip(-1, 1).cpu().numpy().flatten()
            # actions.append(action)
            actions.append(action.cpu().numpy().flatten())
        return actions

    def update(self):
        # Sprawdzamy czy mamy wystarczająco próbek
        if self.replay_buffer.size < BATCH_SIZE:
            return

        (batch_state,  # shape: (B, n_agents * obs_dim)
         batch_obs,    # shape: (B, n_agents, obs_dim)
         batch_action, # shape: (B, n_agents, act_dim)
         batch_reward, # shape: (B, n_agents)
         batch_next_state, 
         batch_next_obs,
         batch_done) = self.replay_buffer.sample(BATCH_SIZE)
        # --- DODAJEMY PRINTY DO DIAGNOZY ---
        # print("batch_state shape:", batch_state.shape)       # powinno być (B, n_agents*obs_dim)
        # print("batch_obs shape:", batch_obs.shape)           # (B, n_agents, obs_dim)
        # print("batch_action shape:", batch_action.shape)     # (B, n_agents, act_dim)
        # print("batch_reward shape:", batch_reward.shape)     # (B, n_agents)
        # print("batch_done shape:", batch_done.shape)         # (B, n_agents)

        # =========================
        # 1. Oblicz lokalne Q_i
        # =========================
        # batch_obs: (B, n_agents, obs_dim)
        # musimy pętlować agentów i skleić
        q_values1 = []
        q_values2 = []
        next_q_values1 = []
        next_q_values2 = []

        all_next_actions = []
        all_next_log_probs = []

        for i in range(self.n_agents):
            obs_i = batch_obs[:, i, :].reshape(BATCH_SIZE, -1)  # Spłaszczenie obserwacji
            obs_i = batch_obs[:, i, :].view(BATCH_SIZE, -1)  
            act_i = batch_action[:, i, :].view(BATCH_SIZE, -1)  # Rozszerzenie wymiaru akcji

            # obs_i = batch_obs[:, i, :]
            # act_i = batch_action[:, i, :]
            # Obliczamy Q_i(obs_i, act_i)
            q1_i = self.critic1[i](obs_i, act_i)
            q2_i = self.critic2[i](obs_i, act_i)
            q_values1.append(q1_i)
            q_values2.append(q2_i)

            # Dla targetów
            with torch.no_grad():
                next_obs_i = batch_next_obs[:, i, :]
                next_action_i, next_log_prob_i = self.actors[i].sample(next_obs_i)
                nq1_i = self.target_critic1[i](next_obs_i, next_action_i)
                nq2_i = self.target_critic2[i](next_obs_i, next_action_i)
                next_q_values1.append(nq1_i)
                next_q_values2.append(nq2_i)
                all_next_actions.append(next_action_i)
                all_next_log_probs.append(next_log_prob_i)

        # stack i uzyskujemy shape (B, n_agents)
        q_values1 = torch.cat(q_values1, dim=1)  # (B, n_agents)
        q_values2 = torch.cat(q_values2, dim=1)  # (B, n_agents)

        next_q_values1 = torch.cat(next_q_values1, dim=1)  # (B, n_agents)
        next_q_values2 = torch.cat(next_q_values2, dim=1)  # (B, n_agents)

        # =========================
        # 2. QMIX: Q_tot i target Q_tot
        # =========================
        with torch.no_grad():
            # Mix values from both critics using target mixer
            next_Q_tot1 = self.target_mixer(next_q_values1, batch_next_state)
            next_Q_tot2 = self.target_mixer(next_q_values2, batch_next_state)
            next_Q_tot = torch.min(next_Q_tot1, next_Q_tot2)
            
            # Compute entropy term for all agents
            next_log_probs = torch.cat(all_next_log_probs, dim=1)
            entropy_term = self.alpha * torch.sum(next_log_probs, dim=1, keepdim=True)
            
            # Apply entropy after mixing
            next_Q_tot_sac = next_Q_tot - entropy_term
            
            # Compute target with rewards and discounting
            sum_rewards = torch.sum(batch_reward, dim=1, keepdim=True)
            done_mask = torch.mean(batch_done, dim=1, keepdim=True)
            target_Q_tot = sum_rewards + GAMMA * (1 - done_mask) * next_Q_tot_sac

        # Obliczamy Q_tot z obecnych Q_values
        # min_q_current = torch.min(q_values1, q_values2)
        # Q_tot_current = self.mixer(min_q_current, batch_state)  # (B,1)


        Q_tot1 = self.mixer(q_values1, batch_state)
        Q_tot2 = self.mixer(q_values2, batch_state)
        Q_tot_current = torch.min(Q_tot1, Q_tot2)

        # =================================================================
        # 2.a. Loss mixer (krok uczący QMIX i krytyków)
        # =================================================================
        # MSE pomiędzy Q_tot a targetem
        td_error = (Q_tot_current - target_Q_tot.detach())
        mixer_loss = (td_error ** 2).mean()

        # Optymalizacja mixer i krytyków (łączymy, by propagować gradient)
        self.mixer_opt.zero_grad()
        for i in range(self.n_agents):
            self.critic1_opt[i].zero_grad()
            self.critic2_opt[i].zero_grad()

        mixer_loss.backward()
        self.mixer_opt.step()
        for i in range(self.n_agents):
            self.critic1_opt[i].step()
            self.critic2_opt[i].step()

        # =================================================================
        # 2.b. Loss dla aktorów (SAC-style, ale z Q_tot w pętli)
        # =================================================================
        # Polityka każdego agenta stara się maksymalizować Q_tot
        # Trzeba ponownie wygenerować akcje i obliczyć Q_tot
        actor_losses = []
        for i in range(self.n_agents):
            obs_i = batch_obs[:, i, :]
            action_i, log_prob_i = self.actors[i].sample(obs_i)
            # Oblicz Q_i dla akcji próbnych
            q1_i = self.critic1[i](obs_i, action_i)
            q2_i = self.critic2[i](obs_i, action_i)
            min_q_i = torch.min(q1_i, q2_i)

            # Musimy złożyć min_q_i dla wszystkich agentów (zastąpić tylko i-tego,
            # a pozostałe zostawić tak jak w batch_action).
            # Sposób uproszczony:
            all_Q = []
            for j in range(self.n_agents):
                if j == i:
                    all_Q.append(min_q_i)
                else:
                    # Q_j(obs_j, action_j z batch_action)
                    obs_j = batch_obs[:, j, :]
                    act_j = batch_action[:, j, :]
                    q1_j = self.critic1[j](obs_j, act_j)
                    q2_j = self.critic2[j](obs_j, act_j)
                    min_q_j = torch.min(q1_j, q2_j)
                    all_Q.append(min_q_j)
            # (B, n_agents)
            all_Q = torch.cat(all_Q, dim=1)
            # Entropia: log_prob_i -> shape (B,1), 
            # Możemy też sumować log_prob innych agentów, w zależności od projektu.
            # Dla uproszczenia – w SAC standardowo odejmujemy alpha * log_prob.
            # W wieloagentowym można by użyć również Q_mix -> zobaczyć jak log_prob i
            # wpływa na Q_tot. Wymaga to nieco zagnieżdżonych obliczeń.
            # Tutaj demonstrujemy prosty wariant: 
            # Q'_tot = mixer(all_Q - alpha * log_prob_i, state)

            # Tworzymy tensor all_Q_sac, gdzie i-ta kolumna = min_q_i - alpha * log_prob_i,
            # reszta kolumn = min_q_j (bez zmiany).
            alpha_term = torch.zeros_like(all_Q)
            alpha_term[:, i] = self.alpha * log_prob_i.squeeze(-1)
            all_Q_sac = all_Q - alpha_term

            Q_tot_actor = self.mixer(all_Q_sac, batch_state)
            # Maksymalizujemy Q_tot, czyli minimalizujemy -Q_tot
            actor_loss = (-Q_tot_actor).mean()

            self.actor_opt[i].zero_grad()
            actor_loss.backward()
            self.actor_opt[i].step()

            actor_losses.append(actor_loss.item())

        # =================================================================
        # 2.c. Soft update targetów
        # =================================================================
        for i in range(self.n_agents):
            soft_update(self.critic1[i], self.target_critic1[i], TAU)
            soft_update(self.critic2[i], self.target_critic2[i], TAU)
            soft_update(self.mixer, self.target_mixer, TAU)


    def train(self, total_episodes=MAX_EPISODES, max_steps=MAX_STEPS):
        ep_rewards = []
        episode = 0
        for ep in range(total_episodes):
            obs_dict, info_dict = env.reset(seed=SEED)
            done_dict = {agent: False for agent in self.agents}
            step = 0
            ep_reward = np.zeros(self.n_agents, dtype=np.float32)

            while not all(done_dict.values()) and step < max_steps:
                # Przygotowanie list (obs, actions)
                current_obs = []
                for agent_i in self.agents:
                    current_obs.append(obs_dict[agent_i])

                # Wybór akcji
                actions = self.select_actions(current_obs, evaluate=False if episode < LEARNING_STARTS else True)

                # Stworzenie słownika akcji do przekazania do env
                action_dict = {}
                for i, agent_i in enumerate(self.agents):
                    action_dict[agent_i] = actions[i]

                next_obs_dict, reward_dict, done_dict, _, _= self.env.step(action_dict)
                # observations, rewards, terminations, truncations, infos

                # Zapis do bufora
                # global_state = prosta konkatenacja, w realnym projekcie: pełen stan
                # global_state = np.concatenate(current_obs, axis=0)  # (n_agents*obs_dim,)
                global_state = np.concatenate(
                    [obs_dict[a].flatten() for a in self.agents], axis=0
                )
                next_global_state = np.concatenate(
                    [next_obs_dict[a].flatten() for a in self.agents], axis=0
                )

                # Konwertujemy do list
                reward_list = np.array([reward_dict[a] for a in self.agents])
                done_list = np.array([done_dict[a] for a in self.agents], dtype=np.float32)

                self.replay_buffer.add(
                    global_state,
                    current_obs,
                    actions,
                    reward_list,
                    next_global_state,
                    [next_obs_dict[a] for a in self.agents],
                    done_list
                )

                # Sumaryczna nagroda w epizodzie
                ep_reward += reward_list

                # update obserwacji
                obs_dict = next_obs_dict
                step += 1

                # Krok treningowy (off-policy)
                self.update()
            
            # Koniec epizodu
            # elavuate
            episode += 1

            ep_rewards.append(np.sum(ep_reward))
            if (ep+1) % 10 == 0:
                print(f"Epizod: {ep+1}, reward: {np.mean(ep_rewards[-10:])}")

        return ep_rewards

# =====================================================================
# 8. URUCHOMIENIE
# =====================================================================
if __name__ == "__main__":
    # Dzięki temu środowisko ma 3 agentów, każdy z dwuwymiarowymi akcjami ciągłymi w przedziale [−1,1]
    env = simple_spread_v3.parallel_env(render_mode=None, N=2, local_ratio=0.5, max_cycles=40, continuous_actions=True)
    env.reset(seed=SEED)
    trainer = QMIX_SAC_Agents(env)
    rewards = trainer.train(total_episodes=TIMESTAMPS, max_steps=MAX_STEPS)


Original observation shape: (12,)
Calculated obs_dim: 12


IndexError: tuple index out of range

In [None]:
# =====================================================================
# 9. TESTOWANIE WYUCZONEGO MODELU
# =====================================================================
def evaluate_trainer(trainer, env, episodes=10, max_steps=MAX_STEPS, render=False):
    """
    Funkcja do testowania wytrenowanego modelu (trainer) w środowisku wieloagentowym.
    
    Arguments:
    trainer -- obiekt wytrenowanego modelu (QMIX_SAC_Agents)
    env -- środowisko wieloagentowe
    episodes -- liczba epizodów do przetestowania
    max_steps -- maksymalna liczba kroków w każdym epizodzie
    render -- czy wyświetlać środowisko (jeśli obsługiwane)

    Returns:
    ep_rewards -- lista sumarycznych nagród dla każdego epizodu
    """
    ep_rewards = []

    for ep in range(episodes):
        obs_dict, info_dict = trainer.env.reset(seed=SEED)
        print(f"Epizod {ep+1}: trainer.env.agents = {trainer.env.agents}")
        if not trainer.env.agents:
            raise ValueError("Lista agentów `trainer.env.agents` jest pusta po resecie środowiska.")
        
        done_dict = {agent: False for agent in trainer.env.agents}
        step = 0

        # Inicjalizacja epizodycznej nagrody
        num_agents = len(trainer.env.agents)
        ep_reward = np.zeros(num_agents, dtype=np.float32)

        while not all(done_dict.values()) and step < max_steps:
            if render:
                trainer.env.render()

            # Sprawdzamy, czy są jeszcze agenci w środowisku
            if len(trainer.env.agents) == 0:
                print("Brak agentów w środowisku. Kończymy epizod.")
                break

            # Pobranie aktualnych obserwacji
            current_obs = [obs_dict[agent] for agent in trainer.env.agents]

            # Wybór akcji przez wytrenowanego trenera
            actions = trainer.select_actions(current_obs, evaluate=True)

            # Stworzenie słownika akcji do przekazania dotrainer.env 
            action_dict = {agent: actions[i] for i, agent in enumerate(trainer.env.agents)}

            # Krok środowiska
            next_obs_dict, reward_dict, done_dict, _, _ = trainer.env.step(action_dict)

            # Debugowanie: sprawdź `reward_dict` i `trainer.env.agents`
            print(f"reward_dict keys: {list(reward_dict.keys())}, trainer.env.agents: {trainer.env.agents}")

            # Zapis nagród dla tego epizodu
            try:
                ep_reward[:len(trainer.env.agents)] += np.array([reward_dict[agent] for agent in trainer.env.agents])
            except KeyError as e:
                raise ValueError(f"Klucz {e} nie istnieje w `reward_dict`. Sprawdź `trainer.env.agents` i `reward_dict`.") from e

            # Aktualizacja obserwacji
            obs_dict = next_obs_dict
            step += 1

        # Zapis sumarycznej nagrody dla epizodu
        ep_rewards.append(np.sum(ep_reward))
        print(f"Epizod testowy: {ep+1}, reward: {np.sum(ep_reward)}")

    if render:
        trainer.env.close()

    print(f"Średnia nagroda w testach: {np.mean(ep_rewards):.2f}")
    return ep_rewards



# =====================================================================
# 10. URUCHOMIENIE TESTÓW
# =====================================================================
print("Testowanie agenta...")
print(trainer.env)
print(env.agents)
test_rewards = evaluate_trainer(trainer, env, episodes=5, render=True)
