## 🌀 unravel kloppy to trained graph neural network!

First run `pip install unravelsports` if you haven't already!


-----


In [None]:
%pip install unravelsports --quiet

### 1. Introduction

This notebook shows how to use this package to convert [Kloppy](https://github.com/PySport/kloppy) tracking data format into Graphs. These Graphs can subsequently be used to train a Graph Neural Network with the [Spektral](https://graphneural.network/) library as discussed in [A Graph Neural Network Deep-dive into Successful Counterattacks {A. Sahasrabudhe & J. Bekkers}](https://github.com/USSoccerFederation/ussf_ssac_23_soccer_gnn/tree/main).

This example follows these steps:
- [2. Imports](#2-imports)
- [3. Open SkillCorner Data](#3-open-skillcorner-data)
- [4. Graph Converter](#4-graph-converter)
- [5. Load Kloppy Data, Convert & Store](#5-load-kloppy-data-convert-and-store)
- [6. Creating a Custom Graph Dataset](#6-creating-a-custom-graph-dataset)
- [7. Prepare for Training](#7-prepare-for-training)
- [8. GNN Training + Prediction](#8-training-gnn)


### 2. Imports

We import `GraphConverter` to help us convert from Kloppy positional tracking frames to graphs.

Due to the power of **Kloppy** we can also load data from many other providers by importing `metrica`, `sportec`, `tracab`, `secondspectrum`, or `statsperform` from `kloppy`.

In [None]:
from unravel.soccer import GraphConverter

from kloppy import skillcorner

-------
### 3. Open SkillCorner Data

The `GraphConverter` class allows for the conversion of every tracking data provider supported by [PySports Kloppy](https://github.com/PySport/kloppy), namely:
- Sportec
- Tracab
- SecondSpectrum
- SkillCorner
- StatsPerform
- Metrica

In this example we're going to use tracking data frames from 4 matches of [Open SkillCorner Data](https://github.com/SkillCorner/opendata). 

All we need to know for now is that this data is from the following matches:

|  id | date_time           | home_team   | away_team   |
|---:|:---------------------:|:-----------------------|:-----------------------|
|  4039 | 2020-07-02T19:15:00Z | Manchester City        | Liverpool              |
|  3749 | 2020-05-26T16:30:00Z | Dortmund               | Bayern Munchen         |
|  3518 | 2020-03-08T19:45:00Z | Juventus               | Inter                  |
|  3442 | 2020-03-01T20:00:00Z | Real Madrid            | FC Barcelona           |

-------
### 4. Graph Converter

To get started with the `GraphConverter` we need to pass one _required_ parameter:
- `dataset` (of type `TrackingDataset` (Kloppy)) 

And one parameter that's required when we're converting for training purposes (more on this later):
- `labels` (a dictionary with `frame_id`s as keys and a value of `{True, False, 1 or 0}`).
```python
{83340: True, 83341: False, etc..} = {83340: 1, 83341: 0, etc..}
```
⚠️ You will need to create your own labels! In this example we'll use `dummy_labels(dataset)` to generate a fake label for each frame.


#### Graph Identifier(s):
When training a GNN it's highly recommended to split data into test/train(/validation) on match level, or sequence/possession level such that all values from one match/sequence/possession all end up in the same test, train or validation set. This should be done to avoid leaking information between test, train and validation sets.

To make this simple we have two _optional_ parameters we can pass to `GraphConverter`, namely:
- `graph_id`. This is a single identifier (str or int) for a whole match, for example the unique match id.
- `graph_ids`. This is a dictionary with the same keys as `labels`, but the values are now the unique identifiers. This option can be used if we want to split by sequence or possession_id. For example: {frame_id: 'matchId-sequenceId', frame_id: 'match_Id-sequenceId2'} etc. You will need to create your own possession/sequence ids. Note, if `labels` and `graph_ids` don't have the exact same keys it will throw an error. In this example we'll use the `graph_id=match_id` as the unique identifier, but feel free to change that for `graph_ids=dummy_graph_ids(dataset)` to test out that behavior.

Correctly splitting the final dataset in train, test and validiation sets is incorporated into `CustomSpektralDataset` (see section 7 for more information).


#### Graph Converter Settings:

<details>
    <summary><b><i> 🌀 Expand for a full table of additional <u>optional</u> GraphConverter parameters </i></b></summary><br>

| Parameter | Type | Description | Default |
|-----------|------|-------------|---------|
| `ball_carrier_threshold` | float | The distance threshold to determine the ball carrier in meters. If no ball carrier within ball_carrier_threshold, we skip the frame. | 25.0 |
| `max_player_speed` | float | The maximum speed of a player in meters per second. Used for normalizing node features. | 12.0 |
| `max_ball_speed` | float | The maximum speed of the ball in meters per second. Used for normalizing node features. | 28.0 |
| `boundary_correction` | float | A correction factor for boundary calculations, used to correct out of bounds as a percentage (Used as 1+boundary_correction, i.e., 0.05). Not setting this might lead to players outside the pitch markings to have values that fall slightly outside of our normalization range. When we set boundary_correction, any players outside the pitch will be moved to be on the closest line. | None |
| `self_loop_ball` | bool | Flag to indicate if the ball node should have a self-loop, aka be connected with itself and not only player(s) | True |
| `adjacency_matrix_connect_type` | str | The type of connection used in the adjacency matrix, typically related to the ball. Choose from 'ball', 'ball_carrier' or 'no_connection' | 'ball' |
| `adjacency_matrix_type` | str | The type of adjacency matrix, indicating how connections are structured, such as split by team. Choose from 'delaunay', 'split_by_team', 'dense', 'dense_ap' or 'dense_dp' | 'split_by_team' |
| `infer_ball_ownership` | bool | Infers 'attacking_team' if no 'ball_owning_team' exist (in Kloppy TrackingDataset) by finding the player closest to the ball using ball xyz, uses 'ball_carrier_threshold' as a cut-off. | True |
| `infer_goalkeepers` | bool | Set True if no GK label is provided, set False for incomplete (broadcast tracking) data that might not have a GK in every frame. | True |
| `defending_team_node_value` | float | Value for the node feature when player is on defending team. Should be between 0 and 1 including. | 0.1 |
| `non_potential_receiver_node_value` | float | Value for the node feature when player is NOT a potential receiver of a pass (when on opposing team or in possession of the ball). Should be between 0 and 1 including. | 0.1 |
| `label_type` | str | The type of prediction label used. Currently only supports 'binary' | 'binary' |
| `random_seed` | int, bool | When a random_seed is given, it will randomly shuffle an individual Graph without changing the underlying structure. When set to True, it will shuffle every frame differently; False won't shuffle. Advised to set True when creating an actual dataset to support Permutation Invariance. | False |
| `pad` | bool | True pads to a total amount of 22 players and ball (so 23x23 adjacency matrix). It dynamically changes the edge feature padding size based on the combination of AdjacencyMatrixConnectType and AdjacencyMatrixType, and self_loop_ball. No need to set padding because smaller and larger graphs can all be used in the same dataset. | False |
| `verbose` | bool | The converter logs warnings / error messages when specific frames have no coordinates, or other missing information. False mutes all of these warnings. | False |

</details>

#### 4.1 What is a Graph?

<details>
    <summary> <b><i>🌀 Expand for an short explanations on Graphs</i></b> </summary>
<div style="display: flex; align-items: flex-start;">
<div style="flex: 1; padding-right: 20px;">

Before we continue it might be good to briefly explain what a Graph even in is!

A Graph is a data structure consisting of:
- Nodes: Individual elements in the graph
- Edges: Connections between nodes

The graph is typically represented by:
- [Adjacency matrix](https://en.wikipedia.org/wiki/Adjacency_matrix): Shows connections between nodes
- Node features: Attributes or properties of each node
- Edge features: Attributes of the connections between nodes

The image on the right represents a stylized version of a frame of tracking data in soccer.

In section 6.1 we can see what this looks like in Python.

</div>
<div style="flex: 1;">

![Graph representation](https://github.com/UnravelSports/unravelsports.github.io/blob/main/imgs/what-is-a-graph-4.png?raw=true)

</div>
</div>
</details>

-------
### 5. Load Kloppy Data, Convert and Store

Here we loop over 4 SkillCorner matches and convert the first 500 frames.

Important things to note:
- We import `dummy_labels` to randomly generate binary labels.
- We import `dummy_graph_ids` to generate fake graph labels.
- Our `GraphConverter` uses the Kloppy `DatasetTransformer` under the hood, which will take care of setting up playing orientation and coordinate system correctly. Technically setting the coordinate system does not matter, because the `DatasetTransformer` transforms everything to `coordinates="secondspectrum"`, but setting it already will speed up parsing a bit.
- In this example we don't have any _actual_ labels for our tracking data frames, you are going to have to create your own. In this example we use `dummy_labels(dataset)` to randomly generate binary labels for each frame. Training with these random labels will not create a good model.
- We will end up with fewer than 2,000 eventhough we set `limit=500` frames because we set `include_empty_frames=False` and all frames without ball coordinates are automatically ommited.
- When using different providers always set `include_empty_frames=False` or `only_alive=True`
- We store the data as individual compressed pickle files, one file for per match. The data that gets stored in the pickle is a list of dictionaries, one dictionary per frame. Each dictionary has keys for the adjacency matrix, node features, edge features, label and graph id.

<details>
    <summary> <b><i> 🌀 Expand for a full list of features </b></i></summary>
    <div style="border: 2px solid #ddd; border-radius: 5px; padding: 10px; background-color: ##282C34;">
    <ul>
  <li>'a' (adjacency matrix) [np.array of shape (players+ball, players+ball)]</li>
  <li>'x' (node features) [np.array of shape (n_nodes, n_node_features)]. The currently implemented node features (in order) are:
    <ul>
      <li>normalized x-coordinate</li>
      <li>normalized y-coordinate</li>
      <li>x component of the velocity unit vector</li>
      <li>y component of the velocity unit vector</li>
      <li>normalized speed</li>
      <li>normalized angle of velocity vector</li>
      <li>normalized distance to goal</li>
      <li>normalized angle to goal</li>
      <li>normalized distance to ball</li>
      <li>normalized angle to ball</li>
      <li>attacking (1) or defending team (`defending_team_node_value`)</li>
      <li>potential receiver (1) else `non_potential_receiver_node_value`</li>
    </ul>
  </li>
  <li>'e' (edge features) [np.array of shape (np.non_zero(a), n_edge_features)]. The currently implemented edge features (in order) are:
    <ul>
      <li>normalized inter-player distance</li>
      <li>normalized inter-player speed difference</li>
      <li>inter-player angle cosine</li>
      <li>inter-player angle sine</li>
      <li>inter-player velocity vector cosine</li>
      <li>inter-player velocity vector sine</li>
      <li>optional: 1 if two players are connected else 0 according to delaunay adjacency matrix. Only if adjacency_matrix_type is NOT 'delauney'</li>
    </ul>
  </li>
  <li>'y' (label) [np.array]</li>
  <li>'id' (graph id) [int, str, None]
</ul>

</details>

In [None]:
from os.path import exists

from unravel.utils import dummy_labels, dummy_graph_ids

match_ids = [4039, 3749, 3518, 3442]
pickle_folder = "pickles"
compressed_pickle_file_path = "{pickle_folder}/{match_id}.pickle.gz"

for match_id in match_ids:
    match_pickle_file_path = compressed_pickle_file_path.format(
        pickle_folder=pickle_folder, match_id=match_id
    )
    # if the output file already exists, skip this whole step
    if not exists(match_pickle_file_path):

        # Load Kloppy dataset
        dataset = skillcorner.load_open_data(
            match_id=match_id,
            coordinates="secondspectrum",
            include_empty_frames=False,
            limit=500,  # limit to 500 frames in this example
        )

        # Initialize the Graph Converter, with dataset, labels and settings
        converter = GraphConverter(
            dataset=dataset,
            # create fake labels
            labels=dummy_labels(dataset),
            graph_id=match_id,
            # graph_ids=dummy_graph_ids(dataset),
            # settings
            ball_carrier_treshold=25.0,
            max_player_speed=12.0,
            max_ball_speed=28.0,
            boundary_correction=None,
            self_loop_ball=True,
            adjacency_matrix_connect_type="ball",
            adjacency_matrix_type="split_by_team",
            label_type="binary",
            infer_ball_ownership=True,
            infer_goalkeepers=True,
            defending_team_node_value=0.1,
            non_potential_receiver_node_value=0.1,
            random_seed=False,
            pad=True,
            verbose=False,
        )
        # Compute the graphs and directly store them as a pickle file
        converter.to_pickle(file_path=match_pickle_file_path)

-------
### 6. Creating a Custom Graph Dataset

- `CustomSpektralDataset` is a [`spektral.data.Dataset`](https://graphneural.network/creating-dataset/). 
This type of dataset makes it very easy to properly load, train and predict with a Spektral GNN.
- The `CustomSpektralDataset` has an option to load from a folder of compressed pickle files, all we have to do is pass the pickle_folder location.

In [None]:
from unravel.utils import CustomSpektralDataset

dataset = CustomSpektralDataset(pickle_folder=pickle_folder)

#### 6.1 Graphs in the CustomSpektralDataset

<details>
    <summary><b><i> 🌀 Expand for a short explanation on CustomSpektralDataset<i></b></summary><br>


##### CustomSpektralDataset
Let's have a look at the internals of our `CustomSpektralDataset`. 

The first item in our dataset has 23 nodes, 12 features per node and 7 features per edge.

<div style="border: 2px solid #ddd; border-radius: 5px; padding: 10px; background-color: ##282C34;">

```python
dataset.graphs[0]

>>> Graph(n_nodes=23, n_node_features=12, n_edge_features=7, n_labels=1)
```
<br>
</details>
<br>
<details>
    <summary><b><i> 🌀 Expand for a short explanation on the representation of adjacency matrix <i></b></summary><br>

##### Adjacency Matrix
The **adjacency matrix** is represented as a [compressed sparse row matrix](https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.csr_matrix.html#scipy.sparse.csr_matrix), as required by Spektral. A 'normal' version of this same matrix would be of shape 23x23 filled with zero's and only one's in places where two players (or ball) are connected. 

Because we set `adjacency_matrix_type='split_by_team'` and `adjacency_matrix_connect_type="ball"` this results in a total of 287 connections (ones), namely between:
- `adjacency_matrix_type='split_by_team'`:
    - All players on team A (11 * 11) 
    - All players on team B (11 * 11)
    - Ball connected to ball (1)
- `adjacency_matrix_connect_type="ball"`
    - All players and the ball (22) 
    - The ball and all players (22)

<div style="border: 2px solid #ddd; border-radius: 5px; padding: 10px; background-color: ##282C34;">

```python
dataset.graphs[0].a
>>> <Compressed Sparse Row sparse matrix of dtype 'float64'
	    with 287 stored elements and shape (23, 23)>
```
<br>
</details>
<br>
<details>
    <summary><b><i> 🌀 Expand for a short explanation on the representation of node feature matrix <i></b></summary><br>

##### Node Features
The **node features** are described using a regular Numpy array. Each column represents one feature and every row represents one player. 

The ball is presented in the last row, unless we set `random_seed=True` then every Graph gets randomly shuffled (while leaving connections in tact).

See the bullet points in **5. Load Kloppy Data, Convert and Store** to learn which column represents which feature.

The rows filled with zero's are 'empty' players created because we set `pad=True`. Graph Neural Networks are flexible enough to deal with all sorts of different graph shapes in the same dataset, normally it's not actually necessary to add these empty players, even for incomplete data with only a couple players in frame.

<div style="border: 2px solid #ddd; border-radius: 5px; padding: 10px; background-color: ##282C34;">

```python
dataset.graphs[0].x
>>> [[-0.163 -0.135  0.245 -0.97   0.007  0.289  0.959  0.191  0.059  0.376  1.     1.   ]
 [-0.332  0.011 -0.061  0.998  0.02   0.76   1.015  0.177  0.029  0.009  1.     0.1  ]
 [ 0.021 -0.072  0.987 -0.162  0.017  0.474  0.88   0.203  0.121  0.468  1.     1.   ]
 [-0.144  0.232  0.343  0.939  0.024  0.694  0.924  0.186  0.077  0.638  1.     1.   ]
 [-0.252  0.302  0.99   0.141  0.032  0.523  0.964  0.176  0.078  0.741  1.     1.   ]
 [ 0.012  0.573  0.834 -0.551  0.035  0.407  0.842  0.191  0.19   0.646  1.     1.   ]
 [-0.293  0.686  0.999 -0.045  0.044  0.493  0.966  0.163  0.182  0.761  1.     1.   ]
 [ 0.     0.     0.     0.     0.     0.     0.     0.     0.     0.     0.     0.   ]
 [ 0.     0.     0.     0.     0.     0.     0.     0.     0.     0.     0.     0.   ]
 [ 0.     0.     0.     0.     0.     0.     0.     0.     0.     0.     0.     0.   ]
 ...
 [ 0.202  0.124 -0.874  0.486  0.024  0.919  0.791  0.214  0.197  0.524  0.1    0.1  ]
 [ 0.404  0.143 -0.997  0.08   0.029  0.987  0.709  0.23   0.281  0.519  0.1    0.1  ]
 [ 0.195 -0.391  0.48  -0.877  0.014  0.33   0.847  0.218  0.222  0.417  0.1    0.1  ]
 [ 0.212 -0.063  0.982 -0.187  0.009  0.47   0.804  0.217  0.2    0.483  0.1    0.1  ]
 [-0.03   0.248 -0.996  0.091  0.021  0.986  0.876  0.194  0.116  0.591  0.1    0.1  ]
 [ 0.     0.     0.     0.     0.     0.     0.     0.     0.     0.     0.     0.   ]
 [ 0.     0.     0.     0.     0.     0.     0.     0.     0.     0.     0.     0.   ]
 [ 0.     0.     0.     0.     0.     0.     0.     0.     0.     0.     0.     0.   ]
 [ 0.     0.     0.     0.     0.     0.     0.     0.     0.     0.     0.     0.   ]
 [-0.262  0.016  0.937 -0.35   0.036  0.443  0.986  0.044  0.     0.     0.     0.   ]]

 
dataset.graphs[0].x.shape
>>> (23, 12)
```
<br>
</details>
<br>
<details>
    <summary><b><i> 🌀 Expand for a short explanation on the representation of edge feature matrix <i></b></summary><br>

##### Edge Features
The **edge features** are also represented in a regular Numpy array. Again, each column represents one feature, and every row decribes the connection between two players, or player and ball.

We saw before how the **adjacency matrix** was presented in a Sparse Row Matrix with 287 rows. It is no coincidence this lines up perfectly with the **edge feature matrix**. 

<div style="border: 2px solid #ddd; border-radius: 5px; padding: 10px; background-color: ##282C34;">

```python
dataset.graphs[0].e
>>> [[ 0.     0.     1.     0.5    0.5    1.     0.   ]
 [ 0.081  0.006  0.936  0.255  0.21   0.907  1.   ]
 [ 0.079  0.004  0.012  0.391  0.     0.515  1.   ]
 [ 0.1    0.007  0.46   0.002  0.005  0.571  1.   ]
 [ 0.125  0.011  0.65   0.023  0.474  0.999  0.   ]
 [ 0.206  0.012  0.322  0.033  0.535  0.999  0.   ]
 [ 0.23   0.016  0.619  0.014  0.567  0.996  0.   ]
 [ 0.     0.     0.     0.     0.     0.     0.   ]
 [ 0.     0.     0.     0.     0.     0.     0.   ]
 [ 0.     0.     0.     0.     0.     0.     0.   ]
 ...
 [ 0.197 -0.025  0.005  0.426  0.929  0.757  1.   ]
 [ 0.281 -0.023  0.004  0.439  0.959  0.699  1.   ]
 [ 0.222 -0.03   0.067  0.75   0.979  0.643  1.   ]
 [ 0.2   -0.032  0.003  0.554  0.982  0.633  1.   ]
 [ 0.116 -0.026  0.08   0.229  0.82   0.884  1.   ]
 [ 0.     0.     0.     0.     0.     0.     1.   ]
 [ 0.     0.     0.     0.     0.     0.     1.   ]
 [ 0.     0.     0.     0.     0.     0.     1.   ]
 [ 0.     0.     0.     0.     0.     0.     1.   ]
 [ 0.     0.     1.     0.5    0.5    1.     1.   ]]

 dataset.graphs[0].e.shape
 (287, 7)
```
<br>
</details>



---------
### 7. Prepare for Training

Now that we have all the data converted as Graphs inside our `CustomSpektralDataset` object, we can prepare to train the GNN model.


#### 7.1 Split Dataset

Our `dataset` object has two custom methods to help split the data into train, test and validation sets.
Either use `dataset.split_test_train()` if we don't need a validation set, or `dataset.split_test_train_validation()` if we do also require a validation set.

We can split our data amongst these subsets 'by_graph_id' if we have provided Graph Ids in our `GraphConverter` using the 'graph_id' or 'graph_ids' parameter.
The 'split_train', 'split_test' and 'split_validation' parameters can either be ratios, percentages or relative size compared to total. 

Note: We can see that, because we are splitting by only 4 different graph_ids here (the 4 match_ids) the ratio's aren't perfectly 4 to 1 to 1. If you change the `graph_id=match_id` parameter in the `GraphConverter` to `graph_ids=dummy_graph_ids(dataset)` you'll see that it's easier to get close to the correct ratios, simply because we have a lot more graph_ids to split a cross. 

In [None]:
train, test, val = dataset.split_test_train_validation(
    split_train=4, split_test=1, split_validation=1, by_graph_id=True, random_seed=42
)
print("Train:", train)
print("Test:", test)
print("Validation:", val)

#### 7.2 Model Configurations

In [None]:
learning_rate = 1e-3
epochs = 5  # Increase for actual training
batch_size = 32
channels = 128
n_layers = 3  # Number of CrystalConv layers

#### 7.3 Build GNN Model

This GNN Model has the same architecture as described in [A Graph Neural Network Deep-dive into Successful Counterattacks {A. Sahasrabudhe & J. Bekkers}](https://github.com/USSoccerFederation/ussf_ssac_23_soccer_gnn/tree/main)

This exact model can also simply be loaded as:

`from unravel.classifiers imoprt CrystalGraphClassifier`

Below we show the exact same code to make it easier to adjust.

In [None]:
from spektral.layers import GlobalAvgPool, CrystalConv
from tensorflow.keras.layers import Dense, Dropout
from tensorflow.keras.models import Model


class CrystalGraphClassifier(Model):
    def __init__(
        self,
        n_layers: int = 3,
        channels: int = 128,
        drop_out: float = 0.5,
        n_out: int = 1,
        **kwargs
    ):
        super().__init__(**kwargs)

        self.n_layers = n_layers
        self.channels = channels
        self.drop_out = drop_out
        self.n_out = n_out

        self.conv1 = CrystalConv()
        self.convs = [CrystalConv() for _ in range(1, self.n_layers)]
        self.pool = GlobalAvgPool()
        self.dense1 = Dense(self.channels, activation="relu")
        self.dropout = Dropout(self.drop_out)
        self.dense2 = Dense(self.channels, activation="relu")
        self.dense3 = Dense(self.n_out, activation="sigmoid")

    def call(self, inputs):
        x, a, e, i = inputs
        x = self.conv1([x, a, e])
        for conv in self.convs:
            x = conv([x, a, e])
        x = self.pool([x, i])
        x = self.dense1(x)
        x = self.dropout(x)
        x = self.dense2(x)
        x = self.dropout(x)
        return self.dense3(x)

#### 7.4 Create DataLoaders

Create a Spektral [`DisjointLoader`](https://graphneural.network/loaders/#disjointloader). This DisjointLoader will help us to load batches of Disjoint Graphs for training purposes.

Note that these Spektral `Loaders` return a generator, so if we want to retrain the model, we need to reload these loaders.

In [None]:
from spektral.data import DisjointLoader

loader_tr = DisjointLoader(train, batch_size=batch_size, epochs=epochs)
loader_va = DisjointLoader(val, epochs=1, shuffle=False, batch_size=batch_size)

--------
### 8. Training

Below we outline how to train the model, make predictions and add the predicted values back to the Kloppy dataframe.

#### 8.1 Compile Model

1. Initialize the `CrystalGraphClassifier` (or create your own Graph Classifier).
2. Compile the model with a loss function, optimizer and your preferred metrics.

In [None]:
from tensorflow.keras.metrics import AUC, BinaryAccuracy
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping

model = CrystalGraphClassifier()

model.compile(
    loss=BinaryCrossentropy(), optimizer=Adam(), metrics=[AUC(), BinaryAccuracy()]
)

#### 8.2 Fit Model

1. We have a a [`DisjointLoader`](https://graphneural.network/loaders/#disjointloader) for training and validation sets.
2. Fit the model. 
3. We add `EarlyStopping` and a `validation_data` dataset to monitor performance, and set `use_multiprocessing=Treu` to improve training speed.

In [None]:
model.fit(
    loader_tr.load(),
    steps_per_epoch=loader_tr.steps_per_epoch,
    epochs=10,
    use_multiprocessing=True,
    validation_data=loader_va.load(),
    callbacks=[EarlyStopping(monitor="loss", patience=5, restore_best_weights=True)],
)

#### 8.3 Save & Load Model

This step is solely included to show how to restore a model.

In [None]:
from tensorflow.keras.models import load_model

model_path = "models/my-first-graph-classifier"
model.save(model_path)
loaded_model = load_model(model_path)

#### 8.4 Evaluate Model

1. Create another `DisjointLoader`, this time for the test set.
2. Evaluate model performance on the test set. This evaluation function uses the `metrics` passed to `model.compile`

Note: Our performance is really bad because we're using random labels, very few epochs and a small dataset.

In [None]:
loader_te = DisjointLoader(test, epochs=1, shuffle=False, batch_size=batch_size)
results = model.evaluate(loader_te.load())

#### 8.5 Predict on New Data

1. Load new, unseen data from the SkillCorner dataset.
2. Convert this data, making sure we use the exact same settings as in step 1.
3. If we set `prediction=True` we do not have to supply labels to the `GraphConverter`.

In [None]:
kloppy_dataset = skillcorner.load_open_data(
    match_id=2068,
    include_empty_frames=False,
    limit=500,
)

preds_converter = GraphConverter(
    dataset=kloppy_dataset,
    prediction=True,
    ball_carrier_treshold=25.0,
    max_player_speed=12.0,
    max_ball_speed=28.0,
    boundary_correction=None,
    self_loop_ball=True,
    adjacency_matrix_connect_type="ball",
    adjacency_matrix_type="split_by_team",
    label_type="binary",
    infer_ball_ownership=True,
    infer_goalkeepers=True,
    defending_team_node_value=0.1,
    non_potential_receiver_node_value=0.1,
    random_seed=False,
    pad=True,
    verbose=False,
)

4. Make a prediction on all the frames of this dataset using `model.predict`

In [None]:
# Compute the graphs and add them to the CustomSpektralDataset
pred_dataset = CustomSpektralDataset(graphs=preds_converter.to_spektral_graphs())

loader_pred = DisjointLoader(
    pred_dataset, batch_size=batch_size, epochs=1, shuffle=False
)
preds = model.predict(loader_pred.load(), use_multiprocessing=True)

5. Convert Klopy dataset to a dataframe and merge back the pedictions using the frame_ids.

Note: Not all frames have a prediction because of missing (ball) data.

In [None]:
import pandas as pd

kloppy_df = kloppy_dataset.to_df()

preds_df = pd.DataFrame(
    {"frame_id": [x.id for x in pred_dataset], "y": preds.flatten()}
)

kloppy_df = pd.merge(kloppy_df, preds_df, on="frame_id", how="left")