In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
import torch_geometric.nn as pyg_nn
import torch_geometric.nn.pool.glob as pool
from sports_gnn.common import *
from sports_gnn.model import SportsGNN

In [2]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = pyg_nn.GCNConv(3, 8)
        self.conv2 = pyg_nn.GCNConv(8, 16)
        self.linear = nn.Linear(16, 2)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        print(x)
        x = pool.global_mean_pool(x, None)
        print(x)
        x = self.linear(x)
        x = F.softmax(x)
        return x
        

In [3]:
num_player = 22
num_feature = 3

x = torch.rand(num_player, num_feature)

edge_index = []

for i in range(num_player):
    for j in range(num_player):
        edge_index.append([i, j])

edge_index = torch.tensor(edge_index)

data = Data(x=x, edge_index=edge_index.t().contiguous())

In [4]:
data

Data(x=[22, 3], edge_index=[2, 484])

In [6]:
model = SportsGNN()
print(model)

hn, cn = model.init()

output, hn, cn = model(data, None, hn, cn)

print(output)

SportsGNN(
  (gat_block): GATBlock(
    (convs): ModuleList(
      (0): BatchNorm1d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (1): GATv2Conv(3, 16, heads=2)
      (2): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): GATv2Conv(32, 16, heads=2)
      (4): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): GATv2Conv(32, 8, heads=2)
    )
  )
  (sum_pool): SumPool(
    (linear1): Linear(in_features=16, out_features=32, bias=True)
    (linear2): Linear(in_features=32, out_features=16, bias=True)
  )
  (game_state_encoder): Linear(in_features=16, out_features=16, bias=True)
  (res_mlp): ResMLP(
    (linear1): Linear(in_features=16, out_features=16, bias=True)
    (linear2): Linear(in_features=16, out_features=16, bias=True)
  )
  (lstm_block): LSTMBlock(
    (lstm): LSTM(16, 16, num_layers=2)
    (fc): Linear(in_features=16, out_features=2, bias=True)
  )
  (softmax): Softmax(dim