# TrackNet Tennis Ball Tracking Inference

This notebook performs tennis ball tracking inference using a pre-trained TrackNet model. It processes video clips to detect and track tennis ball positions, generating both prediction data and visualization videos.

## Overview
- Load pre-trained TrackNet model
- Process video frames to detect ball positions
- Generate tracking predictions and trajectory visualizations
- Export results to CSV and video files

## Import Required Libraries

Import all necessary Python libraries for deep learning, computer vision, and data processing:
- **PyTorch**: For neural network model loading and inference
- **OpenCV**: For video processing and image manipulation
- **NumPy**: For numerical computations
- **Pandas**: For data handling and CSV export
- **Collections**: For efficient frame buffering
- **tqdm**: For progress bars during processing

In [11]:
import torch
import torch.nn as nn
import cv2
import numpy as np
from pathlib import Path
import collections
from tqdm import tqdm
import pandas as pd

## Configuration Setup

Configure all inference parameters including:
- **Model paths**: Location of pre-trained TrackNet weights
- **Input video**: Source video file for ball tracking
- **Model dimensions**: Input resolution (640x360) for processing
- **Output settings**: Paths for generated videos and CSV predictions
- **Visualization parameters**: Colors, sizes, and thickness for ball detection display

In [None]:
class InferenceConfig:
    # Path Configuration
    MODEL_PATH = Path('models/model_best.pt')
    VIDEO_CLIP_PATH = Path('videos/Clip1.mp4') 

    # Model & Input
    INPUT_WIDTH = 640
    INPUT_HEIGHT = 360
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

    # Visualization Configuration 
    OUTPUT_VIDEO_PATH = Path('videos/Clip1_tracknet_track.mp4')
    OUTPUT_TRACKING_VIDEO_PATH = Path('videos/Clip1_tracknet_trajectory.mp4')
    OUTPUT_CSV_PATH = Path('clip1_tracknet_predictions.csv')
    CIRCLE_COLOR = (0, 0, 255)  
    CIRCLE_RADIUS = 5
    CIRCLE_THICKNESS = 3

    # Trajectory trail
    TRAIL_COLOR = (0, 0, 255) 
    TRAIL_THICKNESS = 3

# Create an instance of the config
config = InferenceConfig()
print(f"Using device: {config.DEVICE}")
print(f"Loading model from: {config.MODEL_PATH}")
print(f"Processing video clip starting at: {config.VIDEO_CLIP_PATH}")
print(f"Output will be saved to: {config.OUTPUT_VIDEO_PATH}")

Using device: cuda
Loading model from: models\model_best.pt
Processing video clip starting at: videos\Clip1.mp4
Output will be saved to: videos\Clip1_tracknet_track.mp4


## TrackNet Model Architecture

Define the TrackNet neural network architecture and helper functions:

### Components:
- **ConvBlock**: Basic convolutional block with Conv2D, ReLU, and BatchNorm
- **TrackNet**: Complete encoder-decoder architecture
  - **Encoder**: VGG16-style convolutional layers with max pooling
  - **Decoder**: Upsampling layers to reconstruct heatmap predictions
- **postprocess()**: Converts model heatmap output to ball coordinates using circle detection

The model takes 3 consecutive frames (9 channels) as input and outputs a heatmap indicating ball position.

In [None]:
class ConvBlock(nn.Module):
    """
    Basic convolutional block with Conv2D + ReLU + BatchNorm.
    
    Args:
        in_channels (int): Number of input channels
        out_channels (int): Number of output channels
        kernel_size (int): Convolution kernel size (default: 3)
        pad (int): Padding size (default: 1)
        stride (int): Convolution stride (default: 1)
        bias (bool): Whether to use bias in convolution (default: True)
    """
    def __init__(self, in_channels, out_channels, kernel_size=3, pad=1, stride=1, bias=True):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=bias),
            nn.ReLU(),
            nn.BatchNorm2d(out_channels)
        )

    def forward(self, x):
        return self.block(x)

class TrackNet(nn.Module):
    """
    TrackNet model for tennis ball tracking using encoder-decoder architecture.
    
    Takes 3 consecutive frames (9 channels) as input and outputs a 256-class
    heatmap for ball localization. Architecture follows VGG16-style encoder
    with DeconvNet-style decoder for pixel-level classification.
    
    Args:
        out_channels (int): Number of output classes (default: 256)
    """
    def __init__(self, out_channels=256):
        super().__init__()
        self.out_channels = out_channels

        # Encoder: VGG16-style feature extraction
        # Block 1: 9 -> 64 channels, spatial size: 640x360 -> 320x180
        self.conv1 = ConvBlock(in_channels=9, out_channels=64)
        self.conv2 = ConvBlock(in_channels=64, out_channels=64)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        # Block 2: 64 -> 128 channels, spatial size: 320x180 -> 160x90
        self.conv3 = ConvBlock(in_channels=64, out_channels=128)
        self.conv4 = ConvBlock(in_channels=128, out_channels=128)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        # Block 3: 128 -> 256 channels, spatial size: 160x90 -> 80x45
        self.conv5 = ConvBlock(in_channels=128, out_channels=256)
        self.conv6 = ConvBlock(in_channels=256, out_channels=256)
        self.conv7 = ConvBlock(in_channels=256, out_channels=256)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)

        # Block 4: 256 -> 512 channels (bottleneck), spatial size: 80x45
        self.conv8 = ConvBlock(in_channels=256, out_channels=512)
        self.conv9 = ConvBlock(in_channels=512, out_channels=512)
        self.conv10 = ConvBlock(in_channels=512, out_channels=512)

        # Decoder: DeconvNet-style upsampling
        # Upsample 1: 80x45 -> 160x90, 512 -> 256 channels
        self.ups1 = nn.Upsample(scale_factor=2)
        self.conv11 = ConvBlock(in_channels=512, out_channels=256)
        self.conv12 = ConvBlock(in_channels=256, out_channels=256)
        self.conv13 = ConvBlock(in_channels=256, out_channels=256)

        # Upsample 2: 160x90 -> 320x180, 256 -> 128 channels
        self.ups2 = nn.Upsample(scale_factor=2)
        self.conv14 = ConvBlock(in_channels=256, out_channels=128)
        self.conv15 = ConvBlock(in_channels=128, out_channels=128)

        # Upsample 3: 320x180 -> 640x360, 128 -> 64 -> out_channels
        self.ups3 = nn.Upsample(scale_factor=2)
        self.conv16 = ConvBlock(in_channels=128, out_channels=64)
        self.conv17 = ConvBlock(in_channels=64, out_channels=64)
        self.conv18 = ConvBlock(in_channels=64, out_channels=self.out_channels)

        self.softmax = nn.Softmax(dim=1)
        self._init_weights()
                  
    def forward(self, x):
        """
        Forward pass through TrackNet.
        
        Args:
            x (torch.Tensor): Input tensor of shape (N, 9, H, W)
            
        Returns:
            torch.Tensor: Output heatmap of shape (N, 256, H, W)
        """
        # Encoder path
        x = self.conv1(x)
        x = self.conv2(x)    
        x = self.pool1(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.pool2(x)
        x = self.conv5(x)
        x = self.conv6(x)
        x = self.conv7(x)
        x = self.pool3(x)
        x = self.conv8(x)
        x = self.conv9(x)
        x = self.conv10(x)
        
        # Decoder path
        x = self.ups1(x)
        x = self.conv11(x)
        x = self.conv12(x)
        x = self.conv13(x)
        x = self.ups2(x)
        x = self.conv14(x)
        x = self.conv15(x)
        x = self.ups3(x)
        x = self.conv16(x)
        x = self.conv17(x)
        x = self.conv18(x)

        return x
    
    def _init_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Conv2d):
                nn.init.uniform_(module.weight, -0.05, 0.05)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)

            elif isinstance(module, nn.BatchNorm2d):
                nn.init.constant_(module.weight, 1)
                nn.init.constant_(module.bias, 0)    
    
def postprocess(heatmap):
    """
    MODIFIED to return center_x, center_y, and radius.
    """
    heatmap = heatmap.astype(np.uint8)
    ret, binary_heatmap = cv2.threshold(heatmap, 127, 255, cv2.THRESH_BINARY)
    circles = cv2.HoughCircles(binary_heatmap, cv2.HOUGH_GRADIENT, dp=1, minDist=1,
                               param1=50, param2=2, minRadius=2, maxRadius=7)
    
    if circles is not None and len(circles) == 1:
        x = circles[0][0][0]
        y = circles[0][0][1]
        r = circles[0][0][2]
        return x, y, r
    
    return None, None, None

print("Model and helper functions defined.")

Model and helper functions defined.


## Model Loading and Initialization

Load the pre-trained TrackNet model:
1. **Instantiate** the TrackNet architecture
2. **Load weights** from the saved model file (`model_best.pt`)
3. **Set evaluation mode** to disable training-specific behaviors (dropout, batch norm updates)
4. **Error handling** to ensure model file exists before proceeding

In [None]:
# Instantiate the model
model = TrackNet().to(config.DEVICE)

# Load the Best Model's Weights
try:
    model.load_state_dict(torch.load(config.MODEL_PATH, map_location=config.DEVICE))
    print("Model weights loaded successfully!")
except FileNotFoundError:
    print(f"Error: Model file not found at {config.MODEL_PATH}. Please check the path in the config.")
    # Stop execution if model not found
    assert False

# Set the model to evaluation mode
model.eval()
print("Model is in evaluation mode.")

Model weights loaded successfully!
Model is in evaluation mode.


## Video Processing and Prediction Generation

Process the input video to generate ball position predictions:

### Process:
1. **Open video** and extract metadata (frames, dimensions)
2. **Frame buffering**: Maintain sliding window of 3 consecutive frames
3. **Model inference**: Process frame triplets through TrackNet
4. **Coordinate extraction**: Convert heatmap predictions to ball positions
5. **Scale conversion**: Transform coordinates from model input size to original video dimensions
6. **Data collection**: Store predictions in structured format for CSV export

### Output:
- CSV file with frame-by-frame ball coordinates
- Bounding box format compatible with tracking evaluation

In [None]:
# Initialize video capture and validate file
cap = cv2.VideoCapture(str(config.VIDEO_CLIP_PATH))
if not cap.isOpened():
    print(f"Error: Could not open video file at {config.VIDEO_CLIP_PATH}")
    assert False

# Extract video properties
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
original_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
original_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

# Initialize frame buffer and prediction storage
frame_buffer = collections.deque(maxlen=3)  # Rolling window of 3 frames
predictions_list = []
frame_id_counter = 0

print(f"\nProcessing video: {total_frames} frames, {original_width}x{original_height}")
pbar = tqdm(total=total_frames, desc="Generating Predictions")

while cap.isOpened():
    ret, frame = cap.read()
    if not ret:
        break
    
    # Resize frame to model input dimensions
    resized_frame = cv2.resize(frame, (config.INPUT_WIDTH, config.INPUT_HEIGHT))
    frame_buffer.append(resized_frame)
    
    # Process when we have 3 consecutive frames
    if len(frame_buffer) == 3:
        # Convert frames to tensor format: BGR->RGB, normalize, stack channels
        imgs_list = [torch.from_numpy(cv2.cvtColor(f, cv2.COLOR_BGR2RGB)).float() / 255.0 for f in frame_buffer]
        input_tensor = torch.cat([t.permute(2, 0, 1) for t in imgs_list], dim=0).unsqueeze(0).to(config.DEVICE)
        
        # Run model inference
        with torch.no_grad():
            predictions = model(input_tensor)
        
        # Extract ball coordinates from heatmap
        pred_heatmap = torch.argmax(predictions.squeeze(0), dim=0).cpu().numpy()
        px, py, pr = postprocess(pred_heatmap)
        
        if px is not None:
            # Scale coordinates to original video dimensions
            scale_x = original_width / config.INPUT_WIDTH
            scale_y = original_height / config.INPUT_HEIGHT
            
            center_x = px * scale_x
            center_y = py * scale_y
            radius = pr * ((scale_x + scale_y) / 2.0)
            
            # Convert to bounding box format for tracking evaluation
            tracknet_x = center_x - radius
            tracknet_y = center_y - radius
            tracknet_w = 20  # Fixed bounding box size
            tracknet_h = 20
            
            predictions_list.append({
                'frame_id': frame_id_counter,
                'tracknet_x': int(tracknet_x),
                'tracknet_y': int(tracknet_y),
                'tracknet_w': int(tracknet_w),
                'tracknet_h': int(tracknet_h)
            })
        else:
            # Record NaN for undetected ball
            predictions_list.append({
                'frame_id': frame_id_counter,
                'tracknet_x': np.nan,
                'tracknet_y': np.nan,
                'tracknet_w': np.nan,
                'tracknet_h': np.nan
            })
            
    frame_id_counter += 1
    pbar.update(1)

# Clean up resources
pbar.close()
cap.release()

# Export predictions to CSV
tracknet_df = pd.DataFrame(predictions_list)
tracknet_df.to_csv(config.OUTPUT_CSV_PATH, index=False)

print(f"\n--- Prediction Generation Complete ---")
print(f"Predictions saved to: {config.OUTPUT_CSV_PATH}")
print("\nFirst 10 rows of the prediction data:")
print(tracknet_df.head(10))

# Summary statistics
detected_frames = tracknet_df['tracknet_x'].notna().sum()
print(f"\nDetected a ball in {detected_frames} out of {total_frames} frames.")


Processing video: 207 frames, 1280x720


Generating Predictions: 100%|██████████| 207/207 [00:06<00:00, 32.47it/s]


--- Prediction Generation Complete ---
Predictions saved to: clip1_tracknet_predictions.csv

First 10 rows of the prediction data:
   frame_id  tracknet_x  tracknet_y  tracknet_w  tracknet_h
0         2       594.0       384.0        20.0        20.0
1         3       594.0       364.0        20.0        20.0
2         4       592.0       348.0        20.0        20.0
3         5       590.0       332.0        20.0        20.0
4         6       590.0       316.0        20.0        20.0
5         7       590.0       304.0        20.0        20.0
6         8       590.0       294.0        20.0        20.0
7         9       588.0       284.0        20.0        20.0
8        10       586.0       272.0        20.0        20.0
9        11       586.0       266.0        20.0        20.0

Detected a ball in 186 out of 207 frames.





## Trajectory Visualization Generation

Create an annotated video showing ball tracking results:

### Features:
1. **Real-time detection**: Process each frame and detect ball position
2. **Trajectory trail**: Draw connected lines showing ball's path over time
3. **Current position**: Highlight current frame's detected ball location
4. **Video output**: Generate MP4 with overlaid tracking visualizations

### Visualization Elements:
- **Red circles**: Current ball position
- **Red trajectory lines**: Historical ball path
- **Progress tracking**: Real-time processing status

This creates a comprehensive visual representation of the tracking performance for analysis and presentation.

In [None]:
# Initialize video capture for trajectory visualization
cap = cv2.VideoCapture(str(config.VIDEO_CLIP_PATH))
if not cap.isOpened():
    print(f"Error: Could not open video file at {config.VIDEO_CLIP_PATH}")
    assert False

# Extract video properties
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
original_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
original_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = int(cap.get(cv2.CAP_PROP_FPS))

# Setup video writer for output
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
video_writer = cv2.VideoWriter(str(config.OUTPUT_TRACKING_VIDEO_PATH), fourcc, fps, (original_width, original_height))

# Initialize processing variables
frame_buffer = collections.deque(maxlen=3)  # Rolling window of 3 frames
trajectory_points = []  # Store all detected ball positions for trail

print(f"\nProcessing video: {total_frames} frames, {fps} fps, {original_width}x{original_height}")
pbar = tqdm(total=total_frames, desc="Processing Video")

while cap.isOpened():
    ret, frame = cap.read()
    if not ret:
        break

    # Resize frame for model input
    resized_frame = cv2.resize(frame, (config.INPUT_WIDTH, config.INPUT_HEIGHT))
    frame_buffer.append(resized_frame)
    
    current_ball_pos = None  # Reset for current frame
    
    # Process when we have 3 consecutive frames
    if len(frame_buffer) == 3:
        # Convert frames to tensor format: BGR->RGB, normalize, stack channels
        imgs_list = [torch.from_numpy(cv2.cvtColor(f, cv2.COLOR_BGR2RGB)).float() / 255.0 for f in frame_buffer]
        input_tensor = torch.cat([t.permute(2, 0, 1) for t in imgs_list], dim=0).unsqueeze(0).to(config.DEVICE)
        
        # Run model inference
        with torch.no_grad():
            predictions = model(input_tensor)
        
        # Extract ball coordinates from heatmap
        pred_heatmap = torch.argmax(predictions.squeeze(0), dim=0).cpu().numpy()
        px, py = postprocess(pred_heatmap)
        
        if px is not None:
            # Scale coordinates to original video dimensions
            original_x = int(px * (original_width / config.INPUT_WIDTH))
            original_y = int(py * (original_height / config.INPUT_HEIGHT))
            current_ball_pos = (original_x, original_y)
            trajectory_points.append(current_ball_pos)
            
    # Draw trajectory trail as connected lines
    if len(trajectory_points) > 1:
        for i in range(1, len(trajectory_points)):
            cv2.line(frame, trajectory_points[i - 1], trajectory_points[i], config.TRAIL_COLOR, config.TRAIL_THICKNESS)
            
    # Draw current ball position
    if current_ball_pos is not None:
        cv2.circle(frame, current_ball_pos, config.CIRCLE_RADIUS, config.CIRCLE_COLOR, config.CIRCLE_THICKNESS)

    video_writer.write(frame)
    pbar.update(1)

# Clean up resources
pbar.close()
cap.release()
video_writer.release()
print(f"\n--- Inference Complete ---")
print(f"Trajectory video saved to: {config.OUTPUT_TRACKING_VIDEO_PATH}")


Processing video: 207 frames, 30 fps, 1280x720


Processing Video: 100%|██████████| 207/207 [00:06<00:00, 30.28it/s]


--- Inference Complete ---
Trajectory video saved to: videos\Clip1_tracknet_trajectory.mp4



