In [19]:
%pip install stable-baselines3 numpy torch supersuit pettingzoo pymunk scipy gymnasium matplotlib einops tensorboard wandb imageio 

Note: you may need to restart the kernel to use updated packages.


In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import random
from collections import deque
from typing import List, Dict, NamedTuple
import gymnasium as gym
from pettingzoo.mpe import simple_spread_v3

from torch.utils.tensorboard import SummaryWriter
from argparse import Namespace
import time

# =====================================================================
# 1. Parsing argumentów (Namespace) - przykładowa konfiguracja
# =====================================================================
def get_args():
    args = Namespace()
    args.seed = 1
    # args.total_episodes = 1000000
    args.total_timesteps = 1000000
    args.max_cycles = 75 
    args.n_agents = 2          # N=3 w simple_spread
    args.buffer_size = 5000
    args.batch_size = 32
    args.gamma = 0.99
    args.tau = 0.01
    args.alpha = 0.2
    args.target_entropy = -3.0
    args.lr_actor = 5e-4 
    args.lr_critic = 3e-4
    args.lr_mixer = 3e-4
    args.lr_alpha = 1e-5       
    args.l2 = 1e-5             
    args.learning_starts = 400
    args.train_freq = 1        # co ile kroków/epizodów robić update
    args.eval_freq = 50        # co ile epizodów robić ewaluację
    args.num_eval_episodes = 4 # ile epizodów w ewaluacji
    args.exp_name = "SAC_QMIX_spread_discrete"
    args.epsilon_start = 1.0
    args.epsilon_end = 0.05
    args.epsilon_decay_steps = 50000
    args.target_update_interval = 200
    return args

# =====================================================================
# 2. Definicja Współdzielonej Sieci Aktora
# =====================================================================
class SharedActor(nn.Module):
    def __init__(self, state_dim, agent_id_dim, act_dim, hidden_dim=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim + agent_id_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, act_dim)
        )
    
    def forward(self, state_with_id):
        logits = self.net(state_with_id)
        return logits
    
    def sample(self, state_with_id):
        logits = self.forward(state_with_id)
        dist = torch.distributions.Categorical(logits=logits)
        action = dist.sample()
        log_prob = dist.log_prob(action).unsqueeze(1)
        entropy = dist.entropy().unsqueeze(1)
        return action, log_prob, entropy


# =====================================================================
# 3. Definicja Współdzielonej Sieci Krytyka
# =====================================================================
class SharedCritic(nn.Module):
    def __init__(self, state_dim, act_dim, hidden_dim=64, gru_hidden_dim=64):
        super().__init__()
        self.input_dim = state_dim + act_dim
        self.gru_hidden_dim = gru_hidden_dim
        self.mlp = nn.Sequential(
            nn.Linear(self.input_dim, hidden_dim),
            nn.ReLU(),
        )
        self.gru = nn.GRU(hidden_dim, gru_hidden_dim, batch_first=True)
        self.output = nn.Sequential(
            nn.Linear(gru_hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
        
    def init_hidden(self, batch_size, device):
        """Inicjalizuje stan ukryty"""
        return torch.zeros(1, batch_size, self.gru_hidden_dim, device=device)
    
    def forward(self, state, action, hidden_state=None):
        """
        state: (B, state_dim)
        action: (B, act_dim)
        hidden_state: (1, B, gru_hidden_dim) or None
        Returns:
            q_value: (B, 1)
            next_hidden: (1, B, gru_hidden_dim)
        """
        x = torch.cat([state, action], dim=-1)
        x = self.mlp(x)
        x = x.unsqueeze(1)  # (B, 1, hidden_dim)
        
        if hidden_state is None:
            batch_size = x.size(0)
            hidden_state = torch.zeros(1, batch_size, self.gru_hidden_dim, device=x.device)
        
        hidden_state = hidden_state.detach()
        
        output, new_hidden = self.gru(x, hidden_state)
        q_value = self.output(output.squeeze(1))
        return q_value, new_hidden


# =====================================================================
# 3. Lokalny Krytyk Q dla agenta i (DRQCritic -> critic1[i], critic2[i])
# =====================================================================
class DRQCritic(nn.Module):
    """
    Lokalny krytyk Q z rekurencyjną warstwą GRU.
    Input: local_obs + last_action
    Output: scalar Q^i
    """
    def __init__(self, obs_dim, act_dim, hidden_dim=256, gru_hidden_dim=128):
        super().__init__()
        self.input_dim = obs_dim + act_dim
        self.gru_hidden_dim = gru_hidden_dim
        self.mlp = nn.Sequential(
            nn.Linear(self.input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
        )
        self.gru = nn.GRU(hidden_dim, gru_hidden_dim, batch_first=True)
        self.output = nn.Sequential(
            nn.Linear(gru_hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
        
    def init_hidden(self, batch_size, device):
        """Inicjalizuje stan ukryty"""
        return torch.zeros(1, batch_size, self.gru_hidden_dim, device=device)
    
    def forward(self, obs, action, hidden_state=None):
        """
        obs: (B, obs_dim)
        action: (B, act_dim)
        hidden_state: (1, B, gru_hidden_dim) or None
        Returns:
            q_value: (B, 1)
            next_hidden: (1, B, gru_hidden_dim)
        """
        x = torch.cat([obs, action], dim=-1)
        x = self.mlp(x)
        x = x.unsqueeze(1)  # (B, 1, hidden_dim)
        
        # Inicjalizacja stanu ukrytego jeśli nie został podany
        if hidden_state is None:
            batch_size = x.size(0)
            hidden_state = torch.zeros(1, batch_size, self.gru_hidden_dim, 
                                    device=x.device)
        
        # Tworzenie kopii stanu ukrytego
        hidden_state = hidden_state.detach()
        
        output, new_hidden = self.gru(x, hidden_state)
        q_value = self.output(output.squeeze(1))
        return q_value, new_hidden


# =====================================================================
# 3. Krytyk Q dla agenta i (Twin Q -> critic1[i], critic2[i])
# =====================================================================
class QCritic(nn.Module):
    """
    Input: cat(global_state, joint_action)
    Output: scalar Q^i
    """
    def __init__(self, input_dim, hidden_dim=256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, x):
        return self.net(x)


# =====================================================================
# 4. Definicja Sieci Mieszającej (QMIXMixingNetwork)
# =====================================================================
class QMIXMixingNetwork(nn.Module):
    def __init__(self, n_agents, state_dim, mixing_hidden_dim=8):
        super().__init__()
        self.n_agents = n_agents
        self.state_dim = state_dim

        # Hypernetwork for weights
        self.hyper_w = nn.Sequential(
            nn.Linear(state_dim, mixing_hidden_dim),
            nn.ReLU(),
            nn.Linear(mixing_hidden_dim, n_agents * 1),  # Output: n_agents * 1
        )
        
        # Hypernetwork for bias
        self.hyper_b = nn.Sequential(
            nn.Linear(state_dim, 1),
            nn.ReLU()
        )

    def forward(self, q_values, state):
        """
        q_values: (B, N)
        state: (B, state_dim)
        Zwraca: (B, 1)
        """
        bs = q_values.size(0)

        # Generate weights and bias
        w = self.hyper_w(state).view(bs, self.n_agents, 1)  # (B, N, 1)
        b = self.hyper_b(state).view(bs, 1, 1)            # (B, 1, 1)

        # Expand q_values and compute Q_tot
        q_values = q_values.unsqueeze(-1)  # (B, N, 1)
        q_tot = torch.sum(q_values * w, dim=1) + b.squeeze(1)  # (B, 1)

        q_tot = F.elu(q_tot)  # Zastosowanie ELU

        return q_tot

# =====================================================================
# 5. Funkcja pomocnicza do konwersji akcji na one-hot
# =====================================================================
def actions_to_one_hot(actions, act_dim_each):
    """
    actions: tensor o rozmiarze (B, N)
    zwraca tensor o rozmiarze (B, N * act_dim_each)
    """
    one_hot = F.one_hot(actions, num_classes=act_dim_each).float()
    return one_hot.view(actions.size(0), -1)  # (B, N * act_dim_each)
# =====================================================================
# 5.1. Funkcja do generowania One-Hot Encoding ID Agenta
# =====================================================================
def get_agent_id(agent_index, n_agents):
    agent_id = np.zeros(n_agents)
    agent_id[agent_index] = 1
    return agent_id

# =====================================================================
# 5.2. Funkcja pomocnicza do konwersji akcji na one-hot
# =====================================================================
def actions_to_one_hot(actions, act_dim_each):
    """
    actions: tensor o rozmiarze (B, N)
    zwraca tensor o rozmiarze (B, N * act_dim_each)
    """
    one_hot = F.one_hot(actions, num_classes=act_dim_each).float()
    return one_hot.view(actions.size(0), -1)  # (B, N * act_dim_each)
def get_state_with_id(state, agent_id, batch_size=None):
    """
    Properly reshapes and concatenates state with agent_id for both single and batch inputs.
    
    Args:
        state: numpy array of shape (state_dim,) or (batch_size, state_dim)
        agent_id: numpy array of shape (n_agents,)
        batch_size: int or None
    
    Returns:
        Combined state and agent_id tensor with proper dimensions
    """
    if batch_size is None:
        # Single state case
        state = np.array(state).reshape(1, -1)  # (1, state_dim)
        agent_id = agent_id.reshape(1, -1)      # (1, n_agents)
    else:
        # Batch case
        state = np.array(state)                 # (batch_size, state_dim)
        agent_id = np.tile(agent_id, (batch_size, 1))  # (batch_size, n_agents)
    
    return np.concatenate([state, agent_id], axis=1)  # (batch_size, state_dim + n_agents)

# =====================================================================
# 6. Replay buffer
# =====================================================================
class Transition(NamedTuple):
    state: np.ndarray
    obs: List[np.ndarray]  # local obs for each agent
    action: List[int]       # Zmienione na List[int] dla akcji dyskretnych
    reward: float
    next_state: np.ndarray
    next_obs: List[np.ndarray]
    done: bool

class ReplayBuffer:
    def __init__(self, max_size=100000):
        self.buffer = deque(maxlen=max_size)

    def add(self, *args):
        self.buffer.append(Transition(*args))

    def __len__(self):
        return len(self.buffer)

    def sample(self, batch_size):
        if len(self.buffer) < batch_size:
            return None
        batch = random.sample(self.buffer, batch_size)
        return list(zip(*batch))
        # Zwracamy listy/tuple T= (state, obs, action, reward, next_state, next_obs, done)

# =====================================================================
# 6. Główna klasa/wydzielona pętla treningowa - zmodyfikowana
# =====================================================================
def run_qmix_sac(args):
    torch.autograd.set_detect_anomaly(True)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # 6.1 Przygotowanie środowiska
    env = simple_spread_v3.parallel_env(
        N=args.n_agents,
        local_ratio=0.2,
        max_cycles=args.max_cycles,
        continuous_actions=False
    )
    env.reset(seed=args.seed)
    
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    # 6.2 Ustalenie wymiarów
    agents = env.possible_agents  # np. [agent_0, agent_1]
    
    n_agents = len(agents)
    assert n_agents == args.n_agents, "Nie zgadza się liczba agentów!"
    act_dim_each = env.action_space(agents[0]).n  # Liczba akcji dyskretnych
    state_dim = np.prod(env.state().shape)       # Globalny stan
    obs_dim_each = np.prod(env.observation_space(agents[0]).shape)

    print(env.action_space(agents[0])) 
    print(f"State dim: {state_dim}, Obs dim each: {obs_dim_each}, Act dim each: {act_dim_each}, Total act dim: {act_dim_each * n_agents}")

    # 6.3 Inicjalizacja współdzielonych aktora i krytyka
    agent_id_dim = args.n_agents  # One-hot encoding ID agenta

    shared_actor = SharedActor(state_dim, agent_id_dim, act_dim_each).to(device)
    shared_critic = SharedCritic(state_dim, act_dim_each).to(device)

    target_critic = SharedCritic(state_dim, act_dim_each).to(device)
    target_critic.load_state_dict(shared_critic.state_dict())

    mixer = QMIXMixingNetwork(n_agents, state_dim).to(device)
    target_mixer = QMIXMixingNetwork(n_agents, state_dim).to(device)
    target_mixer.load_state_dict(mixer.state_dict())

    # 6.4 Optymalizatory
    actor_opt = optim.RMSprop(shared_actor.parameters(), lr=args.lr_actor, weight_decay=args.l2)
    critic_opt = optim.RMSprop(shared_critic.parameters(), lr=args.lr_critic, weight_decay=args.l2)
    mixer_opt = optim.RMSprop(mixer.parameters(), lr=args.lr_mixer, weight_decay=args.l2)
    
    # Dynamiczne alpha
    log_alpha = nn.Parameter(torch.log(torch.tensor(0.2, device=device)), requires_grad=True)
    alpha_optim = optim.RMSprop([log_alpha], lr=args.lr_alpha)
    with torch.no_grad():
        alpha = torch.clamp(log_alpha.exp(), min=1e-3, max=1.0)

    # 6.5 Bufor replay
    replay = ReplayBuffer(max_size=args.buffer_size)

    # 6.6 TensorBoard
    writer = SummaryWriter(comment=f"_{args.exp_name}")
    episode_rewards = []

    # Epsilon-Greedy
    class EpsilonGreedy:
        def __init__(self, start=1.0, end=0.05, decay_steps=50000):
            self.start = start
            self.end = end
            self.decay_steps = decay_steps
            self.step = 0
        
        def get_epsilon(self):
            epsilon = self.end + (self.start - self.end) * max(0, (self.decay_steps - self.step)) / self.decay_steps
            self.step += 1
            return epsilon

    epsilon_greedy = EpsilonGreedy(start=args.epsilon_start, end=args.epsilon_end, decay_steps=args.epsilon_decay_steps)

    def soft_update(source_net, target_net, tau):
        for p, tp in zip(source_net.parameters(), target_net.parameters()):
            tp.data.copy_(tau * p.data + (1.0 - tau) * tp.data)

    def evaluate_policy(episode_idx):
        """
        Uruchamia kilka epizodów w trybie testowym (bez noise), liczy średni zwrot.
        Loguje do TB.
        """
        eval_episodes = args.num_eval_episodes
        returns = []
        for _ in range(eval_episodes):
            obs_dict, _ = env.reset()
            done_dict = {ag: False for ag in agents}
            ep_ret = 0.0
            state_np = env.state()
            while not all(done_dict.values()):
                actions_dict = {}
                for i, ag in enumerate(agents):
                    agent_id = get_agent_id(i, args.n_agents)
                    # state_with_id = np.concatenate([state_np, agent_id])
                    # state_tensor = torch.FloatTensor(state_with_id).unsqueeze(0).to(device)
                    state_with_id = get_state_with_id(state_np, agent_id)
                    state_tensor = torch.FloatTensor(state_with_id).to(device)
                    with torch.no_grad():
                        logits = shared_actor.forward(state_tensor)
                        action = torch.argmax(logits, dim=-1).item()
                    actions_dict[ag] = action
                    writer.add_scalar(f"debug/eval_action_{ag}", action, episode_idx)
                next_obs_dict, rews, done_dict, _, _ = env.step(actions_dict)
                ep_ret += sum(rews.values())
                state_np = env.state()
            returns.append(ep_ret)
        avg_ret = np.mean(returns)
        writer.add_scalar("evaluate/avg_return", avg_ret, episode_idx)
        print(f"[EVAL] Episode {episode_idx}, avg_return={avg_ret:.3f}")

    start_time = time.time() 

    # =====================================================================
    # 6.7 Główna pętla treningowa
    # =====================================================================
    for timestep in range(args.total_timesteps):
        obs_dict, _ = env.reset()
        done_dict = {ag: False for ag in agents}
        state_np = env.state()
        ep_ret = 0.0
        step = 0
        
        while not all(done_dict.values()) and step < args.max_cycles:
            actions_dict = {}
            local_obs_list = []
            for i, ag in enumerate(agents):
                agent_id = get_agent_id(i, args.n_agents)
                state_with_id = np.concatenate([state_np, agent_id])
                state_tensor = torch.FloatTensor(state_with_id).unsqueeze(0).to(device)
                
                epsilon = epsilon_greedy.get_epsilon()
                if random.random() < epsilon:
                    action_i = random.randint(0, act_dim_each - 1)
                else:
                    with torch.no_grad():
                        logits = shared_actor.forward(state_tensor)
                        action_i = torch.argmax(logits, dim=-1).item()
                
                actions_dict[ag] = action_i
                local_obs_list.append(obs_dict[ag])

            next_obs_dict, rew_dict, done_dict, _, _ = env.step(actions_dict)
            next_state_np = env.state()
            reward = sum(rew_dict.values())
            done = any(done_dict.values())  # ep. done

            # Zapis do bufora
            act_list = [actions_dict[ag] for ag in agents]
            replay.add(
                state_np,
                local_obs_list,
                act_list,
                reward,
                next_state_np,
                [next_obs_dict[ag] for ag in agents],
                done
            )

            ep_ret += reward
            state_np = next_state_np
            obs_dict = next_obs_dict
            step += 1
            timestep += 1

            # =====================================================================
            # 6.7 Główna pętla treningowa - część: Trening
            # =====================================================================
            if (timestep > args.learning_starts) and (timestep % args.train_freq == 0):
                # Sample batch
                batch = replay.sample(args.batch_size)
                if batch is None:
                    continue
                (states_b, obs_b, acts_b, rews_b, next_states_b, next_obs_b, dones_b) = batch
                
                # Konwersja danych do tensorów
                batch_states = torch.FloatTensor(np.array(states_b)).to(device)                    # (B, state_dim)
                batch_next_states = torch.FloatTensor(np.array(next_states_b)).to(device)          # (B, state_dim)
                batch_rewards = torch.FloatTensor(np.array(rews_b)).unsqueeze(1).to(device)        # (B, 1)
                batch_dones = torch.FloatTensor(np.array(dones_b)).unsqueeze(1).to(device)          # (B, 1)
                
                # Konwersja akcji na tensor (B, N)
                batch_acts = torch.LongTensor(np.array(acts_b)).to(device)                         # (B, N)
                batch_acts_one_hot = actions_to_one_hot(batch_acts, act_dim_each)                  # (B, N * act_dim_each)
                
                # ---------- Obliczanie target Q_tot ----------
                # 1. Update Critics and Mixer
                with torch.no_grad():
                    # Calculate next actions and Q-values
                    all_next_actions = []
                    for i in range(n_agents):
                        agent_id = get_agent_id(i, n_agents)
                        state_with_id = get_state_with_id(next_states_b, agent_id, batch_size=args.batch_size)
                        state_with_id_tensor = torch.FloatTensor(state_with_id).to(device)
                        logits = shared_actor.forward(state_with_id_tensor)
                        next_action = torch.argmax(logits, dim=-1)
                        all_next_actions.append(next_action)
                    
                    # Compute target Q-values
                    q_i_next = []
                    for i in range(n_agents):
                        q_val, _ = target_critic.forward(
                            batch_next_states, 
                            F.one_hot(all_next_actions[i], num_classes=act_dim_each).float()
                        )
                        q_i_next.append(q_val)
                    q_i_cat_next = torch.cat(q_i_next, dim=1).detach()
                    q_tot_next = target_mixer(q_i_cat_next, batch_next_states)
                    
                    # Compute targets
                    y = batch_rewards + (1 - batch_dones) * args.gamma * q_tot_next

                # ---------- Obliczanie obecnych Q_tot ----------
                q_i_current = []
                for i in range(n_agents):
                    q_val, _ = shared_critic.forward(batch_states, F.one_hot(batch_acts[:, i], num_classes=act_dim_each).float())
                    
                    action_one_hot = F.one_hot(batch_acts[:, i], num_classes=act_dim_each).float().requires_grad_(True)
                    q_i_min = torch.min(q_val, q_val)  # Twin krytycy
                    q_i_current.append(q_i_min)
                q_i_cat_current = torch.cat(q_i_current, dim=1)  # (B, N)
                q_tot_current = mixer(q_i_cat_current, batch_states)  # (B, 1)
                
                # Strata dla krytyków i miksującej
                # critic_loss = F.mse_loss(q_tot_current, y).detach()

                critic_loss = F.mse_loss(q_tot_current, y)
                
                # Backpropagation dla krytyków i miksującej
                critic_opt.zero_grad()
                mixer_opt.zero_grad()
                critic_loss.backward()
                critic_opt.step()
                mixer_opt.step()
                
                # ---------- Update Actor Policy ----------
                total_actor_loss = 0
                total_entropy = 0
                
                # ---------- Aktualizacja Polityki Aktora (SAC) ------------
                for i in range(n_agents):
                    agent_id = get_agent_id(i, n_agents)
                    state_with_id = get_state_with_id(states_b, agent_id, batch_size=args.batch_size)
                    state_with_id_tensor = torch.FloatTensor(state_with_id).to(device)
                    
                    # Get action distribution
                    logits = shared_actor.forward(state_with_id_tensor)
                    dist = torch.distributions.Categorical(logits=logits)
                    new_actions = dist.sample()
                    log_probs = dist.log_prob(new_actions).unsqueeze(1)
                    entropy = dist.entropy().unsqueeze(1)
                    
                    # Compute Q-values for new actions
                    q_i_new = []
                    for j in range(n_agents):
                        if j == i:
                            action_one_hot = F.one_hot(new_actions, num_classes=act_dim_each).float()
                        else:
                            action_one_hot = F.one_hot(batch_acts[:, j], num_classes=act_dim_each).float()
                        q_val, _ = shared_critic.forward(batch_states, action_one_hot)
                        q_val = q_val.detach()  # Detach Q-values
                        q_i_new.append(q_val)
                    
                    q_tot_new = mixer(torch.cat(q_i_new, dim=1), batch_states)
                    
                    # Actor loss for this agent
                    actor_loss = -(q_tot_new - args.alpha * entropy).mean()
                    
                    # Update actor for this agent
                    actor_opt.zero_grad()
                    actor_loss.backward(retain_graph=(i < n_agents - 1))  # retain_graph for all but last agent
                    torch.nn.utils.clip_grad_norm_(shared_actor.parameters(), 0.5)
                    actor_opt.step()
                    
                    # Accumulate losses and entropy for logging
                    total_actor_loss += actor_loss.item()
                    total_entropy += entropy.mean().item()
                    
                    # Log individual agent metrics
                    writer.add_scalar(f"loss/actor_loss_agent_{i}", actor_loss.item(), timestep)
                    writer.add_scalar(f"entropy/agent_{i}", entropy.mean().item(), timestep)

                # ---------- Update Alpha ----------
                alpha = log_alpha.exp()
                alpha_loss = -(log_alpha * (total_entropy/n_agents + args.target_entropy)).mean()
                

                # Update alpha
                alpha_optim.zero_grad()
                alpha_loss.backward()
                alpha_optim.step()

                # Clamp alpha
                with torch.no_grad():
                    alpha = torch.clamp(log_alpha.exp(), min=1e-3, max=1.0)

                # Log losses
                writer.add_scalar("loss/critic_loss", critic_loss.item(), timestep)
                writer.add_scalar("loss/alpha_loss", alpha_loss.item(), timestep)
                writer.add_scalar("alpha/value", alpha.item(), timestep)

                # ---------- Soft update targetów ------------
                if timestep % args.target_update_interval == 0:
                    soft_update(shared_critic, target_critic, args.tau)
                    soft_update(mixer, target_mixer, args.tau)

                # ---------- Logowanie do TB -----------
                writer.add_scalar("q_values/q_tot_current_mean", q_tot_current.mean().item(), timestep)
                writer.add_scalar("q_values/q_tot_next_mean", q_tot_next.mean().item(), timestep)
                writer.add_scalar("q_values/y_mean", y.mean().item(), timestep)
                writer.add_scalar("q_values/q_tot", q_tot_current.mean().item(), timestep)

        # Logowanie zwrotów epizodu
        episode_rewards.append(ep_ret)
        writer.add_scalar("charts/episodic_return", ep_ret, timestep)
        writer.add_scalar("charts/average_return", np.mean(episode_rewards[-100:]), timestep)

        # Ewaluacja co określoną liczbę epizodów
        if (timestep + 1) % args.eval_freq == 0:
            evaluate_policy(timestep + 1)

        # Monitorowanie postępu
        if timestep % 1000 == 0:
            print(f"Timestep {timestep}, Avg Return: {np.mean(episode_rewards[-100:]):.2f}")

    env.close()
    writer.close()

    return episode_rewards



# =====================================================================
# 7. Uruchamianie
# =====================================================================
if __name__ == "__main__":
    args = get_args()
    print(args)
    rewards = run_qmix_sac(args)
    print("Done. Last 10 episodes avg return:", np.mean(rewards[-10:]))


Namespace(seed=1, total_timesteps=1000000, max_cycles=75, n_agents=2, buffer_size=5000, batch_size=32, gamma=0.99, tau=0.01, alpha=0.2, target_entropy=-3.0, lr_actor=0.0005, lr_critic=0.0003, lr_mixer=0.0003, lr_alpha=1e-05, l2=1e-05, learning_starts=400, train_freq=1, eval_freq=50, num_eval_episodes=4, exp_name='SAC_QMIX_spread_discrete', epsilon_start=1.0, epsilon_end=0.05, epsilon_decay_steps=50000, target_update_interval=200)
Discrete(5)
State dim: 24, Obs dim each: 12, Act dim each: 5, Total act dim: 10
[EVAL] Episode 100, avg_return=-811.720
[EVAL] Episode 150, avg_return=-510.639
[EVAL] Episode 200, avg_return=-1168.641
[EVAL] Episode 250, avg_return=-989.703
[EVAL] Episode 300, avg_return=-859.092
[EVAL] Episode 350, avg_return=-620.414
[EVAL] Episode 400, avg_return=-1257.633
[EVAL] Episode 450, avg_return=-1617.266
[EVAL] Episode 500, avg_return=-1641.024
[EVAL] Episode 550, avg_return=-1550.577
[EVAL] Episode 600, avg_return=-1655.698
[EVAL] Episode 650, avg_return=-1597.602