In [65]:
from torch import nn
from SiT_model import SietTiny


class SietBackbone(nn.Module):
    def __init__(self, in_chans=4, img_size=84, patch_size=6, patch_size_local=6,
                 embed_dim=32, num_heads=8, depth=1):
        super().__init__()
        self.in_chans = in_chans
        self.img_size = img_size
        self.core = SietTiny(
            img_size=img_size,
            patch_size=patch_size,
            patch_size_local=patch_size_local,
            in_chans=in_chans,
            embed_dim=embed_dim,
            depth=depth,
            num_heads=num_heads,
            qkv_bias=True
        )

    def forward(self, x):
        return self.core(x)


In [66]:
import torch
import random
import numpy as np
from tqdm import tqdm
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, TensorDataset

from procgen import ProcgenEnv
import cv2
from collections import deque


class ActorCritic(nn.Module):
    def __init__(self, nb_actions):
        super().__init__()
        self.head = nn.Sequential(
            nn.Conv2d(4, 16, 8, stride=4), nn.Tanh(),
            nn.Conv2d(16, 32, 4, stride=2), nn.Tanh(),
            nn.Flatten(), nn.Linear(2592, 256), nn.Tanh(),
        )
        self.actor = nn.Sequential(nn.Linear(256, nb_actions))
        self.critic = nn.Sequential(nn.Linear(256, 1),)

    def forward(self, x):
        h = self.head(x)
        return self.actor(h), self.critic(h)


class SimpleProcgenWrapper:
    def __init__(self, game_name="miner", num_envs=1, distribution_mode="easy",  skip=4, frame_stack=4, resize=(84, 84), grayscale=True, start_level=0):
        self.env = ProcgenEnv(num_envs=num_envs, env_name=game_name, distribution_mode=distribution_mode, start_level=start_level)
        self.env.reset()
        self.num_envs = num_envs
        self.skip = skip
        self.frame_stack = frame_stack
        self.resize = resize
        self.grayscale = grayscale

        # action_space.n equivalent
        self.n_actions = self.env.action_space.n

        # per env state
        self.frame_buffers = [deque(maxlen=self.frame_stack) for _ in range(num_envs)]
        self.last_obs = [None for _ in range(num_envs)]
        self.dones = [False for _ in range(num_envs)]
        self.total_rewards = [0.0 for _ in range(num_envs)]
        self.current_life = [1 for _ in range(num_envs)]  # emulate lives

        # initialize observations with zeros
        for i in range(num_envs):
            obs_i = self._reset_single(i)
            self.last_obs[i] = obs_i

    def _proc_obs(self, obs):
        # obs comes as HWC RGB uint8
        if self.grayscale:
            obs = cv2.cvtColor(obs, cv2.COLOR_RGB2GRAY)
        if self.resize is not None:
            if self.grayscale:
                obs = cv2.resize(obs, self.resize, interpolation=cv2.INTER_AREA)
            else:
                obs = cv2.resize(obs, self.resize, interpolation=cv2.INTER_AREA)
        if self.grayscale:
            # shape (H, W) -> add channel
            obs = np.expand_dims(obs, axis=0)
        else:
            # HWC -> CHW
            obs = np.transpose(obs, (2, 0, 1))
        return obs.astype(np.uint8)

    def _ensure_stack(self, env_id, obs_proc):
        # Fill frame stack with the current processed frame
        while len(self.frame_buffers[env_id]) < self.frame_stack:
            self.frame_buffers[env_id].append(obs_proc)

    def _get_stacked_obs(self, env_id):
        # Stack into (C=frame_stack, H, W)
        return np.stack(list(self.frame_buffers[env_id]), axis=0)

    def _reset_single(self, env_id):
        # Procgen reset returns batched obs when num_envs > 1
        obs = self.env.reset()
        # obs shape: (num_envs, H, W, C)
        obs_i = obs[env_id]
        obs_proc = self._proc_obs(obs_i)
        self.frame_buffers[env_id].clear()
        self._ensure_stack(env_id, obs_proc)
        return self._get_stacked_obs(env_id)

    def reset(self):
        # Reset all envs
        obs = self.env.reset()
        for i in range(self.num_envs):
            obs_i = obs[i]
            obs_proc = self._proc_obs(obs_i)
            self.frame_buffers[i].clear()
            self._ensure_stack(i, obs_proc)
            self.dones[i] = False
            self.total_rewards[i] = 0.0
            self.current_life[i] = 1
            self.last_obs[i] = self._get_stacked_obs(i)
        return [self.last_obs[i] for i in range(self.num_envs)]

    def step(self, env_id, action):
        # MaxAndSkip: take the same action skip times, accumulate reward, last obs
        total_reward = 0.0
        dead = False
        done = False
        info = {}

        a = int(action.item()) if hasattr(action, "item") else int(action)
        for _ in range(self.skip):
            obs, rewards, dones, infos = self.env.step(np.array([a] * self.num_envs))
            # pick env_id
            obs_i = obs[env_id]
            r_i = float(rewards[env_id])
            d_i = bool(dones[env_id])
            info_i = infos[env_id] if isinstance(infos, (list, tuple)) else {}

            total_reward += r_i
            done = d_i
            info = info_i

            obs_proc = self._proc_obs(obs_i)
            self.frame_buffers[env_id].append(obs_proc)

            if done:
                dead = True
                break

        # emulate lives info
        info['lives'] = 0 if done else 1

        next_obs = self._get_stacked_obs(env_id)
        self.total_rewards[env_id] += total_reward
        self.current_life[env_id] = info['lives']
        self.last_obs[env_id] = next_obs
        self.dones[env_id] = done

        return next_obs, total_reward, dead, done, info

    @property
    def action_space(self):
        class _A:
            def __init__(self, n): self.n = n
        return _A(self.n_actions)


class Environments():
    def __init__(self, nb_actor):
        # Choose a procgen game. Using "miner" to keep discrete small action space.
        self.envs = [SimpleProcgenWrapper(game_name="miner", num_envs=1, distribution_mode="easy", skip=4, frame_stack=4, resize=(84, 84), grayscale=True) for _ in range(nb_actor)]
        self.observations = [None for _ in range(nb_actor)]
        self.current_life = [None for _ in range(nb_actor)]
        self.done = [False for _ in range(nb_actor)]
        self.total_rewards = [0 for _ in range(nb_actor)]
        self.nb_actor = nb_actor

        for env_id in range(nb_actor):
            self.reset_env(env_id)

    def len(self):
        return self.nb_actor

    def reset_env(self, env_id):
        self.total_rewards[env_id] = 0
        self.envs[env_id].reset()

        # keep random noops by repeating a neutral action if exists
        # in procgen, action 0 is usually "noop"
        noop_action = 0
        for _ in range(random.randint(1, 30)):
            self.observations[env_id], reward, _, info_done, info = self.envs[env_id].step(noop_action)
            self.total_rewards[env_id] += reward
            self.current_life[env_id] = info['lives']
            self.done[env_id] = info_done

    def step(self, env_id, action):
        next_obs, reward, dead, done, info = self.envs[env_id].step(action)
        self.done[env_id] = done
        self.total_rewards[env_id] += reward
        self.current_life[env_id] = info['lives']
        self.observations[env_id] = next_obs
        return next_obs, reward, dead, done, info

    def get_env(self):
        # not used anymore, kept for compatibility
        return self.envs[0]


def PPO(envs, T=128, K=3, batch_size=32*8, gamma=0.99, device='cuda', gae_parameter=0.95,
        vf_coeff_c1=1, ent_coef_c2=0.01, nb_iterations=40_000):

        optimizer = torch.optim.Adam(actorcritic.parameters(), lr=2.5e-4)
        scheduler = torch.optim.lr_scheduler.LinearLR(
            optimizer, start_factor=1., end_factor=0.0, total_iters=nb_iterations)

        max_reward = 0
        total_rewards = [[] for _ in range(envs.len())]
        smoothed_rewards = [[] for _ in range(envs.len())]

        for iteration in tqdm(range(nb_iterations)):
            advantages = torch.zeros((envs.len(), T), dtype=torch.float32, device=device)
            buffer_states = torch.zeros((envs.len(), T, 4, 84, 84), dtype=torch.float32, device=device)
            buffer_actions = torch.zeros((envs.len(), T), dtype=torch.long, device=device)
            buffer_logprobs = torch.zeros((envs.len(), T), dtype=torch.float32, device=device)
            buffer_state_values = torch.zeros((envs.len(), T+1), dtype=torch.float32, device=device)
            buffer_rewards = torch.zeros((envs.len(), T), dtype=torch.float32, device=device)
            buffer_is_terminal = torch.zeros((envs.len(), T), dtype=torch.float16, device=device)

            for env_id in range(envs.len()):
                with torch.no_grad():
                    for t in range(T):
                        obs = torch.from_numpy(envs.observations[env_id] / 255.).unsqueeze(0).float().to(device)
                        logits, value = actorcritic(obs)
                        logits, value = logits.squeeze(0), value.squeeze(0)
                        m = torch.distributions.categorical.Categorical(logits=logits)

                        if envs.done[env_id]:
                            action = torch.tensor([0]).to(device)  # noop on done to advance reset path
                        else:
                            action = m.sample()

                        log_prob = m.log_prob(action)
                        _, reward, dead, done, _ = envs.step(env_id, action)
                        reward = np.sign(reward)

                        buffer_states[env_id, t] = obs
                        buffer_actions[env_id, t] = torch.tensor([action]).to(device)
                        buffer_logprobs[env_id, t] = log_prob
                        buffer_state_values[env_id, t] = value
                        buffer_rewards[env_id, t] = reward
                        buffer_is_terminal[env_id, t] = done

                        if dead:
                            if envs.total_rewards[env_id] > max_reward:
                                max_reward = envs.total_rewards[env_id]
                                torch.save(actorcritic.cpu(), f"actorcritic_{max_reward}")
                                actorcritic.to(device)

                            total_rewards[env_id].append(envs.total_rewards[env_id])
                            envs.reset_env(env_id)

                    buffer_state_values[env_id, T] = actorcritic(
                        torch.from_numpy(envs.observations[env_id] / 255.).unsqueeze(0).float().to(device))[1].squeeze(0)

                    for t in range(T-1, -1, -1):
                        next_non_terminal = 1.0 - buffer_is_terminal[env_id, t]
                        delta_t = buffer_rewards[env_id, t] + gamma * buffer_state_values[
                            env_id, t+1] * next_non_terminal - buffer_state_values[env_id, t]
                        if t == (T-1):
                            A_t = delta_t
                        else:
                            A_t = delta_t + gamma * gae_parameter * advantages[env_id, t+1] * next_non_terminal
                        advantages[env_id, t] = A_t

            if (iteration % 400 == 0) and iteration > 0:
                for env_id in range(envs.len()):
                    smoothed_rewards[env_id].append(np.mean(total_rewards[env_id]) if len(total_rewards[env_id]) > 0 else 0.0)
                    plt.plot(smoothed_rewards[env_id])
                total_rewards = [[] for _ in range(envs.len())]
                plt.title("Average Reward on Procgen")
                plt.xlabel("Training Epochs")
                plt.ylabel("Average Reward per Episode")
                plt.savefig('average_reward_on_procgen.png')
                plt.close()

            for epoch in range(K):
                advantages_data_loader = DataLoader(
                    TensorDataset(
                        advantages.reshape(advantages.shape[0] * advantages.shape[1]),
                        buffer_states.reshape(-1, buffer_states.shape[2], buffer_states.shape[3], buffer_states.shape[4]),
                        buffer_actions.reshape(-1),
                        buffer_logprobs.reshape(-1),
                        buffer_state_values[:, :T].reshape(-1),
                    ),
                    batch_size=batch_size, shuffle=True,
                )

                for batch_advantages in advantages_data_loader:
                    b_adv, obs, action_that_was_taken, old_log_prob, old_state_values = batch_advantages

                    logits, value = actorcritic(obs)
                    logits, value = logits.squeeze(0), value.squeeze(-1)
                    m = torch.distributions.categorical.Categorical(logits=logits)
                    log_prob = m.log_prob(action_that_was_taken)
                    ratio = torch.exp(log_prob - old_log_prob)
                    returns = b_adv + old_state_values

                    policy_loss_1 = b_adv * ratio
                    alpha = 1. - iteration / nb_iterations
                    clip_range = 0.1 * alpha
                    policy_loss_2 = b_adv * torch.clamp(ratio, 1 - clip_range, 1 + clip_range)
                    policy_loss = -torch.min(policy_loss_1, policy_loss_2).mean()

                    value_loss1 = F.mse_loss(returns, value, reduction='none')
                    value_loss2 = F.mse_loss(returns, torch.clamp(value, value - clip_range, value + clip_range), reduction='none')
                    value_loss = torch.max(value_loss1, value_loss2).mean()

                    loss = policy_loss + ent_coef_c2 * -(m.entropy()).mean() + vf_coeff_c1 * value_loss

                    optimizer.zero_grad()
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(actorcritic.parameters(), 0.5)
                    optimizer.step()
            scheduler.step()


if __name__ == "__main__":
    device = 'cuda'
    nb_actor = 8
    envs = Environments(nb_actor)
    actorcritic = ActorCritic(envs.envs[0].action_space.n).to(device)
    PPO(envs, device=device)


KeyError: 0

In [61]:

import torch

if __name__ == "__main__":
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    nb_actor = 8
    envs = Environments(nb_actor)

    # backbone = CNNSimple(in_chans=4, img_size=84)  # CNN
    backbone = SietBackbone(in_chans=4, img_size=84)   # or SietTiny

    nb_actions = envs.envs[0].action_space.n

    actorcritic = ActorCritic(backbone=backbone, action_dim=nb_actions).to(device)

    PPO(envs, actorcritic=actorcritic, device=device)

NameNotFound: Environment `procgen-coinrun` doesn't exist.