In [33]:
import os
import pickle
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
import random
import mediapipe as mp
from pymongo import MongoClient
from neo4j import GraphDatabase
from collections import defaultdict

In [34]:
# Initialize MediaPipe Pose
mp_pose = mp.solutions.pose
pose = mp_pose.Pose(
    static_image_mode=False,
    model_complexity=2,  
    smooth_landmarks=True,
    min_detection_confidence=0.5,
    min_tracking_confidence=0.5
)
mp_drawing = mp.solutions.drawing_utils

In [35]:
 #MongoDB connection
mongo_client = MongoClient("mongodb://admin:password@localhost:27017/")
mongo_db = mongo_client["SportsAnalysis"]
labels_collection = mongo_db["metadata"]

# Neo4j connection
neo4j_driver = GraphDatabase.driver("bolt://localhost:7687", auth=("neo4j", "password"))

In [36]:
def fetch_all_video_ids():
        return [doc['video_id'] for doc in labels_collection.find({}, {'video_id': 1})]

def fetch_label(video_id):
        doc = labels_collection.find_one({'video_id': video_id})
        return doc['label'] if doc else -1

In [37]:

def fetch_graphs_from_neo4j(video_id):
        with neo4j_driver.session() as session:
            # Get all unique time steps
            result = session.run("""
                MATCH (n:PoseNode {video_id: $video_id})
                RETURN DISTINCT n.time_index AS timestep
                ORDER BY timestep ASC
            """, video_id=video_id)
            time_steps = [record["timestep"] for record in result]

            graphs = []
            for t in time_steps:
                # Fetch nodes
                node_query = session.run("""
                    MATCH (n:PoseNode {video_id: $video_id, time_index: $t})
                    RETURN n.node_index AS idx, n.angle AS angle, n.time AS time
                    ORDER BY idx
                """, video_id=video_id, t=t)

                node_data = []
                time_value = 0
                for record in node_query:
                    node_data.append(float(record["angle"]))
                    time_value = float(record["time"])

                x = torch.tensor(node_data, dtype=torch.float).view(-1, 1)

                # Fetch edges
                edge_query = session.run("""
                    MATCH (a:PoseNode {video_id: $video_id, time_index: $t})-[r:CONNECTED_TO]->(b:PoseNode)
                    RETURN a.node_index AS src, b.node_index AS dst, r.weight AS weight
                """, video_id=video_id, t=t)

                edge_index = []
                edge_attr = []
                for record in edge_query:
                    edge_index.append([int(record["src"]), int(record["dst"])])
                    edge_attr.append([float(record["weight"])])

                edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
                edge_attr = torch.tensor(edge_attr, dtype=torch.float)

                graphs.append({
                    "edge_index": edge_index,
                    "edge_attr": edge_attr,
                    "angle_features": x,
                    "time": time_value,
                    "source_video": video_id,
                    "label": fetch_label(video_id),
                    "node_mapping": {},  # Optional: mapping if you have remapped indices
                    "reverse_mapping": {},
                    "node_features": x.clone()
                })
            return graphs

In [38]:
def load_graph_sequences_from_db():
    video_ids = fetch_all_video_ids()
    all_data = []

    for vid in video_ids:
        try:
            graph_sequence = fetch_graphs_from_neo4j(vid)
            if graph_sequence:
                all_data.append(graph_sequence)
        except Exception as e:
            print(f"Error loading video {vid}: {e}")

    print(f"✅ Loaded {len(all_data)} videos from DB")
    return all_data


In [39]:
mp_pose = mp.solutions.pose

class FreeThrowDataset(Dataset):
    def __init__(self, matrix_data):
        self.data = matrix_data
        # The angle nodes are automatically mapped in our enhanced processing
        # But we can still keep track of which landmarks are associated with angles
        self.angle_node_landmarks = [
            mp_pose.PoseLandmark.LEFT_ELBOW.value,
            mp_pose.PoseLandmark.RIGHT_ELBOW.value,
            mp_pose.PoseLandmark.LEFT_SHOULDER.value,
            mp_pose.PoseLandmark.RIGHT_SHOULDER.value,
            mp_pose.PoseLandmark.LEFT_KNEE.value,
            mp_pose.PoseLandmark.RIGHT_KNEE.value,
            mp_pose.PoseLandmark.LEFT_HIP.value,
            mp_pose.PoseLandmark.RIGHT_HIP.value
        ]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sequence = self.data[idx]
        data_sequence = []
        
        for timestep in sequence:
            # Check if timestep is a dictionary (as expected)
            if not isinstance(timestep, dict):
                raise TypeError(f"Expected dictionary, got {type(timestep)}. Value: {timestep}")
                
            try:
                edge_index = timestep['edge_index']
                edge_attr = timestep['edge_attr']
                label = timestep['label']
                angle_features = timestep['angle_features']
                
                # Convert label to tensor if it's not already
                if not isinstance(label, torch.Tensor):
                    y = torch.tensor([label], dtype=torch.float)
                else:
                    y = label
                
                # Create graph data object with already remapped indices and features
                data = Data(
                    x=angle_features,  # Use angle features as node features
                    edge_index=edge_index,
                    edge_attr=edge_attr,
                    y=y,
                    num_nodes=angle_features.size(0)
                )
                
                # Store additional information for reference
                data.original_to_new_mapping = timestep['node_mapping']
                data.new_to_original_mapping = timestep['reverse_mapping']
                data.positional_features = timestep['node_features']  # Store original position features
                data.time = timestep['time']
                data.source_video = timestep['source_video']
                
                data_sequence.append(data)
            except KeyError as e:
                # Print detailed error info for debugging
                print(f"KeyError: {e} not found in timestep. Available keys: {list(timestep.keys())}")
                raise
            except Exception as e:
                print(f"Error processing timestep: {e}")
                raise
            
        return data_sequence

def collate_fn(batch):
    """
    Custom collate function for batching sequences.
    Each batch item is a sequence of frames, and we want to
    maintain these sequences.
    """
    return batch

In [40]:
class GCN_LSTM(nn.Module):
    def __init__(self, in_channels, hidden_channels, lstm_hidden, num_classes):
        super().__init__()
        self.gcn = GCNConv(in_channels, hidden_channels)
        self.lstm = nn.LSTM(hidden_channels, lstm_hidden, batch_first=True)
        self.classifier = nn.Linear(lstm_hidden, num_classes)

    def forward(self, sequence):
        gcn_outputs = []
        for data in sequence:
            x = self.gcn(data.x, data.edge_index)
            x = torch.relu(x)
            pooled = x.mean(dim=0)  # Global mean pooling
            gcn_outputs.append(pooled)

        gcn_outputs = torch.stack(gcn_outputs).unsqueeze(0)  # [1, T, F]
        lstm_out, _ = self.lstm(gcn_outputs)
        out = self.classifier(lstm_out[:, -1, :])  # Use last time step
        return out


In [41]:
graphdata=load_graph_sequences_from_db()

✅ Loaded 60 videos from DB


In [46]:
print(len(graphdata))
print(len(graphdata[0]))

60
68


In [42]:
random.shuffle(graphdata)
split = int(0.7 * len(graphdata))
train_matrix = graphdata[:split]
test_matrix = graphdata[split:]

train_dataset = FreeThrowDataset(train_matrix)
test_dataset = FreeThrowDataset(test_matrix)

train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn)


In [50]:
model = GCN_LSTM(in_channels=1, hidden_channels=32, lstm_hidden=16, num_classes=1)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = nn.BCEWithLogitsLoss()

In [51]:
for epoch in range(50):
    model.train()
    total_loss = 0
    for batch in train_loader:
        sequence = batch[0]  # batch size = 1
        target = sequence[0].y
        output = model(sequence)
        loss = loss_fn(output.view(-1), target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1} - Loss: {total_loss:.4f}")


Epoch 1 - Loss: 28.9152
Epoch 2 - Loss: 26.4335
Epoch 3 - Loss: 26.1221
Epoch 4 - Loss: 26.4773
Epoch 5 - Loss: 26.0238
Epoch 6 - Loss: 25.9259
Epoch 7 - Loss: 25.9511
Epoch 8 - Loss: 25.8912
Epoch 9 - Loss: 25.6281
Epoch 10 - Loss: 25.7117
Epoch 11 - Loss: 25.6445
Epoch 12 - Loss: 25.5260
Epoch 13 - Loss: 25.6860
Epoch 14 - Loss: 25.4818
Epoch 15 - Loss: 25.5574
Epoch 16 - Loss: 25.3836
Epoch 17 - Loss: 26.0285
Epoch 18 - Loss: 25.3294
Epoch 19 - Loss: 25.3650
Epoch 20 - Loss: 25.2722
Epoch 21 - Loss: 25.3602
Epoch 22 - Loss: 25.5061
Epoch 23 - Loss: 25.4994
Epoch 24 - Loss: 25.2227
Epoch 25 - Loss: 25.2028
Epoch 26 - Loss: 25.4472
Epoch 27 - Loss: 25.1104
Epoch 28 - Loss: 24.9300
Epoch 29 - Loss: 25.1114
Epoch 30 - Loss: 25.2619
Epoch 31 - Loss: 24.9557
Epoch 32 - Loss: 25.5159
Epoch 33 - Loss: 25.6152
Epoch 34 - Loss: 24.9046
Epoch 35 - Loss: 25.0050
Epoch 36 - Loss: 24.7620
Epoch 37 - Loss: 25.3235
Epoch 38 - Loss: 25.0255
Epoch 39 - Loss: 24.6949
Epoch 40 - Loss: 24.9693
Epoch 41 

In [52]:
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for batch in test_loader:
        sequence = batch[0]
        target = int(sequence[0].y.item())
        output = model(sequence)
        prediction = (torch.sigmoid(output) > 0.5).int().item()
        correct += int(prediction == target)
        total += 1

print(f"Test Accuracy: {correct}/{total} = {correct / total:.2%}")


Test Accuracy: 16/18 = 88.89%
