In [2]:
import torch
import torch.nn as nn
import cv2
from tqdm import tqdm
import numpy as np
import argparse
from itertools import groupby
from scipy.spatial import distance

In [3]:
import catboost as ctb

In [4]:
import pandas as pd

In [60]:
process_video('sample_videos\\jason_serve.mp4', 
              'output_videos\\jason_serve.avi', 
              'Ball Detection Models\\Pre-trained\\model_best.pt', 
              'Bounce Detection Models\\model_bounce.ctb',
              'Bounce Detection Models\\model_hit.ctb',
              extrapolation=True, extra=False)

100%|██████████| 121/121 [01:14<00:00,  1.63it/s]


In [29]:
def process_video(video_path, output_path, ball_model_path, bounce_model_path, hit_model_path, extrapolation, extra):
    ball_model = BallTrackerNet()
    device = 'cuda'
    ball_model.load_state_dict(torch.load(ball_model_path, map_location=device))
    ball_model = ball_model.to(device)

    frames, fps = read_video(video_path)
    ball_track, dists = infer_model(frames, ball_model)
    ball_track = remove_outliers(ball_track, dists)    
        
    if extrapolation:
        subtracks = split_track(ball_track)
        for r in subtracks:
            ball_subtrack = ball_track[r[0]:r[1]]
            ball_subtrack = interpolation(ball_subtrack)
            ball_track[r[0]:r[1]] = ball_subtrack
    
    bounce_prediction = None
    hit_prediction = None

    if extra:
        bounce_model = ctb.CatBoostRegressor()
        bounce_model.load_model(bounce_model_path)

        hit_model = ctb.CatBoostRegressor()
        hit_model.load_model(hit_model_path)
        df = pd.DataFrame(ball_track, columns = ['x-coordinate', 'y-coordinate'])

        df2 = prepare_new_data(df, 3) 
        bounce_prediction = bounce_model.predict(df2)
        hit_prediction = hit_model.predict(df2)       

    write_track(frames, ball_track, bounce_prediction, hit_prediction, output_path, fps, extra)    

In [6]:
# Assuming your trained model is loaded as `model`
# model = ... (load your trained CatBoost model here)

def create_features(df, num_frames=3):
    eps = 1e-15
    for i in range(1, num_frames):
        df['x_lag_{}'.format(i)] = df['x-coordinate'].shift(i)
        df['x_lag_inv_{}'.format(i)] = df['x-coordinate'].shift(-i)
        df['y_lag_{}'.format(i)] = df['y-coordinate'].shift(i)
        df['y_lag_inv_{}'.format(i)] = df['y-coordinate'].shift(-i)
        df['x_diff_{}'.format(i)] = abs(df['x_lag_{}'.format(i)] - df['x-coordinate'])
        df['y_diff_{}'.format(i)] = df['y_lag_{}'.format(i)] - df['y-coordinate']
        df['x_diff_inv_{}'.format(i)] = abs(df['x_lag_inv_{}'.format(i)] - df['x-coordinate'])
        df['y_diff_inv_{}'.format(i)] = df['y_lag_inv_{}'.format(i)] - df['y-coordinate']
        df['x_div_{}'.format(i)] = abs(df['x_diff_{}'.format(i)]/(df['x_diff_inv_{}'.format(i)] + eps))
        df['y_div_{}'.format(i)] = df['y_diff_{}'.format(i)]/(df['y_diff_inv_{}'.format(i)] + eps)
    
    # Drop rows with NaN values generated by lag features
    for i in range(1, num_frames):
        df = df[df['x_lag_{}'.format(i)].notna()]
        df = df[df['x_lag_inv_{}'.format(i)].notna()]
    df = df[df['x-coordinate'].notna()]  

    return df

def prepare_new_data(coordinate_list, num_frames=3):
    # Convert list to DataFrame
    df = pd.DataFrame(coordinate_list, columns=['x-coordinate', 'y-coordinate'])
    
    # Create features
    df_with_features = create_features(df, num_frames)
    
    # Define feature columns
    colnames_x = ['x_diff_{}'.format(i) for i in range(1, num_frames)] + \
                 ['x_diff_inv_{}'.format(i) for i in range(1, num_frames)] + \
                 ['x_div_{}'.format(i) for i in range(1, num_frames)]
    colnames_y = ['y_diff_{}'.format(i) for i in range(1, num_frames)] + \
                 ['y_diff_inv_{}'.format(i) for i in range(1, num_frames)] + \
                 ['y_div_{}'.format(i) for i in range(1, num_frames)]
    colnames = colnames_x + colnames_y 
    
    # Ensure the DataFrame has all the required columns
    X_new = df_with_features[colnames]
    
    return X_new

In [55]:
def read_video(path_video):
    """ Read video file    
    :params
        path_video: path to video file
    :return
        frames: list of video frames
        fps: frames per second
    """
    cap = cv2.VideoCapture(path_video)
    fps = int(cap.get(cv2.CAP_PROP_FPS))

    frames = []
    while cap.isOpened():
        ret, frame = cap.read()
        if ret:
            frames.append(frame)
        else:
            break
    cap.release()
    return frames, fps

def infer_model(frames, model):
    """ Run pretrained model on a consecutive list of frames    
    :params
        frames: list of consecutive video frames
        model: pretrained model
    :return    
        ball_track: list of detected ball points
        dists: list of euclidean distances between two neighbouring ball points
    """
    height = 360
    width = 640
    dists = [-1]*2
    ball_track = [(None,None)]*2
    for num in tqdm(range(2, len(frames))):
        img = cv2.resize(frames[num], (width, height))
        img_prev = cv2.resize(frames[num-1], (width, height))
        img_preprev = cv2.resize(frames[num-2], (width, height))
        imgs = np.concatenate((img, img_prev, img_preprev), axis=2)
        imgs = imgs.astype(np.float32)/255.0
        imgs = np.rollaxis(imgs, 2, 0)
        inp = np.expand_dims(imgs, axis=0)

        out = model(torch.from_numpy(inp).float().to('cuda'))
        output = out.argmax(dim=1).detach().cpu().numpy()
        x_pred, y_pred = postprocess(output)
        ball_track.append((x_pred, y_pred))

        if ball_track[-1][0] and ball_track[-2][0]:
            dist = distance.euclidean(ball_track[-1], ball_track[-2])
        else:
            dist = -1
        dists.append(dist)  
    return ball_track, dists 

def remove_outliers(ball_track, dists, max_dist = 100):
    """ Remove outliers from model prediction    
    :params
        ball_track: list of detected ball points
        dists: list of euclidean distances between two neighbouring ball points
        max_dist: maximum distance between two neighbouring ball points
    :return
        ball_track: list of ball points
    """
    outliers = list(np.where(np.array(dists) > max_dist)[0])
    for i in outliers:
        if (dists[i+1] > max_dist) | (dists[i+1] == -1):       
            ball_track[i] = (None, None)
            outliers.remove(i)
        elif dists[i-1] == -1:
            ball_track[i-1] = (None, None)
    return ball_track  

def split_track(ball_track, max_gap=4, max_dist_gap=80, min_track=5):
    """ Split ball track into several subtracks in each of which we will perform
    ball interpolation.    
    :params
        ball_track: list of detected ball points
        max_gap: maximun number of coherent None values for interpolation  
        max_dist_gap: maximum distance at which neighboring points remain in one subtrack
        min_track: minimum number of frames in each subtrack    
    :return
        result: list of subtrack indexes    
    """
    list_det = [0 if x[0] else 1 for x in ball_track]
    groups = [(k, sum(1 for _ in g)) for k, g in groupby(list_det)]

    cursor = 0
    min_value = 0
    result = []
    for i, (k, l) in enumerate(groups):
        if (k == 1) & (i > 0) & (i < len(groups) - 1):
            dist = distance.euclidean(ball_track[cursor-1], ball_track[cursor+l])
            if (l >=max_gap) | (dist/l > max_dist_gap):
                if cursor - min_value > min_track:
                    result.append([min_value, cursor])
                    min_value = cursor + l - 1        
        cursor += l
    if len(list_det) - min_value > min_track: 
        result.append([min_value, len(list_det)]) 
    return result    

def interpolation(coords):
    """ Run ball interpolation in one subtrack    
    :params
        coords: list of ball coordinates of one subtrack    
    :return
        track: list of interpolated ball coordinates of one subtrack
    """
    def nan_helper(y):
        return np.isnan(y), lambda z: z.nonzero()[0]

    x = np.array([x[0] if x[0] is not None else np.nan for x in coords])
    y = np.array([x[1] if x[1] is not None else np.nan for x in coords])

    nons, yy = nan_helper(x)
    x[nons]= np.interp(yy(nons), yy(~nons), x[~nons])
    nans, xx = nan_helper(y)
    y[nans]= np.interp(xx(nans), xx(~nans), y[~nans])

    track = [*zip(x,y)]
    return track

def write_track(frames, ball_track, bounce_pred, hit_pred, path_output_video, fps, extra, trace=5):
    """ Write .avi file with detected ball tracks
    :params
        frames: list of original video frames
        ball_track: list of ball coordinates
        path_output_video: path to output video
        fps: frames per second
        trace: number of frames with detected trace
    """
    bounces = []
    hits = []

    height, width = frames[0].shape[:2]

    ratio_x = width/1280.0
    ratio_y = height/720.0

    out = cv2.VideoWriter(path_output_video, cv2.VideoWriter_fourcc(*'DIVX'), 
                          fps, (width, height))
    
    for num in range(len(frames)):
        frame = frames[num]
        hit = False
        if extra:
            frames = frames[:len(bounce_pred)]
            if (bounce_pred[num] > 0.3 and hit_pred[num] <= 0.1):
                bounces.append(num)
            elif (bounce_pred[num] <= 0.3 and hit_pred[num] > 0.1):
                hits.append(num)
                hit = True
            elif (bounce_pred[num] > 0.3 and hit_pred[num] > 0.1):
                if (bounce_pred[num] > hit_pred[num]):
                    bounces.append(num)
                else:
                    hits.append(num)
                    hit = True
            
            #Bounces Stay Permanently on the Image
            for bounce_frame in bounces:
                position = ball_track[bounce_frame]
                position_int = (int(position[0]), int(position[1]))
                frame = cv2.circle(frame, position_int, radius = 3, color = (255, 0, 0), thickness = -1)

        for i in range(trace):
            if (num-i > 0):
                if ball_track[num-i][0]:
                    x = int(ball_track[num-i][0] * ratio_x) 
                    y = int(ball_track[num-i][1] * ratio_y) 
                    if hit:
                        frame = cv2.circle(frame, (x,y), radius=0, color=(0, 255, 0), thickness=10-i)
                    else:
                        frame = cv2.circle(frame, (x,y), radius=0, color=(0, 0, 255), thickness=10-i)
                else:
                    break
        out.write(frame) 
    out.release()    

In [8]:
def postprocess(feature_map, scale=2):
    feature_map *= 255
    feature_map = feature_map.reshape((360, 640))
    feature_map = feature_map.astype(np.uint8)
    ret, heatmap = cv2.threshold(feature_map, 127, 255, cv2.THRESH_BINARY)
    circles = cv2.HoughCircles(heatmap, cv2.HOUGH_GRADIENT, dp=1, minDist=1, param1=50, param2=2, minRadius=2,
                               maxRadius=7)
    x,y = None, None
    if circles is not None:
        if len(circles) == 1:
            x = circles[0][0][0]*scale
            y = circles[0][0][1]*scale
    return x, y

In [9]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, pad=1, stride=1, bias=True):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=bias),
            nn.ReLU(),
            nn.BatchNorm2d(out_channels)
        )

    def forward(self, x):
        return self.block(x)

class BallTrackerNet(nn.Module):
    def __init__(self, out_channels=256):
        super().__init__()
        self.out_channels = out_channels

        self.conv1 = ConvBlock(in_channels=9, out_channels=64)
        self.conv2 = ConvBlock(in_channels=64, out_channels=64)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv3 = ConvBlock(in_channels=64, out_channels=128)
        self.conv4 = ConvBlock(in_channels=128, out_channels=128)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv5 = ConvBlock(in_channels=128, out_channels=256)
        self.conv6 = ConvBlock(in_channels=256, out_channels=256)
        self.conv7 = ConvBlock(in_channels=256, out_channels=256)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv8 = ConvBlock(in_channels=256, out_channels=512)
        self.conv9 = ConvBlock(in_channels=512, out_channels=512)
        self.conv10 = ConvBlock(in_channels=512, out_channels=512)
        self.ups1 = nn.Upsample(scale_factor=2)
        self.conv11 = ConvBlock(in_channels=512, out_channels=256)
        self.conv12 = ConvBlock(in_channels=256, out_channels=256)
        self.conv13 = ConvBlock(in_channels=256, out_channels=256)
        self.ups2 = nn.Upsample(scale_factor=2)
        self.conv14 = ConvBlock(in_channels=256, out_channels=128)
        self.conv15 = ConvBlock(in_channels=128, out_channels=128)
        self.ups3 = nn.Upsample(scale_factor=2)
        self.conv16 = ConvBlock(in_channels=128, out_channels=64)
        self.conv17 = ConvBlock(in_channels=64, out_channels=64)
        self.conv18 = ConvBlock(in_channels=64, out_channels=self.out_channels)

        self.softmax = nn.Softmax(dim=1)
        self._init_weights()
                  
    def forward(self, x, testing=False): 
        batch_size = x.size(0)
        x = self.conv1(x)
        x = self.conv2(x)    
        x = self.pool1(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.pool2(x)
        x = self.conv5(x)
        x = self.conv6(x)
        x = self.conv7(x)
        x = self.pool3(x)
        x = self.conv8(x)
        x = self.conv9(x)
        x = self.conv10(x)
        x = self.ups1(x)
        x = self.conv11(x)
        x = self.conv12(x)
        x = self.conv13(x)
        x = self.ups2(x)
        x = self.conv14(x)
        x = self.conv15(x)
        x = self.ups3(x)
        x = self.conv16(x)
        x = self.conv17(x)
        x = self.conv18(x)
        # x = self.softmax(x)
        out = x.reshape(batch_size, self.out_channels, -1)
        if testing:
            out = self.softmax(out)
        return out                       
    
    def _init_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Conv2d):
                nn.init.uniform_(module.weight, -0.05, 0.05)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)

            elif isinstance(module, nn.BatchNorm2d):
                nn.init.constant_(module.weight, 1)
                nn.init.constant_(module.bias, 0)    