In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import os
import torch
import math
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm  
import numpy as np
import gc

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

In [None]:
class PadelNetDataset(Dataset):
    def __init__(self, base_path, frame_info=3, resize_size=(360,640), transform=None):
        super().__init__()
        self.base_path = base_path
        self.new_H, self.new_W = resize_size
        self.frame_info = frame_info
        self.transform = transform or transforms.Compose([
            transforms.Resize((resize_size[0], resize_size[1])),
            transforms.ToTensor()
        ])

        self.data = []
        self.label = pd.read_csv(os.path.join(base_path, 'Label.csv'))
        clips = [f for f in sorted(os.listdir(base_path)) if f.endswith('jpg')]
        for clip_name in clips:
            clip_path = os.path.join(base_path, clip_name)

            label_df = pd.read_csv(os.path.join(base_path, 'Label.csv'))
            image_names = label_df['file name'].tolist()

            if len(image_names) < frame_info:
                continue

            for idx in range(frame_info-1, len(image_names)):
                self.data.append({
                    'clip_path': clip_path,
                    'label_df': label_df,
                    'center_idx': idx,
                })

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        entry = self.data[idx]
        clip_path = entry['clip_path']
        label_df = entry['label_df']
        center_idx = entry['center_idx']
        imgs = []
        for i in range(center_idx - self.frame_info + 1, center_idx + 1):
            img_name = label_df.iloc[i]['file name']
            img_path = os.path.join(self.base_path, img_name)
            img = Image.open(img_path).convert('RGB')
            if i == center_idx:
                orig_W, orig_H = img.size
            if self.transform: 
                img = self.transform(img)
            imgs.append(img)
           
        imgs = torch.cat(imgs, dim=0)  
        visibility = label_df.iloc[center_idx]['visibility']
        x = label_df.iloc[center_idx]['x-coordinate']
        y = label_df.iloc[center_idx]['y-coordinate']
        if math.isnan(x) or math.isnan(y):
            x_resized, y_resized = -1, -1
        else:
            x_resized = x * (self.new_W / orig_W)
            y_resized = y * (self.new_H / orig_H)

        target = torch.tensor([visibility, x_resized, y_resized], dtype=torch.float32)

        return imgs, target


In [None]:
test_path = './padel_dataset/final_testing'
resize_size = (128, 256)
batch_size = 16
frame_info = 3

In [None]:
test_dataset = PadelNetDataset(test_path, resize_size=resize_size)

In [None]:
test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=16,
    shuffle=False,
    num_workers=0,
    pin_memory=True
)

In [None]:
import torch.nn as nn
import torch
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, pad=1, stride=1, bias=True):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=bias),
            nn.ReLU(),
            nn.BatchNorm2d(out_channels)
        )

    def forward(self, x):
        return self.block(x)

class BallTrackerNet(nn.Module):
    def __init__(self, frame_info=3, out_channels=256):
        super().__init__()
        self.out_channels = out_channels
        # VGG16:generate the feature map
        self.VGG16 = nn.Sequential(
            ConvBlock(in_channels=frame_info*3, out_channels=64),
            ConvBlock(in_channels=64, out_channels=64),
            nn.MaxPool2d(kernel_size=2, stride=2),
            ConvBlock(in_channels=64, out_channels=128),
            ConvBlock(in_channels=128, out_channels=128),
            nn.MaxPool2d(kernel_size=2, stride=2),
            ConvBlock(in_channels=128, out_channels=256),
            ConvBlock(in_channels=256, out_channels=256),
            ConvBlock(in_channels=256, out_channels=256),
            nn.MaxPool2d(kernel_size=2, stride=2),
            ConvBlock(in_channels=256, out_channels=512),
            ConvBlock(in_channels=512, out_channels=512),
            ConvBlock(in_channels=512, out_channels=512)
        )
            # DeconvNet
        self.deconvnet = nn.Sequential(
            nn.Upsample(scale_factor=2),
            ConvBlock(in_channels=512, out_channels=256),
            ConvBlock(in_channels=256, out_channels=256),
            ConvBlock(in_channels=256, out_channels=256),
            nn.Upsample(scale_factor=2),
            ConvBlock(in_channels=256, out_channels=128),
            ConvBlock(in_channels=128, out_channels=128),
            nn.Upsample(scale_factor=2),
            ConvBlock(in_channels=128, out_channels=64),
            ConvBlock(in_channels=64, out_channels=64),
            ConvBlock(in_channels=64, out_channels=self.out_channels)
        )
        self._init_weights()
                  
    def forward(self, x): 
        batch_size = x.size(0)
        x = self.VGG16(x)
        x = self.deconvnet(x)
        out = x
        return out                       
    
    def _init_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Conv2d):
                nn.init.uniform_(module.weight, -0.05, 0.05)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)

            elif isinstance(module, nn.BatchNorm2d):
                nn.init.constant_(module.weight, 1)
                nn.init.constant_(module.bias, 0)  

In [None]:
def generate_heatmap(targets, H=360, W=640, sigma2=10):
    """
    targets: (batch_size, 3) -> (visibility, x, y)
    H, W: heatmap
    sigma: 
    """
    batch_size = targets.shape[0]
    heatmaps = torch.zeros((batch_size, 1, H, W), device=targets.device)
    for i in range(batch_size):
        visibility, x, y = targets[i]
        if visibility != 0:
            yy, xx = torch.meshgrid(torch.arange(H, device=device), torch.arange(W, device=device), indexing='ij')
            xx = xx.float()
            yy = yy.float()
            center_x = x.clone().detach()
            center_y = y.clone().detach()
            heatmap = torch.exp(-((xx - center_x)**2 + (yy - center_y)**2) / (2 * sigma2))
            heatmap = heatmap / heatmap.max()  # normalize
            heatmap = heatmap.clone().detach()
            heatmaps[i, 0] = heatmap

    label = (heatmaps*255).long()
    return label

In [None]:
import torch.nn.functional as F
import cv2
from scipy.spatial import distance
def postprocess_heatmap(heatmap, scale=1):
    """from heatmap using HoughCircles collect (x,y)"""
    heatmap = heatmap * 255
    heatmap = heatmap.astype(np.uint8)
    _, binary = cv2.threshold(heatmap, 127, 255, cv2.THRESH_BINARY)
    circles = cv2.HoughCircles(binary, cv2.HOUGH_GRADIENT, dp=1, minDist=1,
                               param1=50, param2=2, minRadius=2, maxRadius=7)
    if circles is not None and len(circles[0]) == 1:
        x = int(circles[0][0][0] * scale)
        y = int(circles[0][0][1] * scale)
        return x, y
    return None, None

def validate(model, val_loader, criterion, min_dist=5):
    model.eval()
    losses = []
    tp = 0.0
    fp = 0.0
    tn = 0.0
    fn = 0.0
    # batch_bar = tqdm(total=len(train_loader), dynamic_ncols=True, leave=False, position=0, desc='Validation')
    with torch.no_grad():
        for iter_id, (images, targets) in enumerate(val_loader):
            images = images.to(device)          # (B, 9, H, W)
            targets = targets.to(device)
            gt = generate_heatmap(targets,H=resize_size[0], W=resize_size[1]).squeeze(1)
            x_gt_batch = targets[:,1]  # (B,)
            y_gt_batch = targets[:, 2]  # (B,)
            vis_batch = targets[:, 0]   # (B,)

            outputs = model(images)                         # (B, 256, H, W)
            loss = criterion(outputs, gt)
            losses.append(loss.item())

            preds = torch.argmax(F.softmax(outputs, dim=1), dim=1)  # (B, H, W)
            B, H, W = preds.shape
            pred_classes_flat = preds.view(B, -1) # (B, HxW)
            max_indices = torch.argmax(pred_classes_flat, dim=1)  # (B,)
            y_pred = max_indices // W
            x_pred = max_indices % W
            # L2 distance between predicted and GT coordinates
            dist = torch.sqrt((x_pred.float() - x_gt_batch) ** 2 + (y_pred.float() - y_gt_batch) ** 2)

            # Logical mask
            pred_exists = (pred_classes_flat.max(dim=1).values > 0) 
            gt_exists = (vis_batch != 0)
            tp_mask = pred_exists & gt_exists & (dist < min_dist)
            fp_mask = pred_exists & (~gt_exists | (dist >= min_dist))
            fn_mask = (~pred_exists) & gt_exists
            tn_mask = (~pred_exists) & (~gt_exists)

            tp += tp_mask.sum().item()
            fp += fp_mask.sum().item()
            fn += fn_mask.sum().item()
            tn += tn_mask.sum().item()
            
            # batch_bar.set_postfix(loss1="{:.04f}".format(np.mean(losses)),
                                  # tp=tp, tn=tn, fp=fp, fn=fn)
            # batch_bar.update()
    eps = 1e-15
    total = tp + fp + fn + tn
    if total == 0:
        avg_loss = 0
    else:
        avg_loss = np.sum(losses) / total
    precision = tp / (tp + fp + eps)
    recall = tp / (tp + fn + eps)
    f1 = 2 * precision * recall / (precision + recall + eps)
    f1 = 2 * precision * recall / (precision + recall + eps)

    print(f'   Precision = {precision:.4f}')
    print(f'   Recall    = {recall:.4f}')
    print(f'   F1 Score  = {f1:.4f}')
    # batch_bar.close()
    return np.mean(losses), precision, recall, f1

In [None]:
criterion = nn.CrossEntropyLoss()

In [None]:
model = BallTrackerNet()
model.load_state_dict(torch.load('model_best.pth', weights_only=True))
model.to(device)

In [None]:
test_dist, precision, recall, f1 = validate(model, test_loader, criterion)
print("Val dist {:.04f} \t precision: {:.04f} \t recall: {:.04f}\t f1: {:.04f}".format(
    test_dist, precision, recall, f1
))