In [3]:
%pip install stable-baselines3 numpy torch supersuit pettingzoo pymunk scipy gymnasium matplotlib einops tensorboard wandb imageio 

Note: you may need to restart the kernel to use updated packages.


In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import random
from collections import deque
from typing import List, Dict, NamedTuple
import gymnasium as gym
from pettingzoo.mpe import simple_spread_v3

from torch.utils.tensorboard import SummaryWriter
from argparse import Namespace
import time

# =====================================================================
# 1. Parsing argumentów (Namespace) - przykładowa konfiguracja
# =====================================================================
def get_args():
    args = Namespace()
    args.seed = 42
    args.total_episodes = 1000000
    args.max_cycles = 25       # maksymalna liczba kroków w epizodzie (w MPE: max_cycles)
    args.n_agents = 3          # N=3 w simple_spread
    args.buffer_size = 2000000
    args.batch_size = 128 
    args.gamma = 0.98
    args.tau = 0.05
    args.alpha = 0.005           # stała entropii w SAC
    args.lr_actor = 1e-4
    args.lr_critic = 1e-4
    args.lr_mixer = 1e-4
    args.learning_starts = 200 # musi byc wiekszy niz batch_size
    # args.learning_starts = args.total_episodes // 10
    args.train_freq = 1        # co ile kroków/epizodów robić update
    args.eval_freq = 50 # co ile epizodów robić ewaluację
    # args.eval_freq = args.total_episodes // 100       # co ile epizodów robić ewaluację
    args.num_eval_episodes = 4 # ile epizodów w ewaluacji
    args.exp_name = "SAC_QMIX_spread"
    return args


# =====================================================================
# 2. Sieć aktora - do akcji ciągłych (SAC)
# =====================================================================
class Actor(nn.Module):
    def __init__(self, obs_dim, act_dim, hidden_dim=256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
        )
        self.mean_head = nn.Linear(hidden_dim, act_dim)
        self.logstd_head = nn.Linear(hidden_dim, act_dim)
        self.LOG_STD_MIN = -5
        self.LOG_STD_MAX = 2

    def forward(self, obs):
        x = self.net(obs)
        mean = self.mean_head(x)
        log_std = self.logstd_head(x)
        # ograniczamy zakres log_std
        log_std = torch.clamp(log_std, min=self.LOG_STD_MIN, max=self.LOG_STD_MAX)
        return mean, log_std

    def sample(self, obs):
        """
        Zwraca (action, log_prob), z:
        - action w (-1,1)
        - log_prob łączny (po wymiarach akcji)
        """
        mean, log_std = self.forward(obs)
        std = log_std.exp()
        dist = torch.distributions.Normal(mean, std)
        z = dist.rsample()  # reparametrization
        action = torch.tanh(z) # od -1 do 1
        action = .5 * (action + 1)  # od 0 do 1

        # log_prob: korekta za tanh
        # log_prob = dist.log_prob(z) - torch.log(1 - action.pow(2) + 1e-7)
        log_prob = dist.log_prob(z) - torch.log(torch.clamp(1 - action.pow(2), min=1e-7))
        log_prob = log_prob.sum(dim=-1, keepdim=True)
        return action, log_prob


# =====================================================================
# 3. Krytyk Q dla agenta i (Twin Q -> critic1[i], critic2[i])
# =====================================================================
class QCritic(nn.Module):
    """
    Input: cat(global_state, joint_action)
    Output: scalar Q^i
    """
    def __init__(self, input_dim, hidden_dim=256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, x):
        return self.net(x)


# =====================================================================
# 4. Mixing Network (QMIX)
# =====================================================================
class MixingNetwork(nn.Module):
    def __init__(self, n_agents, state_dim, hidden_dim=32):
        super().__init__()
        self.n_agents = n_agents
        self.state_dim = state_dim

        self.hyper_w1 = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, n_agents * hidden_dim)
        )
        self.hyper_b1 = nn.Linear(state_dim, hidden_dim)

        self.hyper_w2 = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        self.hyper_b2 = nn.Linear(state_dim, 1)

        self.relu = nn.ReLU()

    def forward(self, q_values, state):
        """
        q_values: (batch_size, n_agents)
        state: (batch_size, state_dim)
        Zwraca (batch_size, 1)
        """
        bs = q_values.size(0)
        # 1 warstwa
        w1 = self.hyper_w1(state).view(bs, self.n_agents, -1)
        b1 = self.hyper_b1(state).view(bs, 1, -1)

        q_values = q_values.unsqueeze(1)  # (bs, 1, n_agents)
        hidden = torch.bmm(q_values, w1) + b1  # => (bs, 1, hidden_dim)
        hidden = self.relu(hidden)

        # 2 warstwa
        w2 = self.hyper_w2(state).view(bs, -1, 1)
        b2 = self.hyper_b2(state).view(bs, 1, 1)

        q_tot = torch.bmm(hidden, w2) + b2  # => (bs,1,1)
        return q_tot.view(bs, 1)


# =====================================================================
# 5. Replay buffer
# =====================================================================
class Transition(NamedTuple):
    state: np.ndarray
    obs: List[np.ndarray]  # local obs for each agent
    action: List[np.ndarray]
    reward: float
    next_state: np.ndarray
    next_obs: List[np.ndarray]
    done: bool

class ReplayBuffer:
    def __init__(self, max_size=100000):
        self.buffer = deque(maxlen=max_size)

    def add(self, *args):
        self.buffer.append(Transition(*args))

    def __len__(self):
        return len(self.buffer)

    def sample(self, batch_size):
        # print("Sampling from buffer")
        # print("Buffer size:", len(self.buffer))
        # print("Batch size:", batch_size)

        batch = random.sample(self.buffer, batch_size)
        return list(zip(*batch))
        # Zwracamy listy/tuple T= (state, obs, action, reward, next_state, next_obs, done)


# =====================================================================
# 6. Główna klasa/wydzielona pętla treningowa
# =====================================================================
def run_sac_qmix(args):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # 6.1 Przygotowanie środowiska
    env = simple_spread_v3.parallel_env(
        N=args.n_agents,
        local_ratio=0.8,
        max_cycles=args.max_cycles,
        continuous_actions=True
    )
    env.reset(seed=args.seed)
    
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    # 6.2 Ustalenie wymiarów
    agents = env.possible_agents  # np. [agent_0, agent_1, agent_2]
    
    n_agents = len(agents)
    # Sprawdzamy, czy n_agents == args.n_agents
    assert n_agents == args.n_agents, "Nie zgadza się liczba agentów!"
    act_dim_each = env.action_space(agents[0]).shape[0]  # np. 2
    state_dim = np.prod(env.state().shape)               # global state
    # local obs
    obs_dim_each = np.prod(env.observation_space(agents[0]).shape)

    print(env.action_space(agents[0])) 
    # Joint action dimension
    act_dim_total = act_dim_each * n_agents

    # 6.3 Inicjalizacja aktorów i krytyków (plus targety)
    #  - actor[i], critic1[i], critic2[i], target_critic1[i], target_critic2[i]
    #  - mixing network
    #  - target mixing network
    actors = []
    critics1 = []
    critics2 = []
    target_critics1 = []
    target_critics2 = []

    for i in range(n_agents):
        actor = Actor(obs_dim_each, act_dim_each).to(device)
        actors.append(actor)

        c1 = QCritic(state_dim + act_dim_total).to(device)
        c2 = QCritic(state_dim + act_dim_total).to(device)
        tc1 = QCritic(state_dim + act_dim_total).to(device)
        tc1.load_state_dict(c1.state_dict())
        tc2 = QCritic(state_dim + act_dim_total).to(device)
        tc2.load_state_dict(c2.state_dict())

        critics1.append(c1)
        critics2.append(c2)
        target_critics1.append(tc1)
        target_critics2.append(tc2)

    mixer = MixingNetwork(n_agents, state_dim).to(device)
    target_mixer = MixingNetwork(n_agents, state_dim).to(device)
    target_mixer.load_state_dict(mixer.state_dict())

    # 6.4 Optimizers
    actor_opts = [optim.Adam(actors[i].parameters(), lr=args.lr_actor) for i in range(n_agents)]
    critic_opts = []
    for i in range(n_agents):
        params = list(critics1[i].parameters()) + list(critics2[i].parameters())
        critic_opts.append(optim.Adam(params, lr=args.lr_critic))
    mixer_opt = optim.Adam(mixer.parameters(), lr=args.lr_mixer)

    # 6.5 Bufor replay
    replay = ReplayBuffer(max_size=args.buffer_size)

    # 6.6 TensorBoard
    writer = SummaryWriter(comment=f"_{args.exp_name}")
    global_step = 0
    episode_rewards = []

    def soft_update(source_net, target_net, tau):
        for p, tp in zip(source_net.parameters(), target_net.parameters()):
            tp.data.copy_(tau*p.data + (1.0 - tau)*tp.data)

    def evaluate_policy(episode_idx):
        """
        Uruchamia kilka epizodów w trybie testowym (bez noise), liczy średni zwrot.
        Loguje do TB.
        """
        eval_episodes = args.num_eval_episodes
        returns = []
        for _ in range(eval_episodes):
            obs_dict, _ = env.reset()
            done_dict = {ag: False for ag in agents}
            ep_ret = 0.0
            while not all(done_dict.values()):
                actions_dict = {}
                for i, ag in enumerate(agents):
                    # Deterministycznie: bierzemy mean sieci
                    obs_i = torch.FloatTensor(obs_dict[ag]).unsqueeze(0)
                    with torch.no_grad():
                        mean, _ = actors[i].forward(obs_i)
                        action = torch.tanh(mean) # od -1 do 1
                        action = .5 * (action + 1)  # od 0 do 1
                    actions_dict[ag] = action.cpu().numpy().flatten()

                next_obs, rews, done_dict, _, _ = env.step(actions_dict)
                ep_ret += sum(rews.values())
                obs_dict = next_obs
            returns.append(ep_ret)
        avg_ret = np.mean(returns)
        writer.add_scalar("evaluate/avg_return", avg_ret, episode_idx)
        print(f"[EVAL] Episode {episode_idx}, avg_return={avg_ret:.3f}")

    start_time = time.time()

    # 6.7 Główna pętla treningowa
    for ep in range(args.total_episodes):
        obs_dict, _ = env.reset()
        done_dict = {ag: False for ag in agents}
        state_np = env.state()
        ep_ret = 0.0
        step = 0

        while not all(done_dict.values()) and step < args.max_cycles:
            actions_dict = {}
            local_obs_list = []
            for i, ag in enumerate(agents):
                obs_i = torch.FloatTensor(obs_dict[ag]).unsqueeze(0)
                with torch.no_grad():
                    act_i, _ = actors[i].sample(obs_i)
                actions_dict[ag] = act_i.cpu().numpy().flatten()
                local_obs_list.append(obs_dict[ag])

            next_obs_dict, rew_dict, done_dict, _, _ = env.step(actions_dict)
            next_state_np = env.state()
            reward = sum(rew_dict.values())
            done = any(done_dict.values())  # ep. done

            # Zapis do bufora
            act_list = [actions_dict[ag] for ag in agents]
            next_local_obs_list = [next_obs_dict[ag] for ag in agents]

            replay.add(
                state_np,
                local_obs_list,
                act_list,
                reward,
                next_state_np,
                next_local_obs_list,
                done
            )

            ep_ret += reward
            state_np = next_state_np
            obs_dict = next_obs_dict
            step += 1
            global_step += 1

            # Trening
            if (global_step > args.learning_starts) and (global_step % args.train_freq == 0):
                # sample
                (states_b, obs_b, acts_b, rews_b, next_states_b, next_obs_b, dones_b) = replay.sample(args.batch_size)
                # states_b: tuple (batch_size, ), musimy skonwertować
                batch_states = torch.FloatTensor(np.array(states_b))
                batch_next_states = torch.FloatTensor(np.array(next_states_b))
                batch_acts = torch.FloatTensor(np.array(acts_b))  # shape (B, N, act_dim_each)
                batch_acts = batch_acts.view(args.batch_size, -1)  # (B, act_dim_total)
                batch_rewards = torch.FloatTensor(np.array(rews_b)).unsqueeze(1)  # (B,1)
                batch_dones = torch.FloatTensor(np.array(dones_b)).unsqueeze(1)   # (B,1)

                # ---------- Oblicz target -----------
                with torch.no_grad():
                    # next akcje (z actorów)
                    all_next_actions = []
                    sum_log_prob_next = torch.zeros((args.batch_size,1))
                    for i in range(n_agents):
                        # Z next_obs_b
                        agent_next_obs = []
                        for b_idx in range(args.batch_size):
                            agent_next_obs.append(next_obs_b[b_idx][i])
                        agent_next_obs = torch.FloatTensor(np.array(agent_next_obs))
                        # sample
                        a_next_i, logp_next_i = actors[i].sample(agent_next_obs)
                        all_next_actions.append(a_next_i)
                        sum_log_prob_next += logp_next_i

                    # joint action
                    next_joint_actions = torch.cat(all_next_actions, dim=1)
                    # Q^i z targetów
                    all_qi_next = []
                    for i in range(n_agents):
                        inp = torch.cat([batch_next_states, next_joint_actions], dim=1)
                        q1_val = target_critics1[i](inp)
                        q2_val = target_critics2[i](inp)
                        qi_next = torch.min(q1_val, q2_val)  # (B,1)
                        all_qi_next.append(qi_next)
                    q_i_cat_next = torch.cat(all_qi_next, dim=1)  # (B,N)
                    q_tot_next = target_mixer(q_i_cat_next, batch_next_states)  # (B,1)

                    # Odejmiemy alpha * sum_log_prob
                    q_tot_next = q_tot_next - args.alpha * sum_log_prob_next

                    # Bellman
                    y = batch_rewards + (1 - batch_dones) * args.gamma * q_tot_next

                # ---------- Oblicz Q^i current i Q_tot ----------
                # critics1[i], critics2[i]
                all_qi_1 = []
                all_qi_2 = []
                for i in range(n_agents):
                    inp = torch.cat([batch_states, batch_acts], dim=1)
                    q1_val = critics1[i](inp)
                    q2_val = critics2[i](inp)
                    all_qi_1.append(q1_val)
                    all_qi_2.append(q2_val)

                q1_cat = torch.cat(all_qi_1, dim=1)  # (B,N)
                q2_cat = torch.cat(all_qi_2, dim=1)  # (B,N)
                q_tot_1 = mixer(q1_cat, batch_states)  # (B,1)
                q_tot_2 = mixer(q2_cat, batch_states)  # (B,1)
                q_tot_current = torch.min(q_tot_1, q_tot_2)

                critic_loss = F.mse_loss(q_tot_current, y)

                # ---------- Update krytyków + mixera ----------
                for opt in critic_opts:
                    opt.zero_grad()
                mixer_opt.zero_grad()
                critic_loss.backward()
                for opt in critic_opts:
                    opt.step()
                mixer_opt.step()

                # ---------- Update aktorów (SAC) ------------
                # Podobnie jak w pseudo-kodzie: agent i
                for i in range(n_agents):
                    # "Nową" akcję daje actor i, reszta agentów => stara
                    old_actions = batch_acts.clone()  # shape (B, N*act_dim_each)

                    # Wyciągamy batch local_obs i
                    agent_obs = []
                    for b_idx in range(args.batch_size):
                        agent_obs.append(obs_b[b_idx][i])
                    agent_obs = torch.FloatTensor(np.array(agent_obs))

                    new_action_i, new_logp_i = actors[i].sample(agent_obs)
                    # Podmieniamy w old_actions
                    # reshape -> (B, N, act_dim_each)
                    old_actions_resh = old_actions.view(args.batch_size, n_agents, act_dim_each)
                    old_actions_resh[:, i, :] = new_action_i
                    new_joint_action = old_actions_resh.view(args.batch_size, -1)

                    # Obliczamy Q^i z critics1[i], critics2[i]
                    input_new = torch.cat([batch_states, new_joint_action], dim=1)
                    q1_val_i = critics1[i](input_new)
                    q2_val_i = critics2[i](input_new)
                    q_val_i_min = torch.min(q1_val_i, q2_val_i)

                    # Potrzebujemy Q^j dla j != i, bierzemy ze starych akcji?
                    # Dla uproszczenia: liczymy on-the-fly
                    # final wektor [Q^1, Q^2, ..., Q^N]
                    q_current_list = []
                    for j in range(n_agents):
                        if j == i:
                            q_current_list.append(q_val_i_min)  # (B,1)
                        else:
                            # stara akcja (B, N, act_dim_each)
                            # w oryg. artykule QMIX jest robione w pętli,
                            # by gradient nie przepływał przez aktor j
                            # => ok, bierzemy "no_grad"?
                            with torch.no_grad():
                                inp_j = torch.cat([batch_states, batch_acts], dim=1)
                                q1_j = critics1[j](inp_j)
                                q2_j = critics2[j](inp_j)
                                q_j_min = torch.min(q1_j, q2_j)
                            q_current_list.append(q_j_min)
                    # cat => (B, N)
                    q_cat_actor = torch.cat(q_current_list, dim=1)
                    q_tot_actor = mixer(q_cat_actor, batch_states)

                    # entropia -> alpha * logp_i
                    actor_loss = -(q_tot_actor - args.alpha * new_logp_i).mean()

                    actor_opts[i].zero_grad()
                    actor_loss.backward()
                    actor_opts[i].step()
                    
                    writer.add_scalar(f"loss/actor_loss_agent_{i}", actor_loss.item(), global_step)

                # ---------- Soft update targetów ------------
                for i in range(n_agents):
                    soft_update(critics1[i], target_critics1[i], args.tau)
                    soft_update(critics2[i], target_critics2[i], args.tau)
                soft_update(mixer, target_mixer, args.tau)

                # ---------- logi do TB -----------
                writer.add_scalar("loss/critic_loss", critic_loss.item(), global_step)
                writer.add_scalar("loss/qmix_loss", critic_loss.item(), global_step)
                writer.add_scalar("q_values/q_tot", q_tot_current.mean().item(), global_step)
                for i in range(n_agents):
                    writer.add_scalar(f"q_values/q_agent_{i}", q1_cat[:, i].mean().item(), global_step)

                writer.add_scalar("charts/SPS", global_step / (time.time() - start_time), global_step)

        # koniec epizodu
        episode_rewards.append(ep_ret)
        writer.add_scalar("charts/episodic_return", ep_ret, ep)
        writer.add_scalar("charts/average_return", np.mean(episode_rewards[-100:]), ep)

        # ewaluacja co X epizodów
        if (ep+1) % args.eval_freq == 0:
            evaluate_policy(ep+1)

    env.close()
    writer.close()

    return episode_rewards


# =====================================================================
# 7. Uruchamianie
# =====================================================================
if __name__ == "__main__":
    args = get_args()
    print(args)
    rewards = run_sac_qmix(args)
    print("Done. Last 10 episodes avg return:", np.mean(rewards[-10:]))


Namespace(seed=42, total_episodes=1000000, max_cycles=25, n_agents=3, buffer_size=2000000, batch_size=256, gamma=0.98, tau=0.05, alpha=0.005, lr_actor=0.0001, lr_critic=0.0001, lr_mixer=0.0001, learning_starts=200, train_freq=1, eval_freq=50, num_eval_episodes=4, exp_name='SAC_QMIX_spread')
Box(0.0, 1.0, (5,), float32)


ValueError: Sample larger than population or is negative