In [103]:
%pip install stable-baselines3 numpy torch supersuit pettingzoo pymunk scipy gymnasium matplotlib einops tensorboard wandb imageio 
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import random
import time
from collections import namedtuple, deque
from typing import Dict, Any

import gymnasium as gym
from pettingzoo.mpe import simple_reference_v3
from torch.utils.tensorboard import SummaryWriter

Note: you may need to restart the kernel to use updated packages.


In [104]:
# --------------------------------------------------------
# PARAMETRY
# --------------------------------------------------------
timestamps = 7000
class Args:
    seed = 42
    total_timesteps = timestamps
    learning_starts = timestamps // 10
    batch_size = 128
    gamma = 0.99
    tau = 0.005
    q_lr = 1e-3
    policy_lr = 1e-3
    alpha = 0.2
    autotune = False        # Możesz ustawić True, jeśli chcesz stroić alpha
    target_entropy = -3.0   # docelowa entropia (w autotune)
    buffer_size = 1000000
    update_every = timestamps // 1000       # co ile kroków robimy update
    policy_frequency = 2    # (co ile update'ów krytyka aktualizujemy aktora)
    max_cycles = 25         # max_cycles dla środowiska
    continuous_actions = True  
    # Reszta:
    run_name = "SAC_QMIX_EXAMPLE"


args = Args()

In [105]:
# --------------------------------------------------------
# PRZYGOTOWANIE ŚRODOWISKA
# --------------------------------------------------------
def make_env():
    env = simple_reference_v3.parallel_env(
        # N=2, 
        max_cycles=args.max_cycles,
        continuous_actions=args.continuous_actions,
        render_mode="human"

    )
    env.reset(seed=args.seed)
    
    return env


In [106]:
# --------------------------------------------------------
# REPLAY BUFFER
# --------------------------------------------------------
Experience = namedtuple(
    "Experience",
    [
        "obs",       # Dict[agent_id -> np.ndarray]  (lokalne obserwacje)
        "actions",   # Dict[agent_id -> np.ndarray]
        "rewards",   # Dict[agent_id -> float]
        "next_obs",  # Dict[agent_id -> np.ndarray]
        "terminations", # bool
        "truncations",  # bool
        "global_state",      # np.ndarray (globalny stan)
        "next_global_state", # np.ndarray
        "log_probs"  # Dict[agent_id -> float], użyte do obliczeń pi
    ]
)


class MAReplayBuffer:
    """
    Wieloagentowy Replay Buffer. Przechowuje krotki (obs, actions, rewards, next_obs, done, global_state).
    """
    def __init__(self, max_size=100000):
        self.memory = deque(maxlen=max_size)
    
    def add(self, experience: Experience):
        self.memory.append(experience)
    
    def __len__(self):
        return len(self.memory)
    
    def sample(self, batch_size):
        """
        Zwraca listę 'Experience' o długości batch_size (z losowym samplingiem).
        """
        indices = np.random.choice(len(self.memory), size=batch_size, replace=False)
        batch = [self.memory[i] for i in indices]
        return batch


In [107]:
# --------------------------------------------------------
# POMOCNICZE FUNKCJE
# --------------------------------------------------------
def agent_id_to_float(agent_id: str) -> float:
    """
    Zamienia np. 'adversary_0' -> 0, 'agent_0' -> 1, 'agent_1' -> 2, itp.
    Zależy od tego, jak ponazywałeś agentów. Ustal tutaj swój mapping.
    """
    # Prosty heurystyczny mapping:
    # 'adversary_0' => 0
    # 'agent_0' => 1
    # 'agent_1' => 2
    # itp. W razie potrzeby zrób słownik z definicją.
    if "adversary" in agent_id:
        # zakładamy, że mamy tylko 'adversary_0'
        return 0.0
    elif "agent_0" in agent_id:
        return 1.0
    elif "agent_1" in agent_id:
        return 2.0
    else:
        return 99.0  # fallback


In [108]:
# --------------------------------------------------------
# MODELE: AKTOR i KRTYTK (PARAMETER SHARING)
# --------------------------------------------------------
class SharedActor(nn.Module):
    """
    Wspólny aktor dla wszystkich agentów. 
    Dla agent_id wklejamy numeric ID do wejścia sieci.
    """
    def __init__(self, obs_dim: int, action_dim: int, hidden_dim=256):
        super().__init__()
        # Zakładamy, że do obs dolejemy 1 float z ID
        self.fc1 = nn.Linear(obs_dim + 1, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.mean_head = nn.Linear(hidden_dim, action_dim)
        self.logstd_head = nn.Linear(hidden_dim, action_dim)
        # Zakres logstd:
        self.logstd_min = -10
        self.logstd_max = 2
    
    def forward(self, obs_with_id: torch.Tensor) -> (torch.Tensor, torch.Tensor):
        """
        Zwraca (mean, log_std) akcji.
        obs_with_id ma shape (batch, obs_dim + 1).
        """
        x = F.relu(self.fc1(obs_with_id))
        x = F.relu(self.fc2(x))
        mean = self.mean_head(x)
        log_std = self.logstd_head(x)
        # Clamp log std
        log_std = torch.clamp(log_std, self.logstd_min, self.logstd_max)
        return mean, log_std
    
    def get_action(self, obs_with_id: torch.Tensor) -> (torch.Tensor, torch.Tensor):
        """
        Zwraca (action, log_prob). Używa reparametrization trick.
        Action jest w zakresie (-1, 1), bo przez tanh.
        """
        mean, log_std = self(obs_with_id)
        std = log_std.exp()

        # Reparam trick
        dist = torch.distributions.Normal(mean, std)
        z = dist.rsample()  # (batch, action_dim)
        action = torch.tanh(z)
        
        # log_prob
        log_prob = dist.log_prob(z).sum(dim=-1, keepdim=True)
        # Odejmujemy log(1 - tanh(z)^2) -> reguła wyprowadzenia
        log_prob -= torch.log(1 - action.pow(2) + 1e-7).sum(dim=-1, keepdim=True)
        
        return action, log_prob


class SharedCritic(nn.Module):
    """
    Wspólny lokalny krytyk Q, zwraca Q_a(o_a, a_a) dla agent_id.
    W standardowym SAC mamy 2 takie sieci (qf1, qf2).
    """
    def __init__(self, obs_dim: int, action_dim: int, hidden_dim=256):
        super().__init__()
        # Wejście: obs_dim + 1 (agent_id) + action_dim
        self.fc1 = nn.Linear(obs_dim + 1 + action_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.q_head = nn.Linear(hidden_dim, 1)
    
    def forward(self, obs_with_id: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
        """
        obs_with_id: (batch, obs_dim+1)
        action: (batch, action_dim)
        """
        x = torch.cat([obs_with_id, action], dim=-1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        q = self.q_head(x)
        return q  # shape (batch, 1)


In [109]:
# --------------------------------------------------------
# QMIX - MIXING NETWORK
# --------------------------------------------------------
class MixingNetwork(nn.Module):
    """
    Łączy [Q1, Q2, ..., Qn] w jedną wartość Q_tot(s, a_1..a_n).
    Zależy od global_state (np. shape = state_dim).
    
    W uproszczonej wersji poniżej jest standardowa architektura z artykułu QMIX:
    - hyper_w1, hyper_b1, hyper_w2, hyper_b2
    - monotonic mixing (relu)
    """
    def __init__(self, num_agents: int, state_dim: int, mixing_hidden_dim=64):
        super().__init__()
        self.num_agents = num_agents
        self.mixing_hidden_dim = mixing_hidden_dim

        # Pierwsza warstwa
        self.hyper_w1 = nn.Sequential(
            nn.Linear(state_dim, mixing_hidden_dim),
            nn.ReLU(),
            nn.Linear(mixing_hidden_dim, num_agents * mixing_hidden_dim)
        )
        self.hyper_b1 = nn.Linear(state_dim, mixing_hidden_dim)
        
        # Druga warstwa
        self.hyper_w2 = nn.Sequential(
            nn.Linear(state_dim, mixing_hidden_dim),
            nn.ReLU(),
            nn.Linear(mixing_hidden_dim, mixing_hidden_dim)
        )
        self.hyper_b2 = nn.Sequential(
            nn.Linear(state_dim, mixing_hidden_dim),
            nn.ReLU(),
            nn.Linear(mixing_hidden_dim, 1)
        )

        self.relu = nn.ReLU()

    def forward(self, q_local: torch.Tensor, global_state: torch.Tensor) -> torch.Tensor:
        """
        q_local ma shape (batch_size, num_agents) - to [Q1, Q2, ..., Qn].
        global_state ma shape (batch_size, state_dim).

        Zwraca shape (batch_size, 1) = Q_tot.
        """
        bs = q_local.shape[0]

        # W1 i b1
        w1 = self.hyper_w1(global_state)  # (bs, num_agents*mixing_hidden_dim)
        w1 = w1.view(bs, self.num_agents, self.mixing_hidden_dim)
        b1 = self.hyper_b1(global_state).view(bs, 1, self.mixing_hidden_dim)

        # (bs, 1, num_agents) x (bs, num_agents, mixing_hidden_dim) -> (bs, 1, mixing_hidden_dim)
        q_local = q_local.unsqueeze(1)  # (bs, 1, num_agents)
        hidden = torch.bmm(q_local, w1) + b1  # (bs, 1, mixing_hidden_dim)
        hidden = self.relu(hidden)

        # W2 i b2
        w2 = self.hyper_w2(global_state)  # (bs, mixing_hidden_dim * 1) = (bs, mixing_hidden_dim)
        w2 = w2.view(bs, self.mixing_hidden_dim, 1)
        b2 = self.hyper_b2(global_state).view(bs, 1, 1)

        # final
        q_tot = torch.bmm(hidden, w2) + b2  # (bs, 1, 1)
        return q_tot.view(bs, 1)



In [None]:
# --------------------------------------------------------
# GŁÓWNA FUNKCJA 
# --------------------------------------------------------
writer = SummaryWriter(f"runs/{args.run_name}_{int(time.time())}")

env = make_env()
agent_ids = env.possible_agents
num_agents = len(agent_ids)

single_action_space = env.action_space(agent_ids[0])
single_obs_space = env.observation_space(agent_ids[0])

obs_dim = single_obs_space.shape[0]

if isinstance(single_action_space, gym.spaces.Discrete):
    action_dim = single_action_space.n  
    is_discrete = True
else:
    action_dim = single_action_space.shape[0]
    is_discrete = False

global_state_shape = env.state().shape
state_dim = global_state_shape[0]

shared_actor = SharedActor(obs_dim=obs_dim, action_dim=action_dim).to("cpu")
shared_critic1 = SharedCritic(obs_dim=obs_dim, action_dim=action_dim).to("cpu")
shared_critic2 = SharedCritic(obs_dim=obs_dim, action_dim=action_dim).to("cpu")
mixing_network = MixingNetwork(num_agents=num_agents, state_dim=state_dim).to("cpu")

target_critic1 = SharedCritic(obs_dim=obs_dim, action_dim=action_dim).to("cpu")
target_critic2 = SharedCritic(obs_dim=obs_dim, action_dim=action_dim).to("cpu")
target_critic1.load_state_dict(shared_critic1.state_dict())
target_critic2.load_state_dict(shared_critic2.state_dict())

q_params = list(shared_critic1.parameters()) + list(shared_critic2.parameters()) + list(mixing_network.parameters())
critic_optimizer = optim.Adam(q_params, lr=args.q_lr)
actor_optimizer = optim.Adam(shared_actor.parameters(), lr=args.policy_lr)

if args.autotune:
    log_alpha = torch.zeros(1, requires_grad=True)
    alpha_optimizer = optim.Adam([log_alpha], lr=args.q_lr)
    alpha = log_alpha.exp().item()
else:
    alpha = args.alpha

rb = MAReplayBuffer(max_size=args.buffer_size)

torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)

obs_dict, info = env.reset(seed=args.seed)

start_time = time.time()
global_step = 0

global_state = env.state()
done = False
episode_return = 0
episode_length = 0


# -------------------------------------------
# PĘTLA GŁÓWNA
# -------------------------------------------
while global_step < args.total_timesteps:
    # Zbieramy akcje
    actions = {}
    log_probs = {}
    with torch.no_grad():
        for aid in agent_ids:
            # Lokalna obserwacja
            obs_a = obs_dict[aid]  # shape (obs_dim,)
            # Sklejamy z ID
            agent_id_float = agent_id_to_float(aid)
            obs_with_id = np.concatenate([obs_a, [agent_id_float]], axis=0)
            obs_with_id_t = torch.tensor(obs_with_id, dtype=torch.float32).unsqueeze(0)

            if global_step < args.learning_starts:
                # Exploracja
                if is_discrete:
                    # Wybieramy losowo z [0..action_dim-1]
                    act = np.random.randint(action_dim)
                    log_p = 0.0
                    actions[aid] = act
                    log_probs[aid] = log_p
                else:
                    # continuous
                    act = np.random.uniform(low=-1, high=1, size=(action_dim,))
                    log_p = 0.0
                    actions[aid] = act
                    log_probs[aid] = log_p
            else:
                # normalnie z aktora
                act_t, log_p_t = shared_actor.get_action(obs_with_id_t)
                act = act_t.squeeze(0).cpu().numpy()
                lp = log_p_t.item()
                if is_discrete:
                    # konwersja z (-1,1) do dyskretnego 0..action_dim-1
                    # np. bierzemy argmax po logitach = interpretacja? 
                    # (Tutaj w 100% SAC dla dyskretnych wymaga innej obsługi, np. Gumbel-Softmax.)
                    # Dla uproszczenia: 
                    # weźmy argmax (co jest dość "hackowe" jak na SAC).
                    # lepiej wziąć softmax i wylosować. 
                    # Ale to już zależy od ciebie.
                    # Róbmy 'argmax':
                    # scale do 0..1
                    act_in_01 = (act + 1)/2
                    # weźmy argmax
                    a_discrete = np.argmax(act_in_01)
                    actions[aid] = a_discrete
                    log_probs[aid] = float(lp) # w sumie to i tak troche "niepoprawne"
                else:
                    actions[aid] = act
                    log_probs[aid] = float(lp)
    
    # Krok środowiska
    next_obs, rews, terms, truncs, infos = env.step(actions)

    global_step += 1
    episode_length += 1
    episode_return += sum(rews.values())

    # Zapis do buffer
    exp = Experience(
        obs=obs_dict,
        actions=actions,
        rewards=rews,
        next_obs=next_obs,
        terminations=terms,
        truncations=truncs,
        global_state=global_state,
        next_global_state=env.state(),
        log_probs=log_probs
    )
    rb.add(exp)

    # Przejście do next
    obs_dict = next_obs
    global_state = env.state()

    # Czy koniec epizodu
    done = any(terms.values()) or any(truncs.values())
    if done:
        # log do TensorBoard
        writer.add_scalar("charts/episode_return", episode_return, global_step)
        writer.add_scalar("charts/episode_length", episode_length, global_step)

        # reset
        obs_dict, info = env.reset(seed=args.seed)
        global_state = env.state()
        episode_return = 0
        episode_length = 0

    # -------------------------------------------
    # Trening
    # -------------------------------------------
    if (global_step > args.learning_starts) and (global_step % args.update_every == 0):
        for _ in range(args.update_every):
            if len(rb) < args.batch_size:
                break
            batch = rb.sample(args.batch_size)
            # batch to lista Experience

            # Przerabiamy na tensory
            # 1) global_state
            state_batch = []
            next_state_batch = []
            rew_batch = []
            done_batch = []
            
            # W QMIX potrzeba local_q_vals (per agent).
            # Ale obliczymy to wewnątrz pętli agentów.

            # Zbudujmy mini-listy dict[agent_id -> obs], dict[agent_id -> actions]
            # bo trzeba spiąć w tensory.
            obs_batch_dict = {aid: [] for aid in agent_ids}
            act_batch_dict = {aid: [] for aid in agent_ids}
            next_obs_batch_dict = {aid: [] for aid in agent_ids}
            logp_batch_dict = {aid: [] for aid in agent_ids}

            for ex in batch:
                state_batch.append(ex.global_state)
                next_state_batch.append(ex.next_global_state)
                # sum rewards? Albo w QMIX można sumować
                # Tutaj w stylu "centralnego" rewardu dajmy sum(rews)
                rew_batch.append( sum(ex.rewards.values()) )

                # w MPE done = any(terminations), tu uprośćmy:
                done_flag = any(ex.terminations.values()) or any(ex.truncations.values())
                done_batch.append(float(done_flag))

                for aid in agent_ids:
                    obs_a = ex.obs[aid]
                    agent_id_float = agent_id_to_float(aid)
                    obs_with_id = np.concatenate([obs_a, [agent_id_float]], axis=0)
                    obs_batch_dict[aid].append(obs_with_id)

                    if is_discrete:
                        # action = int
                        act_val = ex.actions[aid]
                        # przerób na one-hot lub coś w tym stylu:
                        a_oh = np.zeros(action_dim, dtype=np.float32)
                        a_oh[act_val] = 1.0
                        act_batch_dict[aid].append(a_oh)
                    else:
                        act_batch_dict[aid].append(ex.actions[aid])

                    # next obs
                    next_obs_a = ex.next_obs[aid]
                    obs_with_id_next = np.concatenate([next_obs_a, [agent_id_float]], axis=0)
                    next_obs_batch_dict[aid].append(obs_with_id_next)

                    # logp
                    logp_batch_dict[aid].append(ex.log_probs[aid])

            # Konwersja do tensora
            state_t = torch.tensor(state_batch, dtype=torch.float32)            # (batch, state_dim)
            next_state_t = torch.tensor(next_state_batch, dtype=torch.float32)  # (batch, state_dim)
            rew_t = torch.tensor(rew_batch, dtype=torch.float32).unsqueeze(-1)  # (batch, 1)
            done_t = torch.tensor(done_batch, dtype=torch.float32).unsqueeze(-1)

            # obs_t_dict[aid] shape => (batch, obs_dim+1)
            obs_t_dict = {}
            act_t_dict = {}
            next_obs_t_dict = {}
            logp_t_dict = {}
            for aid in agent_ids:
                obs_t_dict[aid] = torch.tensor(obs_batch_dict[aid], dtype=torch.float32)
                act_t_dict[aid] = torch.tensor(act_batch_dict[aid], dtype=torch.float32)
                next_obs_t_dict[aid] = torch.tensor(next_obs_batch_dict[aid], dtype=torch.float32)
                logp_t_dict[aid] = torch.tensor(logp_batch_dict[aid], dtype=torch.float32).unsqueeze(-1)

            # ---- KRYTYK / QMIX update ----
            # 1) obliczamy Q_a^1, Q_a^2 => (batch, num_agents)
            q1_locals = []
            q2_locals = []
            with torch.no_grad():
                # Najpierw policz next actions i logp do targetu
                next_actions = []
                next_logps = []
                for aid in agent_ids:
                    a_mean, a_logp = shared_actor.get_action(next_obs_t_dict[aid])
                    next_actions.append(a_mean)   # shape (batch, action_dim)
                    next_logps.append(a_logp)     # shape (batch, 1)
                
                # Składamy sumę logp
                # shape (batch, num_agents, 1) => sum -> (batch, 1)
                sum_next_logp = torch.stack(next_logps, dim=1).sum(dim=1)  # (batch, 1)

                # Obliczamy Q'_a( next_obs, next_action )
                # i potem zrobimy mixing
                q1_next_locals = []
                q2_next_locals = []
                for i, aid in enumerate(agent_ids):
                    q1_val = target_critic1(next_obs_t_dict[aid], next_actions[i])
                    q2_val = target_critic2(next_obs_t_dict[aid], next_actions[i])
                    # q1_next_locals.append(q1_val.squeeze(-1))
                    # q2_next_locals.append(q2_val.squeeze(-1))

                    # weź min => local next Q
                    # W oryginalnym QMIX mamy jedną sieć, ale tu mamy 2 do soft update
                    # Mimo wszystko trzymajmy się min do stylu SAC:
                    q_min = torch.min(q1_val, q2_val)
                    q1_next_locals.append(q_min.squeeze(-1))
                    q2_next_locals.append(q_min.squeeze(-1))

                # Tak naprawdę w standardowym 2-krytykowym QMIX
                # należałoby mieć 2 mixing networks. Tu upraszczamy.
                # Bierzemy q_min i liczymy JEDEN mix.
                q1_next_locals_stack = torch.stack(q1_next_locals, dim=1)  # (batch, num_agents)
                # mixing
                q_tot_next = mixing_network(q1_next_locals_stack, next_state_t) # (batch,1)
                # W stylu SAC:
                q_target = rew_t + (1 - done_t) * args.gamma * (q_tot_next - alpha * sum_next_logp)
            
            # Bez gradientów => 'q_target' to nasz y
            # 2) Obliczamy Q_a(obs, actions) => mixing => MSE do q_target
            current_q1_locals = []
            current_q2_locals = []
            for aid in agent_ids:
                q1_val = shared_critic1(obs_t_dict[aid], act_t_dict[aid])  # (batch, 1)
                q2_val = shared_critic2(obs_t_dict[aid], act_t_dict[aid])  # (batch, 1)
                current_q1_locals.append(q1_val.squeeze(-1))
                current_q2_locals.append(q2_val.squeeze(-1))

            q1_locals_stack = torch.stack(current_q1_locals, dim=1)  # (batch, num_agents)
            q2_locals_stack = torch.stack(current_q2_locals, dim=1)  # (batch, num_agents)

            q_tot_1 = mixing_network(q1_locals_stack, state_t)  # (batch,1)
            q_tot_2 = mixing_network(q2_locals_stack, state_t)  # (batch,1)

            critic_loss_1 = F.mse_loss(q_tot_1, q_target)
            critic_loss_2 = F.mse_loss(q_tot_2, q_target)
            critic_loss = critic_loss_1 + critic_loss_2

            critic_optimizer.zero_grad()
            critic_loss.backward()
            critic_optimizer.step()

            # ---- ACTOR update ----
            # Co 'policy_frequency' kroków
            if _ % args.policy_frequency == 0:
                # Re-obliczamy local Q i mix
                # 1) pobierz pi(a|obs)
                # new_actions = []
                # logps = []
                # for aid in agent_ids:
                #     act_t, logp_t = shared_actor.get_action(obs_t_dict[aid])
                #     new_actions.append(act_t)
                #     logps.append(logp_t)

                # sum_logp = torch.stack(logps, dim=1).sum(dim=1)  # (batch,1)
                new_actions = []
                new_logps = []
                for aid in agent_ids:
                    sampled_action, logp_a = shared_actor.get_action(obs_t_dict[aid])
                    new_actions.append(sampled_action)
                    new_logps.append(logp_a)
            
                sum_logp = torch.stack(new_logps, dim=1).sum(dim=1)  # (batch,1)

                # 2) Q1 i Q2 => mixing => min => Q_tot
                # new_q1_locals = []
                # new_q2_locals = []
                # for i, aid in enumerate(agent_ids):
                #     q1_val = shared_critic1(obs_t_dict[aid], new_actions[i])
                #     q2_val = shared_critic2(obs_t_dict[aid], new_actions[i])
                #     q_min = torch.min(q1_val, q2_val)
                #     new_q1_locals.append(q_min.squeeze(-1))
                #     new_q2_locals.append(q_min.squeeze(-1))

                # q1_locals_stack = torch.stack(new_q1_locals, dim=1)
                # q2_locals_stack = torch.stack(new_q2_locals, dim=1)
                # q_tot_1_pi = mixing_network(q1_locals_stack, state_t)
                # q_tot_2_pi = mixing_network(q2_locals_stack, state_t)
                # q_tot_pi = torch.min(q_tot_1_pi, q_tot_2_pi)

                # # Strata actor => maximize Q_tot - alpha * sum_logp
                # # => minimize -(Q_tot - alpha * sum_logp)
                # actor_loss = (alpha * sum_logp - q_tot_pi).mean()

                # actor_optimizer.zero_grad()
                # actor_loss.backward()
                # actor_optimizer.step()
                # Q = min(Q1,Q2) => mixing => actor_loss
                new_q_locals = []
                for i, aid in enumerate(agent_ids):
                    q1_val = shared_critic1(obs_t_dict[aid], new_actions[i])
                    q2_val = shared_critic2(obs_t_dict[aid], new_actions[i])
                    q_min  = torch.min(q1_val, q2_val).squeeze(-1)
                    new_q_locals.append(q_min)
                
                new_q_locals_stack = torch.stack(new_q_locals, dim=1)
                q_tot_pi = mixing_network(new_q_locals_stack, state_t)

                actor_loss = (alpha * sum_logp - q_tot_pi).mean()

                actor_optimizer.zero_grad()
                actor_loss.backward()
                actor_optimizer.step()
                # Autotune alpha (opcjonalnie)
                if args.autotune:
                    # standardowo: alpha_loss = -(log_alpha * (logp + target_entropy))
                    # tu w multi-agent: sum logp. 
                    # Bierzmy average:
                    alpha_loss = -(log_alpha * (sum_logp + args.target_entropy).detach()).mean()
                    alpha_optimizer.zero_grad()
                    alpha_loss.backward()
                    alpha_optimizer.step()
                    alpha = log_alpha.exp().item()

            # ---- Soft update targetów ----
            with torch.no_grad():
                for param, tparam in zip(shared_critic1.parameters(), target_critic1.parameters()):
                    tparam.data.copy_(args.tau * param.data + (1 - args.tau)*tparam.data)
                for param, tparam in zip(shared_critic2.parameters(), target_critic2.parameters()):
                    tparam.data.copy_(args.tau * param.data + (1 - args.tau)*tparam.data)

        writer.add_scalar("losses/critic_loss", critic_loss.item(), global_step)
        writer.add_scalar("charts/SPS", global_step / (time.time() - start_time), global_step)
        if args.autotune:
            writer.add_scalar("charts/alpha", alpha, global_step)

env.close()
writer.close()
print("Trening zakończony.")


 -0.88383278  0.73235229  0.20223002  0.41614516 -0.95883101  0.9398197
  0.66488528 -0.57532178 -0.63635007] that was outside action space Box(0.0, 1.0, (15,), float32). Environment is clipping to space
 -0.72101228 -0.4157107  -0.26727631 -0.08786003  0.57035192 -0.60065244
  0.02846888  0.18482914 -0.90709917] that was outside action space Box(0.0, 1.0, (15,), float32). Environment is clipping to space
 -0.39077246 -0.80465577  0.36846605 -0.11969501 -0.75592353 -0.00964618
 -0.93122296  0.8186408  -0.48244004] that was outside action space Box(0.0, 1.0, (15,), float32). Environment is clipping to space
  0.55026565  0.87899788  0.7896547   0.19579996  0.84374847 -0.823015
 -0.60803428 -0.90954542 -0.34933934] that was outside action space Box(0.0, 1.0, (15,), float32). Environment is clipping to space
 -0.71815155  0.60439396 -0.85089871  0.97377387  0.54448954 -0.60256864
 -0.98895577  0.63092286  0.41371469] that was outside action space Box(0.0, 1.0, (15,), float32). Environment

: 