In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torch.nn as nn 
import torch.optim as optim 
import torch.nn.functional as F
import numpy as np
import random
import os
import time
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from torch_geometric.data import Batch
from torch_geometric.data import Data

In [3]:
import sys

path_to_model_directory = '../model'

# Add this path to sys.path
if path_to_model_directory not in sys.path:
    sys.path.append(path_to_model_directory)

# Now you can import your class
from PreTrainer import PreTrainer
from data import TennisDataset
from model_builder import build_tennis_embedder


In [4]:
# Get the model based on config file: 
config_file = '/home/tawab/e6691-2024spring-project-TECO-as7092-gyt2107-fps2116/src/model/configs/default.yaml'
model = build_tennis_embedder(config_file)

#### Things to do for loading data
- What do we want to do with the values being skipped? 

In [5]:
# Load Train and Val Dataset
train_path = '/home/florsanders/adl_ai_tennis_coach/data/tenniset/shot_labels/train'
val_path = '/home/florsanders/adl_ai_tennis_coach/data/tenniset/shot_labels/val'
train_dataset = TennisDataset(labels_path=train_path)
print('Train Dataset Loaded Successfully')
val_dataset = TennisDataset(labels_path=val_path)
print('Val Dataset Loaded Successfully')

Skipping V006_0068: Incorrect dimensions - positions_2d (0, 2), poses_3d (0,)
Skipping V006_0179: Incorrect dimensions - positions_2d (0, 2), poses_3d (0,)
Skipping V007_0183: Incorrect dimensions - positions_2d (0, 2), poses_3d (0,)
Skipping V007_0184: Incorrect dimensions - positions_2d (0, 2), poses_3d (0,)
Skipping V008_0003: Data file not found.
Skipping V008_0056: Incorrect dimensions - positions_2d (0, 2), poses_3d (0,)
Skipping V008_0156: Incorrect dimensions - positions_2d (0, 2), poses_3d (0,)
Skipping V009_0017: Incorrect dimensions - positions_2d (0, 2), poses_3d (0,)
Skipping V009_0924: Incorrect dimensions - positions_2d (0, 2), poses_3d (0,)
Skipping V009_0947: Incorrect dimensions - positions_2d (0, 2), poses_3d (0,)
Skipping V009_0948: Incorrect dimensions - positions_2d (0, 2), poses_3d (0,)
Skipping V009_1045: Incorrect dimensions - positions_2d (0, 2), poses_3d (0,)
Skipping V009_1281: Incorrect dimensions - positions_2d (0, 2), poses_3d (0,)
Skipping V009_1282: Inc

In [6]:
print('Training Dataset Length:', len(train_dataset))
print('Validation Dataset Length:', len(val_dataset))

# Get pose3d, position2d, and pose graph 
pose3d, position2d, pose_graph, target = train_dataset[2]
print('Pose3d Shape:', pose3d.shape)
print('Position2d Shape:', position2d.shape)
print('Target Shape:', target.shape)
print('Number of Graphs: ', len(pose_graph))
print('Pose Graph Data Shapes (x & edge_index): ' ,pose_graph[0])
print(f'Number of features: {pose_graph[0].num_features}')
print(f'Number of nodes: {pose_graph[0].num_nodes}')
print(f'Number of edges: {pose_graph[0].num_edges}')

Training Dataset Length: 2248
Validation Dataset Length: 565
Pose3d Shape: (24, 17, 3)
Position2d Shape: (24, 2)
Target Shape: (24, 17, 3)
Number of Graphs:  24
Pose Graph Data Shapes (x & edge_index):  Data(x=[17, 3], edge_index=[2, 38])
Number of features: 3
Number of nodes: 17
Number of edges: 38


In [7]:
# Load the PreTrainer
BATCH_SIZE = 32
LR = 0.001
EPOCHS = 10

#### Things to do with the collate function 
- How do we want to handle np.object_ instances of pose3d and position2d? --> At the moment they are being ignored
- How do we want to do padding of sequence graph to same length? --> At the moment the last instance is repeated

In [8]:
from torch.nn.utils.rnn import pad_sequence
from torch_geometric.data import Batch, Data
import torch
import numpy as np

def pad_graphs(graphs, max_frames):
    padded_graphs = []
    for graph_list in graphs:
        num_graphs = len(graph_list)
        if num_graphs < max_frames:
            last_graph = graph_list[-1]
            additional_graphs = [last_graph] * (max_frames - num_graphs)  # Replicate last graph
            padded_graph_list = graph_list + additional_graphs
        else:
            padded_graph_list = graph_list[:max_frames]
        
        padded_graphs.append(padded_graph_list)
        
    return padded_graphs

def my_collate_fn(batch):
    pose3d = []
    position2d = []
    targets = []
    all_graphs = []
    graph_counts = []  # To count graphs per item in the batch


    for item in batch:
        pose3d_item, position2d_item, pose_graph_items, target_item = item
        
        # Check for None values and correct shapes
        if pose3d_item is not None and position2d_item is not None and pose_graph_items is not None: 
            if pose3d_item.dtype != np.object_ and position2d_item.dtype != np.object_ and len(pose_graph_items) > 0:
                
                sequence_graphs = pose_graph_items

                # Ensure numpy arrays are of type float32, convert object arrays if necessary
                # if pose3d_item.dtype == np.object_:
                #     print('NP Object!!')
                #     print("Pose3d: ", pose3d_item)
                #     pose3d_item = np.vstack(pose3d_item).astype(np.float32)
                #     target_item = np.vstack(target_item).astype(np.float32)
                # if position2d_item.dtype == np.object_:
                #     print('NP Object!!')
                #     print("Position2d: ", position2d_item)
                #     position2d_item = np.vstack(position2d_item).astype(np.float32)
                    
                if pose3d_item.ndim == 3 and position2d_item.ndim == 2:  # Ensure the correct dimensionality
                    
                    graph_count = 0
                    if isinstance(sequence_graphs, list):
                        all_graphs.append(sequence_graphs)
                        graph_count = len(sequence_graphs)  # Count graphs for this item
                    else:
                        print("Skipping a graph item due to incorrect type.")
                    graph_counts.append(graph_count)   

                    pose3d.append(torch.tensor(pose3d_item, dtype=torch.float32))
                    position2d.append(torch.tensor(position2d_item, dtype=torch.float32))
                    targets.append(torch.tensor(target_item, dtype=torch.float32))
                # else:
                #     print(f"Skipping due to incorrect dimensions - Pose3D: {pose3d_item.shape}, Position2D: {position2d_item.shape}")
            # else: 
            #     print("Skipping a batch item due to object dtype or graph being empty.")
        # else:
        #     print("Skipping a batch item due to None values.")

    #print("Graphs per item in batch:", graph_counts)  # For Debugging

    # Pad pose3d and position2d sequences if not empty
    pose3d_padded = pad_sequence(pose3d, batch_first=True) if pose3d else torch.Tensor()
    position2d_padded = pad_sequence(position2d, batch_first=True) if position2d else torch.Tensor()
    targets_padded = pad_sequence(targets, batch_first=True) if targets else torch.Tensor()

    #print("Number of Graphs:", len(all_graphs))
    
    # Create a list of Batch objects for each item in the batch
    max_frames = max(len(graphs) for graphs in all_graphs)  # Maximum number of frames in the batch
    #print("Max Frames:", max_frames)
    if len(all_graphs) > 0:
        all_graphs = pad_graphs(all_graphs, max_frames)
        batched_graphs = [Batch.from_data_list(graph_list) for graph_list in all_graphs]
    else:
        batched_graphs = []

    return pose3d_padded, position2d_padded, batched_graphs, targets_padded


In [9]:
# Create the DataLoader
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=my_collate_fn)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=my_collate_fn)

In [14]:
for pose3d, position2d, pose_graph, target in train_loader:
    x = pose3d
    y = position2d
    z = pose_graph
    print('Pose3d Shape:', pose3d.shape)
    print('Position2d Shape:', position2d.shape)
    print('Target Shape:', target.shape)
    print('Number of Sequence Graphs in Batch:', len(pose_graph))
    print('Sequence Graph Data Shapes (x & edge_index): ' ,pose_graph[0])
    print('Number of Graphs in First Sequence:', pose_graph[3].num_graphs)
    x = pose_graph[0].batch
    break

Pose3d Shape: torch.Size([32, 67, 17, 3])
Position2d Shape: torch.Size([32, 67, 2])
Target Shape: torch.Size([32, 67, 17, 3])
Number of Sequence Graphs in Batch: 32
Sequence Graph Data Shapes (x & edge_index):  DataBatch(x=[1139, 3], edge_index=[2, 2546], batch=[1139], ptr=[68])
Number of Graphs in First Sequence: 67


In [10]:
# Trainer Setup
trainer = PreTrainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    batch_size=BATCH_SIZE,
    lr=LR,
    epochs=EPOCHS
)

Using device: cuda


In [11]:
# Start training
trainer.train()


Epoch 1, Loss: 4.519310984812992
Test Loss: 3.9174769984351263
Epoch 2, Loss: 3.4493121865769507
Test Loss: 1.9674275981055365
Epoch 3, Loss: 0.9792592051163526
Test Loss: 0.7634573198027081
Epoch 4, Loss: 0.7425360855921893
Test Loss: 0.6903862721390195
Epoch 5, Loss: 0.7162655608754762
Test Loss: 0.6654880245526632
Epoch 6, Loss: 0.6936964623525109
Test Loss: 0.7403709126843346
Epoch 7, Loss: 0.6784240993815409
Test Loss: 0.6443121847179201
Epoch 8, Loss: 0.6514911639018798
Test Loss: 0.6629044148657057
Epoch 9, Loss: 0.653844780065644
Test Loss: 0.6156553261809878
Epoch 10, Loss: 0.6356361739232507
Test Loss: 0.6150351448191537
