In [None]:
%load_ext autoreload
%autoreload 2
import os

import torch
import torch.nn.functional as F
from torch.nn import Linear
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, global_add_pool

from soccerai.config import build_cfg
from soccerai.data.dataset import WorldCup2022Dataset
from soccerai.trainer import Trainer

In [None]:
dataset = WorldCup2022Dataset(
    "/home/aarcara/soccerai/soccerai/data/resources", "fully_connected"
)
dataset.process()

In [None]:
class GCN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(dataset.num_node_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        self.lin = Linear(hidden_channels, 1)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = self.conv3(x, edge_index)

        x = global_add_pool(x, batch)

        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin(x)

        return x


model = GCN(hidden_channels=64)

In [None]:
cfg = build_cfg("/home/aarcara/soccerai/configs/example.yaml")

In [None]:
loader = DataLoader(
    dataset,
    cfg.bs,
    num_workers=os.cpu_count() - 1,
    shuffle=True,
    pin_memory=True,
    persistent_workers=True,
)

In [None]:
trainer = Trainer(cfg, model, loader, "cuda")
trainer.train("debug")