In [5]:
import importlib

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

%matplotlib inline

In [6]:
model_module = importlib.import_module("handyrl.model")
env_module = importlib.import_module("handyrl.envs.kaggle.hungry_geese")

In [7]:
e = env_module.Environment()
e.reset()

In [8]:
obs = e.observation()
obs = torch.from_numpy(obs.reshape(1, obs.shape[0], obs.shape[1], obs.shape[2])).clone()
obs_ = torch.cat([obs, obs], dim=0)
print(f"size: {obs_.size()}, type: {obs_.dtype}")

size: torch.Size([2, 17, 7, 11]), type: torch.float32


In [9]:
# plt.imshow(obs[0][0], cmap="gray")
# plt.show()

In [10]:
input_ = torch.randn(2, obs.shape[1], obs.shape[2], obs.shape[3])
print(f"size: {input_.size()}, type: {input_.dtype}")

size: torch.Size([2, 17, 7, 11]), type: torch.float32


In [11]:
net = env_module.GeeseNet()

In [12]:
# Disable trainable for pretrained weights.
# for param in net.geese_net.parameters():
#     param.requires_grad = False

In [13]:
out = net(obs_)
# print(f"size: {out.size()}, type: {out.dtype}")
print(f"size: {out['policy'].size()}, type: {out['policy'].dtype}")
print(f"size: {out['value'].size()}, type: {out['value'].dtype}")

size: torch.Size([2, 4]), type: torch.float32
size: torch.Size([2, 1]), type: torch.float32


In [14]:
out

{'policy': tensor([[-0.4537, -0.1722,  0.9496, -0.7086],
         [-0.4537, -0.1722,  0.9496, -0.7086]], grad_fn=<MmBackward>),
 'value': tensor([[0.1847],
         [0.1847]], grad_fn=<TanhBackward>),
 'h_head_p': tensor([[0.8498, 0.1778, 0.4624, 2.2744, 0.0000, 2.3729, 0.0000, 0.0000, 0.0000,
          0.9268, 0.5602, 0.9807, 0.5468, 0.7175, 0.0000, 0.3600, 1.9266, 2.9467,
          0.0000, 0.1890, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 2.4094, 2.4689,
          3.2419, 0.5955, 0.0000, 0.0000, 0.7698],
         [0.8498, 0.1778, 0.4624, 2.2744, 0.0000, 2.3729, 0.0000, 0.0000, 0.0000,
          0.9268, 0.5602, 0.9807, 0.5468, 0.7175, 0.0000, 0.3600, 1.9266, 2.9467,
          0.0000, 0.1890, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 2.4094, 2.4689,
          3.2419, 0.5955, 0.0000, 0.0000, 0.7698]], grad_fn=<SumBackward1>),
 'h_head_v': tensor([[0.8794, 0.0000, 0.0000, 0.0000, 1.1075, 0.0000, 0.4641, 0.7000, 1.5866,
          1.3743, 0.0000, 4.0468, 0.0000, 0.7730, 0.0000, 0.2433, 0.2200,

In [15]:
# pytorch_total_params
params = sum(p.numel() for p in net.parameters())
print(f"{params:,}")

137,568


In [16]:
# pytorch_total_params (trainable)
params = sum(p.numel() for p in net.parameters() if p.requires_grad)
print(f"{params:,}")

137,568


In [17]:
net

GeeseNet(
  (conv0): TorusConv2d(
    (conv): Conv2d(17, 32, kernel_size=(3, 3), stride=(1, 1))
    (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (blocks): ModuleList(
    (0): TorusConv2d(
      (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
      (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): TorusConv2d(
      (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
      (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (2): TorusConv2d(
      (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
      (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (3): TorusConv2d(
      (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
      (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (4): TorusConv2d(
      (conv): Conv2d(32, 32,