In [1]:
import importlib

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from gtrxl_torch.gtrxl_torch import GTrXL

In [2]:
model_module = importlib.import_module("handyrl.model")
env_module = importlib.import_module("handyrl.envs.kaggle.hungry_geese")

Loading environment football failed: No module named 'gfootball'


In [3]:
class GeeseNet(model_module.BaseModel):
    def __init__(self, env, args={}):
        super().__init__(env, args)
        d_model = 1232
        filters = 64
        
        self.gtrxl = GTrXL(
            d_model=d_model,
            nheads=4,
            transformer_layers=1,
            hidden_dims=256,
            n_layers=1
        )
        
        self.head_p1 = nn.Linear(d_model, filters, bias=False)
        self.head_p2 = nn.Linear(filters, 4, bias=False)
        self.head_v1 = nn.Linear(d_model, filters, bias=False)
        self.head_v2 = nn.Linear(filters, 1, bias=False)

    def forward(self, x, _=None):
        h = self.gtrxl(x)
        
        h_p = F.relu_(self.head_p1(h))
        p = self.head_p2(h_p)

        h_v = F.relu_(self.head_v1(h))
        v = torch.tanh(self.head_v2(h_v))

        return {"policy": p, "value": v}

In [4]:
e = env_module.Environment()
e.reset()

In [5]:
net = GeeseNet(e)

In [6]:
input_ = torch.randn(1, 16 * 7 * 11)
input_.size()

torch.Size([1, 1232])

In [7]:
out = net(input_)

In [8]:
out

{'policy': tensor([[[ 0.0777,  0.1091,  0.0838, -0.0897]]], grad_fn=<UnsafeViewBackward>),
 'value': tensor([[[-0.1764]]], grad_fn=<TanhBackward>)}

In [9]:
net

GeeseNet(
  (gtrxl): GTrXL(
    (embed): PositionalEncoding(
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transfomer): TransformerEncoder(
      (layers): ModuleList(
        (0): TEL(
          (self_attn): MultiheadAttention(
            (out_proj): _LinearWithBias(in_features=1232, out_features=1232, bias=True)
          )
          (linear1): Linear(in_features=1232, out_features=256, bias=True)
          (dropout): Dropout(p=0, inplace=False)
          (linear2): Linear(in_features=256, out_features=1232, bias=True)
          (norm1): LayerNorm((1232,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((1232,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0, inplace=False)
          (dropout2): Dropout(p=0, inplace=False)
          (gru_1): GRU(1232, 1232, batch_first=True)
          (gru_2): GRU(1232, 1232, batch_first=True)
        )
      )
    )
  )
  (head_p1): Linear(in_features=1232, out_features=64, bias=False)
  (head_p2)