In [1]:
import importlib

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
model_module = importlib.import_module("handyrl.model")
env_module = importlib.import_module("handyrl.envs.kaggle.hungry_geese")

Loading environment football failed: No module named 'gfootball'


In [3]:
pretrained_model_path = "ds/models/first_stage_2462.pth"
weights = torch.load(pretrained_model_path)

In [4]:
e = env_module.Environment()
e.reset()

In [5]:
obs = e.observation()
obs = torch.from_numpy(obs.reshape(1, 17, 7, 11)).clone()
print(f"size: {obs.size()}, type: {obs.dtype}")

size: torch.Size([1, 17, 7, 11]), type: torch.float32


In [6]:
e.obs_list

[[{'action': 'NORTH',
   'reward': 0,
   'info': {},
   'observation': {'remainingOverageTime': 60,
    'step': 0,
    'geese': [[61], [6], [4], [19]],
    'food': [67, 18],
    'index': 0},
   'status': 'ACTIVE'},
  {'action': 'NORTH',
   'reward': 0,
   'info': {},
   'observation': {'remainingOverageTime': 60, 'index': 1},
   'status': 'ACTIVE'},
  {'action': 'NORTH',
   'reward': 0,
   'info': {},
   'observation': {'remainingOverageTime': 60, 'index': 2},
   'status': 'ACTIVE'},
  {'action': 'NORTH',
   'reward': 0,
   'info': {},
   'observation': {'remainingOverageTime': 60, 'index': 3},
   'status': 'ACTIVE'}]]

In [7]:
input_ = torch.randn(2, 17, 7, 11)
print(f"size: {input_.size()}, type: {input_.dtype}")

size: torch.Size([2, 17, 7, 11]), type: torch.float32


In [8]:
net = env_module.GeeseNetIMO(e)

In [9]:
net.geese_net.load_state_dict(weights, strict=False)

<All keys matched successfully>

In [10]:
for param in net.geese_net.parameters():
    param.requires_grad = False

In [11]:
# pytorch_total_params
params = sum(p.numel() for p in net.parameters())
print(f"{params:,}")

279,009


In [12]:
# pytorch_total_params (trainable)
params = sum(p.numel() for p in net.parameters() if p.requires_grad)
print(f"{params:,}")

141,441


In [13]:
net

GeeseNetIMO(
  (geese_net): GeeseNet(
    (conv0): TorusConv2d(
      (conv): Conv2d(17, 32, kernel_size=(3, 3), stride=(1, 1))
      (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (blocks): ModuleList(
      (0): TorusConv2d(
        (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
        (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): TorusConv2d(
        (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
        (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (2): TorusConv2d(
        (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
        (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (3): TorusConv2d(
        (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
        (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_

In [14]:
out = net(input_)
print(f"size: {out['policy'].size()}, type: {out['policy'].dtype}")

size: torch.Size([2, 4]), type: torch.float32


In [15]:
out

{'policy': tensor([[-0.5088, -0.8240, -0.2078,  0.1730],
         [ 0.5088,  0.8240,  0.2078, -0.1730]], grad_fn=<MmBackward>),
 'value': tensor([[0.0545],
         [0.0087]], grad_fn=<TanhBackward>)}