In [1]:
import os
import cv2
import numpy as np
import torch
from tracknet import BallTrackerNet
import torch.nn.functional as F
from tqdm import tqdm
from postprocess import postprocess, refine_kps
from homography import get_trans_matrix, refer_kps
import argparse


In [24]:
input_path = 'tennis_court.png'
#input_path = 'swingvision_2.png'
use_refine_kps = False
use_homography = False
model_path = 'model_tennis_court_det.pt'
output_path = 'tennis_court_out.jpg'

model = BallTrackerNet(out_channels=15)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()

OUTPUT_WIDTH = 640
OUTPUT_HEIGHT = 360

image = cv2.imread(input_path)
img = cv2.resize(image, (OUTPUT_WIDTH, OUTPUT_HEIGHT))
inp = (img.astype(np.float32) / 255.)
inp = torch.tensor(np.rollaxis(inp, 2, 0))
inp = inp.unsqueeze(0)

out = model(inp.float().to(device))[0]
pred = F.sigmoid(out).detach().cpu().numpy()
points = []
for kps_num in range(14):
    heatmap = (pred[kps_num]*255).astype(np.uint8)
    x_pred, y_pred = postprocess(heatmap, scale = 1,low_thresh=170, max_radius=25)
    if use_refine_kps and kps_num not in [8, 12, 9] and x_pred and y_pred:
        x_pred, y_pred = refine_kps(image, int(y_pred), int(x_pred))
    points.append((x_pred, y_pred))

if use_homography:
    matrix_trans = get_trans_matrix(points)
    if matrix_trans is not None:
        points = cv2.perspectiveTransform(refer_kps, matrix_trans)
        points = [np.squeeze(x) for x in points]

for j in range(len(points)):
    if points[j][0] is not None:
        image = cv2.circle(img, (int(points[j][0]), int(points[j][1])),
                           radius=0, color=(0, 0, 255), thickness=10)

cv2.imwrite(output_path, image)

  model.load_state_dict(torch.load(model_path, map_location=device))


True