# Evaluation Results

In This notebook the results on the test set are generated

### imports

In [1]:
import os
import torch.nn as nn
import torch
import torch.optim as optim
import pandas as pd
from torch.optim.lr_scheduler import ReduceLROnPlateau
import random
import numpy as np
from tqdm import tqdm
from DatasetEval import BallDataSet
from BallTrackNet import BallTrackNet
from helpers import train, validate, WeightedBCELoss, BinaryFocalLoss
import cv2
import torch.nn.functional as F
import time
from scipy.spatial import distance
import json
import math
import time

### functions

In [2]:
def custom_collate(batch):
    # Unpack batch into respective components
    images, heatmaps, points, dataframes = zip(*batch)
    
    # Stack tensors for images, heatmaps, points, and indices
    images = torch.stack(images)
    heatmaps = torch.stack(heatmaps)
    points = torch.stack(points)
    
    return images, heatmaps, points, list(dataframes)

def create_dataloader(path, heatmap_size, input_number, output_number, batch_size):
    
    test = BallDataSet(path=path, split='test', heatmap_size=heatmap_size, input_size=input_number, output_size=output_number, recreate_heatmap=False)

    test_loader = torch.utils.data.DataLoader(
        test,
        batch_size=batch_size,
        shuffle=False,
        num_workers=8,
        pin_memory=True,
        prefetch_factor=4,
        collate_fn=custom_collate
    )

    return test, test_loader

def load_model(path, input_size, output_size, gpus):
    model = BallTrackNet(input_size=input_size, output_size=output_size)
    model = nn.DataParallel(model, device_ids=gpus)
    model = model.to(f'cuda:{gpus[0]}')
    model.load_state_dict(torch.load(path))
    model.eval()
    return model

In [3]:
def calculate_position(output, threshold):
    # Binarize the heatmap
    binary_map = (output > threshold).cpu().numpy().astype(np.uint8)

    # Perform connected component analysis
    num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary_map, connectivity=8)

    if num_labels > 1:  # Exclude background label 0
        # Find the largest region (excluding background label 0)
        largest_label = 1 + np.argmax(stats[1:, cv2.CC_STAT_AREA])  # Index 1 is the first component
        largest_centroid = centroids[largest_label]
        return int(largest_centroid[0]), int(largest_centroid[1])
    else:
        # No region found
        return -1, -1
    

def calculate_frame_metrics(predicted_point, ground_truth_point, min_dist=4):
    # Check if both predicted and ground truth points are valid
    if predicted_point != (-1, -1) and ground_truth_point != (-1, -1):
        # Calculate distance
        if distance.euclidean(predicted_point, ground_truth_point) < min_dist:
            return ("tp")               # True Positive: Prediction matches ground truth
        else:
            return ("fp")               # False Positive: Prediction exists but doesn't match ground truth
    elif predicted_point != (-1, -1):
        return ("fp")                   # False Positive: Prediction exists but no ground truth
    elif ground_truth_point != (-1, -1):
        return ("fn")                   # False Negative: Ground truth exists but no prediction
    else:
        return ("tn")                   # True Negative: Neither prediction nor ground truth exists


def calculate_metrics(stats):
    tp = stats.get("tp", 0)
    tn = stats.get("tn", 0)
    fp = stats.get("fp", 0)
    fn = stats.get("fn", 0)

    # Avoid division by zero with if 
    accuracy = (tp + tn) / (tp + tn + fp + fn) if (tp + tn + fp + fn) > 0 else 0
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn ) if (tp + fn) > 0 else 0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0

    return accuracy, precision, recall, f1


def tensor_to_point(gt_point):
    # Convert tensor([x], [y]) to (x, y)
    if isinstance(gt_point, list) or isinstance(gt_point, tuple):
        return tuple(p.item() for p in gt_point)
    elif isinstance(gt_point, torch.Tensor):
        return tuple(gt_point.tolist())
    else:
        raise ValueError("Unsupported ground truth format")


def validate(model, test_loader, device, min_dist=4, threshold=0.5, criterion=WeightedBCELoss()):
    losses = []
    # Create a dictionary for tp, tn, fp, fn
    stats = {"tp": 0, "tn": 0, "fp": 0, "fn": 0}

    # Store predictions and ground truth for each frame
    frame_results = []

    model.eval()

    # Wrap val_loader with tqdm for a progress bar
    with tqdm(total=len(test_loader), desc="Validation", unit="batch") as pbar:
        for iter_id, batch in enumerate(test_loader):
            with torch.no_grad():
                inputs = batch[0].float().to(device)  # Input frames
                heatmap = batch[1].float().to(device)  # Ground truth heatmap
                points = batch[2]  # Ground truth ball positions
                dataframes_batch = batch[3]

                # Forward pass
                outputs = model(inputs)
                loss = criterion(outputs, heatmap)
                losses.append(loss.item())

                for batch_idx in range(outputs.shape[0]):
                    dataframe_entries = dataframes_batch[batch_idx].reset_index(drop=True)
                    elements_to_skip = outputs[batch_idx].shape[0]- len(dataframe_entries)
                    counter = 0
                    for frame_idx in range(outputs[batch_idx].shape[0]):
                        
                        # Process each batch separately
                        output = outputs[batch_idx][frame_idx]
                        gt_point = points[batch_idx][frame_idx]

                        #  get correct entry
                        if elements_to_skip > 0:
                            entry = None
                            elements_to_skip -= 1
                        else:
                            entry = dataframe_entries.iloc[counter]
                            counter += 1

                        gt_point_formatted = tensor_to_point(gt_point) 

                        # Calculate predicted position for the frame
                        predicted_position = calculate_position(output, threshold)

                        case = calculate_frame_metrics(predicted_position, gt_point_formatted, min_dist)

                        # Update stats
                        stats[case] += 1

                        if entry is not None:
                            # Append results for this frame
                            frame_results.append({
                                'subset': entry['subset'],
                                'video': entry['video'],
                                'clip': entry['clip'],
                                'frame': entry['frame'],
                                'points': entry['points'],
                                'window_index': entry['window_index'],
                                "predicted_position": predicted_position,
                                "ground_truth_position": gt_point_formatted
                            })

                # Update the tqdm bar
                pbar.set_postfix({'loss': round(np.mean(losses), 6)})
                pbar.update(1)

    # Calculate metrics
    accuracy, precision, recall, f1 = calculate_metrics(stats)

    return np.mean(losses), accuracy, precision, recall, f1, frame_results

def calculate_computation_time(test_loader, device, model):
    model.eval()  # Set the model to evaluation mode

    start_time = time.time()  # Record the start time
    
    with torch.no_grad():
        for batch in test_loader:
            inputs = batch[0].float().to(device)  # Input frames
            outputs = model(inputs)  # Perform inference
    
    end_time = time.time()  # Record the end time

    elapsed_time = end_time - start_time  # Calculate elapsed time
    return elapsed_time

## Evaluate models

In [4]:
gpus=[0]

## Create Results DF

In [10]:
def get_Test_Results(gpus, input_size = 3, output_size=3, augmented = True):
    dataset, test_dataloader = create_dataloader(path= "../FinalDataset", heatmap_size = 10, input_number=input_size, output_number=output_size, batch_size=32)
    
    path = f"exps/TrackNet_{input_size}-in-{output_size}-out_aug_{augmented}/model_best.pt"
    
    model = load_model(path=path, input_size=input_size, output_size=output_size, gpus=gpus)
    
    test_loss, accuracy, precision, recall, f1, frame_results = validate(model, test_dataloader, device=gpus[0])

    comp_time = calculate_computation_time(test_loader=test_dataloader , device=gpus[0], model=model)
    fps = 2877 / comp_time

    return test_loss, accuracy, precision, recall, f1, comp_time, fps, frame_results

In [11]:
versions = [
        {"augmentation": True, "input_number": 3, "output_number": 3},
        {"augmentation": False, "input_number": 3, "output_number": 3},
        {"augmentation": True, "input_number": 3, "output_number": 1},
        {"augmentation": False, "input_number": 3, "output_number": 1},
        {"augmentation": True, "input_number": 5, "output_number": 3},
        {"augmentation": False, "input_number": 5, "output_number": 3},
        {"augmentation": True, "input_number": 5, "output_number": 5},
        {"augmentation": False, "input_number": 5, "output_number": 5}
    ]

all_results = []
for version in versions:
    test_loss, accuracy, precision, recall, f1, comp_time, fps, frame_results = get_Test_Results(input_size = version["input_number"], output_size=version["output_number"], augmented = version["augmentation"], gpus = gpus)

    # create object for results
    results = {
        "version": version,
        "test_loss": test_loss,
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
        "f1": f1,
        "time": comp_time,
        "fps": fps,
        "frame_results": frame_results
    }

    all_results.append(results)
    
df_results = pd.DataFrame(all_results)

  model.load_state_dict(torch.load(path))


Windows: 961


Validation: 100%|██████████| 31/31 [00:12<00:00,  2.53batch/s, loss=4.75e+3]
  model.load_state_dict(torch.load(path))


Windows: 961


Validation: 100%|██████████| 31/31 [00:10<00:00,  2.90batch/s, loss=4.16e+3]
  model.load_state_dict(torch.load(path))


Windows: 2877


Validation: 100%|██████████| 90/90 [00:24<00:00,  3.75batch/s, loss=1.8e+3] 
  model.load_state_dict(torch.load(path))


Windows: 2877


Validation: 100%|██████████| 90/90 [00:23<00:00,  3.79batch/s, loss=2.29e+3]
  model.load_state_dict(torch.load(path))


Windows: 961


Validation: 100%|██████████| 31/31 [00:14<00:00,  2.13batch/s, loss=4.22e+3]
  model.load_state_dict(torch.load(path))


Windows: 961


Validation: 100%|██████████| 31/31 [00:11<00:00,  2.78batch/s, loss=4.89e+3]
  model.load_state_dict(torch.load(path))


Windows: 580


Validation: 100%|██████████| 19/19 [00:08<00:00,  2.17batch/s, loss=8.5e+3] 
  model.load_state_dict(torch.load(path))


Windows: 580


Validation: 100%|██████████| 19/19 [00:08<00:00,  2.18batch/s, loss=6.63e+3]


In [12]:
df_results

Unnamed: 0,version,test_loss,accuracy,precision,recall,f1,time,fps,frame_results
0,"{'augmentation': True, 'input_number': 3, 'out...",4753.948033,0.872008,0.942683,0.901578,0.921673,8.694051,330.91595,"[{'subset': 'New', 'video': 'Video_1', 'clip':..."
1,"{'augmentation': False, 'input_number': 3, 'ou...",4161.026125,0.851197,0.923313,0.892418,0.907603,8.86119,324.674231,"[{'subset': 'New', 'video': 'Video_1', 'clip':..."
2,"{'augmentation': True, 'input_number': 3, 'out...",1795.755328,0.861661,0.946661,0.883312,0.91389,21.28565,135.161485,"[{'subset': 'New', 'video': 'Video_1', 'clip':..."
3,"{'augmentation': False, 'input_number': 3, 'ou...",2292.572355,0.808481,0.901713,0.857265,0.878928,21.145362,136.058206,"[{'subset': 'New', 'video': 'Video_1', 'clip':..."
4,"{'augmentation': True, 'input_number': 5, 'out...",4217.184523,0.829344,0.9058,0.882703,0.894102,9.226051,311.834388,"[{'subset': 'New', 'video': 'Video_1', 'clip':..."
5,"{'augmentation': False, 'input_number': 5, 'ou...",4885.998479,0.865765,0.946089,0.889859,0.917113,9.441132,304.730405,"[{'subset': 'New', 'video': 'Video_1', 'clip':..."
6,"{'augmentation': True, 'input_number': 5, 'out...",8495.575709,0.88069,0.9574,0.897489,0.926477,6.4216,448.019195,"[{'subset': 'New', 'video': 'Video_1', 'clip':..."
7,"{'augmentation': False, 'input_number': 5, 'ou...",6633.399764,0.862414,0.928326,0.903132,0.915556,6.857523,419.539229,"[{'subset': 'New', 'video': 'Video_1', 'clip':..."


In [None]:
df_results.to_csv("../results/Pre_Results_TN.csv", index=False)