Jupyter Notebook for Training of Broke Jumpshot Detector ML Model

Steps: 
1. Imports
2. Add model for transfer learning
3. Load dataset
4. Define model architecture
5. Split into train, validation, and test
6. Train
7. Test Performance
8. Save

Step 1: Imports

In [76]:
#Step 1: Imports
import os
import cv2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torchvision
from glob import glob
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
import torch.nn as nn

Step 2: Transfer Learning Model

In [77]:
#Step 2: Transfer Learning Model
#Mediapipe Pose(blazepose), will extract keypoints from images, which will be used to classify shots as broke or not broke
#Movenet is what we're using, but it;s within mediapipe
import mediapipe as mp
mp_pose = mp.solutions.pose
#To extract keypoints: mp_pose.Pose(static_image_mode=True, min_detection_confidence=0.5)

Step 3: Load dataset

In [78]:
#Step 3: Load dataset
datasetPath = "./dataset/"
imageData = []
phases = ["followthrough", "setpoint", "shotpocket"]
labels = ["broke", "butter"]

for phase in phases:
    for label in labels:
        path = os.path.join(datasetPath, phase, label)

        images = glob(os.path.join(path, "*.jpg")) + glob(os.path.join(path, "*.png")) + glob(os.path.join(path, "*.jpeg"))


        for imgFile in images:
            imageData.append((imgFile, phase, label))

print(f"Total images loaded: {len(imageData)}")

Total images loaded: 353


In [79]:
#Normalize keypoints function
def normalize(keypoints):
    
    xs = [kp[0] for kp in keypoints]
    ys = [kp[1] for kp in keypoints]

    min_x, max_x = min(xs), max(xs)
    min_y, max_y = min(ys), max(ys)

    width = max_x - min_x
    height = max_y - min_y

    return [((x-min_x)/width, (y-min_y)/height, z, v)
            for x,y,z,v in keypoints]

In [80]:
#Extract pose data and save in a list
#Each image has its pose data, class(set point, etc), and brokeness
#For each image, extract pose, normalize pose, infer label from filename, append to list
poseData = [] #holds keypoints, phase, label

pose = mp_pose.Pose(static_image_mode=True, model_complexity=2, min_detection_confidence=0.5)

for i in range(len(imageData)):
    imgFile, phase, label = imageData[i]

    img = cv2.imread(imgFile)
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    print("Shape:", img_rgb.shape)

    results = pose.process(img_rgb)

    if not results.pose_landmarks:
        print("No pose detected")
    else:
        landmarks = results.pose_landmarks.landmark
        print("Detected", len(landmarks), "landmarks")

    keypoints = []
    for landmark in landmarks:
        keypoints.append((landmark.x, landmark.y, landmark.z, landmark.visibility))
    
    poseData.append({
        "keypoints": normalize(keypoints), 
        "phase": phase, 
        "label": label,
        "path": imgFile
    })


Shape: (774, 720, 3)
Detected 33 landmarks
Shape: (720, 1280, 3)


I0000 00:00:1767655738.160253 7944265 gl_context.cc:357] GL version: 2.1 (2.1 Metal - 90.5), renderer: Apple M1 Pro
W0000 00:00:1767655738.233221 7957657 inference_feedback_manager.cc:114] Feedback manager requires a model with a single signature inference. Disabling support for feedback tensors.
W0000 00:00:1767655738.261067 7957659 inference_feedback_manager.cc:114] Feedback manager requires a model with a single signature inference. Disabling support for feedback tensors.


Detected 33 landmarks
Shape: (888, 1920, 3)
Detected 33 landmarks
Shape: (854, 474, 3)
Detected 33 landmarks
Shape: (1080, 1920, 3)
Detected 33 landmarks
Shape: (1920, 1080, 3)
Detected 33 landmarks
Shape: (1920, 1080, 3)
No pose detected
Shape: (1920, 1080, 3)
Detected 33 landmarks
Shape: (1920, 1080, 3)
Detected 33 landmarks
Shape: (1920, 1080, 3)
Detected 33 landmarks
Shape: (1920, 1080, 3)
Detected 33 landmarks
Shape: (1920, 1080, 3)
Detected 33 landmarks
Shape: (612, 408, 3)
Detected 33 landmarks
Shape: (1920, 1080, 3)
Detected 33 landmarks
Shape: (1920, 1080, 3)
Detected 33 landmarks
Shape: (612, 408, 3)
Detected 33 landmarks
Shape: (1920, 1080, 3)
Detected 33 landmarks
Shape: (1920, 1080, 3)
Detected 33 landmarks
Shape: (408, 612, 3)
Detected 33 landmarks
Shape: (612, 408, 3)
No pose detected
Shape: (1920, 1080, 3)
Detected 33 landmarks
Shape: (1920, 1080, 3)
Detected 33 landmarks
Shape: (1920, 1080, 3)
Detected 33 landmarks
Shape: (1920, 1080, 3)
Detected 33 landmarks
Shape: (1

In [81]:
#Convert to pytorch dataset
labelToIdx = {"broke": 0, "butter": 1}
phaseToIdx = {"shotpocket": 0, "setpoint": 1, "followthrough": 2}

class PoseDataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        
        poseVector = np.array(item["keypoints"], dtype=np.float32).flatten()

        phaseVector = torch.zeros(len(phaseToIdx), dtype=torch.float32)
        phase_idx = phaseToIdx[item["phase"]]
        phaseVector[phase_idx] = 1.0

        #Pose + phase vector
        inputVector = torch.tensor(poseVector, dtype=torch.float32)
        inputVector = torch.cat([inputVector, phaseVector])

        label_idx = labelToIdx[item["label"]]

        return inputVector, torch.tensor(label_idx)

In [82]:
#Visualize dataset
dataset = PoseDataset(poseData)
loader = DataLoader(dataset, batch_size=8, shuffle=True)

batch = next(iter(loader))
batchInputs, batchLabels = batch

mp_drawing = mp.solutions.drawing_utils
mp_styles = mp.solutions.drawing_styles

def show_pose_on_image(image_path, phase=None, label=None):
    img = cv2.imread(image_path)
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    results = pose.process(img_rgb)

    if not results.pose_landmarks:
        print("No pose detected in:", image_path)
        return

    mp_drawing.draw_landmarks(
        img_rgb,
        results.pose_landmarks,
        mp_pose.POSE_CONNECTIONS,
        landmark_drawing_spec=mp_styles.get_default_pose_landmarks_style()
    )

    plt.figure(figsize=(6,8))
    plt.imshow(img_rgb)
    plt.axis("off")
    plt.show()
    if label and phase:
        print(f"Label: {label}, Phase: {phase}")
    else:
        print(f"Image path: {image_path}")

Step 4: Define model architecture



In [83]:
#Step 4: Define model architecture
#MLP model to classify pose keypoints
class PoseMLP(nn.Module):
    def __init__(self, input_dim = 135, hidden_dim1 = 128, hidden_dim2 = 64, dropout = 0.2, output_dim = 1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim1),
            nn.BatchNorm1d(hidden_dim1),
            nn.ReLU(),
            nn.Dropout(dropout),


            nn.Linear(hidden_dim1, hidden_dim2),
            nn.BatchNorm1d(hidden_dim2),
            nn.ReLU(),
            nn.Dropout(dropout),

            nn.Linear(hidden_dim2, output_dim)
        )
    
    def forward(self, x):
        return self.net(x).squeeze(-1)


In [84]:
#Compile Model
model = PoseMLP(input_dim=135, hidden_dim1=128, hidden_dim2=64, dropout=0.2, output_dim=1)
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)


Step 5: Split data

In [85]:
#Step 5: Split data within list into train, validation, and test
train_data, test_data = train_test_split(poseData, test_size=0.1, random_state=42)
train_data, val_data = train_test_split(train_data, test_size=0.2/.9, random_state=42)

Step 6: Train

In [86]:
#Step 6: Train
numEpochs = 20 
trainSet = PoseDataset(train_data)
trainLoader = DataLoader(trainSet, batch_size=16, shuffle=True, drop_last=True)
valSet = PoseDataset(val_data)
valLoader = DataLoader(valSet, batch_size=16, shuffle=False)

for batch in trainLoader:
    inputs, labels = batch
    print("Input shape:", inputs.shape)
    print("Labels shape:", labels.shape)
    break

#GPU support
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

for epoch in range(numEpochs):
    model.train()
    for inputs, labels in trainLoader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        #Training step
        predictions = model(inputs)
        loss = criterion(predictions, labels.float())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        #Validation step
    model.eval()
    valLoss = 0.0
    with torch.no_grad():
        for valInputs, valLabels in valLoader:
            valInputs = valInputs.to(device)
            valLabels = valLabels.to(device)
            valPredictions = model(valInputs).squeeze()
            valLoss += criterion(valPredictions, valLabels.float()).item()
    valLoss /= len(valLoader)

    print(f"Epoch {epoch+1}/{numEpochs}, Train Loss: {loss.item():.4f}, Val Loss: {valLoss:.4f}")

Input shape: torch.Size([16, 135])
Labels shape: torch.Size([16])
Epoch 1/20, Train Loss: 0.6416, Val Loss: 0.6917
Epoch 2/20, Train Loss: 0.7230, Val Loss: 0.6965
Epoch 3/20, Train Loss: 0.6612, Val Loss: 0.7046
Epoch 4/20, Train Loss: 0.5995, Val Loss: 0.7058
Epoch 5/20, Train Loss: 0.5945, Val Loss: 0.7229
Epoch 6/20, Train Loss: 0.7897, Val Loss: 0.7019
Epoch 7/20, Train Loss: 0.4753, Val Loss: 0.7069
Epoch 8/20, Train Loss: 0.7339, Val Loss: 0.7301
Epoch 9/20, Train Loss: 0.4911, Val Loss: 0.7167
Epoch 10/20, Train Loss: 0.7021, Val Loss: 0.7353
Epoch 11/20, Train Loss: 0.4777, Val Loss: 0.7149
Epoch 12/20, Train Loss: 0.5132, Val Loss: 0.7234
Epoch 13/20, Train Loss: 0.6558, Val Loss: 0.7190
Epoch 14/20, Train Loss: 0.4967, Val Loss: 0.7476
Epoch 15/20, Train Loss: 0.5629, Val Loss: 0.7333
Epoch 16/20, Train Loss: 0.4471, Val Loss: 0.7674
Epoch 17/20, Train Loss: 0.5645, Val Loss: 0.9068
Epoch 18/20, Train Loss: 0.3741, Val Loss: 0.7260
Epoch 19/20, Train Loss: 0.4760, Val Loss: 

Step 7: Test Performance

In [87]:
#Step 7: Test Performance
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

testSet = PoseDataset(test_data)
testLoader = DataLoader(testSet, batch_size=32, shuffle=False)

allLabels = []
allPredictions = []
allPhases = []

phaseCorrect = {0: 0, 1: 0, 2: 0}
phaseTotal = {0: 0, 1: 0, 2: 0}

with torch.no_grad():
    for poses, labels in testLoader:
        poses = poses.to(device)
        labels = labels.to(device)
        
        predictions = model(poses).squeeze()
        predictedLabels = (torch.sigmoid(predictions) > 0.5).int()

        allLabels.extend(labels.cpu().numpy())
        allPredictions.extend(predictedLabels.cpu().numpy())

        phases = torch.argmax(poses[:, -3:], dim=1).cpu().numpy()
        allPhases.extend(phases)

        for p, pred, true in zip(phases, predictedLabels.cpu().numpy(), labels.cpu().numpy()):
            p = int(p.item())
            phaseTotal[p] += 1
            if pred == true:
                phaseCorrect[p] += 1


accuracyScore = accuracy_score(allLabels, allPredictions)
print(f"Test Accuracy: {accuracyScore * 100}%")

cm = confusion_matrix(allLabels, allPredictions)
print("Confusion Matrix:")
print(cm)

report = classification_report(allLabels, allPredictions, target_names=["broke", "butter"])
print("Classification Report:")
print(report)

idxToPhase = {0: "shotpocket", 1: "setpoint", 2: "followthrough"}

for idx, name in idxToPhase.items():
    if phaseTotal[idx] == 0:
        print(f"Phase: {name}, No samples in test set.")
        continue
    correct = phaseCorrect[idx]
    total = phaseTotal[idx]
    accuracy = (correct / total) * 100
    print(f"Phase: {name}, Accuracy: {accuracy}% ({correct}/{total})")


Test Accuracy: 52.77777777777778%
Confusion Matrix:
[[ 7 10]
 [ 7 12]]
Classification Report:
              precision    recall  f1-score   support

       broke       0.50      0.41      0.45        17
      butter       0.55      0.63      0.59        19

    accuracy                           0.53        36
   macro avg       0.52      0.52      0.52        36
weighted avg       0.52      0.53      0.52        36

Phase: shotpocket, Accuracy: 40.0% (2/5)
Phase: setpoint, Accuracy: 50.0% (7/14)
Phase: followthrough, Accuracy: 58.82352941176471% (10/17)


Step 8: Save the Model

In [88]:
#Model should be able to take in a picture and output its prediction on whether the pose is broke or not broke
#Full Model Save
modelVersion = "v6" #1/05/26 5:18pm, update when retrained
os.makedirs("MLmodels", exist_ok=True)
modelName = "broke_jump_shot_detector_model_" + modelVersion + ".pth"
modelPath = os.path.join("./MLmodels/", modelName)
torch.save(model, modelPath)

In [89]:
#Weights Save
weightVersion = "v6" #1/05/26 5:18pm, update when retrained
os.makedirs("MLweights", exist_ok=True)
weightsName = "broke_jump_shot_detector_weights_" + weightVersion + ".pth"
weightsPath = os.path.join("./MLweights/", weightsName)
torch.save(model.state_dict(), weightsPath)
