In [1]:
import importlib

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
model_module = importlib.import_module("handyrl.model")
env_module = importlib.import_module("handyrl.envs.kaggle.hungry_geese")

Loading environment football failed: No module named 'gfootball'


In [3]:
class TorusConv2d(nn.Module):
    def __init__(self, input_dim, output_dim, kernel_size, bn):
        super().__init__()
        self.edge_size = (kernel_size[0] // 2, kernel_size[1] // 2)
        self.conv = nn.Conv2d(input_dim, output_dim, kernel_size=kernel_size)
        self.bn = nn.BatchNorm2d(output_dim) if bn else None

    def forward(self, x):
        h = torch.cat([x[:, :, :, -self.edge_size[1] :], x, x[:, :, :, : self.edge_size[1]]], dim=3)
        h = torch.cat([h[:, :, -self.edge_size[0] :], h, h[:, :, : self.edge_size[0]]], dim=2)
        h = self.conv(h)
        h = self.bn(h) if self.bn is not None else h
        return h


class GeeseNet(model_module.BaseModel):
    def __init__(self, env, args={}):
        super().__init__(env, args)
        input_shape = env.observation().shape
        layers, filters = 12, 32
        self.conv0 = TorusConv2d(input_shape[0], filters, (3, 3), True)
        self.blocks = nn.ModuleList([TorusConv2d(filters, filters, (3, 3), True) for _ in range(layers)])

        self.conv_p = TorusConv2d(filters, filters, (3, 3), True)
        self.conv_v = TorusConv2d(filters, filters, (3, 3), True)

        self.head_p = nn.Linear(filters, 4, bias=False)
        self.head_v1 = nn.Linear(filters * 2, filters, bias=False)
        self.head_v2 = nn.Linear(filters, 1, bias=False)

    def forward(self, x, _=None):
        h = F.relu_(self.conv0(x))
        for block in self.blocks:
            h = F.relu_(h + block(h))

        h_p = F.relu_(self.conv_p(h))
        h_head_p = (h_p * x[:, :1]).view(h_p.size(0), h_p.size(1), -1).sum(-1)
        p = self.head_p(h_head_p)

        h_v = F.relu_(self.conv_v(h))
        h_head_v = (h_v * x[:, :1]).view(h_v.size(0), h_v.size(1), -1).sum(-1)
        h_avg_v = h_v.view(h_v.size(0), h_v.size(1), -1).mean(-1)

        h_v = F.relu_(self.head_v1(torch.cat([h_head_v, h_avg_v], 1)))
        v = torch.tanh(self.head_v2(h_v))

        return {"policy": p, "value": v}

In [4]:
e = env_module.Environment()
e.reset()

In [5]:
net = GeeseNet(e)

In [6]:
input_ = torch.randn(1, 17, 7, 11)
input_.size()

torch.Size([1, 17, 7, 11])

In [7]:
out = net(input_)

In [8]:
out

{'policy': tensor([[-10.2074,  -2.0925,  -1.0508,   4.3439]], grad_fn=<MmBackward>),
 'value': tensor([[-0.0705]], grad_fn=<TanhBackward>)}

In [9]:
net

GeeseNet(
  (conv0): TorusConv2d(
    (conv): Conv2d(17, 32, kernel_size=(3, 3), stride=(1, 1))
    (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (blocks): ModuleList(
    (0): TorusConv2d(
      (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
      (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): TorusConv2d(
      (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
      (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (2): TorusConv2d(
      (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
      (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (3): TorusConv2d(
      (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
      (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (4): TorusConv2d(
      (conv): Conv2d(32, 32,

In [10]:
# pytorch_total_params
params = sum(p.numel() for p in net.parameters())
print(f"{params:,}")

137,568


In [11]:
# pytorch_total_params (trainable)
params = sum(p.numel() for p in net.parameters() if p.requires_grad)
print(f"{params:,}")

137,568
