In [101]:
import torch
from torch import nn
import numpy as np
from einops import rearrange
from torch.distributions import Categorical

In [102]:
class StackedCNNAgent(nn.Module):
    def __init__(self, n_acts=18, ctx_len=4):
        super().__init__()
        self.network = nn.Sequential(
            layer_init(nn.Conv2d(4, 32, 8, stride=4)),
            nn.ReLU(),
            layer_init(nn.Conv2d(32, 64, 4, stride=2)),
            nn.ReLU(),
            layer_init(nn.Conv2d(64, 64, 3, stride=1)),
            nn.ReLU(),
            nn.Flatten(),
            layer_init(nn.Linear(64 * 7 * 7, 512)),
            nn.ReLU(),
        )
        self.actor = layer_init(nn.Linear(512, n_acts), std=0.01)
        self.critic = layer_init(nn.Linear(512, 1), std=1)
        self.train_per_token = False
        self.ctx_len = ctx_len

    def forward(self, done, obs, act, rew):
        obs = rearrange(obs, "b t c h w -> b (t c) h w")
        hidden = self.network(obs / 255.0)
        logits, val = self.actor(hidden), self.critic(hidden)[:, 0]
        logits, val = rearrange(logits, "b a -> b 1 a"), rearrange(val, "b -> b 1")
        return logits, val

In [103]:
def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer
class Agent(nn.Module):
    def __init__(self):
        super().__init__()
        self.network = nn.Sequential(
            layer_init(nn.Conv2d(4, 32, 8, stride=4)),
            nn.ReLU(),
            layer_init(nn.Conv2d(32, 64, 4, stride=2)),
            nn.ReLU(),
            layer_init(nn.Conv2d(64, 64, 3, stride=1)),
            nn.ReLU(),
            nn.Flatten(),
            layer_init(nn.Linear(64 * 7 * 7, 512)),
            nn.ReLU(),
        )
        self.actor = layer_init(nn.Linear(512, 18), std=0.01)
        self.critic = layer_init(nn.Linear(512, 1), std=1)

    def forward(self, x, action=None):
        hidden = self.network(x / 255.0)
        logits = self.actor(hidden)
        probs = Categorical(logits=logits)
        if action is None:
            action = probs.sample()
        return logits, self.critic(hidden)

In [104]:
device = 'cuda'

In [105]:
agent1 = Agent().to(device)
agent2 = StackedCNNAgent().to(device)

In [106]:
%%timeit
x = torch.randn(256, 4, 84, 84, device=device)
l, _ = agent1(x);
l.mean().item();

1.65 ms ± 1.3 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [107]:
%%timeit
x = torch.randn(256, 4, 1, 84, 84, device=device)
l, _ = agent2(None, x, None, None);
l.mean().item();

1.24 ms ± 828 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
