# PPO- based Pysc2

Modular Design
sc2_ppo_project/

├── main.py             # Entry-point to run training

├── config.py           # Hyperparameters and logging config

├── environment.py      # SC2 environment wrapper

├── model.py            # Actor-Critic neural network

├── utils.py            # Observation preprocessing and action utilities

└── ppo.py              # PPO training algorithm implementation



In [1]:
# config.py

In [20]:
import torch
import logging
import os

# ─── Hyperparameters ─────────────────────────────
MAP_NAME     = "DefeatZerglingsAndBanelings"
# MAP_NAME     = "Simple64" 

SCREEN_SIZE  = 84
MINIMAP_SIZE = 64
STEP_MUL     = 16
NB_ACTORS    = 1
T            = 128
K            = 10
BATCH_SIZE   = 256
GAMMA        = 0.99
GAE_LAMBDA   = 0.95
LR           = 2.5e-4
ENT_COEF     = 0.01
VF_COEF      = 1.0
MAX_ITERS    = 1000
DEVICE       = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ─── Replay and Dataset Directories ─────────────
REPLAY_DIR   = os.path.join("replays")             # Where to save .SC2Replay files
DATASET_PATH = os.path.join("dataset.pkl")         # Where to save (obs, action) dataset

# ─── Logging Configuration ──────────────────────
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
    datefmt="%H:%M:%S"
)
logger = logging.getLogger(__name__)

# Ensure replay directory exists
os.makedirs(REPLAY_DIR, exist_ok=True)


In [21]:
# env.py

In [22]:
# from pysc2.env import sc2_env
# from pysc2.lib import actions, features
# # from config import MAP_NAME, SCREEN_SIZE, MINIMAP_SIZE, STEP_MUL, logger

# class SC2Envs:
#     def __init__(self, nb_actor):
#         logger.info("Initializing %d SC2 env(s)...", nb_actor)
#         self.nb   = nb_actor
#         self.envs = [self._make_env() for _ in range(nb_actor)]
#         self.obs  = [None]*nb_actor
#         self.done = [False]*nb_actor
#         self._init_all()
#         logger.info("All SC2 env(s) ready.")

#     def _make_env(self):
#         return sc2_env.SC2Env(
#             map_name=MAP_NAME,
#             players=[sc2_env.Agent(sc2_env.Race.terran)],
#             agent_interface_format=features.AgentInterfaceFormat(
#                 feature_dimensions=features.Dimensions(
#                     screen=SCREEN_SIZE, minimap=MINIMAP_SIZE),
#                 use_feature_units=True,
#                 use_raw_units=False,
#                 use_camera_position=True,
#                 action_space=actions.ActionSpace.FEATURES
#             ),
#             step_mul=STEP_MUL,
#             game_steps_per_episode=0,
#             visualize=False,
#         )

#     def _init_all(self):
#         for i, e in enumerate(self.envs):
#             ts = e.reset()[0]
#             self.obs[i], self.done[i] = ts, False

#     def reset(self, i):
#         ts = self.envs[i].reset()[0]
#         self.obs[i], self.done[i] = ts, False
#         return ts

#     def step(self, i, fc):
#         ts = self.envs[i].step([fc])[0]
#         self.obs[i], self.done[i] = ts, ts.last()
#         return ts

#     def close(self):
#         for e in self.envs:
#             e.close()


In [39]:
import os
from pysc2.env import sc2_env
from pysc2.lib import actions, features
import logging

logger = logging.getLogger(__name__)

class SC2EnvsMulti:
    def __init__(
        self,
        nb_actor,
        replay_dir="replays",
        replay_prefix="run",
        map_name=MAP_NAME,
        screen_size=84,
        minimap_size=64,
        step_mul=16,
    ):
        self.nb = nb_actor

        # — make replay_dir an absolute folder INSIDE YOUR PROJECT
        self.replay_dir = os.path.join(os.getcwd(), replay_dir)
        os.makedirs(self.replay_dir, exist_ok=True)

        self.replay_prefix = replay_prefix

        logger.info("Initializing %d SC2 env(s)…", self.nb)
        self.envs = [
            self._make_env(map_name, screen_size, minimap_size, step_mul)
            for _ in range(self.nb)
        ]
        self.obs  = [env.reset()[0] for env in self.envs]
        self.done = [False] * self.nb

    def _make_env(self, map_name, screen_size, minimap_size, step_mul):
        return sc2_env.SC2Env(
            map_name=map_name,
            players=[
                sc2_env.Agent(sc2_env.Race.terran),
                sc2_env.Bot(sc2_env.Race.terran, sc2_env.Difficulty.very_easy),
            ],
            agent_interface_format=features.AgentInterfaceFormat(
                feature_dimensions=features.Dimensions(
                    screen=screen_size, minimap=minimap_size
                ),
                use_feature_units=True,
            ),
            step_mul=step_mul,
            visualize=False,

            # ← now uses the absolute project path
            save_replay_episodes=1,
            replay_dir=self.replay_dir,
            replay_prefix=self.replay_prefix,
        )

    def step(self, i, action):
        ts = self.envs[i].step([action])[0]
        self.obs[i]  = ts
        self.done[i] = ts.last()
        return ts

    def reset(self, i):
        self.obs[i]  = self.envs[i].reset()[0]
        self.done[i] = False

    def close(self):
        for env in self.envs:
            env.close()


In [40]:
# model.py

In [41]:
import torch
import torch.nn as nn
# from config import SCREEN_SIZE, DEVICE  # ✅ Import shared config

class ActorCritic(nn.Module):
    def __init__(self, in_channels, nb_actions):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, 16, 8, stride=4), nn.Tanh(),
            nn.Conv2d(16, 32, 4, stride=2), nn.Tanh(),
            nn.Flatten(),
        )
        with torch.no_grad():
            dummy = torch.zeros(1, in_channels, SCREEN_SIZE, SCREEN_SIZE).to(DEVICE)
            conv_out = self.conv(dummy).shape[-1]

        self.fc     = nn.Sequential(nn.Linear(conv_out, 256), nn.Tanh())
        self.actor  = nn.Linear(256, nb_actions)
        self.critic = nn.Linear(256, 1)

    def forward(self, x):
        h = self.conv(x)
        h = self.fc(h)
        return self.actor(h), self.critic(h).squeeze(-1)


In [42]:
# utils.py

In [43]:
# import torch
# import numpy as np
# import random
# from pysc2.lib import actions, features
# # from config import DEVICE

# _PLAYER_RELATIVE = features.SCREEN_FEATURES.player_relative.index
# _UNIT_TYPE       = features.SCREEN_FEATURES.unit_type.index

# ACTION_LIST = ['do_nothing', 'select_idle', 'build_refinery', 'harvest']
# FUNC_ID = {
#     'do_nothing': actions.FUNCTIONS.no_op.id,
#     'select_idle': actions.FUNCTIONS.select_idle_worker.id,
#     'build_refinery': actions.FUNCTIONS.Build_Refinery_screen.id,
#     'harvest': actions.FUNCTIONS.Harvest_Gather_screen.id,
# }

# def preprocess(ts):
#     fs = ts.observation.feature_screen
#     pr = fs[_PLAYER_RELATIVE].astype(np.float32) / 4.0
#     ut = fs[_UNIT_TYPE].astype(np.float32) / fs[_UNIT_TYPE].max()
#     stacked = np.stack([pr, ut], axis=0)
#     return torch.from_numpy(stacked).unsqueeze(0).float().to(DEVICE)

# def legal_actions(ts):
#     avail = set(ts.observation.available_actions)
#     fus   = ts.observation.feature_units
#     legal = [0]
#     if FUNC_ID['select_idle'] in avail: legal.append(1)
#     if FUNC_ID['build_refinery'] in avail and any(u.unit_type==342 for u in fus): legal.append(2)
#     if FUNC_ID['harvest'] in avail and any(u.unit_type==341 for u in fus): legal.append(3)
#     return legal

# def make_pysc2_call(action_idx, ts):
#     name, fid = ACTION_LIST[action_idx], FUNC_ID[ACTION_LIST[action_idx]]
#     if name == 'select_idle':
#         return actions.FunctionCall(fid, [[2]])
#     if name in ('build_refinery','harvest'):
#         fus = ts.observation.feature_units
#         cand = [u for u in fus if (u.unit_type==342 if name=='build_refinery' else u.unit_type==341)]
#         if not cand:
#             return actions.FunctionCall(actions.FUNCTIONS.no_op.id, [])
#         u = random.choice(cand)
#         return actions.FunctionCall(fid, [[0],[u.x,u.y]])
#     return actions.FunctionCall(fid, [])


In [44]:
import torch
import numpy as np
import random
from pysc2.lib import actions, features

# ─── Constants ────────────────────────────────────────────────────────────────
_PLAYER_RELATIVE = features.SCREEN_FEATURES.player_relative.index
_UNIT_TYPE       = features.SCREEN_FEATURES.unit_type.index

ACTION_LIST  = ['select', 'do_nothing','build', 'gather', 'move', 'attack','train', 'upgrade']
ACTION_INDEX = {name: idx for idx, name in enumerate(ACTION_LIST)}

SCREEN_SIZE = 84

TERRAN_STRUCTURE_TYPES = [
    18, 20, 21, 22, 23, 24, 25, 27, 28, 29, 30,
    130, 131, 132, 133
]

# ─── Observation Preprocessing ───────────────────────────────────────────────
def safe_coords(x, y, screen_size=SCREEN_SIZE):
    x = max(0, min(screen_size - 1, x))
    y = max(0, min(screen_size - 1, y))
    return [x, y]


def preprocess(ts):
    fs = ts.observation.feature_screen
    pr = fs[_PLAYER_RELATIVE].astype(np.float32) / 4.0
    ut = fs[_UNIT_TYPE].astype(np.float32) / fs[_UNIT_TYPE].max()
    stacked = np.stack([pr, ut], axis=0)
    return torch.from_numpy(stacked).unsqueeze(0).float()

def legal_actions(ts):
    avail = set(ts.observation.available_actions)
    fus   = ts.observation.feature_units

    legal = []

    # 1) Always allow selecting one of your own units
    if actions.FUNCTIONS.select_point.id in avail:
        legal.append(ACTION_INDEX['select'])

    # 2) Fallback to no_op if nothing else
    legal.append(ACTION_INDEX['do_nothing'])

    # 3) Movement & combat
    if actions.FUNCTIONS.Move_screen.id in avail:
        legal.append(ACTION_INDEX['move'])
    if actions.FUNCTIONS.Attack_screen.id in avail:
        legal.append(ACTION_INDEX['attack'])

    # 4) Building (non-quick variants only)
    build_opts = [
        a for a in avail
        if 'Build' in actions.FUNCTIONS[a].name
        and not actions.FUNCTIONS[a].name.endswith('_quick')
    ]
    if build_opts:
        legal.append(ACTION_INDEX['build'])

    # 5) Gathering
    if (actions.FUNCTIONS.Harvest_Gather_screen.id in avail and
        any(u.unit_type == 341 for u in fus)):
        legal.append(ACTION_INDEX['gather'])

    # 6) Tech upgrades
    if any('Research' in actions.FUNCTIONS[a].name for a in avail):
        legal.append(ACTION_INDEX['upgrade'])

    # 7) Training units
    if any('Train' in actions.FUNCTIONS[a].name for a in avail):
        legal.append(ACTION_INDEX['train'])

    return legal



# ─── PySC2 Action Execution Wrapper ─────────────────────────────────────────
# ─── PySC2 Action Execution Wrapper ─────────────────────────────────────────
def make_pysc2_call(action_idx, ts, pending=None):
    """
    Returns (FunctionCall, new_pending).
    Pads missing argument lists, clamps all coordinates, and falls back to
    the best available selection primitive.
    """
    def _pad(fn_id, args):
        fn = actions.FUNCTIONS[fn_id]
        missing = len(fn.args) - len(args)
        if missing > 0:
            args = args + [[0]] * missing
        return args

    obs   = ts.observation
    fus   = obs.feature_units
    avail = set(obs.available_actions)

    # 1) Fire pending follow-up
    if pending:
        fn_id = pending['action_fn']
        if fn_id in avail:
            args = _pad(fn_id, pending['args'])
            return actions.FunctionCall(fn_id, args), None
        return actions.FunctionCall(actions.FUNCTIONS.no_op.id, []), None

    # 2) Handle explicit no_op
    if action_idx == ACTION_INDEX['do_nothing']:
        return actions.FunctionCall(actions.FUNCTIONS.no_op.id, []), None

    # 3) Pick best select primitive
    if actions.FUNCTIONS.select_point.id in avail:
        select_fn = actions.FUNCTIONS.select_point.id
    elif actions.FUNCTIONS.select_unit.id in avail:
        select_fn = actions.FUNCTIONS.select_unit.id
    elif actions.FUNCTIONS.select_rect.id in avail:
        select_fn = actions.FUNCTIONS.select_rect.id
    else:
        return actions.FunctionCall(actions.FUNCTIONS.no_op.id, []), None

    # choose a random friendly unit
    self_units = [u for u in fus if u.alliance == features.PlayerRelative.SELF]
    if not self_units:
        return actions.FunctionCall(actions.FUNCTIONS.no_op.id, []), None

    unit   = random.choice(self_units)
    x, y   = safe_coords(unit.x, unit.y)
    select = actions.FunctionCall(
        select_fn,
        _pad(select_fn, [[0], [x, y]])
    )

    # 4) Prepare any follow-up action (with clamped targets)
    pending = None
    if action_idx == ACTION_INDEX['train']:
        opts = [a for a in avail if 'Train' in actions.FUNCTIONS[a].name]
        if opts:
            pending = {'action_fn': random.choice(opts), 'args': [[0]]}

    elif action_idx == ACTION_INDEX['move'] and actions.FUNCTIONS.Move_screen.id in avail:
        fn_id = actions.FUNCTIONS.Move_screen.id
        tx, ty = safe_coords(
            random.randrange(0, SCREEN_SIZE),
            random.randrange(0, SCREEN_SIZE)
        )
        pending = {'action_fn': fn_id, 'args': [[0], [tx, ty]]}

    elif action_idx == ACTION_INDEX['attack'] and actions.FUNCTIONS.Attack_screen.id in avail:
        enemies = [u for u in fus if u.alliance == features.PlayerRelative.ENEMY]
        if enemies:
            fn_id = actions.FUNCTIONS.Attack_screen.id
            tx, ty = safe_coords(enemies[0].x, enemies[0].y)
            pending = {'action_fn': fn_id, 'args': [[0], [tx, ty]]}

    elif action_idx == ACTION_INDEX['gather'] and actions.FUNCTIONS.Harvest_Gather_screen.id in avail:
        minerals = [u for u in fus if u.unit_type == 341]
        if minerals:
            fn_id = actions.FUNCTIONS.Harvest_Gather_screen.id
            tx, ty = safe_coords(minerals[0].x, minerals[0].y)
            pending = {'action_fn': fn_id, 'args': [[0], [tx, ty]]}

    elif action_idx == ACTION_INDEX['build']:
        opts = [
            a for a in avail
            if 'Build' in actions.FUNCTIONS[a].name
            and not actions.FUNCTIONS[a].name.endswith('_quick')
        ]
        if opts:
            fn_id     = random.choice(opts)
            buildable = np.argwhere(obs.feature_screen.buildable == 1)
            if buildable.size:
                by, bx = random.choice(buildable)
                bx, by = safe_coords(bx, by)
                pending = {'action_fn': fn_id, 'args': [[0], [bx, by]]}

    elif action_idx == ACTION_INDEX['upgrade']:
        opts = [a for a in avail if 'Research' in actions.FUNCTIONS[a].name]
        if opts:
            pending = {'action_fn': random.choice(opts), 'args': [[0]]}

    return select, pending

In [45]:
# PPO training LOOP

In [46]:
import torch
import torch.nn as nn
from torch.distributions.categorical import Categorical
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt
import pickle
import os

from pysc2.lib import actions
from config import *
# from utils import preprocess, legal_actions, make_pysc2_call, extract_enemy_units, infer_enemy_action

def PPO(envs, model):
    optimizer = torch.optim.Adam(model.parameters(), lr=LR)
    scheduler = torch.optim.lr_scheduler.LinearLR(
        optimizer, start_factor=1.0, end_factor=0.0, total_iters=MAX_ITERS
    )

    ep_rewards = []
    expert_dataset = []  # Collect enemy bot data

    logger.info("▶️  Starting PPO for %d iterations", MAX_ITERS)
    for it in range(MAX_ITERS):
        if it % 1000 == 0:
            logger.info("🔄 Iter %d / %d", it, MAX_ITERS)

        # storage buffers
        obs_buf  = torch.zeros(envs.nb, T, 2, SCREEN_SIZE, SCREEN_SIZE, device=DEVICE)
        act_buf  = torch.zeros(envs.nb, T,      dtype=torch.long, device=DEVICE)
        logp_buf = torch.zeros(envs.nb, T,                     device=DEVICE)
        val_buf  = torch.zeros(envs.nb, T+1,                   device=DEVICE)
        rew_buf  = torch.zeros(envs.nb, T,                     device=DEVICE)
        done_buf = torch.zeros(envs.nb, T,                     device=DEVICE)
        adv_buf  = torch.zeros(envs.nb, T,                     device=DEVICE)

        # ─── Rollout ─────────────────────────────────────────────────────────
        with torch.no_grad():
            for t in range(T):
                for i in range(envs.nb):
                    ts    = envs.obs[i]
                    state = preprocess(ts)
                    logits, value = model(state)

                    # mask illegal
                    LA   = legal_actions(ts)
                    mask = torch.full_like(logits, float('-inf'))
                    mask[0, LA] = 0.0
                    dist = Categorical(logits=logits + mask)

                    action = dist.sample()
                    logp   = dist.log_prob(action)
                    fc     = make_pysc2_call(action.item(), ts)

                    # step (fallback to no-op)
                    try:
                        ts2 = envs.step(i, fc)
                    except ValueError:
                        ts2 = envs.step(i, actions.FunctionCall(actions.FUNCTIONS.no_op.id, []))

                    r = ts2.reward
                    d = float(ts2.last())

                    obs_buf[i,t]  = state
                    act_buf[i,t]  = action
                    logp_buf[i,t] = logp
                    val_buf[i,t]  = value
                    rew_buf[i,t]  = r
                    done_buf[i,t] = d

                    if d:
                        ep_rewards.append(sum(rew_buf[i, :t+1].tolist()))
                        envs.reset(i)

            for i in range(envs.nb):
                val_buf[i,T] = model(preprocess(envs.obs[i]))[1]

        # ─── GAE & flatten ────────────────────────────────────────────────────
        for i in range(envs.nb):
            gae = 0
            for t in reversed(range(T)):
                mask  = 1.0 - done_buf[i,t]
                delta = rew_buf[i,t] + GAMMA*val_buf[i,t+1]*mask - val_buf[i,t]
                gae   = delta + GAMMA*GAE_LAMBDA*mask*gae
                adv_buf[i,t] = gae

        b_s  = obs_buf.reshape(-1,2,SCREEN_SIZE,SCREEN_SIZE)
        b_a  = act_buf.reshape(-1)
        b_lp = logp_buf.reshape(-1)
        b_v  = val_buf[:,:T].reshape(-1)
        b_ad = adv_buf.reshape(-1)

        # ─── PPO updates ─────────────────────────────────────────────────────
        for _ in range(K):
            ds     = TensorDataset(b_s,b_a,b_lp,b_v,b_ad)
            loader = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True)
            for st, ac, old_lp, old_v, adv in loader:
                logits, val = model(st)
                dist        = Categorical(logits=logits)
                lp          = dist.log_prob(ac)
                ratio       = torch.exp(lp - old_lp)

                clip   = 0.1 * (1 - it/MAX_ITERS)
                obj1   = adv * ratio
                obj2   = adv * torch.clamp(ratio, 1-clip, 1+clip)
                p_loss = -torch.min(obj1,obj2).mean()

                ret     = adv + old_v
                v1      = (val - ret).pow(2)
                v2      = (torch.clamp(val,old_v-clip,old_v+clip)-ret).pow(2)
                v_loss  = 0.5 * torch.max(v1,v2).mean()

                entropy = dist.entropy().mean()
                loss    = p_loss + VF_COEF*v_loss - ENT_COEF*entropy

                optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(model.parameters(),0.5)
                optimizer.step()

        scheduler.step()

    # ─── Save replays only at end ────────────────────────────────────────────
    for i in range(envs.nb):
        replay_path = os.path.join(REPLAY_DIR, f"ppo_final_{i}.SC2Replay")
        envs.envs[i]._save_replay("PPO", replay_path)
    logger.info("💾 Saved final replay(s) to %s", REPLAY_DIR)

    # ─── Plot learning curve ─────────────────────────────────────────────────
    plt.figure(figsize=(10,5))
    plt.plot(ep_rewards, label="episode reward")
    plt.title("Environment Reward per Episode")
    plt.xlabel("Episode")
    plt.ylabel("Reward")
    plt.grid(True)
    plt.legend()
    plt.savefig("learning_curve.png")
    plt.show()

    envs.close()
    logger.info("✅ Training complete")
    logger.info(f"Saved learning_curve.png over {len(ep_rewards)} episodes")


In [47]:
# main.py

In [48]:
# from absl import app
# # from environment import SC2Envs
# # from model import ActorCritic
# # from ppo import PPO
# # from config import NB_ACTORS, DEVICE
# # from utils import ACTION_LIST

# def main(_):
#     envs = SC2Envs(NB_ACTORS)
#     model = ActorCritic(2, len(ACTION_LIST)).to(DEVICE)
#     PPO(envs, model)

# if __name__ == "__main__":
#     import sys
#     sys.argv = sys.argv[:1]  # Remove extra flags passed by Jupyter or IPython
#     app.run(main)


In [49]:
# from rich.live import Live
# from rich.table import Table
# from rich.console import Console
# from collections import deque
# import matplotlib.pyplot as plt
# import random
# import sys
# from absl import flags

# flags.FLAGS(sys.argv)  # fix required by pysc2
# # from util import preprocess, legal_actions, make_pysc2_call
# # from env import SC2Envs

# console = Console()
# envs = SC2Envs(nb_actor=1)
# pending_action = [None] * envs.nb

# MAX_ROWS = 20
# recent_rows = deque(maxlen=MAX_ROWS)

# # For tracking per-episode scores
# episode_score = [0] * envs.nb
# scores = []

# def generate_table():
#     table = Table(title=f"SC2 Agent Actions (Last {MAX_ROWS} Steps)", expand=True)
#     table.add_column("Step", justify="right")
#     table.add_column("Function ID", justify="right")
#     table.add_column("Args", justify="left")
#     for row in recent_rows:
#         table.add_row(*row)
#     return table

# with Live(generate_table(), refresh_per_second=10, console=console, transient=True) as live:
#     for step in range(MAX_ITERS):
#         for i in range(envs.nb):
#             ts = envs.obs[i]

#             if pending_action[i]:
#                 action, pending_action[i] = make_pysc2_call(None, ts, pending_action[i])
#             else:
#                 legal = legal_actions(ts)
#                 action_idx = random.choice(legal)
#                 action, pending_action[i] = make_pysc2_call(action_idx, ts)

#             recent_rows.append((str(step), str(action.function), str(action.arguments)))
#             live.update(generate_table())

#             ts = envs.step(i, action)
#             episode_score[i] += ts.reward

#             if ts.last():
#                 scores.append(episode_score[i])
#                 episode_score[i] = 0  # reset
#                 envs.reset(i)

# envs.close()

# # Plot episode scores
# plt.figure(figsize=(10, 4))
# plt.plot(scores, label="Episode Score", marker='o', linewidth=1.5)
# plt.xlabel("Episode")
# plt.ylabel("Total Score")
# plt.title("Agent Score per Episode")
# plt.grid(True)
# plt.legend()
# plt.tight_layout()
# plt.show()


In [50]:
# Campain Maps

In [51]:
import os
import torch
import torch.nn as nn
from torch.distributions.categorical import Categorical
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt
import logging

from pysc2.lib import actions
# from config import (
#     NB_ACTORS, T, K, BATCH_SIZE,
#     GAMMA, GAE_LAMBDA, LR, ENT_COEF, VF_COEF,
#     MAX_ITERS, DEVICE, SCREEN_SIZE, REPLAY_DIR
# )
# from utils import preprocess, legal_actions, make_pysc2_call

logger = logging.getLogger(__name__)

def train_PPO(envs, model):
    optimizer = torch.optim.Adam(model.parameters(), lr=LR)
    scheduler = torch.optim.lr_scheduler.LinearLR(
        optimizer, start_factor=1.0, end_factor=0.0, total_iters=MAX_ITERS
    )

    episode_rewards = []

    logger.info("▶️  Starting PPO for %d iterations", MAX_ITERS)
    for it in range(1, MAX_ITERS + 1):
        # storage buffers
        obs_buf  = torch.zeros(NB_ACTORS, T, 2, SCREEN_SIZE, SCREEN_SIZE, device=DEVICE)
        act_buf  = torch.zeros(NB_ACTORS, T,      dtype=torch.long, device=DEVICE)
        logp_buf = torch.zeros(NB_ACTORS, T,                     device=DEVICE)
        val_buf  = torch.zeros(NB_ACTORS, T+1,                   device=DEVICE)
        rew_buf  = torch.zeros(NB_ACTORS, T,                     device=DEVICE)
        done_buf = torch.zeros(NB_ACTORS, T,                     device=DEVICE)
        adv_buf  = torch.zeros(NB_ACTORS, T,                     device=DEVICE)

        # ─── Rollout ─────────────────────────────────────────────────────────────
        with torch.no_grad():
            for t in range(T):
                for i in range(NB_ACTORS):
                    ts = envs.obs[i]
                    state = preprocess(ts)
                    logits, value = model(state)

                    # mask illegal actions
                    legal = legal_actions(ts)
                    mask = torch.full_like(logits, float('-inf'))
                    mask[0, legal] = 0
                    dist = Categorical(logits=logits + mask)

                    action = dist.sample()
                    logp   = dist.log_prob(action)

                    fn_call, _ = make_pysc2_call(action.item(), ts)
                    try:
                        ts2 = envs.step(i, fn_call)
                    except ValueError:
                        ts2 = envs.step(i,
                            actions.FunctionCall(actions.FUNCTIONS.no_op.id, [])
                        )

                    r = ts2.reward
                    d = float(ts2.last())

                    obs_buf[i,t]  = state
                    act_buf[i,t]  = action
                    logp_buf[i,t] = logp
                    val_buf[i,t]  = value
                    rew_buf[i,t]  = r
                    done_buf[i,t] = d

                    if d:
                        total = rew_buf[i,:t+1].sum().item()
                        episode_rewards.append(total)
                        envs.reset(i)

            # bootstrap value for last state
            for i in range(NB_ACTORS):
                last_state = preprocess(envs.obs[i])
                val_buf[i,T] = model(last_state)[1]

        # ─── Compute GAE and advantages ──────────────────────────────────────────
        for i in range(NB_ACTORS):
            gae = 0
            for t in reversed(range(T)):
                mask = 1.0 - done_buf[i,t]
                delta = rew_buf[i,t] + GAMMA * val_buf[i,t+1] * mask - val_buf[i,t]
                gae = delta + GAMMA * GAE_LAMBDA * mask * gae
                adv_buf[i,t] = gae

        # flatten batches
        b_s  = obs_buf.reshape(-1, 2, SCREEN_SIZE, SCREEN_SIZE)
        b_a  = act_buf.reshape(-1)
        b_lp = logp_buf.reshape(-1)
        b_v  = val_buf[:,:T].reshape(-1)
        b_ad = adv_buf.reshape(-1)

        # ─── PPO update ─────────────────────────────────────────────────────────
        for _ in range(K):
            dataset = TensorDataset(b_s, b_a, b_lp, b_v, b_ad)
            loader  = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
            for st, ac, old_lp, old_v, adv in loader:
                logits, value = model(st)
                dist = Categorical(logits=logits)

                lp = dist.log_prob(ac)
                ratio = torch.exp(lp - old_lp)

                clip = 0.1 * (1 - it / MAX_ITERS)
                obj1 = adv * ratio
                obj2 = adv * torch.clamp(ratio, 1-clip, 1+clip)
                p_loss = -torch.min(obj1, obj2).mean()

                ret    = adv + old_v
                v1     = (value - ret).pow(2)
                v2     = (torch.clamp(value, old_v-clip, old_v+clip) - ret).pow(2)
                v_loss = 0.5 * torch.max(v1, v2).mean()

                entropy = dist.entropy().mean()
                loss = p_loss + VF_COEF * v_loss - ENT_COEF * entropy

                optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(model.parameters(), 0.5)
                optimizer.step()

        scheduler.step()

    # ─── Save final replays ────────────────────────────────────────────────────
    os.makedirs(REPLAY_DIR, exist_ok=True)
    for i in range(NB_ACTORS):
        path = os.path.join(REPLAY_DIR, f"ppo_final_{i}.SC2Replay")
        envs.envs[i]._save_replay("PPO", path)
    logger.info("💾 Saved final replay(s) to %s", REPLAY_DIR)

    # ─── Plot learning curve ───────────────────────────────────────────────────
    plt.figure(figsize=(10,5))
    plt.plot(episode_rewards, label="Episode reward")
    plt.xlabel("Episode")
    plt.ylabel("Reward")
    plt.title("PPO Learning Curve")
    plt.grid(True)
    plt.legend()
    plt.savefig("learning_curve.png")
    plt.show()

    envs.close()
    logger.info("✅ PPO training complete over %d episodes", len(episode_rewards))


In [None]:
import os
import random
import sys
import torch
from collections import deque

from rich.live import Live
from rich.table import Table
from rich.console import Console
import matplotlib.pyplot as plt
from absl import flags
from pysc2.lib import actions as sc2_actions

# from environment import SC2EnvsMulti
# from utils      import preprocess, legal_actions, make_pysc2_call, safe_coords
# from model      import ActorCritic        # <— your model class
# from config     import (
#     NUM_EPISODES, NB_ACTORS, SCREEN_SIZE,
#     LR, REPLAY_DIR, REPLAY_PREFIX
# )

NUM_EPISODES  = 200
NB_ACTORS     = 1
REPLAY_DIR    = "replays"
REPLAY_PREFIX = "pysc2_run"


# ─── Fix flags ────────────────────────────────────────────────────────────────
flags.FLAGS(sys.argv, known_only=True)

# ─── Lookup Function ID → Name ───────────────────────────────────────────────
FUNC_ID_TO_NAME = {f.id: f.name for f in sc2_actions.FUNCTIONS}

# ─── Checkpoint setup ─────────────────────────────────────────────────────────
CHECKPOINT_PATH = "ppo_checkpoint.pth"

NUM_ACTIONS = len(ACTION_LIST)  # = 8 if you added 'select' + 7 others

model = ActorCritic(in_channels=2, nb_actions=NUM_ACTIONS).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)

if os.path.exists(CHECKPOINT_PATH):
    print("▶️  Loading checkpoint…")
    ckpt = torch.load(CHECKPOINT_PATH, map_location="cpu")
    model.load_state_dict(ckpt["model_state"])
    optimizer.load_state_dict(ckpt["opt_state"])
    start_ep = ckpt["episode"] + 1
else:
    start_ep = 1

# ─── Environment & UI setup ──────────────────────────────────────────────────
console        = Console()
envs           = SC2EnvsMulti(
    nb_actor=NB_ACTORS,
    replay_dir=REPLAY_DIR,
    replay_prefix=REPLAY_PREFIX,
)
pending_action = [None] * NB_ACTORS
recent_rows    = deque(maxlen=20)
scores         = []

def generate_table():
    table = Table(title=f"SC2 Agent Actions (Last {len(recent_rows)} Steps)", expand=True)
    table.add_column("Step",        justify="right")
    table.add_column("Func ID",     justify="right")
    table.add_column("Action Name", justify="left")
    table.add_column("Args",        justify="left")
    for row in recent_rows:
        table.add_row(*row)
    return table

from tqdm import tqdm


# ─── Main Episode Loop ────────────────────────────────────────────────────────
# with Live(generate_table(), refresh_per_second=10, console=console, transient=True) as live:
from tqdm import tqdm

# ─── Main Episode Loop with TQDM & Static Table ─────────────────────────────
pending_action = [None] * NB_ACTORS
recent_rows    = deque(maxlen=20)
scores         = []

for ep in tqdm(range(start_ep, NUM_EPISODES + 1), desc="Episodes", unit="ep"):
    # reset envs & trackers
    for i in range(NB_ACTORS):
        envs.reset(i)
    episode_score = [0] * NB_ACTORS
    step = 0
    console.log(f"[blue]=== Episode {ep} ===[/blue]")

    # clear the last-20 buffer
    recent_rows.clear()

    while True:
        for i in range(NB_ACTORS):
            ts = envs.obs[i]

            # 1) Fire pending or sample new
            if pending_action[i]:
                fn_call, pending_action[i] = make_pysc2_call(None, ts, pending_action[i])
            else:
                idx = random.choice(legal_actions(ts))
                fn_call, pending_action[i] = make_pysc2_call(idx, ts)

            # 2) Step env & collect reward
            ts2 = envs.step(i, fn_call)
            episode_score[i] += ts2.reward

            # 3) Log into your deque
            fid      = str(fn_call.function)
            fname    = FUNC_ID_TO_NAME.get(fn_call.function, "UNKNOWN")
            args_str = str(fn_call.arguments)
            recent_rows.append((str(step), fid, fname, args_str))

        step += 1

        # ─── Save checkpoint at end of episode ────────────────────────────────
        torch.save({
            "episode":     ep,
            "model_state": model.state_dict(),
            "opt_state":   optimizer.state_dict(),
        }, CHECKPOINT_PATH)
        console.log(f"[yellow]Checkpoint saved at episode {ep}[/yellow]")

    # 4) Print your last-20 actions table (static)
    console.print(generate_table())

    # 5) Save checkpoint
    torch.save({
        "episode":     ep,
        "model_state": model.state_dict(),
        "opt_state":   optimizer.state_dict(),
    }, CHECKPOINT_PATH)
    console.log(f"[yellow]Checkpoint saved at episode {ep}[/yellow]")


# ─── Cleanup & Plot ─────────────────────────────────────────────────────────
envs.close()
plt.figure(figsize=(8,4))
plt.plot(scores, marker="o")
plt.xlabel("Episode")
plt.ylabel("Score")
plt.title("Agent Performance per Episode")
plt.grid(True)
plt.tight_layout()
plt.show()


22:38:06 [INFO] Initializing 1 SC2 env(s)…
22:38:06 [INFO] Launching SC2: D:\Games\StarCraft II\Versions/Base94137\SC2_x64.exe -listen 127.0.0.1 -port 54188 -dataDir D:\Games\StarCraft II\ -tempDir C:\Users\svarp\AppData\Local\Temp\sc-pquu030c\ -displayMode 0 -windowwidth 640 -windowheight 480 -windowx 50 -windowy 50
22:38:06 [INFO] Connecting to: ws://127.0.0.1:54188/sc2api, attempt: 0, running: True


▶️  Loading checkpoint…


22:38:09 [INFO] Connecting to: ws://127.0.0.1:54188/sc2api, attempt: 1, running: True
