In [4]:
import pandas as pd
df = pd.read_csv('data.csv')
df = df.drop(['jerseyNumber', 'playId'], axis=1)
df.to_csv('gat_data.csv', index=False)

In [9]:
import torch
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GATv2Conv
import torch_geometric
import pandas as pd
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

# Function to calculate Euclidean distance
def dist_between_points(p1: tuple, p2: tuple):
    return np.sqrt((p1[0] - p2[0])**2 + (p1[1] - p2[1])**2)

# GAT Model Definition
class GATModel(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(GATModel, self).__init__()
        self.gat1 = GATv2Conv(input_dim, hidden_dim, heads=4, concat=True)
        self.gat2 = GATv2Conv(hidden_dim * 4, output_dim, heads=1, concat=False)

    def forward(self, data):
        x, edge_index, edge_weight, batch = (
            data.x,
            data.edge_index,
            data.edge_attr,
            data.batch,
        )
        x = self.gat1(x, edge_index)  # Node-level embedding
        x = torch.relu(x)
        x = self.gat2(x, edge_index)  # Node-level output
        
        # Aggregate node outputs into graph-level output
        x = torch_geometric.nn.global_mean_pool(x, batch)  # Global mean pooling
        return x

# Prepare data from the dataframe
df = pd.read_csv('gat_data.csv')

graphs = []
for i in range(0, len(df), 22):  # Split every 22 rows
    graph_df = df.iloc[i:i + 22]
    features = graph_df[['x', 'y', 'speed', 'distance', 'direction', 'role']].values
    x = torch.tensor(features, dtype=torch.float)  # Node features
    y = torch.tensor(graph_df.iloc[0, -1], dtype=torch.long)  # Multi-class label (as integers)
    
    # Create fully connected graph
    num_nodes = x.shape[0]
    edge_index = torch.combinations(torch.arange(num_nodes), r=2).t()
    edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1)  # Bidirectional edges
    
    # Compute edge weights
    edge_weights = []
    for (src, dest) in edge_index.t().tolist():
        x1, y1 = graph_df.iloc[src][['x', 'y']].values
        x2, y2 = graph_df.iloc[dest][['x', 'y']].values
        edge_weights.append(dist_between_points((x1, y1), (x2, y2)))
    edge_weights = torch.tensor(edge_weights, dtype=torch.float)
    
    # Create PyG Data object
    graph = Data(x=x, edge_index=edge_index, edge_attr=edge_weights, y=y)
    graphs.append(graph)

# Dataloader
loader = DataLoader(graphs, batch_size=8, shuffle=True)

# Model, optimizer, and loss function
model = GATModel(input_dim=6, hidden_dim=64, output_dim=6)  # 6 output classes
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()  # Use CrossEntropyLoss for multi-class classification

# Training loop
for epoch in range(50):
    model.train()
    for batch in loader:
        optimizer.zero_grad()
        out = model(batch)
        loss = criterion(out, batch.y)
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch + 1}, Loss: {loss.item():.4f}")

# Evaluation
model.eval()
true_labels = []
predictions = []

with torch.no_grad():
    for batch in loader:
        out = model(batch)
        preds = torch.argmax(out, dim=1)  # Class predictions
        true_labels.extend(batch.y.cpu().numpy())  # True labels
        predictions.extend(preds.cpu().numpy())

# Compute metrics
accuracy = accuracy_score(true_labels, predictions)
precision = precision_score(true_labels, predictions, average='weighted')
recall = recall_score(true_labels, predictions, average='weighted')
f1 = f1_score(true_labels, predictions, average='weighted')

print("\nEvaluation Metrics:")
print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1:.4f}")

# GAT results
# Accuracy: 0.7937
# Precision: 0.8284
# Recall: 0.7937
# F1 Score: 0.7924

# GATv2 results with the same parameters
# Accuracy: 0.6984
# Precision: 0.7219
# Recall: 0.6984
# F1 Score: 0.6953



Epoch 1, Loss: 2.7707
Epoch 2, Loss: 3.3548
Epoch 3, Loss: 8.4832
Epoch 4, Loss: 0.8851
Epoch 5, Loss: 0.8955
Epoch 6, Loss: 1.3867
Epoch 7, Loss: 2.8460
Epoch 8, Loss: 1.8325
Epoch 9, Loss: 2.7069
Epoch 10, Loss: 1.6247
Epoch 11, Loss: 2.0362
Epoch 12, Loss: 1.4994
Epoch 13, Loss: 2.5588
Epoch 14, Loss: 1.4185
Epoch 15, Loss: 0.4424
Epoch 16, Loss: 0.2670
Epoch 17, Loss: 1.5346
Epoch 18, Loss: 0.4612
Epoch 19, Loss: 0.4481
Epoch 20, Loss: 1.3081
Epoch 21, Loss: 1.7244
Epoch 22, Loss: 0.6866
Epoch 23, Loss: 2.2121
Epoch 24, Loss: 0.5077
Epoch 25, Loss: 1.1045
Epoch 26, Loss: 1.1209
Epoch 27, Loss: 0.1961
Epoch 28, Loss: 0.5629
Epoch 29, Loss: 0.4505
Epoch 30, Loss: 0.3988
Epoch 31, Loss: 0.7464
Epoch 32, Loss: 1.7851
Epoch 33, Loss: 1.1116
Epoch 34, Loss: 0.3104
Epoch 35, Loss: 2.9175
Epoch 36, Loss: 1.6415
Epoch 37, Loss: 1.1070
Epoch 38, Loss: 1.3140
Epoch 39, Loss: 0.9100
Epoch 40, Loss: 0.6189
Epoch 41, Loss: 1.0309
Epoch 42, Loss: 0.8824
Epoch 43, Loss: 0.8235
Epoch 44, Loss: 0.28

  _warn_prf(average, modifier, msg_start, len(result))
