In [None]:
import torch
import torch.nn as nn
from sit_tiny import SietTiny

class TinySieTPolicy(nn.Module):
    def __init__(self, in_ch, num_actions, h, w,
                 img_size=64, embed_dim=32, depth=1, num_heads=8,
                 patch_size=8, patch_size_local=8, qkv_bias=True):
        super().__init__()
        assert h == img_size and w == img_size, "SietTiny img_size must match observation H and W"
        self.backbone = SietTiny(
            img_size=img_size,
            patch_size=patch_size,
            patch_size_local=patch_size_local,
            in_chans=in_ch,
            embed_dim=embed_dim,
            depth=depth,
            num_heads=num_heads,
            qkv_bias=qkv_bias
        )
        with torch.no_grad():
            dummy = torch.zeros(1, in_ch, h, w)
            z = self.backbone(dummy)
            feat_dim = z.shape[-1]
        self.pi = nn.Linear(feat_dim, num_actions)
        self.v = nn.Linear(feat_dim, 1)

    def forward(self, x):
        z = self.backbone(x)
        return self.pi(z), self.v(z).squeeze(-1)


In [14]:
import os, math, random, time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from procgen import ProcgenEnv

class OrthogonalInit:
    def __call__(self, m, gain=1.0):
        if isinstance(m, nn.Linear):
            nn.init.orthogonal_(m.weight, gain=gain)
            nn.init.constant_(m.bias, 0.0)
        if isinstance(m, nn.Conv2d):
            nn.init.orthogonal_(m.weight, gain=gain)
            nn.init.constant_(m.bias, 0.0)

class CNNPolicy(nn.Module):
    def __init__(self, obs_shape, n_actions):
        super().__init__()
        c = obs_shape[0]
        self.net = nn.Sequential(
            nn.Conv2d(c, 32, 3, 2, 1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, 2, 1),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, 2, 1),
            nn.ReLU(),
            nn.Flatten()
        )
        with torch.no_grad():
            dummy = torch.zeros(1, *obs_shape)
            feat_dim = self.net(dummy).shape[1]
        self.mlp = nn.Sequential(
            nn.Linear(feat_dim, 256),
            nn.ReLU()
        )
        self.policy = nn.Linear(256, n_actions)
        self.value = nn.Linear(256, 1)
        OrthogonalInit()(self.net)
        OrthogonalInit()(self.mlp, gain=1.0)
        OrthogonalInit()(self.policy, gain=0.01)
        OrthogonalInit()(self.value, gain=1.0)

    def forward(self, x):
        x = x.float() / 255.0
        z = self.mlp(self.net(x))
        return self.policy(z), self.value(z)

    def act(self, x):
        logits, v = self(x)
        dist = torch.distributions.Categorical(logits=logits)
        a = dist.sample()
        return a, dist.log_prob(a), dist.entropy(), v.squeeze(-1)

def make_env(num_envs, env_name="coinrun", num_levels=0, start_level=0, distribution_mode="easy"):
    env = ProcgenEnv(num_envs=num_envs, env_name=env_name, num_levels=num_levels, start_level=start_level, distribution_mode=distribution_mode)
    return env

def to_torch(x, device):
    if isinstance(x, dict):
        x = x["rgb"]
    x = torch.from_numpy(x).permute(0, 3, 1, 2).to(device)
    return x

def ppo_train(env_name="coinrun", total_steps=2_000_000, num_envs=64, rollout_len=256, gamma=0.999, lam=0.95, clip_eps=0.2, vf_coef=0.5, ent_coef=0.01, lr=2.5e-4, minibatch_size=8192, update_epochs=3, device="cuda"):
    env = make_env(num_envs=num_envs, env_name=env_name)
    obs = env.reset()
    obs_t = to_torch(obs, device)
    n_actions = env.num_actions
    obs_shape = obs_t.shape[1:]
    policy = CNNPolicy(obs_shape, n_actions).to(device)
    opt = optim.Adam(policy.parameters(), lr=lr, eps=1e-5)
    steps = 0
    rollout_obs = torch.zeros(rollout_len, num_envs, *obs_shape, device=device, dtype=torch.uint8)
    rollout_actions = torch.zeros(rollout_len, num_envs, device=device, dtype=torch.long)
    rollout_logprobs = torch.zeros(rollout_len, num_envs, device=device)
    rollout_values = torch.zeros(rollout_len, num_envs, device=device)
    rollout_rewards = torch.zeros(rollout_len, num_envs, device=device)
    rollout_dones = torch.zeros(rollout_len, num_envs, device=device)

    while steps < total_steps:
        for t in range(rollout_len):
            rollout_obs[t] = to_torch(obs, device)
            with torch.no_grad():
                a, lp, ent, v = policy.act(rollout_obs[t])
            rollout_actions[t] = a
            rollout_logprobs[t] = lp
            rollout_values[t] = v
            obs, rew, done, info = env.step(a.cpu().numpy())
            rollout_rewards[t] = torch.from_numpy(rew).to(device)
            rollout_dones[t] = torch.from_numpy(done.astype(np.float32)).to(device)
        with torch.no_grad():
            next_value = policy(rollout_obs[-1].float() / 255.0)[1].squeeze(-1)
        adv = torch.zeros(rollout_len, num_envs, device=device)
        gae = torch.zeros(num_envs, device=device)
        returns = torch.zeros(rollout_len, num_envs, device=device)
        for t in reversed(range(rollout_len)):
            nonterminal = 1.0 - rollout_dones[t]
            delta = rollout_rewards[t] + gamma * (next_value if t == rollout_len - 1 else rollout_values[t + 1]) * nonterminal - rollout_values[t]
            gae = delta + gamma * lam * nonterminal * gae
            adv[t] = gae
            returns[t] = adv[t] + rollout_values[t]
        b_obs = rollout_obs.reshape(rollout_len * num_envs, *obs_shape).float() / 255.0
        b_actions = rollout_actions.reshape(-1)
        b_logprobs = rollout_logprobs.reshape(-1)
        b_values = rollout_values.reshape(-1)
        b_adv = adv.reshape(-1)
        b_returns = returns.reshape(-1)
        b_adv = (b_adv - b_adv.mean()) / (b_adv.std(unbiased=False) + 1e-8)
        num_batch = b_obs.shape[0]
        for _ in range(update_epochs):
            idx = torch.randperm(num_batch, device=device)
            for start in range(0, num_batch, minibatch_size):
                end = start + minibatch_size
                mb = idx[start:end]
                logits, v = policy(b_obs[mb])
                dist = torch.distributions.Categorical(logits=logits)
                lp = dist.log_prob(b_actions[mb])
                ratio = torch.exp(lp - b_logprobs[mb])
                surr1 = ratio * b_adv[mb]
                surr2 = torch.clamp(ratio, 1.0 - clip_eps, 1.0 + clip_eps) * b_adv[mb]
                pg_loss = -torch.min(surr1, surr2).mean()
                v_pred = v.squeeze(-1)
                v_clipped = b_values[mb] + torch.clamp(v_pred - b_values[mb], -clip_eps, clip_eps)
                vf_loss1 = (v_pred - b_returns[mb]).pow(2)
                vf_loss2 = (v_clipped - b_returns[mb]).pow(2)
                vf_loss = 0.5 * torch.max(vf_loss1, vf_loss2).mean()
                ent = dist.entropy().mean()
                loss = pg_loss + vf_coef * vf_loss - ent_coef * ent
                opt.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(policy.parameters(), 0.5)
                opt.step()
        steps += rollout_len * num_envs
        obs = env.reset()

def main():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    ppo_train(env_name="coinrun", total_steps=1_000_000, num_envs=64, rollout_len=256, device=device)

if __name__ == "__main__":
    main()


AttributeError: 'ToBaselinesVecEnv' object has no attribute 'num_actions'