In [19]:
import importlib
import math

import numpy as np
import torch as T
import torch.nn as nn
import torch.nn.functional as F

In [20]:
model_module = importlib.import_module("handyrl.model")
env_module = importlib.import_module("handyrl.envs.kaggle.hungry_geese")

In [21]:
e = env_module.Environment()
e.reset()

In [22]:
obs = e.observation()
obs.shape

(3, 16, 16)

In [23]:
net = env_module.GeeseNetViT(e)

In [24]:
input_ = T.randn(1, 3, 16, 16, dtype=T.float32)
print(f"size: {input_.size()}, type: {input_.dtype}")

size: torch.Size([1, 3, 16, 16]), type: torch.float32


In [25]:
out = net(input_)
print(f"size: {out['policy'].size()}, type: {out['policy'].dtype}")

size: torch.Size([1, 4]), type: torch.float32


In [26]:
out

{'policy': tensor([[-0.1326,  0.1415, -1.6444, -0.2824]], grad_fn=<AddmmBackward>)}

In [27]:
net

GeeseNetViT(
  (vit): ViT(
    (to_patch_embedding): Sequential(
      (0): Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=2, p2=2)
      (1): Linear(in_features=12, out_features=256, bias=True)
    )
    (dropout): Dropout(p=0.0, inplace=False)
    (transformer): Transformer(
      (layers): ModuleList(
        (0): ModuleList(
          (0): Residual(
            (fn): PreNorm(
              (norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
              (fn): Attention(
                (to_qkv): Linear(in_features=256, out_features=768, bias=False)
                (to_out): Sequential(
                  (0): Linear(in_features=256, out_features=256, bias=True)
                  (1): Dropout(p=0.0, inplace=False)
                )
              )
            )
          )
          (1): Residual(
            (fn): PreNorm(
              (norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
              (fn): FeedForward(
                (net): Sequential(

In [28]:
# pytorch_total_params
params = sum(p.numel() for p in net.parameters())
print(f"{params:,}")

416,772


In [29]:
# pytorch_total_params (trainable)
params = sum(p.numel() for p in net.parameters() if p.requires_grad)
print(f"{params:,}")

416,772
