In [1]:
import os

# change dir to workspace dir
os.chdir("../")

In [2]:
from pathlib import Path
import re

def find_best_checkpoint(root_dir):
    root_path = Path(root_dir)
    best_checkpoints = {}

    # Regex to extract validation loss from filename
    loss_pattern = re.compile(r'val_loss=([\d\.]+)\D*')

    # Traverse through each model's directory
    for model_dir in root_path.iterdir():
        if model_dir.is_dir():
            # Model name is the directory name
            model_name = model_dir.name
            lowest_loss = float('inf')
            best_checkpoint_path = None

            # Loop through all subdirectories and files
            for checkpoint_file in model_dir.rglob('*.ckpt'):
                # Find the validation loss from the filename
                match = loss_pattern.search(checkpoint_file.name)
                if match:
                    match_str = match.group(1)
                    current_loss = float(match_str.rstrip("."))
                    # Update the best checkpoint if the current one has a lower loss
                    if current_loss < lowest_loss:
                        lowest_loss = current_loss
                        best_checkpoint_path = checkpoint_file
            
            if best_checkpoint_path:
                best_checkpoints[model_name] = best_checkpoint_path

    return best_checkpoints

# Example usage
root_directory = 'data/models'
best_checkpoints = find_best_checkpoint(root_directory)
for model, checkpoint in best_checkpoints.items():
    print(f"Best checkpoint for {model}: {checkpoint}")


Best checkpoint for transformer: data/models/transformer/B32_H64_L1_LR1e-04_D0.3/checkpoints/epoch=8-val_loss=37.4.ckpt
Best checkpoint for zoo: data/models/zoo/B32_H64_L1_LR1e-04_D0.3/checkpoints/epoch=8-val_loss=46.7.ckpt


In [3]:
import numpy as np
import polars as pl
import torch
from lightning.pytorch import Trainer
from torch.utils.data import DataLoader

from src.datasets import BDB2024_Dataset, load_datasets
from src.models import LitModel



zoo_df = predict_model_as_df(best_checkpoints["zoo"])
trfm_df = predict_model_as_df(best_checkpoints["transformer"])

display(zoo_df.sample(3))
display(trfm_df.sample(3))

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Restoring states from the checkpoint path at data/models/zoo/B32_H64_L1_LR1e-04_D0.3/checkpoints/epoch=8-val_loss=46.7.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
Loaded model weights from the checkpoint at data/models/zoo/B32_H64_L1_LR1e-04_D0.3/checkpoints/epoch=8-val_loss=46.7.ckpt


Predicting DataLoader 0: 100%|██████████| 724/724 [00:04<00:00, 154.14it/s]


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Restoring states from the checkpoint path at data/models/zoo/B32_H64_L1_LR1e-04_D0.3/checkpoints/epoch=8-val_loss=46.7.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
Loaded model weights from the checkpoint at data/models/zoo/B32_H64_L1_LR1e-04_D0.3/checkpoints/epoch=8-val_loss=46.7.ckpt


Predicting DataLoader 0: 100%|██████████| 218/218 [00:01<00:00, 162.48it/s]


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Restoring states from the checkpoint path at data/models/zoo/B32_H64_L1_LR1e-04_D0.3/checkpoints/epoch=8-val_loss=46.7.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
Loaded model weights from the checkpoint at data/models/zoo/B32_H64_L1_LR1e-04_D0.3/checkpoints/epoch=8-val_loss=46.7.ckpt


Predicting DataLoader 0: 100%|██████████| 94/94 [00:00<00:00, 151.36it/s]


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Restoring states from the checkpoint path at data/models/transformer/B32_H64_L1_LR1e-04_D0.3/checkpoints/epoch=8-val_loss=37.4.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
Loaded model weights from the checkpoint at data/models/transformer/B32_H64_L1_LR1e-04_D0.3/checkpoints/epoch=8-val_loss=37.4.ckpt


Predicting DataLoader 0: 100%|██████████| 724/724 [00:15<00:00, 45.86it/s] 


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Restoring states from the checkpoint path at data/models/transformer/B32_H64_L1_LR1e-04_D0.3/checkpoints/epoch=8-val_loss=37.4.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
Loaded model weights from the checkpoint at data/models/transformer/B32_H64_L1_LR1e-04_D0.3/checkpoints/epoch=8-val_loss=37.4.ckpt


Predicting DataLoader 0: 100%|██████████| 218/218 [00:00<00:00, 326.14it/s]


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Restoring states from the checkpoint path at data/models/transformer/B32_H64_L1_LR1e-04_D0.3/checkpoints/epoch=8-val_loss=37.4.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
Loaded model weights from the checkpoint at data/models/transformer/B32_H64_L1_LR1e-04_D0.3/checkpoints/epoch=8-val_loss=37.4.ckpt


Predicting DataLoader 0: 100%|██████████| 94/94 [00:00<00:00, 332.58it/s]


gameId,playId,mirrored,ballCarrierNflId,ballCarrierName,tackle_frameId,tackle_event,tackle_x,tackle_y,tackle_x_rel,tackle_y_rel,play_origin_x,play_origin_y,tackle_event_enum,frameId,tackle_x_rel_pred,tackle_y_rel_pred,dataset_split,tackle_x_pred,tackle_y_pred,model_type,used_play_features,batch_size,hidden_dim,num_layers,dropout,learning_rate
i64,i64,bool,i64,str,i64,str,f64,f64,f64,f64,f64,f64,i64,i64,f32,f32,str,f64,f64,str,bool,i32,i32,i32,f64,f64
2022101604,2921,True,43971,"""C.J. Ham""",29,"""tackle""",47.56,27.63,6.78,-4.74,40.78,32.37,0,21,10.57,-8.98,"""val""",51.35,23.39,"""zoo""",False,32,64,1,0.3,0.0001
2022100300,2686,True,47819,"""Deebo Samuel""",42,"""tackle""",109.38,40.93,9.22,0.16,100.16,40.77,0,32,12.08,13.16,"""val""",112.24,53.93,"""zoo""",False,32,64,1,0.3,0.0001
2022103004,2511,False,43334,"""Derrick Henry""",44,"""tackle""",98.08,39.03,8.8,9.37,89.28,29.66,0,16,10.7,6.45,"""train""",99.98,36.11,"""zoo""",False,32,64,1,0.3,0.0001


gameId,playId,mirrored,ballCarrierNflId,ballCarrierName,tackle_frameId,tackle_event,tackle_x,tackle_y,tackle_x_rel,tackle_y_rel,play_origin_x,play_origin_y,tackle_event_enum,frameId,tackle_x_rel_pred,tackle_y_rel_pred,dataset_split,tackle_x_pred,tackle_y_pred,model_type,used_play_features,batch_size,hidden_dim,num_layers,dropout,learning_rate
i64,i64,bool,i64,str,i64,str,f64,f64,f64,f64,f64,f64,i64,i64,f32,f32,str,f64,f64,str,bool,i32,i32,i32,f64,f64
2022103000,3868,True,40129,"""Latavius Murray""",52,"""tackle""",107.7,41.33,14.63,12.09,93.07,29.24,0,42,14.26,11.67,"""train""",107.33,40.91,"""transformer""",False,32,64,1,0.3,0.0001
2022091808,3526,True,54558,"""Tyrion Davis-Price""",36,"""tackle""",108.19,25.23,6.35,-4.56,101.84,29.79,0,28,5.4,-5.7,"""val""",107.24,24.09,"""transformer""",False,32,64,1,0.3,0.0001
2022091200,315,True,53464,"""Javonte Williams""",54,"""out_of_bounds""",55.95,53.17,19.05,23.23,36.9,29.94,1,3,12.36,-0.52,"""val""",49.26,29.42,"""transformer""",False,32,64,1,0.3,0.0001


: 