In [15]:
%pip install stable-baselines3 numpy torch supersuit pettingzoo pymunk scipy gymnasium matplotlib einops tensorboard wandb imageio 

Note: you may need to restart the kernel to use updated packages.


In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import random
from collections import deque
from typing import List, Dict, NamedTuple
import gymnasium as gym
from pettingzoo.mpe import simple_spread_v3

from torch.utils.tensorboard import SummaryWriter
from argparse import Namespace
import time

# =====================================================================
# 1. Parsing argumentów (Namespace) - przykładowa konfiguracja
# =====================================================================
def get_args():
    args = Namespace()
    args.seed = 1
    args.total_episodes = 1000000
    args.max_cycles = 75 
    args.n_agents = 2          # N=3 w simple_spread
    args.buffer_size = 50000
    args.batch_size = 64
    args.gamma = 0.99
    args.tau = 0.01
    args.alpha = 0.2
    args.target_entropy = -3.0
    args.lr_actor = 3e-4
    args.lr_critic = 3e-4
    args.lr_mixer = 3e-4
    args.lr_alpha = 1e-5       
    args.l2 = 1e-5             
    args.learning_starts = 2000
    args.train_freq = 1        # co ile kroków/epizodów robić update
    args.eval_freq = 50        # co ile epizodów robić ewaluację
    args.num_eval_episodes = 2 # ile epizodów w ewaluacji
    args.exp_name = "SAC_QMIX_spread_discrete"
    return args

# =====================================================================
# 2. Sieć aktora - do akcji dyskretnych
# =====================================================================
class Actor(nn.Module):
    def __init__(self, obs_dim, act_dim, hidden_dim=256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
        )

        self.logits_head = nn.Linear(hidden_dim, act_dim)
    
    def forward(self, obs):
        x = self.net(obs)
        logits = self.logits_head(x)
        return logits
    
    def sample(self, obs):
        """
        Zwraca (action, log_prob, entropy), gdzie:
        - action jest indeksowanym wyborem z kategorii akcji
        - log_prob to log prawdopodobieństwa wybranej akcji
        - entropy to entropia rozkładu akcji
        """
        logits = self.forward(obs)
        dist = torch.distributions.Categorical(logits=logits)
        action = dist.sample()
        log_prob = dist.log_prob(action).unsqueeze(1)
        entropy = dist.entropy().unsqueeze(1)
        return action, log_prob, entropy



# =====================================================================
# 3. Krytyk Q dla agenta i (Twin Q -> critic1[i], critic2[i])
# =====================================================================
class QCritic(nn.Module):
    """
    Input: cat(global_state, joint_action)
    Output: scalar Q^i
    """
    def __init__(self, input_dim, hidden_dim=256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, x):
        return self.net(x)


# =====================================================================
# 4. Mixing Network (QMIX)
# =====================================================================
class MixingNetwork(nn.Module):
    def __init__(self, n_agents, state_dim, hidden_dim=32):
        super().__init__()
        self.n_agents = n_agents
        self.state_dim = state_dim

        self.hyper_w1 = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, n_agents * hidden_dim)
        )
        self.hyper_b1 = nn.Linear(state_dim, hidden_dim)

        self.hyper_w2 = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        self.hyper_b2 = nn.Linear(state_dim, 1)

        self.relu = nn.ReLU()

    def forward(self, q_values, state):
        """
        q_values: (batch_size, n_agents)
        state: (batch_size, state_dim)
        Zwraca (batch_size, 1)
        """
        bs = q_values.size(0)
        # 1 warstwa
        w1 = self.hyper_w1(state).view(bs, self.n_agents, -1)
        w1 = torch.relu(w1)
        b1 = self.hyper_b1(state).view(bs, 1, -1)

        q_values = q_values.unsqueeze(1)  # (bs, 1, n_agents)
        hidden = torch.bmm(q_values, w1) + b1  # => (bs, 1, hidden_dim)
        hidden = self.relu(hidden)

        # 2 warstwa
        w2 = self.hyper_w2(state).view(bs, -1, 1)
        w2 = torch.relu(w2)
        b2 = self.hyper_b2(state).view(bs, 1, 1)

        q_tot = torch.bmm(hidden, w2) + b2  # => (bs,1,1)
        return q_tot.view(bs, 1)

# =====================================================================
# 5. Funkcja pomocnicza do konwersji akcji na one-hot
# =====================================================================
def actions_to_one_hot(actions, act_dim_each):
    """
    actions: tensor o rozmiarze (B, N)
    zwraca tensor o rozmiarze (B, N * act_dim_each)
    """
    one_hot = F.one_hot(actions, num_classes=act_dim_each).float()
    return one_hot.view(actions.size(0), -1)  # (B, N * act_dim_each)

# =====================================================================
# 6. Replay buffer
# =====================================================================
class Transition(NamedTuple):
    state: np.ndarray
    obs: List[np.ndarray]  # local obs for each agent
    action: List[int]       # Zmienione na List[int] dla akcji dyskretnych
    reward: float
    next_state: np.ndarray
    next_obs: List[np.ndarray]
    done: bool

class ReplayBuffer:
    def __init__(self, max_size=100000):
        self.buffer = deque(maxlen=max_size)

    def add(self, *args):
        self.buffer.append(Transition(*args))

    def __len__(self):
        return len(self.buffer)

    def sample(self, batch_size):
        if len(self.buffer) < batch_size:
            return None
        batch = random.sample(self.buffer, batch_size)
        return list(zip(*batch))
        # Zwracamy listy/tuple T= (state, obs, action, reward, next_state, next_obs, done)

# =====================================================================
# 7. Główna klasa/wydzielona pętla treningowa
# =====================================================================
def run_sac_qmix(args):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # 6.1 Przygotowanie środowiska
    env = simple_spread_v3.parallel_env(
        N=args.n_agents,
        local_ratio=0.2,
        max_cycles=args.max_cycles,
        continuous_actions=False
    )
    env.reset(seed=args.seed)
    
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    # 6.2 Ustalenie wymiarów
    agents = env.possible_agents  # np. [agent_0, agent_1]
    
    n_agents = len(agents)
    # Sprawdzamy, czy n_agents == args.n_agents
    assert n_agents == args.n_agents, "Nie zgadza się liczba agentów!"
    act_dim_each = env.action_space(agents[0]).n  # Zmienione z .shape[0] na .n dla akcji dyskretnych
    state_dim = np.prod(env.state().shape)               # global state
    # local obs
    obs_dim_each = np.prod(env.observation_space(agents[0]).shape)

    print(env.action_space(agents[0])) 
    # Joint action dimension
    act_dim_total = act_dim_each * n_agents
    print(f"State dim: {state_dim}, Obs dim each: {obs_dim_each}, Act dim each: {act_dim_each}, Total act dim: {act_dim_total}")

    # 6.3 Inicjalizacja aktorów i krytyków (plus targety)
    actors = []
    critics1 = []
    critics2 = []
    target_critics1 = []
    target_critics2 = []

    for i in range(n_agents):
        actor = Actor(obs_dim_each, act_dim_each).to(device)
        actors.append(actor)

        c1 = QCritic(state_dim + act_dim_total).to(device)
        c2 = QCritic(state_dim + act_dim_total).to(device)
        tc1 = QCritic(state_dim + act_dim_total).to(device)
        tc1.load_state_dict(c1.state_dict())
        tc2 = QCritic(state_dim + act_dim_total).to(device)
        tc2.load_state_dict(c2.state_dict())

        critics1.append(c1)
        critics2.append(c2)
        target_critics1.append(tc1)
        target_critics2.append(tc2)

    mixer = MixingNetwork(n_agents, state_dim).to(device)
    target_mixer = MixingNetwork(n_agents, state_dim).to(device)
    target_mixer.load_state_dict(mixer.state_dict())

    # 6.4 Optimizers
    actor_opts = [
        optim.Adam(actors[i].parameters(), lr=args.lr_actor, weight_decay=args.l2) 
        for i in range(n_agents)
    ]
    critic_opts = []
    for i in range(n_agents):
        params = list(critics1[i].parameters()) + list(critics2[i].parameters())
        critic_opts.append(
            optim.Adam(params, lr=args.lr_critic, weight_decay=args.l2)
        )
    mixer_opt = optim.Adam(mixer.parameters(), lr=args.lr_mixer, weight_decay=args.l2)
    
    # DYNAMICZNE ALPHA
    args.target_entropy = -np.log(act_dim_each)  # ≈ -1.609 dla 5 akcji
    print("Target entropy:", args.target_entropy)
    log_alpha = nn.Parameter(torch.log(torch.tensor(0.2, device=device)))
    alpha_optim = optim.Adam([log_alpha], lr=args.lr_alpha)
    with torch.no_grad():
        alpha = torch.clamp(log_alpha.exp(), min=1e-3, max=1.0)  # Dodanie maksymalnej wartości

    # 6.5 Bufor replay
    replay = ReplayBuffer(max_size=args.buffer_size)

    # 6.6 TensorBoard
    writer = SummaryWriter(comment=f"_{args.exp_name}")
    episode_rewards = []

    def soft_update(source_net, target_net, tau):
        for p, tp in zip(source_net.parameters(), target_net.parameters()):
            tp.data.copy_(tau * p.data + (1.0 - tau) * tp.data)

    def evaluate_policy(episode_idx):
        """
        Uruchamia kilka epizodów w trybie testowym (bez noise), liczy średni zwrot.
        Loguje do TB.
        """
        eval_episodes = args.num_eval_episodes
        returns = []
        for _ in range(eval_episodes):
            obs_dict, _ = env.reset()
            done_dict = {ag: False for ag in agents}
            ep_ret = 0.0
            while not all(done_dict.values()):
                actions_dict = {}
                for i, ag in enumerate(agents):
                    actors[i].eval()  # Upewnij się, że aktor jest w trybie ewaluacji
                    with torch.no_grad():
                        obs_i = torch.FloatTensor(obs_dict[ag]).unsqueeze(0).to(device)
                        logits = actors[i].forward(obs_i)
                        dist = torch.distributions.Categorical(logits=logits)
                        action = dist.probs.argmax(dim=-1)  # Deterministyczny wybór akcji
                    actors[i].train()  # Przywróć tryb treningu
                    actions_dict[ag] = action.cpu().numpy().flatten()[0]  # action jako scalar (int)
                    writer.add_scalar(f"debug/eval_action_{ag}", actions_dict[ag], episode_idx)

                next_obs, rews, done_dict, _, _ = env.step(actions_dict)
                ep_ret += sum(rews.values())
                obs_dict = next_obs
            returns.append(ep_ret)
        avg_ret = np.mean(returns)
        writer.add_scalar("evaluate/avg_return", avg_ret, episode_idx)
        print(f"[EVAL] Episode {episode_idx}, avg_return={avg_ret:.3f}")

    
    start_time = time.time() 
    # =====================================================================
    # 6.7 Główna pętla treningowa - część: Próbkowanie Akcji
    # =====================================================================

    for ep in range(args.total_episodes):
        obs_dict, _ = env.reset()
        done_dict = {ag: False for ag in agents}
        state_np = env.state()
        ep_ret = 0.0
        step = 0

        while not all(done_dict.values()) and step < args.max_cycles:
            actions_dict = {}
            local_obs_list = []
            for i, ag in enumerate(agents):
                obs_i = torch.FloatTensor(obs_dict[ag]).unsqueeze(0).to(device)
                with torch.no_grad():
                    action_i, logp_i, entropy_i = actors[i].sample(obs_i)
                actions_dict[ag] = action_i.cpu().numpy().flatten()[0]  # action jako scalar (int)
                local_obs_list.append(obs_dict[ag])

            next_obs_dict, rew_dict, done_dict, _, _ = env.step(actions_dict)
            next_state_np = env.state()
            reward = sum(rew_dict.values())
            done = any(done_dict.values())  # ep. done

            # Zapis do bufora
            act_list = [actions_dict[ag] for ag in agents]  # Lista int
            replay.add(
                state_np,
                local_obs_list,
                act_list,
                reward,
                next_state_np,
                [next_obs_dict[ag] for ag in agents],
                done
            )

            ep_ret += reward
            state_np = next_state_np
            obs_dict = next_obs_dict
            step += 1
            
            # =====================================================================
            # 6.7 Główna pętla treningowa - część: Trening
            # =====================================================================
            if (ep > args.learning_starts) and (ep % args.train_freq == 0):
                # Sample
                batch = replay.sample(args.batch_size)
                if batch is None:
                    continue
                (states_b, obs_b, acts_b, rews_b, next_states_b, next_obs_b, dones_b) = batch
                
                # Konwersja akcji do tensorów
                batch_states = torch.FloatTensor(np.array(states_b)).to(device)                    # (B, state_dim)
                batch_next_states = torch.FloatTensor(np.array(next_states_b)).to(device)          # (B, state_dim)
                batch_rewards = torch.FloatTensor(np.array(rews_b)).unsqueeze(1).to(device)        # (B, 1)
                batch_dones = torch.FloatTensor(np.array(dones_b)).unsqueeze(1).to(device)          # (B, 1)
                
                # Konwersja akcji na tensor (B, N)
                batch_acts = torch.LongTensor(np.array(acts_b)).to(device)                         # (B, N)
                batch_acts_one_hot = actions_to_one_hot(batch_acts, act_dim_each)                  # (B, N * act_dim_each)
                
                # ---------- Oblicz target -----------
                with torch.no_grad():
                    # Próba akcji z actorów dla następnych stanów
                    all_next_actions = []
                    sum_entropy = torch.zeros((args.batch_size, 1), device=device)

                    for i in range(n_agents):
                        agent_next_obs = torch.FloatTensor([next_obs_b[b][i] for b in range(args.batch_size)]).to(device)
                        a_next_i, logp_next_i, entropy_i = actors[i].sample(agent_next_obs)
                        all_next_actions.append(a_next_i)
                        sum_entropy += entropy_i  # Sumowanie entropii

                        # Dodaj sprawdzenie kształtu (opcjonalne)
                        assert a_next_i.shape == (args.batch_size,), f"a_next_i shape: {a_next_i.shape}"
                        assert logp_next_i.shape == (args.batch_size, 1), f"logp_next_i shape: {logp_next_i.shape}"
                        assert entropy_i.shape == (args.batch_size, 1), f"entropy_i shape: {entropy_i.shape}"

                    # Konwersja do one-hot
                    next_joint_actions = actions_to_one_hot(torch.stack(all_next_actions, dim=1), act_dim_each)  # (B, N * act_dim_each)

                    # Obliczanie Q^i z targetów i miksowanie
                    all_qi_next = []
                    for i in range(n_agents):
                        inp = torch.cat([batch_next_states, next_joint_actions], dim=1)
                        q1_val = target_critics1[i](inp)
                        q2_val = target_critics2[i](inp)
                        qi_next = torch.min(q1_val, q2_val)
                        all_qi_next.append(qi_next)
                    q_i_cat_next = torch.cat(all_qi_next, dim=1)  # (B, N)
                    q_tot_next = target_mixer(q_i_cat_next, batch_next_states)  # (B, 1)

                    # Dodanie entropii
                    q_tot_next = q_tot_next + (alpha * sum_entropy)  # [B,1] + [B,1]

                    # Bellman target y = rewards + gamma * (1 - done) * q_tot_next
                    y = batch_rewards + (1 - batch_dones) * args.gamma * q_tot_next

                # ---------- Oblicz Q^i current i Q_tot ----------
                all_qi_1 = []
                all_qi_2 = []
                for i in range(n_agents):
                    inp = torch.cat([batch_states, batch_acts_one_hot], dim=1)
                    q1_val = critics1[i](inp)
                    q2_val = critics2[i](inp)
                    all_qi_1.append(q1_val)
                    all_qi_2.append(q2_val)

                q1_cat = torch.cat(all_qi_1, dim=1)  # (B, N)
                q2_cat = torch.cat(all_qi_2, dim=1)  # (B, N)
                q_tot_1 = mixer(q1_cat, batch_states)  # (B,1)
                q_tot_2 = mixer(q2_cat, batch_states)  # (B,1)
                q_tot_current = torch.min(q_tot_1, q_tot_2)  # (B,1)

                # Loss krytyków
                critic_loss = F.mse_loss(q_tot_current, y) + \
                            0.01 * sum(qi.pow(2).mean() for qi in all_qi_1) + \
                            0.01 * sum(qi.pow(2).mean() for qi in all_qi_2)

                # Zero_grad i clip dla krytyków
                for i in range(n_agents):
                    critic_opts[i].zero_grad()
                    torch.nn.utils.clip_grad_norm_(critics1[i].parameters(), 1.0)  # Zmniejszony clip
                    torch.nn.utils.clip_grad_norm_(critics2[i].parameters(), 1.0)  # Zmniejszony clip

                # Mixer
                mixer_opt.zero_grad()

                critic_loss.backward()

                # Step
                for i in range(n_agents):
                    critic_opts[i].step()
                mixer_opt.step()

                # ---------- Update aktorów (SAC) ------------
                for i in range(n_agents):
                    # Nowa akcja od aktora i
                    agent_obs = torch.FloatTensor([obs_b[b][i] for b in range(args.batch_size)]).to(device)
                    new_action_i, new_logp_i, entropy_i = actors[i].sample(agent_obs)  # (B,), (B,1), (B,1)

                    # Podmiana akcji w joint action
                    old_actions = batch_acts.clone()  # (B, N)
                    old_actions[:, i] = new_action_i  # Zastąpienie akcji dla agenta i
                    new_joint_actions = actions_to_one_hot(old_actions, act_dim_each)  # (B, N * act_dim_each)

                    # Obliczanie Q^i z krytyków
                    input_new = torch.cat([batch_states, new_joint_actions], dim=1)
                    q1_val_i = critics1[i](input_new)
                    q2_val_i = critics2[i](input_new)
                    q_val_i_min = torch.min(q1_val_i, q2_val_i)

                    # Obliczanie Q^j dla j != i
                    q_current_list = []
                    for j in range(n_agents):
                        if j == i:
                            q_current_list.append(q_val_i_min)  # (B,1)
                        else:
                            with torch.no_grad():
                                inp_j = torch.cat([batch_states, batch_acts_one_hot], dim=1)
                                q1_j = critics1[j](inp_j)
                                q2_j = critics2[j](inp_j)
                                q_j_min = torch.min(q1_j, q2_j)
                            q_current_list.append(q_j_min)

                    # Tworzenie q_tot_actor
                    q_cat_actor = torch.cat(q_current_list, dim=1)  # (B, N)
                    q_tot_actor = mixer(q_cat_actor, batch_states)    # (B, 1)

                    # Obliczanie actor_loss
                    actor_loss = -(q_tot_actor - alpha.detach() * entropy_i).mean()

                    # Aktualizacja aktora
                    actor_opts[i].zero_grad()
                    actor_loss.backward()
                    torch.nn.utils.clip_grad_norm_(actors[i].parameters(), 0.5)  # Zmniejszony clip
                    actor_opts[i].step()

                    # Logging
                    writer.add_scalar(f"loss/actor_loss_agent_{i}", actor_loss.item(), ep)
                    writer.add_scalar(f"entropy/agent_{i}", entropy_i.mean().item(), ep)
                
                # ---------- Update alpha ------------
                alpha = log_alpha.exp()
                alpha_loss = -(log_alpha * (sum_entropy.detach() - args.target_entropy)).mean()
                writer.add_scalar("loss/alpha_loss", alpha_loss.item(), ep)
                writer.add_scalar("alpha/value", alpha.item(), ep)

                # Aktualizacja alpha
                alpha_optim.zero_grad()
                alpha_loss.backward()
                alpha_optim.step()

                # Aktualizacja alpha po aktualizacji
                with torch.no_grad():
                    alpha = torch.clamp(log_alpha.exp(), min=1e-3, max=1.0)

                # ---------- Soft update targetów ------------
                for i in range(n_agents):
                    soft_update(critics1[i], target_critics1[i], args.tau)
                    soft_update(critics2[i], target_critics2[i], args.tau)
                soft_update(mixer, target_mixer, args.tau)


                # ---------- logi do TB -----------
                writer.add_scalar("loss/critic_loss", critic_loss.item(), ep)
                writer.add_scalar("loss/qmix_loss", critic_loss.item(), ep)
                writer.add_scalar("q_values/q_tot_current_mean", q_tot_current.mean().item(), ep)
                writer.add_scalar("q_values/q_tot_next_mean", q_tot_next.mean().item(), ep)
                writer.add_scalar("q_values/y_mean", y.mean().item(), ep)
                writer.add_scalar("q_values/q_tot", q_tot_current.mean().item(), ep)
                for i in range(n_agents):
                    writer.add_scalar(f"q_values/q_agent_{i}", q1_cat[:, i].mean().item(), ep)
                    writer.add_scalar(f"q_values/q_agent_{i}_critic1_mean", q1_cat[:, i].mean().item(), ep)
                    writer.add_scalar(f"q_values/q_agent_{i}_critic2_mean", q2_cat[:, i].mean().item(), ep)
                    critic_grad_norm = sum(
                        p.grad.norm().item() 
                        for p in critics1[i].parameters() 
                        if p.grad is not None
                    )
                    writer.add_scalar(f"debug/critic_{i}_grad_norm", critic_grad_norm, ep)

        writer.add_scalar("charts/SPS", ep / (time.time() - start_time), ep)

        # koniec epizodu
        episode_rewards.append(ep_ret)
        writer.add_scalar("charts/episodic_return", ep_ret, ep)
        writer.add_scalar("charts/average_return", np.mean(episode_rewards[-100:]), ep)

        # ewaluacja co X epizodów
        if (ep + 1) % args.eval_freq == 0:
            evaluate_policy(ep + 1)

    env.close()
    writer.close()

    return episode_rewards


# =====================================================================
# 7. Uruchamianie
# =====================================================================
if __name__ == "__main__":
    args = get_args()
    print(args)
    rewards = run_sac_qmix(args)
    print("Done. Last 10 episodes avg return:", np.mean(rewards[-10:]))


Namespace(seed=1, total_episodes=1000000, max_cycles=75, n_agents=2, buffer_size=50000, batch_size=64, gamma=0.99, tau=0.01, alpha=0.2, target_entropy=-3.0, lr_actor=0.0003, lr_critic=0.0003, lr_mixer=0.0003, lr_alpha=1e-05, l2=1e-05, learning_starts=2000, train_freq=1, eval_freq=50, num_eval_episodes=2, exp_name='SAC_QMIX_spread_discrete')
Discrete(5)
State dim: 24, Obs dim each: 12, Act dim each: 5, Total act dim: 10
Target entropy: -1.6094379124341003
[EVAL] Episode 50, avg_return=-380.656
[EVAL] Episode 100, avg_return=-525.279
[EVAL] Episode 150, avg_return=-341.472
[EVAL] Episode 200, avg_return=-420.556
[EVAL] Episode 250, avg_return=-268.284
[EVAL] Episode 300, avg_return=-263.906
[EVAL] Episode 350, avg_return=-250.621
[EVAL] Episode 400, avg_return=-438.052
[EVAL] Episode 450, avg_return=-505.277
[EVAL] Episode 500, avg_return=-287.696
[EVAL] Episode 550, avg_return=-425.913
[EVAL] Episode 600, avg_return=-277.215
[EVAL] Episode 650, avg_return=-445.825
[EVAL] Episode 700, av

KeyboardInterrupt: 