In [1]:
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from procgen import ProcgenEnv
import matplotlib.pyplot as plt

# Reuse SiT (SietTiny) as the shared encoder for PPO
# Assumes SietTiny is defined in sit_model.py (the file you shared)
from SiT_model import SietTiny

class DummyAugment:
    def __call__(self, obs: torch.Tensor) -> torch.Tensor:
        return obs

class SiTPPO(nn.Module):
    def __init__(self, action_space_n: int, img_size=64, device="cuda"):
        super().__init__()
        self.device = torch.device(device if torch.cuda.is_available() else "cpu")
        self.encoder = SietTiny(img_size=img_size, patch_size=8, patch_size_local=8, in_chans=3,
                                embed_dim=32, depth=1, num_heads=8)
        # SiT encoder returns a flat embedding vector per obs; infer dim with a dummy
        with torch.no_grad():
            dummy = torch.zeros(1, 3, img_size, img_size)
            enc_dim = self.encoder(dummy).shape[-1]
        self.pi = nn.Linear(enc_dim, action_space_n)
        self.v = nn.Linear(enc_dim, 1)
        self.augment = DummyAugment()
        self.to(self.device)

    def encode(self, obs_np: np.ndarray) -> torch.Tensor:
        x = torch.from_numpy(obs_np).float().to(self.device) / 255.0
        x = x.permute(0, 3, 1, 2)
        x = self.augment(x)
        return self.encoder(x)

    def policy_value(self, obs_np: np.ndarray):
        z = self.encode(obs_np)
        logits = self.pi(z)
        value = self.v(z).squeeze(-1)
        return logits, value

def make_env(num_envs=8, env_name="coinrun", num_levels=0, start_level=0, render=False):
    env = ProcgenEnv(num_envs=num_envs, env_name=env_name, start_level=start_level,
                     num_levels=num_levels, distribution_mode="easy")
    return env

def ppo_train():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    num_envs = 16
    num_steps = 256
    total_updates = 800
    gamma = 0.999
    gae_lambda = 0.95
    clip_coef = 0.2
    ent_coef = 0.01
    vf_coef = 0.5
    max_grad_norm = 0.5
    lr = 2.5e-4
    minibatch_size = 2048
    env_name = "coinrun"
    img_size = 64

    env = make_env(num_envs=num_envs, env_name=env_name)
    obs_space = env.observation_space
    act_space = env.action_space
    assert isinstance(act_space, gym.spaces.Discrete)
    agent = SiTPPO(act_space.n, img_size=img_size, device=device)

    optimizer = optim.Adam(agent.parameters(), lr=lr, eps=1e-5)

    obs = env.reset()
    ep_returns = np.zeros(num_envs, dtype=np.float32)
    logged_returns = []
    losses = []

    for update in range(total_updates):
        obs_buf = np.zeros((num_steps, num_envs, img_size, img_size, 3), dtype=np.uint8)
        act_buf = np.zeros((num_steps, num_envs), dtype=np.int64)
        logp_buf = np.zeros((num_steps, num_envs), dtype=np.float32)
        rew_buf = np.zeros((num_steps, num_envs), dtype=np.float32)
        val_buf = np.zeros((num_steps, num_envs), dtype=np.float32)
        done_buf = np.zeros((num_steps, num_envs), dtype=np.bool_)

        for t in range(num_steps):
            logits, value = agent.policy_value(obs)
            dist = torch.distributions.Categorical(logits=logits)
            action = dist.sample()
            logp = dist.log_prob(action)

            obs_buf[t] = obs
            act_buf[t] = action.cpu().numpy()
            logp_buf[t] = logp.detach().cpu().numpy()
            val_buf[t] = value.detach().cpu().numpy()

            obs, rew, done, info = env.step(action.cpu().numpy())
            rew_buf[t] = rew
            done_buf[t] = done

            ep_returns += rew
            for i, d in enumerate(done):
                if d:
                    logged_returns.append(ep_returns[i])
                    ep_returns[i] = 0.0

        with torch.no_grad():
            next_logits, next_value = agent.policy_value(obs)
            next_value = next_value.cpu().numpy()

        adv_buf = np.zeros_like(rew_buf)
        lastgaelam = 0
        for t in reversed(range(num_steps)):
            next_nonterminal = 1.0 - done_buf[t].astype(np.float32)
            next_values = next_value if t == num_steps - 1 else val_buf[t + 1]
            delta = rew_buf[t] + gamma * next_values * next_nonterminal - val_buf[t]
            lastgaelam = delta + gamma * gae_lambda * next_nonterminal * lastgaelam
            adv_buf[t] = lastgaelam
        ret_buf = adv_buf + val_buf
        adv_flat = adv_buf.reshape(-1)
        adv_flat = (adv_flat - adv_flat.mean()) / (adv_flat.std() + 1e-8)

        b_obs = obs_buf.reshape(num_steps * num_envs, img_size, img_size, 3)
        b_act = act_buf.reshape(-1)
        b_logp = logp_buf.reshape(-1)
        b_adv = adv_flat
        b_ret = ret_buf.reshape(-1)
        b_val = val_buf.reshape(-1)

        inds = np.arange(b_obs.shape[0])
        np.random.shuffle(inds)

        total_loss_epoch = 0.0
        for start in range(0, len(inds), minibatch_size):
            mb_inds = inds[start:start + minibatch_size]
            mb_obs = b_obs[mb_inds]
            mb_act = torch.from_numpy(b_act[mb_inds]).to(agent.device)
            mb_adv = torch.from_numpy(b_adv[mb_inds]).float().to(agent.device)
            mb_ret = torch.from_numpy(b_ret[mb_inds]).float().to(agent.device)
            mb_logp_old = torch.from_numpy(b_logp[mb_inds]).float().to(agent.device)

            logits, value = agent.policy_value(mb_obs)
            dist = torch.distributions.Categorical(logits=logits)
            logp = dist.log_prob(mb_act)
            entropy = dist.entropy().mean()

            ratio = torch.exp(logp - mb_logp_old)
            surr1 = ratio * mb_adv
            surr2 = torch.clamp(ratio, 1.0 - clip_coef, 1.0 + clip_coef) * mb_adv
            pg_loss = -torch.min(surr1, surr2).mean()

            v_loss = 0.5 * (mb_ret - value).pow(2).mean()
            loss = pg_loss + vf_coef * v_loss - ent_coef * entropy

            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            nn.utils.clip_grad_norm_(agent.parameters(), max_grad_norm)
            optimizer.step()

            total_loss_epoch += loss.item()

        avg_loss = total_loss_epoch / max(1, len(inds) // minibatch_size)
        losses.append(avg_loss)

        if (update + 1) % 10 == 0:
            print(f"Update {update+1}/{total_updates} | AvgLoss {avg_loss:.4f} | MeanReturn {np.mean(logged_returns[-100:]) if logged_returns else 0:.2f}")

    env.close()

    plt.figure(figsize=(8,4))
    plt.plot(losses, label="Training Loss")
    if len(logged_returns) > 0:
        window = 50
        returns_smoothed = np.convolve(logged_returns, np.ones(window)/window, mode="valid")
        plt.plot(np.linspace(0, len(losses), len(returns_smoothed)), returns_smoothed, label=f"Return MA({window})")
    plt.xlabel("Update")
    plt.ylabel("Metric")
    plt.legend()
    plt.title("PPO with SiT on Procgen")
    plt.tight_layout()
    plt.show()

if __name__ == "__main__":
    ppo_train()


Gym has been unmaintained since 2022 and does not support NumPy 2.0 amongst other critical functionality.
Please upgrade to Gymnasium, the maintained drop-in replacement of Gym, or contact the authors of your software and request that they upgrade.
Users of this version of Gym should be able to simply replace 'import gym' with 'import gymnasium as gym' in the vast majority of cases.
See the migration guide at https://gymnasium.farama.org/introduction/migration_guide/ for additional information.


RuntimeError: shape '[64, 16, 4, 8]' is invalid for input of size 8192