In [None]:
import json
import os
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torch.nn as nn
from tensorboardX import SummaryWriter
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset
import os
import cv2
import numpy as np
import json
import numpy as np
from sympy import Line 
import sympy
from scipy.spatial import distance

In [None]:
val_best_accuracy = 0
num_epochs = 25
steps_per_epoch = 300
val_intervals = 20
batch_size = 2
lr = 1e-5

In [None]:
def gaussian2D(shape, sigma=1):
    m, n = [(ss - 1.) / 2. for ss in shape]
    y, x = np.ogrid[-m:m+1,-n:n+1]

    h = np.exp(-(x * x + y * y) / (2 * sigma * sigma))
    h[h < np.finfo(h.dtype).eps * h.max()] = 0
    return h

def draw_umich_gaussian(heatmap, center, radius, k=1):
    diameter = 2 * radius + 1
    gaussian = gaussian2D((diameter, diameter), sigma=diameter / 6)

    x, y = int(center[0]), int(center[1])

    height, width = heatmap.shape[0:2]

    left, right = min(x, radius), min(width - x, radius + 1)
    top, bottom = min(y, radius), min(height - y, radius + 1)

    masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right]
    masked_gaussian = gaussian[radius - top:radius + bottom, radius - left:radius + right]
    if min(masked_gaussian.shape) > 0 and min(masked_heatmap.shape) > 0: 
        np.maximum(masked_heatmap, masked_gaussian * k, out=masked_heatmap)
    return heatmap


def gaussian_radius(det_size, min_overlap=0.7):
    height, width = det_size

    a1  = 1
    b1  = (height + width)
    c1  = width * height * (1 - min_overlap) / (1 + min_overlap)
    sq1 = np.sqrt(b1 ** 2 - 4 * a1 * c1)
    r1  = (b1 + sq1) / 2

    a2  = 4
    b2  = 2 * (height + width)
    c2  = (1 - min_overlap) * width * height
    sq2 = np.sqrt(b2 ** 2 - 4 * a2 * c2)
    r2  = (b2 + sq2) / 2

    a3  = 4 * min_overlap
    b3  = -2 * min_overlap * (height + width)
    c3  = (min_overlap - 1) * width * height
    sq3 = np.sqrt(b3 ** 2 - 4 * a3 * c3)
    r3  = (b3 + sq3) / 2
    return min(r1, r2, r3)

def line_intersection(line1, line2):
    """
    Find 2 lines intersection point
    """
    l1 = Line((line1[0], line1[1]), (line1[2], line1[3]))
    l2 = Line((line2[0], line2[1]), (line2[2], line2[3]))

    intersection = l1.intersection(l2)
    point = None
    if len(intersection) > 0:
        if isinstance(intersection[0], sympy.geometry.point.Point2D):
            point = intersection[0].coordinates
    return point


def is_point_in_image(x, y, input_width=1280, input_height=720):
    res = False
    if x and y:
        res = (x >= 0) and (x <= input_width) and (y >= 0) and (y <= input_height)
    return res

In [None]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, pad=1, stride=1, bias=True):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=bias),
            nn.ReLU(),
            nn.BatchNorm2d(out_channels)
        )

    def forward(self, x):
        return self.block(x)

class ChannelAttention(nn.Module):
    def __init__(self, in_channels, reduction=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.fc = nn.Sequential(
            nn.Conv2d(in_channels, in_channels // reduction, 1, bias=False),
            nn.ReLU(),
            nn.Conv2d(in_channels // reduction, in_channels, 1, bias=False)
        )

        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        out = avg_out + max_out
        return x * self.sigmoid(out)

class TrackNet(nn.Module):
    def __init__(self, input_channels=3, out_channels=14):
        super().__init__()
        self.out_channels = out_channels
        self.input_channels = input_channels

        self.conv1 = ConvBlock(in_channels=self.input_channels, out_channels=64)
        self.conv2 = ConvBlock(in_channels=64, out_channels=64)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv3 = ConvBlock(in_channels=64, out_channels=128)
        self.conv4 = ConvBlock(in_channels=128, out_channels=128)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv5 = ConvBlock(in_channels=128, out_channels=256)
        self.conv6 = ConvBlock(in_channels=256, out_channels=256)
        self.conv7 = ConvBlock(in_channels=256, out_channels=256)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv8 = ConvBlock(in_channels=256, out_channels=512)
        self.conv9 = ConvBlock(in_channels=512, out_channels=512)
        self.conv10 = ConvBlock(in_channels=512, out_channels=512)
        self.ups1 = nn.Upsample(scale_factor=2)
        self.conv11 = ConvBlock(in_channels=512, out_channels=256)
        self.conv12 = ConvBlock(in_channels=256, out_channels=256)
        self.ca1 = ChannelAttention(in_channels=256)
        self.ups2 = nn.Upsample(scale_factor=2)
        self.conv13 = ConvBlock(in_channels=256, out_channels=128)
        self.conv14 = ConvBlock(in_channels=128, out_channels=128)
        self.ca2 = ChannelAttention(in_channels=128)
        self.ups3 = nn.Upsample(scale_factor=2)
        self.conv15 = ConvBlock(in_channels=128, out_channels=64)
        self.conv16 = ConvBlock(in_channels=64, out_channels=64)
        self.ca3 = ChannelAttention(in_channels=64)
        self.conv17 = ConvBlock(in_channels=64, out_channels=self.out_channels)

        self._init_weights()
                  
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)   
        x = self.pool1(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.pool2(x)
        x = self.conv5(x)
        x = self.conv6(x)
        x = self.conv7(x)
        x = self.pool3(x)
        x = self.conv8(x)
        x = self.conv9(x)
        x = self.conv10(x)
        x = self.ups1(x)
        x = self.conv11(x)
        x = self.conv12(x)
        x = self.ca1(x)
        x = self.ups2(x)
        x = self.conv13(x)
        x = self.conv14(x)
        x = self.ca2(x)
        x = self.ups3(x)
        x = self.conv15(x)
        x = self.conv16(x)
        x = self.ca3(x)
        x = self.conv17(x)
        return x
    
    def _init_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Conv2d):
                nn.init.uniform_(module.weight, -0.05, 0.05)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)

            elif isinstance(module, nn.BatchNorm2d):
                nn.init.constant_(module.weight, 1)
                nn.init.constant_(module.bias, 0)


In [None]:
class courtDataset(Dataset):
    def __init__(self, mode, input_height=720, input_width=1280, scale=2, hp_radius=55):
        self.mode = mode
        assert mode in ['train', 'val'], 'incorrect mode'
        self.input_height = input_height
        self.input_width = input_width
        self.output_height = int(input_height/scale)
        self.output_width = int(input_width/scale)
        self.num_joints = 14
        self.hp_radius = hp_radius
        self.scale = scale

        self.path_dataset = '/kaggle/input/court-keypoints/data'
        self.path_images = os.path.join(self.path_dataset, 'images')
        with open(os.path.join(self.path_dataset, 'data_{}.json'.format(mode)), 'r') as f:
            self.data = json.load(f)
        print('mode = {}, len = {}'.format(mode, len(self.data)))


    def filter_data(self):
        new_data = []
        for i in range(len(self.data)):
            max_elems = np.array(self.data[i]['kps']).max(axis=0)
            min_elems = np.array(self.data[i]['kps']).min(axis=0)
            if max_elems[0] < self.input_width and min_elems[0] > 0 and max_elems[1] < self.input_height and \
                    min_elems[1] > 0:
                new_data.append(self.data[i])
        return new_data

        
    def __getitem__(self, index):
        img_name = self.data[index]['id'] + '.png'
        kps = self.data[index]['kps']
        img = cv2.imread(os.path.join(self.path_images, img_name))
        img = cv2.resize(img, (self.output_width, self.output_height))
        inp = (img.astype(np.float32) / 255.)
        inp = np.rollaxis(inp, 2, 0)

        hm_hp = np.zeros((self.num_joints+1, self.output_height, self.output_width), dtype=np.float32)
        draw_gaussian = draw_umich_gaussian

        for i in range(len(kps)):
            if kps[i][0] >=0 and kps[i][0] <= self.input_width and kps[i][1] >=0 and kps[i][1] <= self.input_height:
            # if is_point_in_image(kps[i][0], kps[i][1], self.input_width, self.input_height):
                x_pt_int = int(kps[i][0]/self.scale)
                y_pt_int = int(kps[i][1]/self.scale)
                draw_gaussian(hm_hp[i], (x_pt_int, y_pt_int), self.hp_radius)

        # draw center point of tennis court
        x_ct, y_ct = line_intersection((kps[0][0], kps[0][1], kps[3][0], kps[3][1]), (kps[1][0], kps[1][1],
                                                                                      kps[2][0], kps[2][1]))
        draw_gaussian(hm_hp[self.num_joints], (int(x_ct/self.scale), int(y_ct/self.scale)), self.hp_radius)
        
        return inp, hm_hp, np.array(kps, dtype=int), img_name[:-4]
        
        
    def __len__(self):
        return len(self.data)

In [None]:
def postprocess(heatmap, scale=2, low_thresh=155, min_radius=10, max_radius=30):
    x_pred, y_pred = None, None
    ret, heatmap = cv2.threshold(heatmap, low_thresh, 255, cv2.THRESH_BINARY)
    circles = cv2.HoughCircles(heatmap, cv2.HOUGH_GRADIENT, dp=1, minDist=20, param1=50, param2=2, minRadius=min_radius,
                               maxRadius=max_radius)
    if circles is not None:
        x_pred = circles[0][0][0] * scale
        y_pred = circles[0][0][1] * scale
    return x_pred, y_pred

def refine_kps(img, x_ct, y_ct, crop_size=40):
    refined_x_ct, refined_y_ct = x_ct, y_ct
    
    img_height, img_width = img.shape[:2]
    x_min = max(x_ct-crop_size, 0)
    x_max = min(img_height, x_ct+crop_size)
    y_min = max(y_ct-crop_size, 0)
    y_max = min(img_width, y_ct+crop_size)

    img_crop = img[x_min:x_max, y_min:y_max]
    lines = detect_lines(img_crop)
    
    if len(lines) > 1:
        lines = merge_lines(lines)
        if len(lines) == 2:
            inters = line_intersection(lines[0], lines[1])
            if inters:
                new_x_ct = int(inters[1])
                new_y_ct = int(inters[0])
                if new_x_ct > 0 and new_x_ct < img_crop.shape[0] and new_y_ct > 0 and new_y_ct < img_crop.shape[1]:
                    refined_x_ct = x_min + new_x_ct
                    refined_y_ct = y_min + new_y_ct                    
    return refined_y_ct, refined_x_ct


def detect_lines(image):
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    gray = cv2.threshold(gray, 155, 255, cv2.THRESH_BINARY)[1]
    lines = cv2.HoughLinesP(gray, 1, np.pi / 180, 30, minLineLength=10, maxLineGap=30)
    lines = np.squeeze(lines) 
    if len(lines.shape) > 0:
        if len(lines) == 4 and not isinstance(lines[0], np.ndarray):
            lines = [lines]
    else:
        lines = []
    return lines

def merge_lines(lines):
    lines = sorted(lines, key=lambda item: item[0])
    mask = [True] * len(lines)
    new_lines = []

    for i, line in enumerate(lines):
        if mask[i]:
            for j, s_line in enumerate(lines[i + 1:]):
                if mask[i + j + 1]:
                    x1, y1, x2, y2 = line
                    x3, y3, x4, y4 = s_line
                    dist1 = distance.euclidean((x1, y1), (x3, y3))
                    dist2 = distance.euclidean((x2, y2), (x4, y4))
                    if dist1 < 20 and dist2 < 20:
                        line = np.array([int((x1+x3)/2), int((y1+y3)/2), int((x2+x4)/2), int((y2+y4)/2)],
                                        dtype=np.int32)
                        mask[i + j + 1] = False
            new_lines.append(line)  
    return new_lines       

In [None]:
train_dataset = courtDataset('train')
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=1,
    pin_memory=True
)

val_dataset = courtDataset('val')
val_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=1,
    pin_memory=True
)

In [None]:
def val(model, val_loader, criterion, device, epoch):
    model.eval()
    losses = []
    tp, fp, fn, tn = 0, 0, 0, 0
    max_dist = 7
    for iter_id, batch in enumerate(val_loader):
        with torch.no_grad():
            batch_size = batch[0].shape[0]
            out = model(batch[0].float().to(device))
            kps = batch[2]
            gt_hm = batch[1].float().to(device)
            loss = criterion(F.sigmoid(out), gt_hm)

            pred = F.sigmoid(out).detach().cpu().numpy()
            for bs in range(batch_size):
                for kps_num in range(14):
                    heatmap = (pred[bs][kps_num] * 255).astype(np.uint8)
                    x_pred, y_pred = postprocess(heatmap)
                    x_gt = kps[bs][kps_num][0].item()
                    y_gt = kps[bs][kps_num][1].item()

                    if is_point_in_image(x_pred, y_pred) and is_point_in_image(x_gt, y_gt):
                        dst = distance.euclidean((x_pred, y_pred), (x_gt, y_gt))
                        if dst < max_dist:
                            tp += 1
                        else:
                            fp += 1
                    elif is_point_in_image(x_pred, y_pred) and not is_point_in_image(x_gt, y_gt):
                        fp += 1
                    elif not is_point_in_image(x_pred, y_pred) and is_point_in_image(x_gt, y_gt):
                        fn += 1
                    elif not is_point_in_image(x_pred, y_pred) and not is_point_in_image(x_gt, y_gt):
                        tn += 1

            eps = 1e-15
            precision = round(tp / (tp + fp + eps), 5)
            accuracy = round((tp + tn) / (tp + tn + fp + fn + eps), 5)
            print('val, epoch = {}, iter_id = {}/{}, loss = {}, tp = {}, fp = {}, fn = {}, tn = {}, precision = {}, '
                  'accuracy = {}'.format(epoch, iter_id, len(val_loader), round(loss.item(), 5), tp, fp, fn, tn,
                                         precision, accuracy))
            losses.append(loss.item())
    return np.mean(losses), tp, fp, fn, tn, precision, accuracy

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = TrackNet(out_channels=15).to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr, betas=(0.9, 0.999), weight_decay=0)

In [None]:
def train(model, train_loader, optimizer, criterion, device, epoch, max_iters=1000):
    model.train()
    losses = []
    max_iters = min(max_iters, len(train_loader))

    for iter_id, batch in enumerate(train_loader):
        out = model(batch[0].float().to(device))
        gt_hm_hp = batch[1].float().to(device)
        loss = criterion(F.sigmoid(out), gt_hm_hp)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        print('train, epoch = {}, iter_id = {}/{}, loss = {}'.format(epoch, iter_id, max_iters, loss.item()))
        losses.append(loss.item())
        if iter_id > max_iters:
            break

    return np.mean(losses)

In [None]:
exp_id = 'default'
exps_path = './exps/{}'.format(exp_id)
tb_path = os.path.join(exps_path, 'plots')
if not os.path.exists(tb_path):
    os.makedirs(tb_path)
log_writer = SummaryWriter(tb_path)
model_last_path = os.path.join(exps_path, 'model_last.pt')
model_best_path = os.path.join(exps_path, 'model_best.pt')

In [None]:
for epoch in range(num_epochs):
    train_loss = train(model, train_loader, optimizer, criterion, device, epoch, steps_per_epoch)
    log_writer.add_scalar('Train/training_loss', train_loss, epoch)

    if (epoch > 0) & (epoch % val_intervals == 0):
        val_loss, tp, fp, fn, tn, precision, accuracy = val(model, val_loader, criterion, device, epoch)
        print('val loss = {}'.format(val_loss))
        log_writer.add_scalar('Val/loss', val_loss, epoch)
        log_writer.add_scalar('Val/tp', tp, epoch)
        log_writer.add_scalar('Val/fp', fp, epoch)
        log_writer.add_scalar('Val/fn', fn, epoch)
        log_writer.add_scalar('Val/tn', tn, epoch)
        log_writer.add_scalar('Val/precision', precision, epoch)
        log_writer.add_scalar('Val/accuracy', accuracy, epoch)
        if accuracy > val_best_accuracy:
            val_best_accuracy = accuracy
            torch.save(model.state_dict(), model_best_path)     
        torch.save(model.state_dict(), model_last_path)

In [None]:
%ls