In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
import torch_geometric.nn as pyg_nn
import torch_geometric.nn.pool.glob as pool
from sports_gnn.common import *
from sports_gnn.model import SportsGNN

In [2]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = pyg_nn.GCNConv(3, 8)
        self.conv2 = pyg_nn.GCNConv(8, 16)
        self.linear = nn.Linear(16, 2)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        print(x)
        x = pool.global_mean_pool(x, None)
        print(x)
        x = self.linear(x)
        x = F.softmax(x)
        return x
        

In [3]:
num_player = 22
num_feature = 3

x = torch.rand(num_player, num_feature)

edge_index = []

for i in range(num_player):
    for j in range(num_player):
        edge_index.append([i, j])

edge_index = torch.tensor(edge_index).T

edge_attr = torch.rand(484, 2)

graph_data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
game_state = torch.rand(4)

In [4]:
graph_data

Data(x=[22, 3], edge_index=[2, 484], edge_attr=[484, 2])

In [5]:
graph_data.edge_attr

tensor([[6.4764e-01, 7.3459e-02],
        [1.4356e-01, 1.9749e-01],
        [9.0070e-01, 6.6127e-01],
        [3.9562e-01, 2.7615e-01],
        [5.7837e-01, 2.1888e-01],
        [3.6177e-01, 1.0874e-02],
        [8.9401e-01, 9.9617e-01],
        [7.7090e-01, 6.6897e-01],
        [6.9250e-01, 2.3651e-01],
        [8.9655e-02, 5.5145e-01],
        [5.4173e-01, 4.8322e-01],
        [3.3486e-01, 5.0073e-01],
        [5.4949e-01, 1.3117e-01],
        [8.4076e-01, 5.5501e-01],
        [2.4687e-01, 4.2318e-01],
        [5.3804e-01, 4.4479e-01],
        [2.9272e-01, 4.9089e-01],
        [3.8005e-01, 2.8533e-01],
        [6.1651e-01, 4.4460e-01],
        [9.6365e-01, 5.5986e-01],
        [5.6640e-01, 7.7932e-01],
        [9.1116e-01, 3.8057e-01],
        [1.2015e-01, 1.3422e-01],
        [9.4932e-01, 2.4997e-01],
        [5.9452e-02, 4.2010e-01],
        [1.4194e-01, 2.4594e-01],
        [8.2081e-01, 4.9382e-01],
        [3.6636e-01, 7.2121e-01],
        [5.9447e-01, 5.4832e-01],
        [3.925

In [6]:
model = SportsGNN()
print(model)

hn, cn = model.init()

output, hn, cn = model(graph_data, game_state, hn, cn)

print(output)

SportsGNN(
  (gat_block): GATBlock(
    (convs): ModuleList(
      (0): BatchNorm1d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (1): GATv2Conv(3, 16, heads=2)
      (2): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): GATv2Conv(32, 16, heads=2)
      (4): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): GATv2Conv(32, 8, heads=2)
    )
  )
  (sum_pool): SumPool(
    (linear1): Linear(in_features=16, out_features=64, bias=True)
    (linear2): Linear(in_features=64, out_features=16, bias=True)
  )
  (game_state_encoder): Linear(in_features=20, out_features=16, bias=True)
  (res_mlp): ResMLP(
    (linear1): Linear(in_features=16, out_features=16, bias=True)
    (linear2): Linear(in_features=16, out_features=16, bias=True)
  )
  (lstm_block): LSTMBlock(
    (lstm): LSTM(16, 16, num_layers=2)
    (fc): Linear(in_features=16, out_features=2, bias=True)
  )
  (softmax): Softmax(dim