In [1]:
import os
import pickle
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
import random



In [2]:
graphdir = "GraphData/"
graphdata = []

# Load all pickle files
for filename in os.listdir(graphdir):
    if filename.endswith(".pkl"):
        filepath = os.path.join(graphdir, filename)
        with open(filepath, 'rb') as f:
            data = pickle.load(f)
            graphdata.append(data)

In [3]:
graphdata[0][1]

{'graph': <networkx.classes.graph.Graph at 0x1b67de17280>,
 'node_features': tensor([[ 5.9674e-01,  3.8303e-01, -3.9097e-01,  9.9787e-01],
         [ 6.0173e-01,  3.4315e-01, -3.5382e-01,  9.9507e-01],
         [ 6.0871e-01,  3.4554e-01, -3.5418e-01,  9.9536e-01],
         [ 6.1513e-01,  3.4892e-01, -3.5422e-01,  9.9619e-01],
         [ 5.7466e-01,  3.3674e-01, -3.6979e-01,  9.9662e-01],
         [ 5.6337e-01,  3.3541e-01, -3.7053e-01,  9.9723e-01],
         [ 5.5158e-01,  3.3500e-01, -3.7083e-01,  9.9762e-01],
         [ 6.0972e-01,  3.6857e-01, -1.3627e-01,  9.9849e-01],
         [ 5.2440e-01,  3.5615e-01, -2.1193e-01,  9.9738e-01],
         [ 6.0704e-01,  4.2948e-01, -3.0080e-01,  9.9907e-01],
         [ 5.7367e-01,  4.2146e-01, -3.2778e-01,  9.9909e-01],
         [ 6.6430e-01,  6.4296e-01, -4.1268e-02,  9.9974e-01],
         [ 4.1306e-01,  6.1594e-01, -1.0622e-01,  9.9892e-01],
         [ 6.6434e-01,  8.9774e-01, -1.5040e-01,  5.0578e-01],
         [ 2.9108e-01,  9.5926e-01, -2.134

In [4]:
#Remaping apprach : Keeps graph structure and angle data
class FreeThrowDataset(Dataset):
    def __init__(self, matrix_data):
        self.data = matrix_data
        # Store the node indices that correspond to the angles
        # This can be inferred from your data sample
        self.angle_nodes = [11, 12, 13, 14, 15, 16, 23, 24, 25, 26, 27, 28] 

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sequence = self.data[idx]
        data_sequence = []
        
        for timestep in sequence:
            # Extract angles, edge info and label
            angles = timestep['angles']
            edge_index = timestep['edge_index']
            edge_attr = timestep['edge_attr']
            label = timestep['label']
            
            # Get all unique node indices from edge_index
            unique_nodes = torch.unique(edge_index)
            
            # Create mapping from original node indices to new consecutive indices (0, 1, 2...)
            node_mapping = {int(old_idx): new_idx for new_idx, old_idx in enumerate(unique_nodes)}
            
            # Remap edge indices using the mapping
            new_edge_index = torch.zeros_like(edge_index)
            for i in range(edge_index.size(1)):
                new_edge_index[0, i] = node_mapping[int(edge_index[0, i])]
                new_edge_index[1, i] = node_mapping[int(edge_index[1, i])]
            
            # Create features for all nodes in the new mapping
            # Initially set all to 0 (or some default value)
            num_graph_nodes = len(node_mapping)
            node_features = torch.zeros((num_graph_nodes, 1))
            
            # Map angles to the appropriate nodes in our new indexing
            for i, angle_val in enumerate(angles):
                if i < len(self.angle_nodes):  # Ensure we don't exceed available angles
                    orig_node_idx = self.angle_nodes[i]
                    if orig_node_idx in node_mapping:  # Check if this node exists in our graph
                        new_node_idx = node_mapping[orig_node_idx]
                        node_features[new_node_idx, 0] = angle_val
            
            # Create graph data object with remapped indices and features
            y = torch.tensor([label], dtype=torch.float)
            data = Data(
                x=node_features, 
                edge_index=new_edge_index, 
                edge_attr=edge_attr, 
                y=y, 
                num_nodes=num_graph_nodes
            )
            
            # Store original indices mapping for debugging or reference
            data.original_to_new_mapping = node_mapping
            
            data_sequence.append(data)
            
        return data_sequence

def collate_fn(batch):
    return batch

In [5]:
class GCN_LSTM(nn.Module):
    def __init__(self, in_channels, hidden_channels, lstm_hidden, num_classes):
        super().__init__()
        self.gcn = GCNConv(in_channels, hidden_channels)
        self.lstm = nn.LSTM(hidden_channels, lstm_hidden, batch_first=True)
        self.classifier = nn.Linear(lstm_hidden, num_classes)

    def forward(self, sequence):
        gcn_outputs = []
        for data in sequence:
            x = self.gcn(data.x, data.edge_index)
            x = torch.relu(x)
            pooled = x.mean(dim=0)  # Global mean pooling
            gcn_outputs.append(pooled)

        gcn_outputs = torch.stack(gcn_outputs).unsqueeze(0)  # [1, T, F]
        lstm_out, _ = self.lstm(gcn_outputs)
        out = self.classifier(lstm_out[:, -1, :])  # Use last time step
        return out


In [10]:
random.shuffle(graphdata)
split = int(0.66 * len(graphdata))
train_matrix = graphdata[:split]
test_matrix = graphdata[split:]

train_dataset = FreeThrowDataset(train_matrix)
test_dataset = FreeThrowDataset(test_matrix)

train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn)


In [11]:
model = GCN_LSTM(in_channels=1, hidden_channels=32, lstm_hidden=16, num_classes=1)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = nn.BCEWithLogitsLoss()

In [12]:
for epoch in range(10):
    model.train()
    total_loss = 0
    for batch in train_loader:
        sequence = batch[0]  # batch size = 1
        target = sequence[0].y
        output = model(sequence)
        loss = loss_fn(output.view(-1), target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1} - Loss: {total_loss:.4f}")


Epoch 1 - Loss: 22.2404
Epoch 2 - Loss: 16.4594
Epoch 3 - Loss: 15.1236
Epoch 4 - Loss: 14.7029
Epoch 5 - Loss: 13.6588
Epoch 6 - Loss: 13.2527
Epoch 7 - Loss: 13.2183
Epoch 8 - Loss: 13.1706
Epoch 9 - Loss: 13.1930
Epoch 10 - Loss: 13.0082


In [13]:
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for batch in test_loader:
        sequence = batch[0]
        target = int(sequence[0].y.item())
        output = model(sequence)
        prediction = (torch.sigmoid(output) > 0.5).int().item()
        correct += int(prediction == target)
        total += 1

print(f"Test Accuracy: {correct}/{total} = {correct / total:.2%}")


Test Accuracy: 10/12 = 83.33%
