In [1]:
import pandas as pd
import numpy as np 
import os

import torch
from torch.utils.data import Dataset
from torch_geometric.data import DataLoader

from preprocessing import drop_null_in_players
from graph import create_graph_from_dataset
from model_data import model_data, DiskDataset, check_data
from model import GNNRegression, train_nn


from torch.optim.lr_scheduler import ReduceLROnPlateau



In [2]:
path_data_model = "../data/model_data/"
path_src_data = "../data/parquet/"
os.makedirs(path_data_model, exist_ok=True)

In [7]:
files_output = list(map(lambda x: str(x).replace(".pt", ""), os.listdir(path_data_model)))

In [9]:
files_output = list(map(lambda x: str(x).replace(".pt", ""), os.listdir(path_data_model)))


for name_game in os.listdir(path_src_data):

    if name_game not in files_output:
        try:
            data = pd.read_parquet(f"{path_src_data}{name_game}")
            data.sort_values(by=['quarter', 'game_clock'], ascending=[True, False], inplace=True)

            data[['awayscore', 'homescore']] = data[['awayscore', 'homescore']].ffill()
            data['shot_clock'] = data['shot_clock'].fillna(24)
            
            demo_data = data.iloc[0]    
            teams_list  = [demo_data['player2_team'],demo_data['player9_team']] 

            data_wo_empty_posessions = data[data['posession'].notna()]
            data_wo_null_values = drop_null_in_players(data=data_wo_empty_posessions)
            saved_graphs = create_graph_from_dataset(data=data_wo_null_values)
            y = np.where(data_wo_null_values['posession'] == 'home', data_wo_null_values['y_home'], data_wo_null_values['y_away'])

            data_list = model_data(saved_graphs=saved_graphs, y=y)
            check_data(data_list=data_list)

            save_dir = f'{path_data_model}{name_game}.pt'
            torch.save(data_list, save_dir)

        except:
            pass


  self.add_edge(row[0], row[1], weight=row[2])


#### Create graph between players from a moment

#### Parse it to graph

#### Write to disk

In [23]:
batch_size = 32
dataset = torch.load(f"{path_data_model}{files_output[0]}.pt")
loader = DataLoader(dataset, batch_size=32, shuffle=True)




In [26]:
def train_nn(model, criterion, optimizer, loader, patience_early_stopping=7, patience_plateau=3, min_lr=1e-5, save_path='best_model.pth'):
    # Check if GPU is available
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print("Using :", device)
    
    # Move model to GPU if available
    model.to(device)
    
    # Initialize the learning rate scheduler
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=patience_plateau, min_lr=min_lr)
    
    # Training loop
    model.train()
    best_loss = float('inf')
    epochs_no_improve = 0
    
    for epoch in range(100):
        total_loss = 0
        for batch in loader:
            # Move data to GPU if available
            batch = batch.to(device)
            
            optimizer.zero_grad()
            output = model(batch)
            
            # Reshape target tensor to match the shape of the output tensor
            target = batch.y.view(-1, 1)
            
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        
        avg_loss = total_loss / len(loader)
        
        # Step the scheduler
        scheduler.step(avg_loss)
        
        # Early stopping check
        if avg_loss < best_loss:
            best_loss = avg_loss
            epochs_no_improve = 0
            
            # Save the model
            torch.save(model.state_dict(), save_path)
            print(f'Model saved at epoch {epoch} with loss {avg_loss}')
        else:
            epochs_no_improve += 1
        
        if epoch % 10 == 0:
            print(f'Epoch {epoch}, Loss: {avg_loss}')
        
        if epochs_no_improve >= patience_early_stopping:
            print(f'Early stopping at epoch {epoch}')
            break

In [28]:
# Initialize the model, loss function, and optimizer
input_dim = 2 ## number of node features
hidden_dim = 16
output_dim = 1

model = GNNRegression(input_dim, hidden_dim, output_dim)
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)


train_nn(model=model, criterion=criterion, optimizer=optimizer, loader=loader)


Using : cpu
Model saved at epoch 0 with loss 30.39961739497173
Epoch 0, Loss: 30.39961739497173
Model saved at epoch 1 with loss 1.380727461219704
Model saved at epoch 2 with loss 1.362103943552124
Model saved at epoch 3 with loss 1.3494276236153577
Model saved at epoch 6 with loss 1.346230309955105
Model saved at epoch 7 with loss 1.3455237315519013
Model saved at epoch 8 with loss 1.3398704488202022
Model saved at epoch 9 with loss 1.3379573911936034
Epoch 10, Loss: 1.3405592686648495
Model saved at epoch 13 with loss 1.3326278252033132
Model saved at epoch 14 with loss 1.3313080877283194
Model saved at epoch 16 with loss 1.3227812517298398
Epoch 20, Loss: 1.3234331802150048
Model saved at epoch 21 with loss 1.3143099050452238
Model saved at epoch 22 with loss 1.3051259101452328
Model saved at epoch 26 with loss 1.3051236657620637
Model saved at epoch 27 with loss 1.2926261165426305
Model saved at epoch 29 with loss 1.292539354658475
Epoch 30, Loss: 1.2952757676442463
Model saved at 

In [15]:
for batch in loader:
    print(len(batch))

12877


In [18]:
batch

[DataBatch(x=[26, 2], edge_index=[2, 20], edge_attr=[20], y=[2], batch=[26], ptr=[3]),
 DataBatch(x=[26, 2], edge_index=[2, 20], edge_attr=[20], y=[2], batch=[26], ptr=[3]),
 DataBatch(x=[26, 2], edge_index=[2, 20], edge_attr=[20], y=[2], batch=[26], ptr=[3]),
 DataBatch(x=[26, 2], edge_index=[2, 20], edge_attr=[20], y=[2], batch=[26], ptr=[3]),
 DataBatch(x=[26, 2], edge_index=[2, 20], edge_attr=[20], y=[2], batch=[26], ptr=[3]),
 DataBatch(x=[26, 2], edge_index=[2, 20], edge_attr=[20], y=[2], batch=[26], ptr=[3]),
 DataBatch(x=[26, 2], edge_index=[2, 20], edge_attr=[20], y=[2], batch=[26], ptr=[3]),
 DataBatch(x=[26, 2], edge_index=[2, 20], edge_attr=[20], y=[2], batch=[26], ptr=[3]),
 DataBatch(x=[26, 2], edge_index=[2, 20], edge_attr=[20], y=[2], batch=[26], ptr=[3]),
 DataBatch(x=[26, 2], edge_index=[2, 20], edge_attr=[20], y=[2], batch=[26], ptr=[3]),
 DataBatch(x=[26, 2], edge_index=[2, 20], edge_attr=[20], y=[2], batch=[26], ptr=[3]),
 DataBatch(x=[26, 2], edge_index=[2, 20], e