In [1]:
import sys
import json

path_to_model_directory = '../model'

# Add this path to sys.path
if path_to_model_directory not in sys.path:
    sys.path.append(path_to_model_directory)

# Now you can import your class
from PreTrainer import PreTrainer, validate_data_format

In [2]:
# Load the PreTrainer
BATCH_SIZE = 32
LR = 0.001
EPOCHS = 10
train_path = '/home/florsanders/adl_ai_tennis_coach/data/tenniset/shot_labels/train'
val_path = '/home/florsanders/adl_ai_tennis_coach/data/tenniset/shot_labels/val'
model_config_file = '/home/tawab/e6691-2024spring-project-TECO-as7092-gyt2107-fps2116/src/model/configs/default.yaml'

In [3]:
# Validate Train Data
train_invalid_data_path = validate_data_format(train_path)

Total files: 2248
Invalid Files: 30
Valid Files: 2218
Percentage of valid files: 98.66548042704626%


In [4]:
# Validate Val Data
val_invalid_data_path = validate_data_format(val_path)

Total files: 565
Invalid Files: 5
Valid Files: 560
Percentage of valid files: 99.11504424778761%


In [22]:
# Write list to a JSON file
with open('train_invalid_data.json', 'w') as file:
    json.dump(train_invalid_data_path, file)
with open('val_invalid_data.json', 'w') as file:
    json.dump(val_invalid_data_path, file)

In [3]:
# Trainer Setup
trainer = PreTrainer(
    batch_size=BATCH_SIZE,
    lr=LR,
    epochs=EPOCHS,
    train_path=train_path,
    val_path=val_path,
    model_config_path=model_config_file,
    model_save_path='trained_models'
)

Skipping V006_0068: Incorrect dimensions - positions_2d (0, 2), poses_3d (0,)
Skipping V006_0179: Incorrect dimensions - positions_2d (0, 2), poses_3d (0,)
Skipping V007_0183: Incorrect dimensions - positions_2d (0, 2), poses_3d (0,)
Skipping V007_0184: Incorrect dimensions - positions_2d (0, 2), poses_3d (0,)
Skipping V008_0003: Data file not found.
Skipping V008_0056: Incorrect dimensions - positions_2d (0, 2), poses_3d (0,)
Skipping V008_0156: Incorrect dimensions - positions_2d (0, 2), poses_3d (0,)
Skipping V009_0017: Incorrect dimensions - positions_2d (0, 2), poses_3d (0,)
Skipping V009_0924: Incorrect dimensions - positions_2d (0, 2), poses_3d (0,)
Skipping V009_0947: Incorrect dimensions - positions_2d (0, 2), poses_3d (0,)
Skipping V009_0948: Incorrect dimensions - positions_2d (0, 2), poses_3d (0,)
Skipping V009_1045: Incorrect dimensions - positions_2d (0, 2), poses_3d (0,)
Skipping V009_1281: Incorrect dimensions - positions_2d (0, 2), poses_3d (0,)
Skipping V009_1282: Inc

In [4]:
trainer.train()

100%|██████████| 71/71 [00:17<00:00,  4.06it/s]


Epoch 1, Loss: 6.9366803589001504
Validation Loss: 2.0761478741963706
Saving model at trained_models with validation loss of 2.0761478741963706


100%|██████████| 71/71 [00:15<00:00,  4.52it/s]


Epoch 2, Loss: 2.0405447617383072
Validation Loss: 1.7640170653661091
Saving model at trained_models with validation loss of 1.7640170653661091


100%|██████████| 71/71 [00:15<00:00,  4.52it/s]


Epoch 3, Loss: 1.9298397940649112
Validation Loss: 1.7268546554777358
Saving model at trained_models with validation loss of 1.7268546554777358


100%|██████████| 71/71 [00:15<00:00,  4.49it/s]


Epoch 4, Loss: 1.9005043288351784
Validation Loss: 1.7138367295265198
Saving model at trained_models with validation loss of 1.7138367295265198


100%|██████████| 71/71 [00:15<00:00,  4.52it/s]


Epoch 5, Loss: 1.8822210640974448
Validation Loss: 1.6981180177794561
Saving model at trained_models with validation loss of 1.6981180177794561


100%|██████████| 71/71 [00:15<00:00,  4.51it/s]


Epoch 6, Loss: 1.910371723309369
Validation Loss: 1.692180057366689
Saving model at trained_models with validation loss of 1.692180057366689


100%|██████████| 71/71 [00:15<00:00,  4.49it/s]


Epoch 7, Loss: 1.86625565273661
Validation Loss: 1.6973976360427008


100%|██████████| 71/71 [00:15<00:00,  4.52it/s]


Epoch 8, Loss: 1.8665103895563475
Validation Loss: 1.7120945784780714


100%|██████████| 71/71 [00:15<00:00,  4.56it/s]


Epoch 9, Loss: 1.8387023546326329
Validation Loss: 1.6665880613856845
Saving model at trained_models with validation loss of 1.6665880613856845


100%|██████████| 71/71 [00:15<00:00,  4.53it/s]


Epoch 10, Loss: 1.7953675884596059
Validation Loss: 1.5780009494887457
Saving model at trained_models with validation loss of 1.5780009494887457


### Test for Masking Logic Used in the Training

In [128]:
import torch
from torch.nn.utils.rnn import pad_sequence

# Define the maximum sequence length based on your previous logic
max_length = 10 + 31

# Create target tensor initialized to a specific value for easy verification (e.g., all ones)
targets = torch.ones([32, max_length, 17, 3], dtype=torch.float32)

# Create masks and pose_graphs with varying sequence lengths
masks = []
pose_graphs = []

for i in range(32):
    length = 10 + i  # Variable sequence lengths
    pose_graphs.append(torch.ones(length, 17, 3))  # Simulate real data
    masks.append(torch.ones(length, dtype=torch.bool))  # True where data is valid

# Pad pose_graphs and masks to uniform lengths
pose_graphs_padded = pad_sequence(pose_graphs, batch_first=True)
masks_padded = pad_sequence(masks, batch_first=True, padding_value=0)

# Expand the mask to match targets' dimensions
expanded_masks = masks_padded.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, 17, 3)

# Apply the mask to the targets
masked_targets = targets * expanded_masks

# Verification steps
for i in range(32):
    expected_length = 10 + i
    actual_mask = expanded_masks[i]
    actual_data = masked_targets[i]

    # Check the areas expected to be masked
    if torch.any(actual_data[:expected_length] != 1):
        print(f"Error: Data corruption in unmasked area for batch {i}")
    if torch.any(actual_data[expected_length:] != 0):
        print(f"Error: Incomplete masking in padded area for batch {i}")
    else:
        print(f"Mask and data verified successfully for batch {i}")

# Print the shape of tensors to verify
print("Shape of original targets:", targets.shape)
print("Shape of pose_graphs (after padding):", pose_graphs_padded.shape)
print("Shape of masks (after padding and expanding):", expanded_masks.shape)
print("Shape of masked targets:", masked_targets.shape)


Mask and data verified successfully for batch 0
Mask and data verified successfully for batch 1
Mask and data verified successfully for batch 2
Mask and data verified successfully for batch 3
Mask and data verified successfully for batch 4
Mask and data verified successfully for batch 5
Mask and data verified successfully for batch 6
Mask and data verified successfully for batch 7
Mask and data verified successfully for batch 8
Mask and data verified successfully for batch 9
Mask and data verified successfully for batch 10
Mask and data verified successfully for batch 11
Mask and data verified successfully for batch 12
Mask and data verified successfully for batch 13
Mask and data verified successfully for batch 14
Mask and data verified successfully for batch 15
Mask and data verified successfully for batch 16
Mask and data verified successfully for batch 17
Mask and data verified successfully for batch 18
Mask and data verified successfully for batch 19
Mask and data verified success