# Model Pre-Training

## Setup

In [32]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [33]:
# Import libraries
import sys
import os
import json
import numpy as np
import matplotlib.pyplot as plt
plt.style.use("tableau-colorblind10")
plt.rcParams["font.size"] = 18

In [34]:
# Add model path to sys.path
path_to_model_directory = os.path.abspath("../model")
if path_to_model_directory not in sys.path:
    sys.path.append(path_to_model_directory)

# Now you can import your class
from PreTrainer import PreTrainer
from data import validate_data_format, ServeDataset, HitDataset
from TennisShotEmbedder import TennisShotEmbedder

In [35]:
# Configure the PreTrainer
BATCH_SIZE = 32
LR = 0.001
EPOCHS = 50
train_path = '/home/florsanders/adl_ai_tennis_coach/data/tenniset/shot_labels/train'
val_path = '/home/florsanders/adl_ai_tennis_coach/data/tenniset/shot_labels/val'
test_path = '/home/florsanders/adl_ai_tennis_coach/data/tenniset/shot_labels/test'
model_config_file = '/home/tawab/e6691-2024spring-project-TECO-as7092-gyt2107-fps2116/src/model/configs/default.yaml'

In [5]:
# Validate Train Data
train_invalid_data_path = validate_data_format(train_path)

Total files: 2248
Invalid Files: 0
Valid Files: 2248
Percentage of valid files: 100.0%


In [6]:
# Validate Val Data
val_invalid_data_path = validate_data_format(val_path)

Total files: 565
Invalid Files: 0
Valid Files: 565
Percentage of valid files: 100.0%


In [7]:
# Validate Val Data
test_invalid_data_path = validate_data_format(test_path)

Total files: 370
Invalid Files: 0
Valid Files: 370
Percentage of valid files: 100.0%


In [8]:
if not os.path.exists("./invalid_files.json"):
    with open("./invalid_files.json", "w") as f:
        invalid_files = {
            "train": train_invalid_data_path,
            "val": val_invalid_data_path,
            "test": test_invalid_data_path
        }
        json.dump(invalid_files, f)

## Training

In [9]:
# Trainer Setup
trainer = PreTrainer(
    batch_size=BATCH_SIZE,
    lr=LR,
    epochs=EPOCHS,
    train_path=train_path,
    val_path=val_path,
    model_config_path=model_config_file,
    model_save_path='trained_models'
)

Using device: cuda


In [None]:
# Perform training
train_history, val_history = trainer.train()

100%|█████████████████████████████████████████████████████████| 70/70 [00:17<00:00,  3.98it/s]


Epoch 1, Loss: 0.2303910300667797
Validation Loss: 0.06699909932083553
Saving model at trained_models with validation loss of 0.06699909932083553


100%|█████████████████████████████████████████████████████████| 70/70 [00:16<00:00,  4.17it/s]


Epoch 2, Loss: 0.058986256537692884
Validation Loss: 0.05578536871406767
Saving model at trained_models with validation loss of 0.05578536871406767


100%|█████████████████████████████████████████████████████████| 70/70 [00:17<00:00,  4.04it/s]


Epoch 3, Loss: 0.052577965121184074
Validation Loss: 0.05340303522017267
Saving model at trained_models with validation loss of 0.05340303522017267


100%|█████████████████████████████████████████████████████████| 70/70 [00:16<00:00,  4.19it/s]


Epoch 4, Loss: 0.05141281597316265
Validation Loss: 0.05322703222433726
Saving model at trained_models with validation loss of 0.05322703222433726


100%|█████████████████████████████████████████████████████████| 70/70 [00:16<00:00,  4.20it/s]


Epoch 5, Loss: 0.051381815703851835
Validation Loss: 0.05275031427542368
Saving model at trained_models with validation loss of 0.05275031427542368


100%|█████████████████████████████████████████████████████████| 70/70 [00:17<00:00,  3.99it/s]


Epoch 6, Loss: 0.04935416265257767
Validation Loss: 0.048730045557022095
Saving model at trained_models with validation loss of 0.048730045557022095


100%|█████████████████████████████████████████████████████████| 70/70 [00:16<00:00,  4.13it/s]


Epoch 7, Loss: 0.04634011709796531
Validation Loss: 0.049208645398418106


100%|█████████████████████████████████████████████████████████| 70/70 [00:17<00:00,  4.11it/s]


Epoch 8, Loss: 0.0454807491174766
Validation Loss: 0.045298157466782465
Saving model at trained_models with validation loss of 0.045298157466782465


100%|█████████████████████████████████████████████████████████| 70/70 [00:17<00:00,  4.03it/s]


Epoch 9, Loss: 0.04351637552359274
Validation Loss: 0.044081577927702
Saving model at trained_models with validation loss of 0.044081577927702


100%|█████████████████████████████████████████████████████████| 70/70 [00:16<00:00,  4.17it/s]


Epoch 10, Loss: 0.0422158256439226
Validation Loss: 0.04212299361824989
Saving model at trained_models with validation loss of 0.04212299361824989


100%|█████████████████████████████████████████████████████████| 70/70 [00:17<00:00,  4.09it/s]


Epoch 11, Loss: 0.04079023421342884
Validation Loss: 0.041117426111466356
Saving model at trained_models with validation loss of 0.041117426111466356


100%|█████████████████████████████████████████████████████████| 70/70 [00:17<00:00,  4.04it/s]


Epoch 12, Loss: 0.04024311953357288
Validation Loss: 0.04121794437782632


100%|█████████████████████████████████████████████████████████| 70/70 [00:16<00:00,  4.23it/s]


Epoch 13, Loss: 0.04035384670964309
Validation Loss: 0.04032427062176996
Saving model at trained_models with validation loss of 0.04032427062176996


100%|█████████████████████████████████████████████████████████| 70/70 [00:17<00:00,  4.10it/s]


Epoch 14, Loss: 0.03976871578821114
Validation Loss: 0.039829686801466674
Saving model at trained_models with validation loss of 0.039829686801466674


100%|█████████████████████████████████████████████████████████| 70/70 [00:17<00:00,  4.07it/s]


Epoch 15, Loss: 0.0393363142652171
Validation Loss: 0.0396665186724729
Saving model at trained_models with validation loss of 0.0396665186724729


 81%|██████████████████████████████████████████████▍          | 57/70 [00:13<00:03,  4.29it/s]

In [None]:
# Plot training & validation history
fig, ax = plt.subplots(figsize=(8, 4.5))
ax.plot(train_history / len(trainer.train_loader), label="Train")
ax.plot(val_history, label="Validation")
ax.set_xlabel("Epoch")
ax.set_ylabel("Loss")
ax.set_xlim(0, EPOCHS-1)
ax.set_ylim(0)
ax.grid(True)
ax.legend()
fig.tight_layout()
fig.
plt.show()