In [1]:
import os
import pickle
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
import random



In [2]:
graphdir = "GraphData/"
graphdata = []

# Load all pickle files
for filename in os.listdir(graphdir):
    if filename.endswith(".pkl"):
        filepath = os.path.join(graphdir, filename)
        with open(filepath, 'rb') as f:
            data = pickle.load(f)
            graphdata.append(data)

In [3]:
graphdata[0][1]

{'graph': <networkx.classes.graph.Graph at 0x1846777ad30>,
 'node_features': tensor([[ 5.9391e-01,  3.7739e-01, -3.9466e-01,  9.9787e-01],
         [ 6.0078e-01,  3.3996e-01, -3.6062e-01,  9.9508e-01],
         [ 6.0762e-01,  3.4313e-01, -3.6099e-01,  9.9537e-01],
         [ 6.1390e-01,  3.4680e-01, -3.6104e-01,  9.9619e-01],
         [ 5.7383e-01,  3.3162e-01, -3.7300e-01,  9.9663e-01],
         [ 5.6228e-01,  3.3003e-01, -3.7370e-01,  9.9724e-01],
         [ 5.5041e-01,  3.2993e-01, -3.7395e-01,  9.9762e-01],
         [ 6.0957e-01,  3.6739e-01, -1.5495e-01,  9.9850e-01],
         [ 5.2296e-01,  3.5365e-01, -2.1433e-01,  9.9740e-01],
         [ 6.0434e-01,  4.2409e-01, -3.0932e-01,  9.9907e-01],
         [ 5.6956e-01,  4.1482e-01, -3.3171e-01,  9.9909e-01],
         [ 6.6580e-01,  6.4528e-01, -9.0966e-02,  9.9974e-01],
         [ 4.1218e-01,  6.1576e-01, -1.1522e-01,  9.9890e-01],
         [ 6.6696e-01,  8.9283e-01, -3.2242e-01,  5.2545e-01],
         [ 2.9072e-01,  9.5876e-01, -2.452

In [4]:
#Remaping apprach : Keeps graph structure and angle data
class FreeThrowDataset(Dataset):
    def __init__(self, matrix_data):
        self.data = matrix_data
        # Store the node indices that correspond to the angles
        # This can be inferred from your data sample
        self.angle_nodes = [11, 12, 13, 14, 15, 16, 23, 24, 25, 26, 27, 28] 

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sequence = self.data[idx]
        data_sequence = []
        
        for timestep in sequence:
            # Extract angles, edge info and label
            angles = timestep['angles']
            edge_index = timestep['edge_index']
            edge_attr = timestep['edge_attr']
            label = timestep['label']
            
            # Get all unique node indices from edge_index
            unique_nodes = torch.unique(edge_index)
            
            # Create mapping from original node indices to new consecutive indices (0, 1, 2...)
            node_mapping = {int(old_idx): new_idx for new_idx, old_idx in enumerate(unique_nodes)}
            
            # Remap edge indices using the mapping
            new_edge_index = torch.zeros_like(edge_index)
            for i in range(edge_index.size(1)):
                new_edge_index[0, i] = node_mapping[int(edge_index[0, i])]
                new_edge_index[1, i] = node_mapping[int(edge_index[1, i])]
            
            # Create features for all nodes in the new mapping
            # Initially set all to 0 (or some default value)
            num_graph_nodes = len(node_mapping)
            node_features = torch.zeros((num_graph_nodes, 1))
            
            # Map angles to the appropriate nodes in our new indexing
            for i, angle_val in enumerate(angles):
                if i < len(self.angle_nodes):  # Ensure we don't exceed available angles
                    orig_node_idx = self.angle_nodes[i]
                    if orig_node_idx in node_mapping:  # Check if this node exists in our graph
                        new_node_idx = node_mapping[orig_node_idx]
                        node_features[new_node_idx, 0] = angle_val
            
            # Create graph data object with remapped indices and features
            y = torch.tensor([label], dtype=torch.float)
            data = Data(
                x=node_features, 
                edge_index=new_edge_index, 
                edge_attr=edge_attr, 
                y=y, 
                num_nodes=num_graph_nodes
            )
            
            # Store original indices mapping for debugging or reference
            data.original_to_new_mapping = node_mapping
            
            data_sequence.append(data)
            
        return data_sequence

def collate_fn(batch):
    return batch

In [5]:
class GCN_LSTM(nn.Module):
    def __init__(self, in_channels, hidden_channels, lstm_hidden, num_classes):
        super().__init__()
        self.gcn = GCNConv(in_channels, hidden_channels)
        self.lstm = nn.LSTM(hidden_channels, lstm_hidden, batch_first=True)
        self.classifier = nn.Linear(lstm_hidden, num_classes)

    def forward(self, sequence):
        gcn_outputs = []
        for data in sequence:
            x = self.gcn(data.x, data.edge_index)
            x = torch.relu(x)
            pooled = x.mean(dim=0)  # Global mean pooling
            gcn_outputs.append(pooled)

        gcn_outputs = torch.stack(gcn_outputs).unsqueeze(0)  # [1, T, F]
        lstm_out, _ = self.lstm(gcn_outputs)
        out = self.classifier(lstm_out[:, -1, :])  # Use last time step
        return out


In [6]:
random.shuffle(graphdata)
split = int(0.7 * len(graphdata))
train_matrix = graphdata[:split]
test_matrix = graphdata[split:]

train_dataset = FreeThrowDataset(train_matrix)
test_dataset = FreeThrowDataset(test_matrix)

train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn)


In [7]:
model = GCN_LSTM(in_channels=1, hidden_channels=32, lstm_hidden=16, num_classes=1)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = nn.BCEWithLogitsLoss()

In [8]:
for epoch in range(10):
    model.train()
    total_loss = 0
    for batch in train_loader:
        sequence = batch[0]  # batch size = 1
        target = sequence[0].y
        output = model(sequence)
        loss = loss_fn(output.view(-1), target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1} - Loss: {total_loss:.4f}")


Epoch 1 - Loss: 10.3410
Epoch 2 - Loss: 9.7905
Epoch 3 - Loss: 9.3556
Epoch 4 - Loss: 9.1096
Epoch 5 - Loss: 9.0799
Epoch 6 - Loss: 9.0732
Epoch 7 - Loss: 9.0513
Epoch 8 - Loss: 9.0335
Epoch 9 - Loss: 9.0357
Epoch 10 - Loss: 9.0378


In [10]:
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for batch in test_loader:
        sequence = batch[0]
        target = int(sequence[0].y.item())
        output = model(sequence)
        prediction = (torch.sigmoid(output) > 0.5).int().item()
        correct += int(prediction == target)
        total += 1

print(f"Test Accuracy: {correct}/{total} = {correct / total:.2%}")


Test Accuracy: 6/7 = 85.71%
