# Soccer Ball Tracking 

In [None]:
import numpy as np
import pandas as pd
import os
import cv2
import math
import shutil
import matplotlib.pyplot as plt
from tqdm import tqdm
from PIL import Image, ImageSequence

In [None]:
# Uninstall aiohttp
!pip uninstall -y aiohttp

# Clean up residual files if any
dir_path = '/opt/conda/lib/python3.10/site-packages/aiohttp-3.9.1.dist-info'
if os.path.exists(dir_path):
    shutil.rmtree(dir_path)

# Reinstall aiohttp
!pip install aiohttp

# Verify the installation
import aiohttp
print(aiohttp.__version__)

In [2]:
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu117

!pip install tensorflow[and-cuda]

!pip install parse
import parse
import tensorflow as tf
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

Looking in indexes: https://download.pytorch.org/whl/cu117
Collecting keras<2.16,>=2.15.0 (from tensorflow[and-cuda])
  Downloading keras-2.15.0-py3-none-any.whl.metadata (2.4 kB)
Collecting nvidia-cublas-cu12==12.2.5.6 (from tensorflow[and-cuda])
  Downloading nvidia_cublas_cu12-12.2.5.6-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.2.142 (from tensorflow[and-cuda])
  Downloading nvidia_cuda_cupti_cu12-12.2.142-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cuda-nvcc-cu12==12.2.140 (from tensorflow[and-cuda])
  Downloading nvidia_cuda_nvcc_cu12-12.2.140-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.2.140 (from tensorflow[and-cuda])
  Downloading nvidia_cuda_nvrtc_cu12-12.2.140-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.2.140 (from tensorflow[and-cuda])
  Downloading nvidia_cuda_runtime_cu12-12.2.140-py3-none-manylinux1_x86_64.whl.metadata 

## Model

In [2]:
class ChannelAttentionModule(nn.Module):
    def __init__(self, channel, ratio=16):
        super(ChannelAttentionModule, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
 
        self.shared_MLP = nn.Sequential(
            nn.Conv2d(channel, channel // ratio, 1, bias=False),
            nn.ReLU(),
            nn.Conv2d(channel // ratio, channel, 1, bias=False)
        )
        self.sigmoid = nn.Sigmoid()
 
    def forward(self, x):
        avgout = self.shared_MLP(self.avg_pool(x))
        maxout = self.shared_MLP(self.max_pool(x))
        return self.sigmoid(avgout + maxout)
 
class SpatialAttentionModule(nn.Module):
    def __init__(self):
        super(SpatialAttentionModule, self).__init__()
        self.conv2d = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=7, stride=1, padding=3)
        self.sigmoid = nn.Sigmoid()
 
    def forward(self, x):
        #map尺寸不变，缩减通道
        avgout = torch.mean(x, dim=1, keepdim=True)
        maxout, _ = torch.max(x, dim=1, keepdim=True)
        out = torch.cat([avgout, maxout], dim=1)
        out = self.sigmoid(self.conv2d(out))
        return out
 
class CBAM(nn.Module):
    def __init__(self, channel):
        super(CBAM, self).__init__()
        self.channel_attention = ChannelAttentionModule(channel)
        self.spatial_attention = SpatialAttentionModule()
 
    def forward(self, x):
        out = self.channel_attention(x) * x
        # out = self.spatial_attention(out) * out
        return out

class Conv2DBlock(nn.Module):
    """ Conv + ReLU + BN"""
    def __init__(self, in_dim, out_dim, kernel_size, padding='same', bias=True, **kwargs):
        super(Conv2DBlock, self).__init__(**kwargs)
        self.conv = nn.Conv2d(in_dim, out_dim, kernel_size=kernel_size, padding=padding, bias=bias)
        self.bn = nn.BatchNorm2d(out_dim)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

class Double2DConv(nn.Module):
    """ Conv2DBlock x 2"""
    def __init__(self, in_dim, out_dim):
        super(Double2DConv, self).__init__()
        self.conv_1 = Conv2DBlock(in_dim, out_dim, (3, 3))
        self.conv_2 = Conv2DBlock(out_dim, out_dim, (3, 3))

    def forward(self, x):
        x = self.conv_1(x)
        x = self.conv_2(x)
        return x

class Double2DConv2(nn.Module):
    """ Conv2DBlock x 2"""
    def __init__(self, in_dim, out_dim):
        super(Double2DConv2, self).__init__()
        self.conv_1 = Conv2DBlock(in_dim, out_dim, (1, 1))
        self.conv_2 = Conv2DBlock(out_dim, out_dim, (3, 3))

        self.conv_3 = Conv2DBlock(in_dim, out_dim, (3, 3))
        self.conv_4 = Conv2DBlock(out_dim, out_dim, (3, 3))

        self.conv_5 = Conv2DBlock(in_dim, out_dim, (5, 5))
        self.conv_6 = Conv2DBlock(out_dim, out_dim, (3, 3))

        self.conv_7 = Conv2DBlock(out_dim*3, out_dim, (3, 3))

    def forward(self, x):
        x1 = self.conv_1(x)
        x1 = self.conv_2(x1)

        x2 = self.conv_3(x)
        x2 = self.conv_4(x2)

        x3 = self.conv_5(x)
        x3 = self.conv_6(x3)

        x = torch.cat([x1, x2, x3], dim=1)

        x = self.conv_7(x)
        x = x + x2

        return x
    
class Triple2DConv(nn.Module):
    def __init__(self, in_dim, out_dim):
        super(Triple2DConv, self).__init__()
        self.conv_1 = Conv2DBlock(in_dim, out_dim, (3, 3))
        self.conv_2 = Conv2DBlock(out_dim, out_dim, (3, 3))
        self.conv_3 = Conv2DBlock(out_dim, out_dim, (3, 3))

    def forward(self, x):
        x = self.conv_1(x)
        x = self.conv_2(x)
        x = self.conv_3(x)
        return x

class TrackNetV2(nn.Module):
    """ Original structure but less two layers 
        Total params: 10,161,411
        Trainable params: 10,153,859
        Non-trainable params: 7,552
    """
    def __init__(self, in_dim=9, out_dim=3):
        super(TrackNetV2, self).__init__()
        self.down_block_1 = Double2DConv2(in_dim=in_dim, out_dim=64)
        self.down_block_2 = Double2DConv2(in_dim=64, out_dim=128)
        self.down_block_3 = Double2DConv2(in_dim=128, out_dim=256)
        self.bottleneck = Triple2DConv(in_dim=256, out_dim=512)
        self.up_block_1 = Double2DConv(in_dim=768, out_dim=256)
        self.up_block_2 = Double2DConv(in_dim=384, out_dim=128)
        self.up_block_3 = Double2DConv(in_dim=192, out_dim=64)
        self.predictor = nn.Conv2d(64, out_dim, (1, 1))
        self.sigmoid = nn.Sigmoid()
        self.cbam1 = CBAM(channel=256) #only channel attention
        self.cbam2 = CBAM(channel=128)
        self.cbam3 = CBAM(channel=64)

        self.cbam0_2 = CBAM(channel=256)
        self.cbam1_2 = CBAM(channel=128)
        self.cbam2_2 = CBAM(channel=64)

    def forward(self, x):
        """ model input shape: (F*3, 288, 512), output shape: (F, 288, 512) """
        x1 = self.down_block_1(x)                                   # (64, 288, 512)
        x = nn.MaxPool2d((2, 2), stride=(2, 2))(x1)                 # (64, 144, 256)
        x2 = self.down_block_2(x)                                   # (128, 144, 256)
        x = nn.MaxPool2d((2, 2), stride=(2, 2))(x2)                 # (128, 72, 128)
        x3 = self.down_block_3(x)                                   # (256, 72, 128), one less conv layer
        x = nn.MaxPool2d((2, 2), stride=(2, 2))(x3)                 # (256, 36, 64)
        x = self.bottleneck(x)                                      # (512, 36, 64)
        x3 = self.cbam0_2(x3)
        x = torch.cat([nn.Upsample(scale_factor=2)(x), x3], dim=1)  # (768, 72, 128) 256+512
        
        x = self.up_block_1(x)                                      # (256, 72, 128), one less conv layer
        x = self.cbam1(x)
        x2 = self.cbam1_2(x2)
        x = torch.cat([nn.Upsample(scale_factor=2)(x), x2], dim=1)  # (384, 144, 256) 256+128
        
        x = self.up_block_2(x)                                      # (128, 144, 256)
        x = self.cbam2(x)
        x1 = self.cbam2_2(x1)
        x = torch.cat([nn.Upsample(scale_factor=2)(x), x1], dim=1)  # (192, 288, 512) 128+64
        
        x = self.up_block_3(x)                                      # (64, 288, 512)
        x = self.cbam3(x)
        x = self.predictor(x)                                       # (3, 288, 512)
        x = self.sigmoid(x)
        return  x


# from torchsummary import summary
# Tr = TrackNetV2().cuda()
# summary(Tr, (9, 288, 512))

In [3]:
HEIGHT = 288
WIDTH = 512
data_dir = '/kaggle/input/soccerball-tracking/ball_tracking'


def list_dirs(directory):
    """Return a sorted list of directory paths including input directory."""
    return sorted([os.path.normpath(os.path.join(directory, path)).replace("\\", "/") for path in os.listdir(directory)])

def get_model(model_name, num_frame, input_type):
    """ Create model by name and the configuration parameter.

        args:
            model_name - A str of model name
            num_frame - An int specifying the length of a single input sequence
            input_type - A str specifying input type
                '2d' for stacking all the frames at RGB channel dimesion result in shape (H, W, F*3)
                '3d' for stacking all the frames at extra dimesion result in shape (F, H, W, 3)

        returns:
            model - A keras.Model
            input_shape - A tuple specifying the input shape (for model.summary)
    """
    # Import model
    
    if model_name in ['TrackNetV2']:
        model = TrackNetV2(in_dim=num_frame*3, out_dim=num_frame)
    
    return model

def model_summary(model, model_name):
    total_count = 0
    total_byte_coubt = 0
    for param in model.parameters():
        total_count += param.nelement()
        total_byte_coubt += param.nelement()*param.element_size()
    print('=======================================')
    print(f'Model: {model_name}')
    print(f'Number of parameters: {total_count}.')
    print(f'Memory usage of : {total_byte_coubt/1024/1024:.4f} MB')
    print('=======================================')

def frame_first_RGB(input, input_type):
    """ Helper function for transforming x to cv image format.

        args:
            input - A numpy.ndarray of RGB image sequences with shape (N, input_shape)
            input_type - A str specifying input type
                '2d' for stacking all the frames at RGB channel dimesion result in shape (H, W, F*3)
                '3d' for stacking all the frames at extra dimesion result in shape (F, H, W, 3)

        returns:
            A numpy.ndarray of RGB image sequences with shape (N, F, H, W, 3)
    """
    assert len(input.shape) > 3
    if input_type == '2d': # (N, F*3, H ,W)
        input = np.transpose(input, (0, 2, 3, 1)) # (N, H ,W, F*3)
    else: # (N, 3, F, H ,W)
        return np.transpose(input, (0, 2, 3, 4, 1))
    
    # Case of input_type == '2d'
    num_frame = int(input.shape[-1]/3)
    tmp_img = np.array([]).reshape(0, num_frame, HEIGHT, WIDTH, 3)
    for n in range(input.shape[0]):
        tmp_frame = np.array([]).reshape(0, HEIGHT, WIDTH, 3)
        for f in range(0, input.shape[-1], 3):
            img = input[n, :, :, f:f+3]
            tmp_frame = np.concatenate((tmp_frame, img.reshape(1, HEIGHT, WIDTH, 3)), axis=0)
        tmp_img = np.concatenate((tmp_img, tmp_frame.reshape(1, num_frame, HEIGHT, WIDTH, 3)), axis=0)
    
    return tmp_img

def frame_first_RGBD(input, input_type):
    """ Helper function for transforming x to cv image format.

        args:
            input - A numpy.ndarray of RGBD image sequences with shape (N, input_shape)
            input_type - A str specifying input type
                '2d' for stacking all the frames at RGB channel dimesion result in shape (H, W, F*3)
                '3d' for stacking all the frames at extra dimesion result in shape (F, H, W, 3)

        returns:
            A numpy.ndarray of RGB image sequences with shape (N, F, H, W, 3)
    """
    assert len(input.shape) > 3
    if input_type == '2d': 
        # (N, F*4, H ,W)
        input = np.transpose(input, (0, 2, 3, 1)) # (N, H ,W, F*4)
    else: 
        # (N, 4, F, H ,W)
        input = input[:, :-1, :, :, :]
        return np.transpose(input, (0, 2, 3, 4, 1))
    
    # Case of input_type == '2d'
    num_frame = int(input.shape[-1]/4)
    tmp_img = np.array([]).reshape(0, num_frame, HEIGHT, WIDTH, 3)
    for n in range(input.shape[0]):
        tmp_frame = np.array([]).reshape(0, HEIGHT, WIDTH, 3)
        for f in range(0, input.shape[-1], 4):
            img = input[n, :, :, f:f+3]
            tmp_frame = np.concatenate((tmp_frame, img.reshape(1, HEIGHT, WIDTH, 3)), axis=0)
        tmp_img = np.concatenate((tmp_img, tmp_frame.reshape(1, num_frame, HEIGHT, WIDTH, 3)), axis=0)
    
    return tmp_img

def frame_first_Gray(input, input_type):
    """ Helper function for transforming y to cv image format.

        args:
            input - A numpy.ndarray of gray scale image sequences with shape (N, input_shape)
            input_type - A str specifying input type
                '2d' for stacking all the frames at RGB channel dimesion result in shape (H, W, F*3)
                '3d' for stacking all the frames at extra dimesion result in shape (F, H, W, 3)
        returns:
            img - A numpy.ndarray of scale imag sequences with shape (N, F, H, W)
    """
    assert len(input.shape) > 3
    if input_type == '2d':
        # (N, F, H ,W)
        return input
    else: 
        # (N, 1, F, H ,W)
        return np.squeeze(input, axis=1)

def get_num_frames(video_file):
    """ Return the number of frames in the video.

        args:
            video_file - A str of video file path with format '{data_dir}/{split}/match{match_id}/video/{rally_id}.mp4

        returns:
            A int specifying the number of frames in the video
    """
    # video_file: 
    assert video_file[-4:] == '.mp4'
    print(video_file)
    match_dir, rally_id = parse.parse('{}/video/{}.mp4', video_file) 
    frame_dir = f'{match_dir}/frame/{rally_id}'
    assert os.path.exists(frame_dir)

    return len(os.listdir(frame_dir))

def generate_frames(video_file):
    """ Sample frames from the video.

        args:
            video_file - A str of video file path with format '{data_dir}/{split}/match{match_id}/video/{rally_id}.mp4
    """
    try:
        assert video_file[-4:] == '.mp4'
        match_dir, rally_id = parse.parse('{}/video/{}.mp4', video_file)
        csv_file = f'{match_dir}/csv/{rally_id}_ball.csv'
        assert os.path.exists(video_file) and os.path.exists(csv_file)
    except:
        print(f'{video_file} no match csv file.')
        return

    frame_dir = f'{match_dir}/frame/{rally_id}'
    if not os.path.exists(frame_dir):
        # Haven't process
        os.makedirs(frame_dir)
    else:
        label_df = pd.read_csv(csv_file, encoding='utf8')
        if len(list_dirs(frame_dir)) != len(label_df):
            # Some error occur
            shutil.rmtree(frame_dir)
            os.makedirs(frame_dir)
        else:
            # Already processed.
            return

    label_df = pd.read_csv(csv_file, encoding='utf8')
    cap = cv2.VideoCapture(video_file)
    num_frames = 0
    success = True

    # Sample frames until video end or exceed the number of labels
    while success and num_frames != len(label_df):
        success, image = cap.read()
        if success:
            cv2.imwrite(f'{frame_dir}/{num_frames}.jpg', image, [cv2.IMWRITE_JPEG_QUALITY, 95])

            num_frames += 1

def get_eval_frame_pathes(tuple_array, data):
    """ Get frame pathes according to the evaluation tuple results.

        args:
            tuple_array - A numpy.ndarray of the evaluation tuple results
                each tuple specifying (sequence_id, frame_no)
            data - A dictionary which stored the information for building dataset
                data['filename']: A numpy.ndarray of frame pathe sequences with shape (N, F)
                data['coordinates']: A numpy.ndarray of coordinate sequences with shape (N, F, 2)
                data['visibility']: A numpy.ndarray of visibility sequences with shape (N, F) - 

        returns:
            A list of frame pathes
    """
    path_list = []
    for (i, f) in tqdm(tuple_array):
        path_list.append(data['filename'][i][f])
    return sorted(path_list)

def get_eval_statistic(data_dir, path_list):
    """ Count the number of frame pathes from each rally.

        args:
            data_dir - A str of the root directory of the dataset
            path_list - A list of frame pathes

        returns:
            A dictionary specipying the statistic
                each pair specifying {'{match_id}_{rally_id}': path_count}
    """
    res_dict = {}
    format_string = data_dir + '/{}/match{}/frame/{}/{}.jpg'
    for path in tqdm(path_list):
        _, m_id, c_id, _ = parse.parse(format_string, path)
        key = f'{m_id}_{c_id}'
        if key not in res_dict.keys():
            res_dict[key] = 1
        else:
            res_dict[key] += 1
    res_dict = sorted(res_dict.items(), key=lambda x:x[1], reverse=True)
    return {k: c for k, c in res_dict}

##################################  Training Functions ##################################
def WeightedBinaryCrossEntropy(y, y_pred):
    # epsilon = 1e-7
    loss = (-1)*(torch.square(1 - y_pred) * y * torch.log(torch.clamp(y_pred, 1e-7, 1)) + torch.square(y_pred) * (1 - y) * torch.log(torch.clamp(1 - y_pred, 1e-7, 1)))
    return torch.mean(loss) # (N, 3, 288, 512)

def FocalWBCE(y, y_pred):
    # epsilon = 1e-7
    gamma = 1
    loss = (-1)*(torch.square(1 - y_pred) * (torch.clamp(1 - y_pred, 1e-7, 1)** gamma) * y * torch.log(torch.clamp(y_pred, 1e-7, 1)) + torch.square(y_pred)* ((torch.clamp(y_pred, 1e-7, 1)) ** gamma) * (1 - y) * torch.log(torch.clamp(1 - y_pred, 1e-7, 1)))
    return torch.mean(loss) # (N, 3, 288, 512)

def train(epoch, model, optimizer, loss_fn, data_loader, input_type, display_step, save_dir):
    model.train()
    data_prob = tqdm(data_loader)
    epoch_loss = []
    for step, (i, x, y, c) in enumerate(data_prob):
        x, y = x.float().cuda(), y.float().cuda()
        optimizer.zero_grad()
        y_pred = model(x)
        loss = loss_fn(y, y_pred)
        epoch_loss.append(loss.item())
        loss.backward()
        optimizer.step()

        if (step + 1) % display_step == 0:
            show_prediction(x, y, y_pred, c, input_type, save_dir)
            data_prob.set_description(f'Epoch [{epoch}]')
            data_prob.set_postfix(loss=loss.item())
    return float(np.mean(epoch_loss))

def evaluation(model, data_loader, tolerance, input_type):
    model.eval()
    data_prob = tqdm(data_loader)
    TP, TN, FP1, FP2, FN = [], [], [], [], []
    for step, (i, x, y, c) in enumerate(data_prob):
        x, y = x.float().cuda(), y.float().cuda()
        with torch.no_grad():
            y_pred = model(x)
        y_pred = y_pred > 0.5
        # y_pred = y_pred > 0.4
        tp, tn, fp1, fp2, fn = get_confusion_matrix(i, y_pred, y, c, tolerance, input_type=input_type)
        TP.extend(tp)
        TN.extend(tn)
        FP1.extend(fp1)
        FP2.extend(fp2)
        FN.extend(fn)
        
        data_prob.set_description(f'Evaluation')
        data_prob.set_postfix(TP=len(TP), TN=len(TN), FP1=len(FP1), FP2=len(FP2), FN=len(FN))
    
    accuracy, precision, recall = get_metric(len(TP), len(TN), len(FP1), len(FP2), len(FN))
    print(f'\nacc: {accuracy:.4f}\tprecision: {precision:.4f}\trecall: {recall:.4f}\tTP: {len(TP)}\tTN: {len(TN)}\tFP1: {len(FP1)}\tFP2: {len(FP2)}\tFN: {len(FN)}')
    return accuracy, precision, recall, TP, TN, FP1, FP2, FN

def get_confusion_matrix(indices, y_pred, y_true, y_coor, tolerance, input_type='3d'):
    """ Helper function Generate input sequences from frames.

        args:
            indices - A tf.EagerTensor of indices for sequences
            y_pred - A tf.EagerTensor of predicted heatmap sequences
            y_true - A tf.EagerTensor of ground-truth heatmap sequences
            y_coor - A tf.EagerTensor of ground-truth coordinate sequences
            tolerance - A int speicfying the tolerance for FP1
            input_type - A str specifying input type
                '2d' for stacking all the frames at RGB channel dimesion result in shape (H, W, F*3)
                '3d' for stacking all the frames at extra dimesion result in shape (F, H, W, 3)
        returns:
            TP, TN, FP1, FP2, FN - Lists of tuples of all the prediction results
                                    each tuple specifying (sequence_id, frame_no)
    """
    TP, TN, FP1, FP2, FN = [], [], [], [], []
    y_pred, y_true = y_pred.detach().cpu().numpy(), y_true.detach().cpu().numpy()
    y_pred = frame_first_Gray(y_pred, input_type)
    y_true = frame_first_Gray(y_true, input_type)
    for n in range(y_pred.shape[0]):
        num_frame = y_pred.shape[1]
        for f in range(num_frame):
            y_p = y_pred[n][f]
            y_t = y_true[n][f]
            c_t = y_coor[n][f]
            if np.amax(y_p) == 0 and np.amax(y_t) == 0:
                # True Negative: prediction is no ball, and ground truth is no ball
                TN.append((int(indices[n]), int(f)))
            elif np.amax(y_p) > 0 and np.amax(y_t) == 0:
                # False Positive 2: prediction is ball existing, but ground truth is no ball
                FP2.append((int(indices[n]), int(f)))
            elif np.amax(y_p) == 0 and np.amax(y_t) > 0:
                # False Negative: prediction is no ball, but ground truth is ball existing
                FN.append((int(indices[n]), int(f)))
            elif np.amax(y_p) > 0 and np.amax(y_t) > 0:
                # both prediction and ground truth are ball existing
                h_pred = y_p * 255
                h_true = y_t * 255
                h_pred = h_pred.astype('uint8')
                h_true = h_true.astype('uint8')
                #h_pred
                (cnts, _) = cv2.findContours(h_pred.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
                rects = [cv2.boundingRect(ctr) for ctr in cnts]
                max_area_idx = 0
                max_area = rects[max_area_idx][2] * rects[max_area_idx][3]
                for i in range(len(rects)):
                    area = rects[i][2] * rects[i][3]
                    if area > max_area:
                        max_area_idx = i
                        max_area = area
                target = rects[max_area_idx]
                cx_pred, cy_pred = int(target[0] + target[2] / 2), int(target[1] + target[3] / 2)
                cx_true, cy_true = int(c_t[0]), int(c_t[1])
                dist = math.sqrt(pow(cx_pred-cx_true, 2)+pow(cy_pred-cy_true, 2))
                if dist > tolerance:
                    # False Positive 1: prediction is ball existing, but is too far from ground truth
                    FP1.append((int(indices[n]), int(f)))
                else:
                    # True Positive
                    TP.append((int(indices[n]), int(f)))
    return TP, TN, FP1, FP2, FN

def get_metric(TP, TN, FP1, FP2, FN):
    """ Helper function Generate input sequences from frames.

        args:
            TP, TN, FP1, FP2, FN - Each float specifying the count for each result type of prediction

        returns:
            accuracy, precision, recall - Each float specifying the value of metric
    """
    try:
        accuracy = (TP + TN) / (TP + TN + FP1 + FP2 + FN)
    except:
        accuracy = 0
    try:
        precision = TP / (TP + FP1 + FP2)
    except:
        precision = 0
    try:
        recall = TP / (TP + FN)
    except:
        recall = 0
    return accuracy, precision, recall

##################################  Prediction Functions ##################################
def get_frame_unit(frame_list, num_frame):
    """ Sample frames from the video.

        args:
            frame_list - A str of video file path with format '{data_dir}/{split}/match{match_id}/video/{rally_id}.mp4

        return:
            frames - A tf.Tensor of a mini batch input sequence
    """
    batch = []
    # Get the resize scaler
    h, w, _ = frame_list[0].shape
    h_ratio = h / HEIGHT
    w_ratio = w / WIDTH
    
    def get_unit(frame_list):
        """ Generate an input sequence from frame pathes and labels.

            args:
                frame_list - A numpy.ndarray of single frame sequence with shape (F,)

            returns:
                frames - A numpy.ndarray of resized frames with shape (H, W, 3*F)
        """
        frames = np.array([]).reshape(0, HEIGHT, WIDTH)

        # Process each frame in the sequence
        for img in frame_list:
            img = cv2.resize(img, (WIDTH, HEIGHT))
            img = np.moveaxis(img, -1, 0)
            frames = np.concatenate((frames, img), axis=0)
        
        return frames
    
    # Form a mini batch of input sequence
    for i in range(0, len(frame_list), num_frame):
        frames = get_unit(frame_list[i: i+num_frame])
        frames /= 255.
        batch.append(frames)

    batch = np.array(batch)
    return torch.FloatTensor(batch)

def get_object_center(heatmap):
    """ Get coordinates from the heatmap.

        args:
            heatmap - A numpy.ndarray of a single heatmap with shape (H, W)

        returns:
            ints specifying center coordinates of object
    """
    if np.amax(heatmap) == 0:
        # No respond in heatmap
        return 0, 0
    else:
        # Find all respond area in the heapmap
        (cnts, _) = cv2.findContours(heatmap.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        rects = [cv2.boundingRect(ctr) for ctr in cnts]

        # Find largest area amoung all contours
        max_area_idx = 0
        max_area = rects[max_area_idx][2] * rects[max_area_idx][3]
        for i in range(len(rects)):
            area = rects[i][2] * rects[i][3]
            if area > max_area:
                max_area_idx = i
                max_area = area
        target = rects[max_area_idx]
    
    return int((target[0] + target[2] / 2)), int((target[1] + target[3] / 2))

def get_pred_type(cx_pred, cy_pred, cx, cy, tolerance):
    """ Get the result type of the prediction.

        args:
            cx_pred, cy_pred - ints specifying the predicted coordinates
            cx, cy - ints specifying the ground-truth coordinates
            tolerance - A int speicfying the tolerance for FP1

        returns:
            A str specifying the result type of the prediction
    """
    pred_has_ball = False if (cx_pred == 0 and cy_pred == 0) else True
    gt_has_ball = False if (cx == 0 and cy == 0) else True
    if  not pred_has_ball and not gt_has_ball:
        return 'TN'
    elif pred_has_ball and not gt_has_ball:
        return 'FP2'
    elif not pred_has_ball and gt_has_ball:
        return 'FN'
    else:
        dist = math.sqrt(pow(cx_pred-cx, 2)+pow(cy_pred-cy, 2))
        if dist > tolerance:
            return 'FP1'
        else:
            return 'TP'

################################  Visualization Functions ################################

def plot_result(loss_list=None, train_acc_dict=None, test_acc_dict=None, num_frame=3, save_dir='', model_name=''):
    """ Plot training performance.

        args:
            loss_list - A list of epoch losses
            train_acc_dict - A dictionary which stored statistic of evaluation on training set
                structure {'TP':[], 'TN': [], 'FP1': [], 'FP2': [], 'FN': [], 'accuracy': [], 'precision': [], 'recall': []}
            test_acc_dict - A dictionary which stored statistic of evaluation on testing set
                structure {'TP':[], 'TN': [], 'FP1': [], 'FP2': [], 'FN': [], 'accuracy': [], 'precision': [], 'recall': []}
            num_frame - An int specifying the length of a single input sequence
            save_dir - A str specifying the save directory
            model_name - A str of model name
    """
    # Plot training epoch losses
    if loss_list:
        plt.title(f'{model_name} (f = {num_frame})\nTraining Loss (WBCE)')
        plt.xlabel('epoch')
        plt.ylabel('loss')
        plt.plot(loss_list)
        plt.tight_layout()
        plt.savefig(f'{save_dir}/loss.jpg')
        plt.clf()

    # Plot accuracy, precision, recall result from evaluation
    plt.title(f'{model_name} (f = {num_frame})\nPerformance')
    if test_acc_dict:
        # test_acc, test_precision, test_recall = np.max(test_acc_dict['accuracy']), np.max(test_acc_dict['precision']), np.max(test_acc_dict['recall'])
        test_acc = np.max(test_acc_dict['accuracy']) #新增的
        index_of_test = np.where(test_acc_dict['accuracy'] == test_acc)[0][0]
        test_precision = test_acc_dict['precision'][index_of_test]
        test_recall = test_acc_dict['recall'][index_of_test]


        plt.plot(test_acc_dict['accuracy'], label='test_accuracy')
        plt.plot(test_acc_dict['precision'], label='test_precision')
        plt.plot(test_acc_dict['recall'], label='test_recall')
    if train_acc_dict:
        # train_acc, train_precision, train_recall = np.max(train_acc_dict['accuracy']), np.max(train_acc_dict['precision']), np.max(train_acc_dict['recall'])
        train_acc = np.max(train_acc_dict['accuracy'])
        index_of_train = np.where(train_acc_dict['accuracy'] == train_acc)[0][0]
        train_precision = train_acc_dict['precision'][index_of_train]
        train_recall = train_acc_dict['recall'][index_of_train]
        
        plt.plot(train_acc_dict['accuracy'], label='train_accuracy')
        plt.plot(train_acc_dict['precision'], label='train_precision')
        plt.plot(train_acc_dict['recall'], label='train_recall')
        
    if train_acc_dict and test_acc_dict:
        plt.xlabel(f'epoch\ntrain  accuracy: {train_acc*100.:.2f} %  precision: {train_precision*100.:.2f} %  recall: {train_recall*100.:.2f} %\n test  accuracy: {test_acc*100.:.2f} %  precision: {test_precision*100.:.2f} %  recall: {test_recall*100.:.2f} %')
    elif test_acc_dict:
        plt.xlabel(f'epochn\n test  accuracy: {test_acc*100.:.2f} %  precision: {test_precision*100.:.2f} %  recall: {test_recall*100.:.2f} %')
    elif train_acc_dict:
        plt.xlabel(f'epochn\n test  accuracy: {train_acc*100.:.2f} %  precision: {train_precision*100.:.2f} %  recall: {train_recall*100.:.2f} %')
    else:
        pass
    plt.ylabel('metric')
    plt.ylim((0.,1.))
    plt.legend()
    plt.tight_layout()
    plt.savefig(f'{save_dir}/performance.jpg')
    plt.close()

def plot_eval_statistic(FN_res, FP1_res, FP2_res, split, save_file, figsize=(12, 5)):
    """ Plot the distribution of FN, FP1,and FP2 in all rallies.

        args:
            FN_res, FP1_res, FP2_res - Dictionaries which stored the statistic of each prediction result type
                each pair specifying {'{match_id}_{rally_id}': path_count}
            split - A str specify the split of dataset
            save_file - A str specifying the save file name
            figsize - A tuple specifying the size of figure with shape (W, H)
    """
    rally_key = sorted(FN_res.keys())
    FN_list, FP1_list, FP2_list = [], [], []
    # Ensure every rally has value
    for k in rally_key:
        if k in FN_res.keys():
            FN_list.append(FN_res[k])
        else:
            FN_list.append(0)
        if k in FP1_res.keys():
            FP1_list.append(FP1_res[k])
        else:
            FP1_list.append(0)
        if k in FP2_res.keys():
            FP2_list.append(FP2_res[k])
        else:
            FP2_list.append(0)
    
    # Plot stack bar chart
    width = 0.8
    x_tick = np.arange(len(rally_key))
    FN_list, FP1_list, FP2_list = np.array(FN_list), np.array(FP1_list), np.array(FP2_list)
    total_count = FN_list+FP1_list+FP2_list
    plt.figure(figsize=figsize)
    plt.title(f'{split} Set Error Analysis')
    plt.xlabel('clip label')
    plt.ylabel('frame count')
    plt.ylim((0.,np.max(total_count)+60))
    plt.bar(x_tick, FN_list, color='b', label='FN', width=width)
    plt.bar(x_tick, FP1_list, bottom=FN_list, color='g', label='FP1', width=width)
    plt.bar(x_tick, FP2_list, bottom=FN_list+FP1_list, color='r', label='FP2', width=width)
    plt.xticks(x_tick, rally_key, rotation=90)
    for i, c in zip(x_tick, total_count):
        plt.text(x=i-width , y=c+10 , s=c, fontsize=12)
    plt.legend()
    plt.tight_layout()
    plt.savefig(f'{save_file}.jpg')
    plt.close()

def show_prediction(x, y, y_pred, y_coor, input_type, save_dir):
    """ Visualize the inupt sequence with its predicted heatmap.
        Save as a gif image.

        args:
            x - A tf.EagerTensor of input sequences
            y - A tf.EagerTensor of ground-truth heatmap sequences
            y_pred - A tf.EagerTensor of predicted heatmap sequences
            y_coor - A tf.EagerTensor of ground-truth coordinate sequences
            input_type - A str specifying input type
                '2d' for stacking all the frames at RGB channel dimesion result in shape (H, W, F*3)
                '3d' for stacking all the frames at extra dimesion result in shape (F, H, W, 3)
            save_dir - A str specifying the save directory
    """
    imgs = []
    x, y, y_pred, y_coor = x.detach().cpu().numpy(), y.detach().cpu().numpy(), y_pred.detach().cpu().numpy(), y_coor.detach().cpu().numpy()

    # Transform to cv image format (N, F, H , W, C)
    x = frame_first_RGB(x, input_type)
    y = frame_first_Gray(y, input_type)
    y_pred = frame_first_Gray(y_pred, input_type)

    # Only plot the first sequence in the mini-batch
    x, y, y_pred, y_coor = x[0], y[0], y_pred[0], y_coor[0]
    y_map = y_pred > 0.5

    # Scale value from [0, 1] to [0, 255]
    x = x * 255
    y = y * 255
    y_p = y_pred * 255
    y_m = y_map * 255
    x = x.astype('uint8')
    y = y.astype('uint8')
    y_p = y_p.astype('uint8')
    y_m = y_m.astype('uint8')
    
    # Write image sequence to gif
    for f in range(y_coor.shape[0]):
        # Stack channels to form RGB images
        tmp_y = cv2.cvtColor(y[f], cv2.COLOR_GRAY2BGR)
        tmp_pred = cv2.cvtColor(y_p[f], cv2.COLOR_GRAY2BGR)
        tmp_map = cv2.cvtColor(y_m[f], cv2.COLOR_GRAY2BGR)
        tmp_x = x[f]
        assert tmp_x.shape == tmp_y.shape == tmp_pred.shape == tmp_map.shape

        # Mark ground-truth label
        if int(y_coor[f][0]) > 0 and int(y_coor[f][1]) > 0:
            cv2.circle(tmp_x, (int(y_coor[f][0]), int(y_coor[f][1])), 2, (255, 0, 0), -1)
        up_img = cv2.hconcat([tmp_x, tmp_y])
        down_img = cv2.hconcat([tmp_pred, tmp_map])
        img = cv2.vconcat([up_img, down_img])

        # Cast cv image to PIL image for saving gif format
        img = Image.fromarray(img)
        imgs.append(img)
        imgs[0].save(f'{save_dir}/pred_cur.gif', format='GIF', save_all=True, append_images=imgs[1:], duration=1000, loop=0)

## Dataset

In [4]:
class Badminton_Dataset(Dataset):
    def __init__(self, root_dir=data_dir, split='train', mode='2d', num_frame=3, slideing_step=1, frame_dir=None, debug=False):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.HEIGHT = 288
        self.WIDTH = 512

        self.mag = 1
        self.sigma = 8

        self.root_dir = root_dir
        self.split = split
        self.mode = mode
        self.num_frame = num_frame
        self.slideing_step = slideing_step

        if not os.path.exists(os.path.join(self.root_dir, f'f{self.num_frame}_s{self.slideing_step}_{self.split}.npz')):
            self._gen_frame_files()
        data_dict = np.load(os.path.join("/kaggle/working/", f'f{self.num_frame}_s{self.slideing_step}_{self.split}.npz'))
        
        if debug:
            num_debug = 256
            self.frame_files = data_dict['filename'][:num_debug] # (N, 3)
            self.coordinates = data_dict['coordinates'][:num_debug] # (N, 3, 2)
            self.visibility = data_dict['visibility'][:num_debug] # (N, 3)
        elif frame_dir:
            self.frame_files, self.coordinates, self.visibility = self._gen_frame_unit(frame_dir)
        else:
            self.frame_files = data_dict['filename'] # (N, 3)
            self.coordinates = data_dict['coordinates'] # (N, 3, 2)
            self.visibility = data_dict['visibility'] # (N, 3)

    def _get_rally_dirs(self):
        match_dirs = list_dirs(os.path.join(self.root_dir, self.split))
        match_dirs = sorted(match_dirs, key=lambda s: int(s.split('match')[-1]))
        rally_dirs = []
        for match_dir in match_dirs:
            rally_dir = list_dirs(os.path.join(match_dir, 'frame'))
            rally_dirs.extend(rally_dir)

        # print(rally_dirs)
        return rally_dirs

    def _gen_frame_files(self):
        rally_dirs = self._get_rally_dirs()
        frame_files = np.array([]).reshape(0, self.num_frame)
        coordinates = np.array([], dtype=np.float32).reshape(0, self.num_frame, 2)
        visibility = np.array([], dtype=np.float32).reshape(0, self.num_frame)
        # print(rally_dirs)
        # Generate input sequences from each rally
        for rally_dir in tqdm(rally_dirs):
            
            match_dir, rally_id = parse.parse('{}/frame/{}', rally_dir)
            
            csv_file = os.path.join(match_dir, 'csv', f'{rally_id}_ball.csv')
            try:
                label_df = pd.read_csv(csv_file, encoding='utf8').sort_values(by='Frame').fillna(0)
            except:
                print(f'Label file {rally_id}_ball.csv not found.')
                continue
            
            frame_file = np.array([os.path.join(rally_dir, f'{f_id}.jpg') for f_id in label_df['Frame']])
            x, y, vis = np.array(label_df['X']), np.array(label_df['Y']), np.array(label_df['Visibility'])
            assert len(frame_file) == len(x) == len(y) == len(vis)

            # Sliding on the frame sequence
            for i in range(0, len(frame_file)-self.num_frame, self.slideing_step):
                tmp_frames, tmp_coor, tmp_vis = [], [], []
                # Construct a single input sequence
                for f in range(self.num_frame):
                    if os.path.exists(frame_file[i+f]):
                        tmp_frames.append(frame_file[i+f])
                        tmp_coor.append((x[i+f], y[i+f]))
                        tmp_vis.append(vis[i+f])
                    else:
                        break
                    
                if len(tmp_frames) == self.num_frame:
                    assert len(tmp_frames) == len(tmp_coor) == len(tmp_vis)
                    frame_files = np.concatenate((frame_files, [tmp_frames]), axis=0)
                    coordinates = np.concatenate((coordinates, [tmp_coor]), axis=0)
                    visibility = np.concatenate((visibility, [tmp_vis]), axis=0)
        
        np.savez(os.path.join("/kaggle/working/", f'f{self.num_frame}_s{self.slideing_step}_{self.split}.npz'), filename=frame_files, coordinates=coordinates, visibility=visibility)
        # print(rally_dirs)
        
    def _gen_frame_unit(self, frame_dir):
        frame_files = np.array([]).reshape(0, self.num_frame)
        coordinates = np.array([], dtype=np.float32).reshape(0, self.num_frame, 2)
        visibility = np.array([], dtype=np.float32).reshape(0, self.num_frame)
        
        match_dir, rally_id = parse.parse('{}/frame/{}', frame_dir)
        csv_file = f'{match_dir}/csv/{rally_id}_ball.csv'
        label_df = pd.read_csv(csv_file, encoding='utf8').sort_values(by='Frame')
        frame_file = np.array([f'{frame_dir}/{f_id}.jpg' for f_id in label_df['Frame']])
        x, y, vis = np.array(label_df['X']), np.array(label_df['Y']), np.array(label_df['Visibility'])
        assert len(frame_file) == len(x) == len(y) == len(vis)

        # Sliding on the frame sequence
        for i in range(0, len(frame_file)-self.num_frame, self.slideing_step):
            tmp_frames, tmp_coor, tmp_vis = [], [], []
            # Construct a single input sequence
            for f in range(self.num_frame):
                if os.path.exists(frame_file[i+f]):
                    tmp_frames.append(frame_file[i+f])
                    tmp_coor.append((x[i+f], y[i+f]))
                    tmp_vis.append(vis[i+f])

            # Append the input sequence
            if len(tmp_frames) == self.num_frame:
                assert len(tmp_frames) == len(tmp_coor) == len(tmp_vis)
                frame_files = np.concatenate((frame_files, [tmp_frames]), axis=0)
                coordinates = np.concatenate((coordinates, [tmp_coor]), axis=0)
                visibility = np.concatenate((visibility, [tmp_vis]), axis=0)

        return frame_files, coordinates, visibility

    def _get_heatmap(self, cx, cy, visible):
        if not visible:
            return np.zeros((1, self.HEIGHT, self.WIDTH)) if self.mode == '2d' else np.zeros((1, 1, self.HEIGHT, self.WIDTH))
        x, y = np.meshgrid(np.linspace(1, self.WIDTH, self.WIDTH), np.linspace(1, self.HEIGHT, self.HEIGHT))
        heatmap = ((y - (cy + 1))**2) + ((x - (cx + 1))**2)
        heatmap[heatmap <= self.sigma**2] = 1.
        heatmap[heatmap > self.sigma**2] = 0.
        heatmap = heatmap * self.mag
        return heatmap.reshape(1, self.HEIGHT, self.WIDTH) if self.mode == '2d' else heatmap.reshape(1, 1, self.HEIGHT, self.WIDTH)

    def __len__(self):
        return len(self.frame_files)

    def __getitem__(self, idx):
        frame_file = self.frame_files[idx]
        coors = self.coordinates[idx]
        vis = self.visibility[idx]

        # Get the resize scaler
        h, w, _ = cv2.imread(frame_file[0]).shape
        h_ratio, w_ratio = h / self.HEIGHT, w / self.WIDTH

        # Transform the coordinate
        coors[:, 0] = coors[:, 0] / h_ratio
        coors[:, 1] = coors[:, 1] / w_ratio

        if self.mode == '2d':
            frames = np.array([]).reshape(0, self.HEIGHT, self.WIDTH)
            heatmaps = np.array([]).reshape(0, self.HEIGHT, self.WIDTH)

            for i in range(self.num_frame):
                img = tf.keras.utils.load_img(frame_file[i])
                img = tf.keras.utils.img_to_array(img.resize(size=(self.WIDTH, self.HEIGHT)))
                img = np.moveaxis(img, -1, 0)
                frames = np.concatenate((frames, img), axis=0)
                heatmap = self._get_heatmap(int(coors[i][0]), int(coors[i][1]), vis[i])
                heatmaps = np.concatenate((heatmaps, heatmap), axis=0)        
        else:
            frames = np.array([]).reshape(3, 0, self.HEIGHT, self.WIDTH)
            heatmaps = np.array([]).reshape(1, 0, self.HEIGHT, self.WIDTH)

            for i in range(self.num_frame):
                img = tf.keras.utils.load_img(frame_file[i])
                img = tf.keras.utils.img_to_array(img.resize(size=(self.WIDTH, self.HEIGHT)))
                img = np.moveaxis(img, -1, 0) 
                img = img.reshape(3, 1, self.HEIGHT, self.WIDTH)
                frames = np.concatenate((frames, img), axis=1)
                heatmap = self._get_heatmap(int(coors[i][0]), int(coors[i][1]), vis[i])
                heatmaps = np.concatenate((heatmaps, heatmap), axis=1)
        
        frames /= 255.
        return idx, frames, heatmaps, coors 

## Training

In [None]:
import os
import json
import time
import argparse
import numpy as np

import torch
from torch.utils.data import DataLoader

from tqdm import tqdm

torch.backends.cudnn.benchmark = True 

model_name = "TrackNetV2"
num_frame = 3
input_type = "2d"
epochs = 15
batch_size = 8
learning_rate = 0.001
tolerance = 4
save_dir = "exp"
resume_training = False
debug = False
save_dir = f'{save_dir}_debug' if debug else save_dir
display_step = 4 if debug else 100


if not os.path.exists(save_dir):
    os.makedirs(save_dir)

# Load dataset
print(f'Data dir: {data_dir}')
print(f'Data input type: {input_type}')
train_dataset = Badminton_Dataset(root_dir=data_dir, split='train', mode=input_type, num_frame=num_frame, slideing_step=1, debug=debug)
eval_test_dataset = Badminton_Dataset(root_dir=data_dir, split='test', mode=input_type, num_frame=num_frame, slideing_step=num_frame, debug=debug)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True, pin_memory=False) #已更改pin_memory=False
eval_loader = DataLoader(eval_test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, drop_last=False, pin_memory=False) #已更改pin_memory=False
if __name__ == '__main__':

    # create model
    model = get_model(model_name, num_frame, input_type).cuda()
    model_summary(model, model_name)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    if not resume_training:
        loss_list = []
        test_acc_dict = {'TP':[], 'TN': [], 'FP1': [], 'FP2': [], 'FN': [], 'accuracy': [], 'precision': [], 'recall': []}
        start_epoch = 0
        max_test_acc = 0.
    else:
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        loss_list = checkpoint['loss_list']
        test_acc_dict = checkpoint['test_acc']
        start_epoch = checkpoint['epoch'] + 1
        max_test_acc = np.max(test_acc_dict['accuracy'])
        print(f'Resume training from epoch {start_epoch}.')

    # training loop
    train_start_time = time.time()
    for epoch in range(start_epoch, epochs):
        start_time = time.time()
        loss = train(epoch, model, optimizer, WeightedBinaryCrossEntropy, train_loader, input_type, display_step, save_dir)
        loss_list.append(loss)
        torch.save(dict(epoch=epoch,
                        model_state_dict=model.state_dict(),
                        optimizer_state_dict=optimizer.state_dict(),
#                         param_dict=param_dict,
                        loss_list=loss_list,
                        test_acc=test_acc_dict), f'{save_dir}/model_cur.pt')

        accuracy, precision, recall, TP, TN, FP1, FP2, FN = evaluation(model, eval_loader, tolerance, input_type)
        TP, TN, FP1, FP2, FN = len(TP), len(TN), len(FP1), len(FP2), len(FN)
        #print(f'\nacc: {accuracy:.4f}\tprecision: {precision:.4f}\trecall: {recall:.4f}\tTP: {TP}\tTN: {TN}\tFP1: {FP1}\tFP2: {FP2}\tFN: {FN}')
        
        test_acc_dict['TP'].append(TP)
        test_acc_dict['TN'].append(TN)
        test_acc_dict['FP1'].append(FP1)
        test_acc_dict['FP2'].append(FP2)
        test_acc_dict['FN'].append(FN)
        test_acc_dict['accuracy'].append(accuracy)
        test_acc_dict['precision'].append(precision)
        test_acc_dict['recall'].append(recall)

        print(f'[epoch: {epoch})]\tEpoch runtime: {(time.time() - start_time) / 3600.:.2f} hrs')
        plot_result(loss_list, None, test_acc_dict, num_frame, save_dir, model_name)
        
        if test_acc_dict['accuracy'][-1] >= max_test_acc:
            max_test_acc = test_acc_dict['accuracy'][-1]
            torch.save(dict(epoch=epoch,
                            model_state_dict=model.state_dict(),
                            optimizer_state_dict=optimizer.state_dict(),
#                             param_dict=param_dict,
                            loss_list=loss_list,
                            test_acc=test_acc_dict), f'{save_dir}/model_best.pt')

    torch.save(dict(epoch=epoch,
                    model_state_dict=model.state_dict(),
                    optimizer_state_dict=optimizer.state_dict(),
#                     param_dict=param_dict,
                    loss_list=loss_list,
                    test_acc=test_acc_dict), f'{save_dir}/model_last.pt')

    print(f'runtime: {(time.time() - train_start_time) / 3600.:.2f} hrs')
    print('Done......')

## Prediction on a video using best_model

In [8]:
video_file = "/kaggle/input/tracknet-v3-model-testing/00002.mp4"
model_file = "/kaggle/input/model-for-ball/exp/model_best.pt"
num_frame = 3
batch_size = 1
save_dir = "/kaggle/working/"

video_name = video_file.split('/')[-1][:-4]
video_format = video_file.split('/')[-1][-3:]
out_video_file = f'{save_dir}/{video_name}_pred.{video_format}'
out_csv_file = f'{save_dir}/{video_name}_ball.csv'

checkpoint = torch.load(model_file)
# param_dict = checkpoint['param_dict']
model_name = "TrackNetV2"
num_frame = 3
input_type = "2d"

if not os.path.exists(save_dir):
    os.makedirs(save_dir)
 
# Load model
model = get_model(model_name, num_frame, input_type).cuda()
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# Video output configuration
if video_format == 'avi':
    fourcc = cv2.VideoWriter_fourcc(*'DIVX')
elif video_format == 'mp4':
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
else:
    raise ValueError('Invalid video format.')

# Write csv file head
f = open(out_csv_file, 'w')
f.write('Frame,Visibility,X,Y\n')

# Cap configuration
cap = cv2.VideoCapture(video_file)
fps = int(cap.get(cv2.CAP_PROP_FPS))
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
success = True
frame_count = 0
num_final_frame = 0
ratio = h / HEIGHT
out = cv2.VideoWriter(out_video_file, fourcc, fps, (w, h))

while success:
    print(f'Number of sampled frames: {frame_count}')
    # Sample frames to form input sequence
    frame_queue = []
    for _ in range(num_frame*batch_size):
        success, frame = cap.read()
        if not success:
            break
        else:
            frame_count += 1
            frame_queue.append(frame)

    if not frame_queue:
        break
    
    # If mini batch incomplete
    if len(frame_queue) % num_frame != 0:
        frame_queue = []
        # Record the length of remain frames
        num_final_frame = len(frame_queue) +1
        print(num_final_frame)
        # Adjust the sample timestampe of cap
        frame_count = frame_count - num_frame*batch_size
        cap.set(cv2.CAP_PROP_POS_FRAMES, frame_count)
        # Re-sample mini batch
        for _ in range(num_frame*batch_size):
            success, frame = cap.read()
            if not success:
                break
            else:
                frame_count += 1
                frame_queue.append(frame)
        if len(frame_queue) % num_frame != 0:
            continue
    
    x = get_frame_unit(frame_queue, num_frame)
    
    # Inference
    with torch.no_grad():
        y_pred = model(x.cuda())
    y_pred = y_pred.detach().cpu().numpy()
    h_pred = y_pred > 0.5
    h_pred = h_pred * 255.
    h_pred = h_pred.astype('uint8')
    h_pred = h_pred.reshape(-1, HEIGHT, WIDTH)
    
    
    def find_rectangle_from_intersection(intersection_x, intersection_y, height=22, width=22):
        left = intersection_x - width / 2
        top = intersection_y - height / 2
        return float(top), float(left), float(height), float(width)

    for i in range(h_pred.shape[0]):
        if num_final_frame > 0 and i < (num_frame*batch_size - num_final_frame-1):
            print('aaa')
            # Special case of last incomplete mini batch
            # Igore the frame which is already written to the output video
            continue 
        else:
            img = frame_queue[i].copy()
            cx_pred, cy_pred = get_object_center(h_pred[i])
            cx_pred, cy_pred = int(ratio*cx_pred), int(ratio*cy_pred)
            vis = 1 if cx_pred > 0 and cy_pred > 0 else 0
            # Write prediction result
            f.write(f'{frame_count-(num_frame*batch_size)+i},{vis},{cx_pred},{cy_pred}\n')
            # print(frame_count-(num_frame*batch_size)+i)
            if cx_pred != 0 or cy_pred != 0:
                top, left, height, width = find_rectangle_from_intersection(cx_pred, cy_pred)
                cv2.rectangle(img, (int(left), int(top)), (int(left + width), int(top + height)), (255, 0, 0), 2)
#                 cv2.circle(img, (cx_pred, cy_pred), 5, (0, 0, 255), -1)
            out.write(img)

out.release()
print('Done.')

Number of sampled frames: 0
Number of sampled frames: 3
Number of sampled frames: 6
Number of sampled frames: 9
Number of sampled frames: 12
Number of sampled frames: 15
Number of sampled frames: 18
Number of sampled frames: 21
Number of sampled frames: 24
Number of sampled frames: 27
Number of sampled frames: 30
Number of sampled frames: 33
Number of sampled frames: 36
Number of sampled frames: 39
Number of sampled frames: 42
Number of sampled frames: 45
Number of sampled frames: 48
Number of sampled frames: 51
Number of sampled frames: 54
Number of sampled frames: 57
Number of sampled frames: 60
Number of sampled frames: 63
Number of sampled frames: 66
Number of sampled frames: 69
Number of sampled frames: 72
Number of sampled frames: 75
Number of sampled frames: 78
Number of sampled frames: 81
Number of sampled frames: 84
Number of sampled frames: 87
Number of sampled frames: 90
Number of sampled frames: 93
Number of sampled frames: 96
Number of sampled frames: 99
Number of sampled 