<a href="https://colab.research.google.com/github/JasaZnidar/Predvidenje-zmagovalca-vaterpolo/blob/main/Diplomska_naloga.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup enviroment and imports

## Setup

In [1]:
!pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.2.2+cu121.html
!pip install torch torchvision torchaudio -f https://data.pyg.org/whl/torch-2.2.2+cu121.html
!pip install torch-geometric
!pip install torcheval
!pip install scikit-plot

Looking in links: https://data.pyg.org/whl/torch-2.2.2+cu121.html
Collecting pyg_lib
  Downloading https://data.pyg.org/whl/torch-2.2.0%2Bcu121/pyg_lib-0.4.0%2Bpt22cu121-cp310-cp310-linux_x86_64.whl (2.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.5/2.5 MB[0m [31m15.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch_scatter
  Downloading https://data.pyg.org/whl/torch-2.2.0%2Bcu121/torch_scatter-2.1.2%2Bpt22cu121-cp310-cp310-linux_x86_64.whl (10.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.9/10.9 MB[0m [31m54.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch_sparse
  Downloading https://data.pyg.org/whl/torch-2.2.0%2Bcu121/torch_sparse-0.6.18%2Bpt22cu121-cp310-cp310-linux_x86_64.whl (5.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.0/5.0 MB[0m [31m51.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch_cluster
  Downloading https://data.pyg.org/whl/torch-2.2.0%2Bcu121/torch_cluster-1.6.3%2Bp

## Imports

In [2]:
import json
import networkx as nx
import torch
import torch_geometric
from torch_geometric.utils.convert import from_networkx
from torch_geometric import nn, sampler
from torch_geometric.data import HeteroData
from torch_geometric import transforms as T
from torch_geometric import loader
from torcheval.metrics import R2Score, MeanSquaredError
import tqdm
from sklearn.metrics import roc_auc_score, roc_curve
import scikitplot as skplt
import matplotlib.pyplot as plt
import requests
from zipfile import ZipFile
from io import BytesIO
%matplotlib inline

## Other

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Analizing scraped data and create graph



## Get Zipped test.json file

In [4]:
with requests.get("https://github.com/JasaZnidar/totalwaterpolo-web-scraper/raw/master/test.zip", ) as r:
  ZipFile(BytesIO(r.content), "r").extractall()

## Get scraped data from github repository

In [5]:
# open raw data scraped from the website
with open("/content/test.json") as f:
    scraped_data = json.load(f)

## Data generating function
We will create a function that will create HeteroData from scraped data before a selected date. This will be used to create training data.

### player_in_match functions

In [17]:
def average(history: list[list[float]]) -> list[float]:
  ret = [0]*len(history[0])

  for match in history:
    for i in range(len(match)):
      ret[i] += match[i]

  for i in range(len(ret)):
    ret[i] /= len(history)

  return ret

### Create graph from json

In [18]:
def createData(data: dict, start: int=0, stop: int=-1) -> HeteroData:
  # data
  ret_data = HeteroData()

  #=============================================================================
  # Matrices that will define the graph
  #=============================================================================
  # player (player)
  player_dim = 5
  player_matrix = torch.empty(0, player_dim, dtype=torch.float32)

  # player in a match (played_in_match)
  playerInMatch_dim = 10
  playerInMatch_matrix = torch.empty(0, playerInMatch_dim, dtype=torch.float32)

  # team (team)
  team_dim = 0
  team_matrix = torch.empty(0, team_dim, dtype=torch.float32)

  # team in a match (team_in_match)
  teamInMatch_dim = 3
  teamInMatch_matrix = torch.empty(0, teamInMatch_dim, dtype=torch.float32)

  # player instance relation (player, player_instance, played_in_match)
  playerInstance_matrix = torch.empty(2, 0, dtype=torch.long)

  # played relation (played_in_match, played, team_in_match)
  played_matrix = torch.empty(2, 0, dtype=torch.long)
  played_attr = torch.empty(1, 0, dtype=torch.float32)

  # team participated in a match (team, team_instance, team_in_match)
  teamInstance_matrix = torch.empty(2, 0, dtype=torch.long)

  # match result (team_in_match, result, team_in_match)
  result_matrix = torch.empty(2, 0, dtype=torch.long)
  result_attr = torch.empty(1, 0, dtype=torch.float32)
  # reversed relations
  result_rev_attr = torch.empty(1, 0, dtype=torch.float32)

  #=============================================================================
  # Other data
  #=============================================================================
  player_match_history = {}

  #=============================================================================
  # Sort matches in order of date, and filter out the matches that happened
  # after the specified date
  #=============================================================================
  # filter out matches
  sorted_match_ids = []
  breakpoint_match_id = ("", -1)
  for match_id in data['matches']:
    if 'date' in data['matches'][match_id]:
      sorted_match_ids.append((match_id, data['matches'][match_id]['date']))

  # sort matches
  sorted_match_ids = sorted(sorted_match_ids, key=lambda t: t[1])


  #=============================================================================
  # Loop through the matches and fill out the matrices
  #=============================================================================
  players_id_index = {}
  cumulative_player_data = {}
  cumulative_team_data = {}

  for index in range(len(sorted_match_ids)):
    # exit loop if enough matches have been added
    if result_matrix.shape[1] == stop - start:
      break
    match_id, _ = sorted_match_ids[index]
    match_data = data['matches'][match_id]

    # check if match is relevant (at least 7 players in each team)
    if len(match_data['lineup']['home']) < 7 or len(match_data['lineup']['away']) < 7:
      continue

    # result of match
    if index >= start:
      result_matrix = torch.cat((result_matrix, torch.empty((2, 1))), dim=1)
      result_attr = torch.cat((result_attr, torch.empty((1, 1))), dim=1)
      result_rev_attr = torch.cat((result_rev_attr, torch.empty((1, 1))), dim=1)
      if match_data['result']['home'] > match_data['result']['away']:
        result_attr[0, -1] = 1.0
        result_rev_attr[0, -1] = 0.0
      elif match_data['result']['home'] < match_data['result']['away']:
        result_attr[0, -1] = 0.0
        result_rev_attr[0, -1] = 1.0
      else:
        result_attr[0, -1] = 0.5
        result_rev_attr[0, -1] = 0.5

    # go through the lineup
    for team in ['home', 'away']:
      # create teamInMatch
      team_key = match_data['name'][team]

      # if there is no instance of the team, we need a new team
      if not team_key in cumulative_team_data:
        cumulative_team_data[team_key] = {
            "data": [[0, 0, 0 if team == "home" else 1]],   # [[wins, matches, home/away]]
            "last index": teamInMatch_matrix.size(dim=0),
            "team index": team_matrix.size(dim=0)
        }

        # add new team
        if index >= start:
          team_matrix = torch.cat((team_matrix, torch.empty(1, team_dim)), dim=0)

      # update cumulative_team_data
      cumulative_team_data[team_key]['data'][0][2] = 0 if team == "home" else 1
      cumulative_team_data[team_key]['last index'] = teamInMatch_matrix.size(dim=0)
      # add new teamInMatch
      if index >= start:
        teamInMatch_matrix = torch.cat((teamInMatch_matrix, torch.Tensor(cumulative_team_data[team_key]['data'])), dim=0)

      # connect teamInMatch to enemy teamInMatch
      if index >= start:
        result_matrix[0 if team == "home" else 1, -1] = cumulative_team_data[team_key]['last index']

      # connect team to teamInMatch
      if index >= start:
        teamInstance_matrix = torch.cat((teamInstance_matrix, torch.Tensor([[cumulative_team_data[team_key]['team index']], [cumulative_team_data[team_key]['last index']]])), dim=1)

      # update cumulative_team_data
      cumulative_team_data[team_key]['data'][0][0] += 1 if match_data['result'][team] > match_data['result']['away' if team == "home" else 'home'] else 0
      cumulative_team_data[team_key]['data'][0][1] += 1


      # loop through the lineup
      for player_num in match_data['lineup'][team]:
        player_id = match_data['lineup'][team][player_num]['id']

        # check if player is in player matrix
        if not player_id in players_id_index:
          try:
            player = data['players'][player_id]
          except KeyError:
            player = {
                "position": '',
                "hand": '',
                "height": 0,
                "weight": 0,
                "birth": 0
            }
          players_id_index[player_id] = len(players_id_index.keys())

          # player attributes to int
          player_attr = torch.zeros((1, player_dim), dtype=torch.int32)
          player_attr[0, 0] = player['birth']
          player_attr[0, 1] = 1 if player['hand'] == "R" else -1 if player['hand'] == "L" else 0
          player_attr[0, 2] = player['height'] if player['height'] else 0
          match player['position']:
            case '':
              player_attr[0, 3] = 0
            case 'Goalkeeper':
              player_attr[0, 3] = 1
            case 'Driver':
              player_attr[0, 3] = 2
            case 'Left Driver':
              player_attr[0, 3] = 3
            case 'Right Driver':
              player_attr[0, 3] = 4
            case 'Central Defender':
              player_attr[0, 3] = 5
            case 'Left Winger':
              player_attr[0, 3] = 6
            case 'Right Winger':
              player_attr[0, 3] = 7
            case 'Center Forward':
              player_attr[0, 3] = 8
          player_attr[0, 4] = player['weight'] if player['weight'] else 0

          if index >= start:
            # add player to player matrix
            player_matrix = torch.cat((player_matrix, player_attr), dim=0)

          # add player to cumulate data
          cumulative_player_data[player_id] = {
              "data": [0] * (playerInMatch_dim + 2),  # [goals, shots, assists, blocks, saves, exclusions, penalties, suspensions, brutalities, sprints won] + [matches, sprints]
              "last index": playerInMatch_matrix.size(dim=0)
          }

        if player_id not in player_match_history:
          player_match_history[player_id] = []
        player_match_history[player_id].append([0] * (playerInMatch_dim + 2))

        if index >= start:
          # add new playerInMatch
          playerInMatch_matrix = torch.cat((playerInMatch_matrix, torch.Tensor([cumulative_player_data[player_id]['data'][:-2]])), dim=0)
          if not cumulative_player_data[player_id]['data'][-2] == 0:
            playerInMatch_matrix[-1, :-1] /= cumulative_player_data[player_id]['data'][-2]
          if not cumulative_player_data[player_id]['data'][-1] == 0:
            playerInMatch_matrix[-1, -1] /= cumulative_player_data[player_id]['data'][-1]

          # connect player to playerInMatch
          playerInstance_matrix = torch.cat((playerInstance_matrix, torch.Tensor([[players_id_index[player_id]], [cumulative_player_data[player_id]['last index']]])), dim=1)

          # connect playerInMatch to teamInMatch
          played_matrix = torch.cat((played_matrix, torch.Tensor([[cumulative_player_data[player_id]['last index']], [cumulative_team_data[team_key]['last index']]])), dim=1)
          played_attr = torch.cat((played_attr, torch.Tensor([[0 if player_num == "1" or player_num == "13" else 1]])), dim=1)

        # update cumulative_player_data
        cumulative_player_data[player_id]['data'][10] += 1  # played in a match

    # update data with data from this match
    # go throught ALL plays and update cumulative_player_data and cumulative_team_data
    for play in match_data['plays']:
      # check if a player was marked
      if play['player_1'] == 0:
        continue

      # find teams
      team_1 = play['team']
      team_2 = "home" if team_1 == "away" else "away"

      # find players who participated in the play
      try:
        id_1 = match_data['lineup'][team_1][str(play['player_1'])]['id']
      except Exception as ex:
        print(match_id, team_1, match_data['name'][team_1])
        #print(json.dumps(match_data['lineup'][team_1], sort_keys=True, indent=4))
        print(json.dumps(play, sort_keys=True, indent=4))
        raise ex
      id_2 = [0, 0] # depending on the play, the second player could be from the same team (first value is the opposing team)
      if not play['player_2'] == 0:
        if str(play['player_2']) in match_data['lineup'][team_2]:
          id_2[0] = match_data['lineup'][team_2][str(play['player_2'])]['id']
        if str(play['player_2']) in match_data['lineup'][team_1]:
          id_2[1] = match_data['lineup'][team_1][str(play['player_2'])]['id']

      # detect play type
      if "goal scored" in play['action']:
        cumulative_player_data[id_1]['data'][0] += 1        # goals
        player_match_history[id_1][-1][0] += 1
        cumulative_player_data[id_1]['data'][1] += 1        # shots
        player_match_history[id_1][-1][1] += 1
        if not id_2[1] == 0:
          cumulative_player_data[id_2]['data'][2] += 1      # assists
          player_match_history[id_1][-1][2] += 1
      elif "exclusion" in play['action']:
        cumulative_player_data[id_1]['data'][5] += 1        # exclusion
        player_match_history[id_1][-1][5] += 1
      elif "penalty foul" in play['action']:
        cumulative_player_data[id_1]['data'][6] += 1        # penalty
        player_match_history[id_1][-1][6] += 1
      elif "shot missed" in play['action']:
        cumulative_player_data[id_1]['data'][1] += 1        # shots
        player_match_history[id_1][-1][1] += 1
      elif "shot saved" in play['action']:
        cumulative_player_data[id_1]['data'][1] += 1        # shots
        player_match_history[id_1][-1][1] += 1
      elif "shot blocked" in play['action']:
        cumulative_player_data[id_1]['data'][1] += 1        # shots
        player_match_history[id_1][-1][1] += 1
        if not id_2[0] == 0:
          cumulative_player_data[id_2[0]]['data'][3] += 1   # blocks
          player_match_history[id_2][-1][3] += 1
      elif "suspention" in play['action']:
        cumulative_player_data[id_1]['data'][7] += 1        # suspensions
        player_match_history[id_1][-1][7] += 1
      elif "brutality" in play['action']:
        cumulative_player_data[id_1]['data'][8] += 1        # brutalities
        player_match_history[id_1][-1][8] += 1
      elif "sprint won" in play['action']:
        cumulative_player_data[id_1]['data'][9] += 1        # sprint won
        player_match_history[id_1][-1][9] += 1
        cumulative_player_data[id_1]['data'][11] += 1       # sprint
        player_match_history[id_1][-1][11] += 1
        if not id_2[0] == 0:
          cumulative_player_data[id_2[0]]['data'][11] += 1  # sprint
          player_match_history[id_1][-1][11] += 1

  # save data
  ret_data['player'].x = player_matrix
  ret_data['player_in_match'].x = playerInMatch_matrix
  ret_data['team'].x = team_matrix
  ret_data['team_in_match'].x = teamInMatch_matrix

  ret_data['player', 'player_instance', 'player_in_match'].edge_index = playerInstance_matrix.type(torch.long)
  ret_data['player_in_match', 'player_instance_rev', 'player'].edge_index = playerInstance_matrix.flip([0]).type(torch.long)

  ret_data['player_in_match', 'played', 'team_in_match'].edge_index = played_matrix.type(torch.long)
  ret_data['player_in_match', 'played', 'team_in_match'].edge_attr = played_attr
  ret_data['team_in_match', 'played_rev', 'player_in_match'].edge_index = played_matrix.flip([0]).type(torch.long)
  ret_data['team_in_match', 'played_rev', 'player_in_match'].edge_attr = played_attr

  ret_data['team', 'team_instance', 'team_in_match'].edge_index = teamInstance_matrix.type(torch.long)
  ret_data['team_in_match', 'team_instance_rev', 'team'].edge_index = teamInstance_matrix.flip([0]).type(torch.long)

  ret_data['team_in_match', 'result', 'team_in_match'].edge_index = result_matrix.type(torch.long)
  ret_data['team_in_match', 'result', 'team_in_match'].edge_attr = result_attr
  ret_data['team_in_match', 'result_rev', 'team_in_match'].edge_index = result_matrix.flip([0]).type(torch.long)
  ret_data['team_in_match', 'result_rev', 'team_in_match'].edge_attr = result_rev_attr

  P = '8440'
  print(playerInMatch_matrix[ cumulative_player_data[P]['last index'] ])
  print(player_match_history[P][-1])

  return ret_data

In [19]:
data = createData(scraped_data)
print(data)

a
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
[3, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
HeteroData(
  player={ x=[10356, 5] },
  player_in_match={ x=[110227, 10] },
  team={ x=[497, 0] },
  team_in_match={ x=[8768, 3] },
  (player, player_instance, player_in_match)={ edge_index=[2, 110227] },
  (player_in_match, player_instance_rev, player)={ edge_index=[2, 110227] },
  (player_in_match, played, team_in_match)={
    edge_index=[2, 110227],
    edge_attr=[1, 110227],
  },
  (team_in_match, played_rev, player_in_match)={
    edge_index=[2, 110227],
    edge_attr=[1, 110227],
  },
  (team, team_instance, team_in_match)={ edge_index=[2, 8768] },
  (team_in_match, team_instance_rev, team)={ edge_index=[2, 8768] },
  (team_in_match, result, team_in_match)={
    edge_index=[2, 4384],
    edge_attr=[1, 4384],
  },
  (team_in_match, result_rev, team_in_match)={
    edge_index=[2, 4384],
    edge_attr=[1, 4384],
  }
)


In [None]:
print(data['player'].x[:5, :5])

tensor([[4.5591e+04, 1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [4.5142e+04, 1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [4.8020e+04, 1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [4.5137e+04, 1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [4.5550e+04, 1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00]])


### Split data into training and validation

Values to define the scope of the training and validation.

In [None]:
training = -1
validate = 1000

If train or validate are not bigger then 0, their values are addapted acordingly to fully utilize the avaliable data.

In [None]:
if training <= 0 and validate <= 0:
  validate = 1
  training = data["team_in_match", "result", "team_in_match"].edge_index.size(dim=1) - 1
elif training <= 0:
  training = data["team_in_match", "result", "team_in_match"].edge_index.size(dim=1) - validate
elif validate <= 0:
  validate = data["team_in_match", "result", "team_in_match"].edge_index.size(dim=1) - training


In [None]:
train_data = createData(scraped_data, stop=training)
print(train_data)
val_data = createData(scraped_data, start=training, stop=(training + validate))
print("\n", val_data)

HeteroData(
  player={ x=[8941, 5] },
  player_in_match={ x=[84384, 10] },
  team={ x=[437, 0] },
  team_in_match={ x=[6768, 3] },
  (player, player_instance, player_in_match)={ edge_index=[2, 84384] },
  (player_in_match, player_instance_rev, player)={ edge_index=[2, 84384] },
  (player_in_match, played, team_in_match)={
    edge_index=[2, 84384],
    edge_attr=[1, 84384],
  },
  (team_in_match, played_rev, player_in_match)={
    edge_index=[2, 84384],
    edge_attr=[1, 84384],
  },
  (team, team_instance, team_in_match)={ edge_index=[2, 6768] },
  (team_in_match, team_instance_rev, team)={ edge_index=[2, 6768] },
  (team_in_match, result, team_in_match)={
    edge_index=[2, 3384],
    edge_attr=[1, 3384],
  },
  (team_in_match, result_rev, team_in_match)={
    edge_index=[2, 3384],
    edge_attr=[1, 3384],
  }
)

 HeteroData(
  player={ x=[1442, 5] },
  player_in_match={ x=[25919, 10] },
  team={ x=[60, 0] },
  team_in_match={ x=[2000, 3] },
  (player, player_instance, player_in_matc

# Machine learning

## Encoder and Decoder

In [None]:
class Encoder(torch.nn.Module):
  def __init__(self, in_channels, hidden_channels, out_channels, layers=2, layer=nn.GATConv, device='cpu'):
    super().__init__()

    self.convs = torch.nn.ModuleList()
    self.convs.append(layer(in_channels, hidden_channels))
    for _ in range(1, layers-1):
      self.convs.append(layer(hidden_channels, hidden_channels))
    self.convs.append(layer(hidden_channels, out_channels))

  def forward(self, x, edge_index):
    for conv in self.convs[:-1]:
      x = conv(x, edge_index)
      x = x.relu()
    x = self.convs[-1](x, edge_index)

    return x

  def forward(self, x, edge_index):
    for conv in self.convs[:-1]:
      x = conv(x, edge_index)
      x = x.relu()
    x = self.convs[-1](x, edge_index)

    return x

class Decoder(torch.nn.Module):
  def __init__(self, in_channels, hidden_channels, out_channels, layers=2):
    super().__init__()

    self.lins = torch.nn.ModuleList()
    self.lins.append(nn.Linear(in_channels, hidden_channels))
    for _ in range(1, layers-1):
      self.lins.append(nn.Linear(hidden_channels, hidden_channels))
    self.lins.append(nn.Linear(hidden_channels, out_channels))

  def forward(self, x_dict, edge_label_index):
    row, col = edge_label_index
    x = torch.cat([x_dict['team'][row], x_dict['team'][col]], dim=-1)

    for lin in self.lins[:-1]:
      x = lin(x)
      x = x.relu()
    x = self.lins[-1](x)

    return x

## Simple module

In [None]:
class basicModule(torch.nn.Module):
  def __init__(self, device='cpu'):
    super().__init__()
    self.device = device

    self.lin = torch.nn.Linear(2, 1).to(self.device)

  def forward(self, edge, edge_index):
    x = torch.zeros([edge_index.shape[1], 2])
    x[:, 0] = edge[edge_index[0], 2]
    x[:, 1] = edge[edge_index[1], 2]

    x = self.lin(x.to(self.device))

    return x

## GCN module

In [None]:
class GCN(torch.nn.Module):
  def __init__(self, in_channels, hidden_channels, out_channels, layers=2, device='cpu'):
    super().__init__()
    self.device = device

    if layers == 1:
      self.gcn = nn.GCNConv(in_channels, out_channels, add_self_loops=False)
    else:
      self.gcn = torch.nn.ModuleList()
      self.gcn.append(nn.GCNConv(in_channels, hidden_channels, add_self_loops=False))
      for _ in range(1, layers-1):
        self.gcn.append(nn.GCNConv(hidden_channels, hidden_channels, add_self_loops=False))
      self.gcn.append(nn.GCNConv(hidden_channels, out_channels, add_self_loops=False))

    #self.gcn = nn.to_hetero(self.gcn, train_data.metadata(), aggr='sum')

    self.gcn.to(self.device)

  def forward(self, x, edge_index):
    x.to(self.device)

    for module in self.gcn:
      x = module(x, edge_index)
      x = x.relu()

    return x

## SAGEConv module

In [None]:
class SAGE(torch.nn.Module):
  def __init__(self, in_channels, hidden_channels, out_channels, encode_layers=2, edecode_layers=2):
    super().__init__()
    self.encoder = Encoder(
        in_channels,
        hidden_channels,
        hidden_channels,
        encode_layers,
        lambda in_c, hidden_c: nn.SAGEConv(in_c, hidden_c, add_self_loops=False)
    )
    self.encoder = nn.to_hetero(self.encoder, data.metadata(), aggr='sum')

    self.decoder = Decoder(hidden_channels, hidden_channels, out_channels, edecode_layers)

  def forward(self, x, edge_index):
    # encode
    x = self.encoder(x, edge_index)

    # decode
    x = self.decoder(x)

    return x

## GAT Module

In [None]:
class GAT(torch.nn.Module):
  def __init__(self, in_channels, hidden_channels, out_channels, encode_layers=2, edecode_layers=2):
    super().__init__()
    self.encoder = Encoder(
        in_channels,
        hidden_channels,
        hidden_channels,
        encode_layers,
        lambda in_c, hidden_c: nn.GATConv(in_c, hidden_c, add_self_loops=False)
    )
    self.encoder = nn.to_hetero(self.encoder, data.metadata(), aggr='sum')

    self.decoder = Decoder(hidden_channels, hidden_channels, out_channels, edecode_layers)

  def forward(self, x, edge_index):
    # encode
    x = self.encoder(x, edge_index)

    # decode
    x = self.decoder(x)

    return x

## Training

In [None]:
epochs = 100
size = 64
iterations = 3

neighbors = {
    ('player', 'player_instance', 'player_in_match'): [20]*iterations,
    ('player_in_match', 'player_instance_rev', 'player'): [1]*iterations,
    ('player_in_match', 'played', 'team_in_match'): [1]*iterations,
    ('team_in_match', 'played_rev', 'player_in_match'): [13]*iterations,
    ('team', 'team_instance', 'team_in_match'): [20]*iterations,
    ('team_in_match', 'team_instance_rev', 'team'): [1]*iterations,
    ('team_in_match', 'result', 'team_in_match'): [1]*iterations,
    ('team_in_match', 'result_rev', 'team_in_match'): [1]*iterations
}

edgeIndex = tuple((link, train_data[link].edge_index) for link in train_data.metadata()[1])

matchLoader = loader.NeighborLoader(
    train_data.to(device),
    num_neighbors=neighbors,
    batch_size=size,
    input_nodes='team_in_match',
    # temporal_strategy='last'  # za izbiro zgodovino tekem?
    # is_sorted=True  # edge indexi so urejeni časovno
)

matchData = next(iter(matchLoader))
print(matchData)

AttributeError: module 'torch_geometric.loader' has no attribute 'HeteroSamplerOutput'

### Basic Module

In [None]:
basic = basicModule(device=device)

optimizer = torch.optim.Adam(basic.parameters(), lr=0.001)
r2 = R2Score().to(device)
mse = MeanSquaredError().to(device)

for epoch in range(epochs):
  total_loss = total_examples = 0
  optimizer.zero_grad()
  pred = basic(
      train_data['team_in_match'].x,
      train_data['team_in_match', 'result', 'team_in_match'].edge_index
  ).T.to(device)
  ground_truth = train_data['team_in_match', 'result', 'team_in_match'].edge_attr.to(device)
  loss = torch.nn.functional.binary_cross_entropy_with_logits(pred, ground_truth)
  loss.backward()
  optimizer.step()
  total_loss += float(loss) * pred.numel()
  total_examples += pred.numel()
  r2.update(pred.T, ground_truth.T)
  mse.update(pred.T, ground_truth.T)

  #print(f"Epoch: {epoch+1:03d}, Loss: {total_loss / total_examples:.4f}, R2: {r2.compute()}, MSE: {mse.compute()}")

# validate
pred = basic(
    val_data['team_in_match'].x,
    val_data['team_in_match', 'result', 'team_in_match'].edge_index
).T.to(device)
ground_truth = val_data['team_in_match', 'result', 'team_in_match'].edge_attr.to(device)
r2.update(pred.T, ground_truth.T)
mse.update(pred.T, ground_truth.T)
print(f"Validation\n\tR2: {r2.compute()}\n\tMSE: {mse.compute()}")

### GCN

In [None]:
GCNModule = GCN(64, 64, 1, device=device)

optimizer = torch.optim.Adam(GCNModule.parameters(), lr=0.001)

train_edge_index_dict = {edge: train_data[edge].edge_index for edge in train_data.metadata()[1]}

for epoch in range(epochs):
  total_loss = total_examples = 0
  optimizer.zero_grad()

  train_data.to(device)
  pred = GCNModule(
      train_data['team_in_match'].x,
      train_data['team_in_match', 'result', 'team_in_match'].edge_index
  ).T.to(device)
  ground_truth = train_data['team_in_match', 'result', 'team_in_match'].edge_attr.to(device)
  loss = torch.nn.functional.binary_cross_entropy_with_logits(pred, ground_truth)
  loss.backward()
  optimizer.step()
  total_loss += float(loss) * pred.numel()
  total_examples += pred.numel()
  r2.update(pred.T, ground_truth.T)
  mse.update(pred.T, ground_truth.T)

  #print(f"Epoch: {epoch+1:03d}, Loss: {total_loss / total_examples:.4f}, R2: {r2.compute()}, MSE: {mse.compute()}")

# validate
pred = GCNModule(
    val_data['team_in_match'].x,
    val_data['team_in_match', 'result', 'team_in_match'].edge_index
).T.to(device)
ground_truth = val_data['team_in_match', 'result', 'team_in_match'].edge_attr.to(device)
r2.update(pred.T, ground_truth.T)
mse.update(pred.T, ground_truth.T)
print(f"Validation\n\tR2: {r2.compute()}\n\tMSE: {mse.compute()}")

### GAT

In [None]:
GATModule = GAT(64, 64, 1).to(device)

optimizer = torch.optim.Adam(GATModule.parameters(), lr=0.001)

for epoch in range(epochs):
  for i in range(training):
    total_loss = total_examples = 0
    optimizer.zero_grad()
    train_data.to(device)
    pred = GATModule(train_data['team_in_match'].x, train_data['team_in_match', 'result', 'team_in_match'].edge_index[:, i])
    ground_truth = train_data['team_in_match', 'result', 'team_in_match'].edge_attr[:, i]
    loss = torch.nn.functional.binary_cross_entropy_with_logits(pred, ground_truth)
    loss.backward()
    optimizer.step()
    total_loss += float(loss) * pred.numel()
    total_examples += pred.numel()

  print(f"Epoch: {epoch+1:03d}, Loss: {total_loss / total_examples:.4f}")

## Validate

In [None]:
preds = []
ground_truths = []

for i in range(validate):
  with torch.no_grad():
    preds.append(model(val_data['team_in_match'].x, val_data['team_in_match', 'result', 'team_in_match'].edge_index[:, i]))
    ground_truths.append(val_data['team_in_match', 'result', 'team_in_match'].edge_attr[:, i])

preds = [torch.Tensor([0.0]) if p[0] < 0.33 else torch.Tensor([1.0]) for p in preds]
ground_truths = [torch.Tensor([0.0]) if gt[0] < 0.33 else torch.Tensor([1.0]) for gt in ground_truths]

auc = roc_auc_score(ground_truths, preds)
fpr, tpr, _ = roc_curve(ground_truths, preds)

plt.plot([0, 1], [0, 1], color="red", lw=2, linestyle="--")
plt.plot(fpr,tpr, color="navy")
plt.ylabel('True Positive Rate')
plt.xlabel('False Positive Rate')
plt.show()
print(f"Validation AUC: {auc:.4f}")