In [3]:
import sys
import os
project_root = os.getcwd()  # This will use the current working directory
sys.path.append(os.path.join(project_root, 'Code'))
from utils import *
from data import *
from model import *
from train import *
from validate import *
from visualization import *

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

Using device: cuda


## Data

In [5]:
batch_size = 16
resize_size=(128,256)
preprocess_path = './Dataset_preprocess'
base_path = './Dataset'
frame_info = 3

full_dataset = TrackNetDataset(base_path, resize_size=resize_size)
train_size = int(0.8 * len(full_dataset))
val_size = int(0.15 * len(full_dataset))
test_size = len(full_dataset) - train_size - val_size
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size, test_size])
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=0,
    pin_memory=True
)
val_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=0,
    pin_memory=True
)
test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=0,
    pin_memory=True
)

## Train

In [6]:
epoch_num = 50
best_lev_dist = 1500
model_save_name = 'model_best_eca.pth'
use_eca =True

In [7]:
# Tracknet
gc.collect()
torch.cuda.empty_cache()
model = BallTrackerNet(use_eca=use_eca).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-2)
criterion = nn.CrossEntropyLoss()
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.95, patience=1, verbose=True, threshold=1e-2)



In [None]:
for epoch in range(0, epoch_num):
    print("\nEpoch: {}/{}".format(epoch+1, epoch_num))
    # Call train and validate 
    train_loss = train(model, train_loader, optimizer, criterion)
    val_loss, val_dist, precision, recall, f1 = validate(model, val_loader, criterion, min_dist=2)
    scheduler.step(val_dist)
    print("\nEpoch {}/{}: \t Train Loss {:.04f} ".format(
          epoch + 1,
          epoch_num,
          train_loss
          ))
    print("Val loss {:.04f} \t Val dist {:.04f} \t precision: {:.04f} \t recall: {:.04f}\t f1: {:.04f}".format(
          val_loss, val_dist, precision, recall, f1
          ))
    torch.cuda.empty_cache()
    if val_dist <= best_lev_dist:
        best_lev_dist = val_dist
        # Save your model checkpoint here
        print("Saving model")
        torch.save(model.state_dict(), model_save_name)

## Test

In [10]:
model = BallTrackerNet(use_eca=use_eca).to(device)  # Ensure this matches your model
# Load the model state_dict
model.load_state_dict(torch.load(model_save_name))
model.to(device)
# Set model to evaluation mode
test_loss, test_dist, precision, recall, f1 = validate(model, test_loader, criterion, min_dist=2)
print("Val loss {:.04f} \t Val dist {:.04f} \t precision: {:.04f} \t recall: {:.04f}\t f1: {:.04f}".format(
          test_loss, test_dist, precision, recall, f1
          ))
torch.cuda.empty_cache()

  model.load_state_dict(torch.load(model_save_name))
                                                                                                                 

Val loss 0.0327 	 Val dist 0.8197 	 precision: 0.9874 	 recall: 1.0000	 f1: 0.9936




## Visualization

In [11]:
visualize_predictions(model, test_loader, output_dir="visualizations", device=device)
print("Saved comparison frames to ./visualizations/")

Saved comparison frames to ./visualizations/
