In [1]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [2]:
import cv2
import numpy as np
import torch

HEIGHT = 288
WIDTH = 512

def get_model(model_name, num_frame, input_type):
    """ Create model by name and configuration parameters. """
    if model_name == 'TrackNetV2':
        from model import TrackNetV2 as TrackNet
    if model_name in ['TrackNetV2']:
        model = TrackNet(in_dim=num_frame*3, out_dim=num_frame)
    return model

def get_frame_unit(frame_list, num_frame):
    """ Generate input sequences from video frames. """
    batch = []
    h, w, _ = frame_list[0].shape
    h_ratio, w_ratio = h / HEIGHT, w / WIDTH

    def get_unit(frames):
        """ Resize and normalize frames. """
        resized_frames = np.array([]).reshape(0, HEIGHT, WIDTH)
        for img in frames:
            img = cv2.resize(img, (WIDTH, HEIGHT))
            img = np.moveaxis(img, -1, 0)
            resized_frames = np.concatenate((resized_frames, img), axis=0)
        return resized_frames

    for i in range(0, len(frame_list), num_frame):
        frames = get_unit(frame_list[i:i+num_frame]) / 255.0
        batch.append(frames)

    return torch.FloatTensor(np.array(batch))

def get_object_center(heatmap):
    """ Extract ball center coordinates from the heatmap. """
    if np.amax(heatmap) == 0:
        return 0, 0  # No detection
    else:
        contours, _ = cv2.findContours(heatmap.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        rects = [cv2.boundingRect(ctr) for ctr in contours]
        max_rect = max(rects, key=lambda r: r[2] * r[3])
        return int(max_rect[0] + max_rect[2] / 2), int(max_rect[1] + max_rect[3] / 2)


In [3]:
import torch
import torch.nn as nn

class ChannelAttentionModule(nn.Module):
    def __init__(self, channel, ratio=16):
        super(ChannelAttentionModule, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.shared_MLP = nn.Sequential(
            nn.Conv2d(channel, channel // ratio, 1, bias=False),
            nn.ReLU(),
            nn.Conv2d(channel // ratio, channel, 1, bias=False)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avgout = self.shared_MLP(self.avg_pool(x))
        maxout = self.shared_MLP(self.max_pool(x))
        return self.sigmoid(avgout + maxout)

class SpatialAttentionModule(nn.Module):
    def __init__(self):
        super(SpatialAttentionModule, self).__init__()
        self.conv2d = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=7, stride=1, padding=3)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avgout = torch.mean(x, dim=1, keepdim=True)
        maxout, _ = torch.max(x, dim=1, keepdim=True)
        out = torch.cat([avgout, maxout], dim=1)
        out = self.sigmoid(self.conv2d(out))
        return out

class CBAM(nn.Module):
    def __init__(self, channel):
        super(CBAM, self).__init__()
        self.channel_attention = ChannelAttentionModule(channel)
        self.spatial_attention = SpatialAttentionModule()

    def forward(self, x):
        out = self.channel_attention(x) * x
        # out = self.spatial_attention(out) * out
        return out

class Conv2DBlock(nn.Module):
    """ Conv + ReLU + BN"""
    def __init__(self, in_dim, out_dim, kernel_size, padding='same', bias=True, **kwargs):
        super(Conv2DBlock, self).__init__(**kwargs)
        self.conv = nn.Conv2d(in_dim, out_dim, kernel_size=kernel_size, padding=padding, bias=bias)
        self.bn = nn.BatchNorm2d(out_dim)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

class Double2DConv(nn.Module):
    """ Conv2DBlock x 2"""
    def __init__(self, in_dim, out_dim):
        super(Double2DConv, self).__init__()
        self.conv_1 = Conv2DBlock(in_dim, out_dim, (3, 3))
        self.conv_2 = Conv2DBlock(out_dim, out_dim, (3, 3))

    def forward(self, x):
        x = self.conv_1(x)
        x = self.conv_2(x)
        return x

class Double2DConv2(nn.Module):
    """ Conv2DBlock x 2"""
    def __init__(self, in_dim, out_dim):
        super(Double2DConv2, self).__init__()
        self.conv_1 = Conv2DBlock(in_dim, out_dim, (1, 1))
        self.conv_2 = Conv2DBlock(out_dim, out_dim, (3, 3))

        self.conv_3 = Conv2DBlock(in_dim, out_dim, (3, 3))
        self.conv_4 = Conv2DBlock(out_dim, out_dim, (3, 3))

        self.conv_5 = Conv2DBlock(in_dim, out_dim, (5, 5))
        self.conv_6 = Conv2DBlock(out_dim, out_dim, (3, 3))

        self.conv_7 = Conv2DBlock(out_dim*3, out_dim, (3, 3))

    def forward(self, x):
        x1 = self.conv_1(x)
        x1 = self.conv_2(x1)

        x2 = self.conv_3(x)
        x2 = self.conv_4(x2)

        x3 = self.conv_5(x)
        x3 = self.conv_6(x3)

        x = torch.cat([x1, x2, x3], dim=1)

        x = self.conv_7(x)
        x = x + x2

        return x

class Triple2DConv(nn.Module):
    def __init__(self, in_dim, out_dim):
        super(Triple2DConv, self).__init__()
        self.conv_1 = Conv2DBlock(in_dim, out_dim, (3, 3))
        self.conv_2 = Conv2DBlock(out_dim, out_dim, (3, 3))
        self.conv_3 = Conv2DBlock(out_dim, out_dim, (3, 3))

    def forward(self, x):
        x = self.conv_1(x)
        x = self.conv_2(x)
        x = self.conv_3(x)
        return x

class TrackNetV2(nn.Module):
    """ Original structure but less two layers
        Total params: 10,161,411
        Trainable params: 10,153,859
        Non-trainable params: 7,552
    """
    def __init__(self, in_dim=9, out_dim=3):
        super(TrackNetV2, self).__init__()
        self.down_block_1 = Double2DConv2(in_dim=in_dim, out_dim=64)
        self.down_block_2 = Double2DConv2(in_dim=64, out_dim=128)
        self.down_block_3 = Double2DConv2(in_dim=128, out_dim=256)
        self.bottleneck = Triple2DConv(in_dim=256, out_dim=512)
        self.up_block_1 = Double2DConv(in_dim=768, out_dim=256)
        self.up_block_2 = Double2DConv(in_dim=384, out_dim=128)
        self.up_block_3 = Double2DConv(in_dim=192, out_dim=64)
        self.predictor = nn.Conv2d(64, out_dim, (1, 1))
        self.sigmoid = nn.Sigmoid()
        self.cbam1 = CBAM(channel=256) #only channel attention
        self.cbam2 = CBAM(channel=128)
        self.cbam3 = CBAM(channel=64)

        self.cbam0_2 = CBAM(channel=256)
        self.cbam1_2 = CBAM(channel=128)
        self.cbam2_2 = CBAM(channel=64)

    def forward(self, x):
        """ model input shape: (F*3, 288, 512), output shape: (F, 288, 512) """
        x1 = self.down_block_1(x)                                   # (64, 288, 512)
        x = nn.MaxPool2d((2, 2), stride=(2, 2))(x1)                 # (64, 144, 256)
        x2 = self.down_block_2(x)                                   # (128, 144, 256)
        x = nn.MaxPool2d((2, 2), stride=(2, 2))(x2)                 # (128, 72, 128)
        x3 = self.down_block_3(x)                                   # (256, 72, 128), one less conv layer
        x = nn.MaxPool2d((2, 2), stride=(2, 2))(x3)                 # (256, 36, 64)
        x = self.bottleneck(x)                                      # (512, 36, 64)
        x3 = self.cbam0_2(x3)
        x = torch.cat([nn.Upsample(scale_factor=2)(x), x3], dim=1)  # (768, 72, 128) 256+512

        x = self.up_block_1(x)                                      # (256, 72, 128), one less conv layer
        x = self.cbam1(x)
        x2 = self.cbam1_2(x2)
        x = torch.cat([nn.Upsample(scale_factor=2)(x), x2], dim=1)  # (384, 144, 256) 256+128

        x = self.up_block_2(x)                                      # (128, 144, 256)
        x = self.cbam2(x)
        x1 = self.cbam2_2(x1)
        x = torch.cat([nn.Upsample(scale_factor=2)(x), x1], dim=1)  # (192, 288, 512) 128+64

        x = self.up_block_3(x)                                      # (64, 288, 512)
        x = self.cbam3(x)
        x = self.predictor(x)                                       # (3, 288, 512)
        x = self.sigmoid(x)
        return  x


from torchsummary import summary
Tr = TrackNetV2().cuda()
summary(Tr, (9, 288, 512))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 288, 512]             640
       BatchNorm2d-2         [-1, 64, 288, 512]             128
              ReLU-3         [-1, 64, 288, 512]               0
       Conv2DBlock-4         [-1, 64, 288, 512]               0
            Conv2d-5         [-1, 64, 288, 512]          36,928
       BatchNorm2d-6         [-1, 64, 288, 512]             128
              ReLU-7         [-1, 64, 288, 512]               0
       Conv2DBlock-8         [-1, 64, 288, 512]               0
            Conv2d-9         [-1, 64, 288, 512]           5,248
      BatchNorm2d-10         [-1, 64, 288, 512]             128
             ReLU-11         [-1, 64, 288, 512]               0
      Conv2DBlock-12         [-1, 64, 288, 512]               0
           Conv2d-13         [-1, 64, 288, 512]          36,928
      BatchNorm2d-14         [-1, 64, 2

In [None]:
import os
import cv2
import torch
# from utils import get_model, get_frame_unit, get_object_center

# Directly assign values to variables
video_file = '/content/drive/MyDrive/test-video.mp4'  # Specify the path to your input video file
model_file = '/content/drive/MyDrive/tracknet-v3-pretrained-model.pt'  # Path to your model file
num_frame = 3  # Number of frames to process at a time
batch_size = 1  # Batch size for processing
save_dir = '/content/drive/MyDrive/pred_result'  # Directory to save the output

# Extract video name and format
video_name = os.path.splitext(os.path.basename(video_file))[0]
video_format = os.path.splitext(video_file)[1][1:]
out_video_file = f'{save_dir}/{video_name}_pred.{video_format}'
out_csv_file = f'{save_dir}/{video_name}_ball.csv'

# Load model checkpoint
checkpoint = torch.load(model_file)
param_dict = checkpoint['param_dict']
model_name = param_dict['model_name']
num_frame = param_dict['num_frame']
input_type = param_dict['input_type']

# Create the output directory if it doesn't exist
os.makedirs(save_dir, exist_ok=True)

# Load model
model = TrackNetV2(in_dim=num_frame * 3, out_dim=num_frame).cuda()  # Manually create model
checkpoint = torch.load(model_file)  # Load weights
model.load_state_dict(checkpoint['model_state_dict'])  # Apply weights
model.eval()

# Video output configuration
if video_format == 'avi':
    fourcc = cv2.VideoWriter_fourcc(*'DIVX')
elif video_format == 'mp4':
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
else:
    raise ValueError('Invalid video format.')

# Write CSV file header
with open(out_csv_file, 'w') as f:
    f.write('Frame,Visibility,X,Y\n')

# Video capture configuration
cap = cv2.VideoCapture(video_file)
fps = int(cap.get(cv2.CAP_PROP_FPS))
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
success = True
frame_count = 0
num_final_frame = 0
ratio = h / HEIGHT
out = cv2.VideoWriter(out_video_file, fourcc, fps, (w, h))

while success:
    print(f'Number of sampled frames: {frame_count}')
    # Sample frames to form input sequence
    frame_queue = []
    for _ in range(num_frame * batch_size):
        success, frame = cap.read()
        if not success:
            break
        else:
            frame_count += 1
            frame_queue.append(frame)

    if not frame_queue:
        break

    # If mini-batch is incomplete
    if len(frame_queue) % num_frame != 0:
        frame_queue = []
        # Record the length of remaining frames
        num_final_frame = len(frame_queue) + 1
        print(num_final_frame)
        # Adjust the sample timestamp of cap
        frame_count = frame_count - num_frame * batch_size
        cap.set(cv2.CAP_PROP_POS_FRAMES, frame_count)
        # Re-sample mini-batch
        for _ in range(num_frame * batch_size):
            success, frame = cap.read()
            if not success:
                break
            else:
                frame_count += 1
                frame_queue.append(frame)
        if len(frame_queue) % num_frame != 0:
            continue

    x = get_frame_unit(frame_queue, num_frame)

    # Inference
    with torch.no_grad():
        y_pred = model(x.cuda())
    y_pred = y_pred.detach().cpu().numpy()
    h_pred = (y_pred > 0.5).astype('uint8') * 255
    h_pred = h_pred.reshape(-1, HEIGHT, WIDTH)

    for i in range(h_pred.shape[0]):
        if num_final_frame > 0 and i < (num_frame * batch_size - num_final_frame - 1):
            print('Skipping frame already written to output video.')
            continue
        else:
            img = frame_queue[i].copy()
            cx_pred, cy_pred = get_object_center(h_pred[i])
            cx_pred, cy_pred = int(ratio * cx_pred), int(ratio * cy_pred)
            vis = 1 if cx_pred > 0 and cy_pred > 0 else 0
            # Write prediction result
            with open(out_csv_file, 'a') as f:
                f.write(f'{frame_count - (num_frame * batch_size) + i},{vis},{cx_pred},{cy_pred}\n')
            if cx_pred != 0 or cy_pred != 0:
                cv2.circle(img, (cx_pred, cy_pred), 5, (0, 0, 255), -1)
            out.write(img)

out.release()
print('Done.')


KeyboardInterrupt: 

**Performance** **Metrics**

In [None]:
import pandas as pd
import numpy as np
from sklearn.metrics import mean_squared_error, mean_absolute_error, accuracy_score, precision_score, recall_score, f1_score

# File paths
pred_csv = "/content/drive/MyDrive/pred_result/test-video_ball.csv"  # Predicted CSV
target_csv = "/content/drive/MyDrive/target.csv"  # Manually labeled CSV

# Load CSV files
pred_df = pd.read_csv(pred_csv)
target_df = pd.read_csv(target_csv)

# Ensure both dataframes are sorted by frame number
pred_df = pred_df.sort_values(by="Frame").reset_index(drop=True)
target_df = target_df.sort_values(by="Frame").reset_index(drop=True)

# Merge predictions with ground truth based on frame numbers
df = pd.merge(target_df, pred_df, on="Frame", suffixes=("_true", "_pred"))

# Compute errors in X and Y coordinates (only for visible frames)
valid_rows = df["Visibility_true"] == 1  # Consider only frames where the ball is visible

mse_x = mean_squared_error(df.loc[valid_rows, "X_true"], df.loc[valid_rows, "X_pred"])
mse_y = mean_squared_error(df.loc[valid_rows, "Y_true"], df.loc[valid_rows, "Y_pred"])
mae_x = mean_absolute_error(df.loc[valid_rows, "X_true"], df.loc[valid_rows, "X_pred"])
mae_y = mean_absolute_error(df.loc[valid_rows, "Y_true"], df.loc[valid_rows, "Y_pred"])

# Compute visibility classification metrics
accuracy = accuracy_score(df["Visibility_true"], df["Visibility_pred"])
precision = precision_score(df["Visibility_true"], df["Visibility_pred"])
recall = recall_score(df["Visibility_true"], df["Visibility_pred"])
f1 = f1_score(df["Visibility_true"], df["Visibility_pred"])

# Print results
print(f"Visibility Accuracy: {accuracy:.3f}")
print(f"Precision: {precision:.3f}, Recall: {recall:.3f}, F1-score: {f1:.3f}")
print(f"Mean Squared Error (X): {mse_x:.3f}, Mean Squared Error (Y): {mse_y:.3f}")
print(f"Mean Absolute Error (X): {mae_x:.3f}, Mean Absolute Error (Y): {mae_y:.3f}")


Visibility Accuracy: 0.839
Precision: 0.877, Recall: 0.940, F1-score: 0.908
Mean Squared Error (X): 1000252.585, Mean Squared Error (Y): 101904.648
Mean Absolute Error (X): 946.832, Mean Absolute Error (Y): 269.149


In [4]:
import torch
from torch.utils.data import Dataset
from PIL import Image
import os
import pandas as pd
from torchvision import transforms
from sklearn.model_selection import train_test_split
import numpy as np

def generate_gaussian_heatmap(height, width, center, sigma=5):
    """
    Generate a 2D Gaussian heatmap.

    Args:
        height (int): Height of the heatmap.
        width (int): Width of the heatmap.
        center (tuple): (x, y) coordinates of the Gaussian center.
        sigma (float): Standard deviation of the Gaussian.

    Returns:
        torch.Tensor: Heatmap of shape (1, height, width) with values in [0, 1].
    """
    x_coord = np.arange(width)
    y_coord = np.arange(height)
    x_grid, y_grid = np.meshgrid(x_coord, y_coord)

    # Unpack center coordinates
    cx, cy = center
    # Calculate the Gaussian
    gaussian = np.exp(-((x_grid - cx) ** 2 + (y_grid - cy) ** 2) / (2 * sigma ** 2))

    # Normalize to [0,1]
    gaussian = (gaussian - gaussian.min()) / (gaussian.max() - gaussian.min() + 1e-8)

    # Return as tensor with shape (1, height, width)
    return torch.tensor(gaussian, dtype=torch.float32).unsqueeze(0)

class VideoFrameDataset(Dataset):
    def __init__(self, image_dir, annotation_file, split="train", transform=None, test_size=0.2, random_state=42, sigma=5):
        self.image_dir = image_dir
        self.transform = transform
        self.sigma = sigma

        # Load annotations from CSV
        annotations = pd.read_csv(annotation_file)
        # Sort by Frame number to ensure sequential order
        annotations = annotations.sort_values("Frame").reset_index(drop=True)

        # Split into train and test
        train_data, test_data = train_test_split(annotations, test_size=test_size, random_state=random_state)
        self.annotations = train_data if split == "train" else test_data

    def __len__(self):
        # We need three consecutive frames, so subtract 2
        return len(self.annotations) - 2

    def __getitem__(self, idx):
        frames = []
        # For each sample, stack 3 consecutive frames
        for i in range(3):
            row = self.annotations.iloc[idx + i]
            # Adjust filename format if needed (here assumed as "{Frame}.png")
            img_path = os.path.join(self.image_dir, f"{int(row['Frame'])}.png")
            image = Image.open(img_path).convert("RGB")
            if self.transform:
                image = self.transform(image)  # Expected shape: (3, H, W)
            frames.append(image)

        # Stack 3 frames along the channel dimension → (9, H, W)
        stacked_frames = torch.cat(frames, dim=0)

        # Use the keypoint from the last frame to generate the heatmap target.
        last_row = self.annotations.iloc[idx + 2]
        # The keypoint coordinates (X, Y) should be scaled according to the transformed image size.
        # Here we assume the coordinates in the CSV are given for the resized image.
        center = (last_row['X'], last_row['Y'])
        # Get height and width from one transformed image
        _, H, W = frames[0].shape
        heatmap = generate_gaussian_heatmap(H, W, center, sigma=self.sigma)

        return stacked_frames, heatmap



In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader

# Define transformations (ensure images are resized to 512x288 as per paper)
transform = transforms.Compose([
    transforms.Resize((288, 512)),
    transforms.ToTensor()
])

# Initialize dataset – note that the dataset class should convert coordinate annotations to heatmaps
# You may create a helper function to generate a Gaussian heatmap given (X, Y)
image_dir = "/content/drive/MyDrive/data/frames"
annotation_file = "/content/drive/MyDrive/data/target.csv"
train_dataset = VideoFrameDataset(image_dir, annotation_file, split="train", transform=transform)
test_dataset  = VideoFrameDataset(image_dir, annotation_file, split="test", transform=transform)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader  = DataLoader(test_dataset, batch_size=16, shuffle=False)

# Load the TrackNetV3 model (ensure your model includes the MCC encoder and improved decoder with Channel Attention)
model = TrackNetV2()
checkpoint_path = "/content/drive/MyDrive/data/tracknet-v3-pretrained-model.pt"
checkpoint = torch.load(checkpoint_path, map_location="cuda")
model.load_state_dict(checkpoint["model_state_dict"])

# Freeze layers if desired, then unfreeze the decoder layers that incorporate Channel Attention
for param in model.parameters():
    param.requires_grad = False

# For example, unfreeze the decoder layers (adjust based on your model's attribute names)
for name, param in model.named_parameters():
    if any(layer in name for layer in ['decoder', 'cbam', 'predictor']):
        param.requires_grad = True

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define Weighted BCE Loss (WBCE) as in the paper or use a custom loss function
# Here, we use a placeholder for the custom loss
def weighted_bce_loss(pred, target, weight):
    bce = - (1 - weight) * target * torch.log(pred + 1e-8) - weight * (1 - target) * torch.log(1 - pred + 1e-8)
    return torch.mean(bce)

criterion = weighted_bce_loss  # Replace with your implementation; alternatively, you can adapt nn.BCEWithLogitsLoss

optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)

num_epochs = 30  # As per the paper

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for images, heatmaps in train_loader:
        images, heatmaps = images.to(device), heatmaps.to(device)  # heatmaps are the converted ground-truth

        optimizer.zero_grad()
        outputs = model(images)  # outputs shape: (batch, 3, 288, 512) as heatmaps

        # Compute loss; set weight factor based on your strategy
        loss = criterion(outputs, heatmaps, weight=0.5)  # Example weight; adjust as needed
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss:.4f}")

# Save the fine-tuned model
torch.save(model.state_dict(), "/content/drive/MyDrive/data/transfer_tracknet_pickleball.pt")


Epoch [1/30], Loss: 0.0858
Epoch [2/30], Loss: 0.0815
Epoch [3/30], Loss: 0.0780
Epoch [4/30], Loss: 0.0753
Epoch [5/30], Loss: 0.0730
Epoch [6/30], Loss: 0.0712
Epoch [7/30], Loss: 0.0698
Epoch [8/30], Loss: 0.0687
Epoch [9/30], Loss: 0.0680
Epoch [10/30], Loss: 0.0672
Epoch [11/30], Loss: 0.0665
Epoch [12/30], Loss: 0.0660
Epoch [13/30], Loss: 0.0654
Epoch [14/30], Loss: 0.0650
Epoch [15/30], Loss: 0.0645
Epoch [16/30], Loss: 0.0641
Epoch [17/30], Loss: 0.0636
Epoch [18/30], Loss: 0.0632
Epoch [19/30], Loss: 0.0627
Epoch [20/30], Loss: 0.0622
Epoch [21/30], Loss: 0.0617
Epoch [22/30], Loss: 0.0612
Epoch [23/30], Loss: 0.0608
Epoch [24/30], Loss: 0.0604
Epoch [25/30], Loss: 0.0600
Epoch [26/30], Loss: 0.0595
Epoch [27/30], Loss: 0.0592
Epoch [28/30], Loss: 0.0589
Epoch [29/30], Loss: 0.0585
Epoch [30/30], Loss: 0.0581


In [8]:
import torch
import numpy as np

# Helper functions to compute pixel-wise accuracy and IoU.
def compute_metrics(outputs, targets, threshold=0.5):
    """
    Compute pixel-wise accuracy and IoU for a batch.
    Assumes outputs and targets are torch tensors of the same shape.
    """
    # Threshold the outputs to obtain binary predictions
    preds = (outputs > threshold).float()

    # Compute pixel-wise accuracy
    correct = (preds == targets).float().sum()
    total = torch.numel(preds)
    accuracy = correct / total

    # Compute Intersection over Union (IoU)
    # Add a small epsilon to avoid division by zero.
    eps = 1e-8
    intersection = (preds * targets).sum(dim=[1, 2, 3])
    union = (preds + targets - preds * targets).sum(dim=[1, 2, 3]) + eps
    iou = (intersection / union).mean()  # Average IoU over batch

    return accuracy.item(), iou.item()

# Evaluation function that computes loss, accuracy, and IoU.
def evaluate_model_metrics(model, data_loader, criterion, device, weight=0.5, threshold=0.5):
    model.eval()  # Set model to evaluation mode
    total_loss = 0.0
    total_acc = 0.0
    total_iou = 0.0
    total_samples = 0

    with torch.no_grad():
        for images, heatmaps in data_loader:
            images, heatmaps = images.to(device), heatmaps.to(device)
            outputs = model(images)

            # Compute loss for the batch
            loss = criterion(outputs, heatmaps, weight=weight)
            batch_size = images.size(0)
            total_loss += loss.item() * batch_size
            total_samples += batch_size

            # Compute metrics for the current batch
            acc, iou = compute_metrics(outputs, heatmaps, threshold=threshold)
            total_acc += acc * batch_size
            total_iou += iou * batch_size

    avg_loss = total_loss / total_samples
    avg_acc = total_acc / total_samples
    avg_iou = total_iou / total_samples

    return avg_loss, avg_acc, avg_iou

# Load the original model checkpoint (old model)
old_model = TrackNetV2()  # Make sure this matches your model class definition
old_checkpoint = torch.load("/content/drive/MyDrive/data/tracknet-v3-pretrained-model.pt", map_location=device)
old_model.load_state_dict(old_checkpoint["model_state_dict"])
old_model.to(device)

# Load the fine-tuned (new) model
new_model = TrackNetV2()  # Create a fresh instance if necessary
new_model.load_state_dict(torch.load("/content/drive/MyDrive/data/transfer_tracknet_pickleball.pt", map_location=device))
new_model.to(device)

# Evaluate both models on the test dataset
old_loss, old_acc, old_iou = evaluate_model_metrics(old_model, test_loader, criterion, device)
new_loss, new_acc, new_iou = evaluate_model_metrics(new_model, test_loader, criterion, device)

print(f"Old Model - Test Loss: {old_loss:.4f}, Accuracy: {old_acc:.4f}, IoU: {old_iou:.4f}")
print(f"New Model - Test Loss: {new_loss:.4f}, Accuracy: {new_acc:.4f}, IoU: {new_iou:.4f}")


Old Model - Test Loss: 0.0024, Accuracy: 0.9716, IoU: 0.0001
New Model - Test Loss: 0.0012, Accuracy: 0.9717, IoU: 0.0000
