In [12]:
import sys
import os
project_root = os.getcwd()  # This will use the current working directory
sys.path.append(os.path.join(project_root, 'Code'))
from utils import *
from data import *
from model import *
from train import *
from validate import *
from visualization import *

In [13]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

Using device: cuda


## Data

In [3]:
batch_size = 16
resize_size=(128,256)
preprocess_path = './Dataset_preprocess'
base_path = './Dataset'
frame_info = 3
torch.manual_seed(42)
full_dataset = TrackNetDataset(base_path, resize_size=resize_size)
train_size = int(0.8 * len(full_dataset))
val_size = int(0.15 * len(full_dataset))
test_size = len(full_dataset) - train_size - val_size
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size, test_size])
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=0,
    pin_memory=True
)
val_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=0,
    pin_memory=True
)
test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=0,
    pin_memory=True
)

## Train

In [4]:
epoch_num = 50
best_lev_dist = 1500
model_save_name = 'model_best_eca1.pth'
use_eca =True

In [5]:
# Tracknet
gc.collect()
torch.cuda.empty_cache()
model = BallTrackerNet(use_eca=use_eca).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-2)
criterion = nn.CrossEntropyLoss()
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.95, patience=1, verbose=True, threshold=1e-2)



In [7]:
train_loss_res = []
val_loss_res = []
val_dist_res =[]
precision_res = []
recall_res = []
f1_res = []

In [None]:

for epoch in range(0, epoch_num):
    print("\nEpoch: {}/{}".format(epoch+1, epoch_num))
    # Call train and validate 
    train_loss = train(model, train_loader, optimizer, criterion)
    train_loss_res.append(train_loss)
    val_loss, val_dist, precision, recall, f1 = validate(model, val_loader, criterion, min_dist=5)
    scheduler.step(val_dist)
    print("\nEpoch {}/{}: \t Train Loss {:.04f} ".format(
          epoch + 1,
          epoch_num,
          train_loss
          ))
    print("Val loss {:.04f} \t Val dist {:.04f} \t precision: {:.04f} \t recall: {:.04f}\t f1: {:.04f}".format(
          val_loss, val_dist, precision, recall, f1
          ))
    val_loss_res.append(val_loss)
    val_dist_res.append(val_dist.cpu().item())
    precision_res.append(precision)
    recall_res.append(recall)
    f1_res.append(f1)
    torch.cuda.empty_cache()
    if val_dist <= best_lev_dist:
      best_lev_dist = val_dist
      # Save your model checkpoint here
      print("Saving model")
      torch.save(model.state_dict(), model_save_name)
df_metrics = pd.DataFrame({
"train_loss": train_loss_res,
"val_loss": val_loss_res,
"val_dist": val_dist_res,
"precision": precision_res,
"recall": recall_res,
"f1": f1_res
})
df_metrics.to_csv('training_metrics.csv', index=False)


Epoch: 1/50


                                                                                                                        


Epoch 1/50: 	 Train Loss 1.0466 
Val loss 0.1308 	 Val dist 133.5867 	 precision: 0.0000 	 recall: 0.0000	 f1: 0.0000
Saving model

Epoch: 2/50


                                                                                                                        


Epoch 2/50: 	 Train Loss 0.1095 
Val loss 0.0973 	 Val dist 133.5867 	 precision: 0.0000 	 recall: 0.0000	 f1: 0.0000
Saving model

Epoch: 3/50


                                                                                                                        


Epoch 3/50: 	 Train Loss 0.0936 
Val loss 0.0896 	 Val dist 133.5867 	 precision: 0.0000 	 recall: 0.0000	 f1: 0.0000
Saving model

Epoch: 4/50


                                                                                                                        


Epoch 4/50: 	 Train Loss 0.0881 
Val loss 0.0865 	 Val dist 133.5867 	 precision: 0.0000 	 recall: 0.0000	 f1: 0.0000
Saving model

Epoch: 5/50


                                                                                                                           


Epoch 5/50: 	 Train Loss 0.0682 
Val loss 0.0499 	 Val dist 36.4178 	 precision: 0.9232 	 recall: 0.7419	 f1: 0.8227
Saving model

Epoch: 6/50


                                                                                                                                 


Epoch 6/50: 	 Train Loss 0.0457 
Val loss 1263960.9515 	 Val dist 27.6233 	 precision: 0.8917 	 recall: 0.7922	 f1: 0.8390
Saving model

Epoch: 7/50


                                                                                                                                     


Epoch 7/50: 	 Train Loss 0.0406 
Val loss 90287908272.9099 	 Val dist 10.8571 	 precision: 0.9245 	 recall: 0.9507	 f1: 0.9374
Saving model

Epoch: 8/50


                                                                                                                                 


Epoch 8/50: 	 Train Loss 0.0375 
Val loss 9259071.6910 	 Val dist 14.5827 	 precision: 0.9612 	 recall: 0.9046	 f1: 0.9321

Epoch: 9/50


                                                                                                                                 


Epoch 9/50: 	 Train Loss 0.0354 
Val loss 87617060.3070 	 Val dist 14.2857 	 precision: 0.9648 	 recall: 0.9137	 f1: 0.9385

Epoch: 10/50


                                                                                                                                 


Epoch 10/50: 	 Train Loss 0.0335 
Val loss 46667233.8180 	 Val dist 11.3459 	 precision: 0.9777 	 recall: 0.9295	 f1: 0.9530

Epoch: 11/50


                                                                                                                                   


Epoch 11/50: 	 Train Loss 0.0326 
Val loss 496803009.7009 	 Val dist 10.0403 	 precision: 0.9626 	 recall: 0.9391	 f1: 0.9507
Saving model

Epoch: 12/50


                                                                                                                                  


Epoch 12/50: 	 Train Loss 0.0314 
Val loss 2818339510.5726 	 Val dist 5.2535 	 precision: 0.9575 	 recall: 0.9761	 f1: 0.9667
Saving model

Epoch: 13/50


                                                                                                                                    


Epoch 13/50: 	 Train Loss 0.0305 
Val loss 70273849201.7749 	 Val dist 13.1611 	 precision: 0.9857 	 recall: 0.9108	 f1: 0.9468

Epoch: 14/50


                                                                                                                                     


Epoch 14/50: 	 Train Loss 0.0298 
Val loss 893661334565.2261 	 Val dist 9.1294 	 precision: 0.9775 	 recall: 0.9416	 f1: 0.9592

Epoch: 15/50


                                                                                                                                   


Epoch 15/50: 	 Train Loss 0.0288 
Val loss 10015489730.6821 	 Val dist 6.7534 	 precision: 0.9757 	 recall: 0.9577	 f1: 0.9666

Epoch: 16/50


                                                                                                                                  


Epoch 16/50: 	 Train Loss 0.0283 
Val loss 78071650093.6966 	 Val dist 3.3679 	 precision: 0.9690 	 recall: 0.9869	 f1: 0.9779
Saving model

Epoch: 17/50


                                                                                                                                 


Epoch 17/50: 	 Train Loss 0.0276 
Val loss 4730960617.1850 	 Val dist 3.0158 	 precision: 0.9556 	 recall: 0.9960	 f1: 0.9754
Saving model

Epoch: 18/50


                                                                                                                            


Epoch 18/50: 	 Train Loss 0.0271 
Val loss 33900.4359 	 Val dist 2.9937 	 precision: 0.9698 	 recall: 0.9909	 f1: 0.9802
Saving model

Epoch: 19/50


                                                                                                                                


Epoch 19/50: 	 Train Loss 0.0264 
Val loss 39376334.1991 	 Val dist 6.8974 	 precision: 0.9801 	 recall: 0.9564	 f1: 0.9681

Epoch: 20/50


                                                                                                                              


Epoch 20/50: 	 Train Loss 0.0255 
Val loss 3230013.5889 	 Val dist 2.7408 	 precision: 0.9801 	 recall: 0.9903	 f1: 0.9852
Saving model

Epoch: 21/50


                                                                                                                           


Epoch 21/50: 	 Train Loss 0.0251 
Val loss 3846.3808 	 Val dist 2.0674 	 precision: 0.9794 	 recall: 0.9939	 f1: 0.9866
Saving model

Epoch: 22/50


                                                                                                                           


Epoch 22/50: 	 Train Loss 0.0245 
Val loss 1795.1266 	 Val dist 2.3324 	 precision: 0.9791 	 recall: 0.9932	 f1: 0.9861

Epoch: 23/50


                                                                                                                         


Epoch 23/50: 	 Train Loss 0.0242 
Val loss 55.7928 	 Val dist 2.7502 	 precision: 0.9822 	 recall: 0.9892	 f1: 0.9857

Epoch: 24/50


                                                                                                                            


Epoch 24/50: 	 Train Loss 0.0234 
Val loss 16767.9548 	 Val dist 4.2303 	 precision: 0.9845 	 recall: 0.9774	 f1: 0.9809

Epoch: 25/50


                                                                                                                              


Epoch 25/50: 	 Train Loss 0.0229 
Val loss 1105838.3267 	 Val dist 2.3524 	 precision: 0.9833 	 recall: 0.9932	 f1: 0.9882

Epoch: 26/50


                                                                                                                           


Epoch 26/50: 	 Train Loss 0.0222 
Val loss 47515.5196 	 Val dist 2.2269 	 precision: 0.9750 	 recall: 0.9978	 f1: 0.9863

Epoch: 27/50


                                                                                                                             


Epoch 27/50: 	 Train Loss 0.0216 
Val loss 428356.8543 	 Val dist 1.9537 	 precision: 0.9823 	 recall: 0.9946	 f1: 0.9884
Saving model

Epoch: 28/50


                                                                                                                              


Epoch 28/50: 	 Train Loss 0.0212 
Val loss 930923.1270 	 Val dist 2.3670 	 precision: 0.9784 	 recall: 0.9932	 f1: 0.9857

Epoch: 29/50


                                                                                                                              


Epoch 29/50: 	 Train Loss 0.0207 
Val loss 1143559.4298 	 Val dist 1.7837 	 precision: 0.9809 	 recall: 0.9957	 f1: 0.9882
Saving model

Epoch: 30/50


                                                                                                                           


Epoch 30/50: 	 Train Loss 0.0204 
Val loss 4204.0245 	 Val dist 2.3635 	 precision: 0.9836 	 recall: 0.9918	 f1: 0.9877

Epoch: 31/50


                                                                                                                            


Epoch 31/50: 	 Train Loss 0.0200 
Val loss 18245.2227 	 Val dist 2.2054 	 precision: 0.9830 	 recall: 0.9946	 f1: 0.9888

Epoch: 32/50


                                                                                                                           


Epoch 32/50: 	 Train Loss 0.0194 
Val loss 7416.3854 	 Val dist 3.4900 	 precision: 0.9881 	 recall: 0.9814	 f1: 0.9848

Epoch: 33/50


                                                                                                                            


Epoch 33/50: 	 Train Loss 0.0191 
Val loss 161399.2919 	 Val dist 1.8174 	 precision: 0.9816 	 recall: 0.9971	 f1: 0.9893

Epoch: 34/50


                                                                                                                            


Epoch 34/50: 	 Train Loss 0.0187 
Val loss 13343.1785 	 Val dist 1.9232 	 precision: 0.9844 	 recall: 0.9961	 f1: 0.9902

Epoch: 35/50


                                                                                                                           


Epoch 35/50: 	 Train Loss 0.0181 
Val loss 1080.1534 	 Val dist 2.0630 	 precision: 0.9844 	 recall: 0.9957	 f1: 0.9900

Epoch: 36/50


                                                                                                                           


Epoch 36/50: 	 Train Loss 0.0177 
Val loss 5041.9540 	 Val dist 2.2711 	 precision: 0.9844 	 recall: 0.9921	 f1: 0.9882

Epoch: 37/50


                                                                                                                            


Epoch 37/50: 	 Train Loss 0.0174 
Val loss 18707.8473 	 Val dist 2.1594 	 precision: 0.9847 	 recall: 0.9936	 f1: 0.9891

Epoch: 38/50


                                                                                                                          


Epoch 38/50: 	 Train Loss 0.0168 
Val loss 161.2413 	 Val dist 1.9147 	 precision: 0.9841 	 recall: 0.9957	 f1: 0.9899

Epoch: 39/50


                                                                                                                           


Epoch 39/50: 	 Train Loss 0.0165 
Val loss 2040.7609 	 Val dist 2.0868 	 precision: 0.9826 	 recall: 0.9943	 f1: 0.9884

Epoch: 40/50


                                                                                                                             


Epoch 40/50: 	 Train Loss 0.0163 
Val loss 790191.8917 	 Val dist 2.6113 	 precision: 0.9812 	 recall: 0.9914	 f1: 0.9863

Epoch: 41/50


                                                                                                                             


Epoch 41/50: 	 Train Loss 0.0157 
Val loss 539826.7872 	 Val dist 2.2255 	 precision: 0.9830 	 recall: 0.9939	 f1: 0.9884

Epoch: 42/50


                                                                                                                             


Epoch 42/50: 	 Train Loss 0.0154 
Val loss 255820.6534 	 Val dist 1.9830 	 precision: 0.9827 	 recall: 0.9961	 f1: 0.9893

Epoch: 43/50


                                                                                                                           


Epoch 43/50: 	 Train Loss 0.0151 
Val loss 7270.3653 	 Val dist 1.9771 	 precision: 0.9841 	 recall: 0.9950	 f1: 0.9895

Epoch: 44/50


                                                                                                                            


Epoch 44/50: 	 Train Loss 0.0146 
Val loss 45022.6970 	 Val dist 2.2393 	 precision: 0.9840 	 recall: 0.9918	 f1: 0.9879

Epoch: 45/50


                                                                                                                            


Epoch 45/50: 	 Train Loss 0.0144 
Val loss 61683.2207 	 Val dist 2.1546 	 precision: 0.9847 	 recall: 0.9936	 f1: 0.9891

Epoch: 46/50


                                                                                                                             


Epoch 46/50: 	 Train Loss 0.0139 
Val loss 362226.4944 	 Val dist 1.8531 	 precision: 0.9837 	 recall: 0.9964	 f1: 0.9900

Epoch: 47/50


                                                                                                                            


Epoch 47/50: 	 Train Loss 0.0138 
Val loss 18404.1762 	 Val dist 2.0776 	 precision: 0.9823 	 recall: 0.9932	 f1: 0.9877

Epoch: 48/50


                                                                                                                             


Epoch 48/50: 	 Train Loss 0.0135 
Val loss 426734.1559 	 Val dist 2.1183 	 precision: 0.9851 	 recall: 0.9943	 f1: 0.9897

Epoch: 49/50


                                                                                                                             


Epoch 49/50: 	 Train Loss 0.0132 
Val loss 91230.2658 	 Val dist 2.2529 	 precision: 0.9819 	 recall: 0.9932	 f1: 0.9875

Epoch: 50/50


                                                                                                                             


Epoch 50/50: 	 Train Loss 0.0129 
Val loss 149974.9262 	 Val dist 2.1355 	 precision: 0.9826 	 recall: 0.9939	 f1: 0.9882




In [9]:
from visualization import *
plot_training_metrics(train_loss_res, val_loss_res, val_dist_res, precision_res, recall_res, f1_res, save_path="training_merics_eca1.jpg")


NameError: name 'plot_training_metrics' is not defined

In [15]:
val_dist_res = [a.cpu().item() for a in val_dist_res]

In [16]:
df_metrics = pd.DataFrame({
        "train_loss": train_loss_res,
        "val_loss": val_loss_res,
        "val_dist": val_dist_res,
        "precision": precision_res,
        "recall": recall_res,
        "f1": f1_res
    })
df_metrics.to_csv('training_metrics_eca1.csv', index=False)

## Test

In [None]:
model = BallTrackerNet(use_eca=use_eca).to(device)  # Ensure this matches your model
# Load the model state_dict
model.load_state_dict(torch.load(model_save_name))
model.to(device)
# Set model to evaluation mode
test_loss, test_dist, precision, recall, f1 = validate(model, test_loader, criterion, min_dist=2)
print("Test loss {:.04f} \t Test dist {:.04f} \t precision: {:.04f} \t recall: {:.04f}\t f1: {:.04f}".format(
          test_loss, test_dist, precision, recall, f1
          ))
torch.cuda.empty_cache()

  model.load_state_dict(torch.load(model_save_name))
                                                                                                                 

Val loss 0.0327 	 Val dist 0.8197 	 precision: 0.9874 	 recall: 1.0000	 f1: 0.9936




## Visualization

In [11]:
visualize_predictions(model, test_loader, output_dir="visualizations", device=device)
print("Saved comparison frames to ./visualizations/")

Saved comparison frames to ./visualizations/
