# Custom TTNet Implementation

## Setup

### Imports

In [1]:
import os
import glob
import cv2 as cv
import numpy as np
import pandas as pd
from PIL import Image
import albumentations as A
from tqdm.notebook import tqdm
from turbojpeg import TurboJPEG
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

### GPU

In [2]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")

print('Device: ' + str(device))
if use_cuda:
    print('GPU: ' + str(torch.cuda.get_device_name(0)))

# Turns on cuDNN Autotuner
torch.backends.cudnn.benchmark = True

Device: cuda:0
GPU: NVIDIA GeForce RTX 2080 Ti


### TTNet Params

In [3]:
# Folder paths
dataset_base_path = './Dataset/images'
savePath_base = "./Trained_Models/"
outputPath = "./Results/"

# Training parameters
learning_rate = 1e-4
eps = 1e-7
sigma = 1
event_num = 4 # Fly = 0, Bounce = 1, Hit = 2, Out = 3

# Img resolutions
data_width = 1920
data_height = 1080 
TTN_width=320
TTN_height=128

seed=42
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

## Datset

### Configure

In [37]:
train_transform = A.Compose([
  A.RandomCrop(height=(int(data_height*0.85)), width=(int(data_width*0.85)), p=0.5),
  A.Rotate(limit=15, p=0.5),
  A.HorizontalFlip(p=0.5),
  A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0, hue=0.1, p=0.5),
  A.Resize(height=TTN_height, width=TTN_width, interpolation=1, always_apply=True, p=1)],
  #A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0, always_apply=True, p=1.0)],
  keypoint_params=A.KeypointParams(format='xy'),
  additional_targets = {
    'img2': 'image',
    'img3': 'image',
    'img4': 'image',
    'img5': 'image',
    'img6': 'image',
    'img7': 'image',
    'img8': 'image',
    'img9': 'image',
  })

test_transform = A.Compose([
  A.Resize(height=TTN_height, width=TTN_width, interpolation=1, always_apply=True, p=1)],
  #A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0, always_apply=True, p=1.0)],
  keypoint_params=A.KeypointParams(format='xy'),
  additional_targets = {
    'img2': 'image',
    'img3': 'image',
    'img4': 'image',
    'img5': 'image',
    'img6': 'image',
    'img7': 'image',
    'img8': 'image',
    'img9': 'image',
  })

def smooth_event(events,event_num):
  n = 5
  is_found = False
  for i in range((len(events)+1)//2):
    middle_index =  (len(events))//2
    if((int(events[middle_index+i])==event_num or int(events[middle_index-i])==event_num) and not is_found):
      n = i
      is_found = True
  prob = np.cos(n * np.pi / 8)
  if(prob<0.01):
    return 0
  else:
    return prob

class TTN_Dataset(Dataset):
  def __init__(self, split, window_size=9):
    self.split = split
    self.window_size = window_size
    self.window_paths = []
    self.xy = []
    self.event_probs = []
    self.jpeg_reader = TurboJPEG()

    game_list = os.listdir(f"{dataset_base_path}/{split}")
    # Iterate through each game
    for game in game_list:
      game_dir = f"{dataset_base_path}/{split}/{game}"
      clips = os.listdir(game_dir)
      # Iterate through each clip
      for clip in clips:
        # Read and store annotation information
        clip_dir = f"{game_dir}/{clip}"
        annotation_path = f"{clip_dir}/Annotation.csv"
        annotation = pd.read_csv(annotation_path)
        img_paths = np.asarray(annotation.iloc[:, 0])
        ball_x = np.asarray(annotation.iloc[:, 1])
        ball_y = np.asarray(annotation.iloc[:, 2])
        event = np.asarray(annotation.iloc[:, 3])

        for frame_no, _ in enumerate(img_paths):
          # Skip indices where it is not possible to make a frame window
          if not (self.window_size-1 < frame_no <= len(img_paths)-1):
            continue

          # Get frame window image paths
          window_path = []
          middle_index = frame_no-self.window_size+(self.window_size+1)//2
          for window_frame in range(self.window_size):
            window_path.append(f"{clip_dir}/{img_paths[(frame_no+1)-self.window_size+window_frame]}")
          self.window_paths.append(window_path)

          # Get xy location
          x = ball_x[middle_index]
          y = ball_y[middle_index]
          if x == 1920:
            x -= 1
          if y == 1080:
            y -= 1
          self.xy.append([(x, y)])

          event_prob = torch.zeros((event_num))
          for i in range(event_num):
            event_prob[i] = smooth_event(event[frame_no+1-self.window_size:frame_no+1],i)
          self.event_probs.append(event_prob)

  def loader(self, start_index):
    # Get images and keypoints
    imgs = []
    for i in range(9):
      in_file = open(self.window_paths[start_index][i], 'rb')
      image = self.jpeg_reader.decode(in_file.read(), 0)
      imgs.append(image)
    kp = self.xy[start_index]

    # Apply transformation and get outputs
    if self.split == "train":
      transformed = train_transform(image=imgs[0], img2=imgs[1], img3=imgs[2], img4=imgs[3], img5=imgs[4], img6=imgs[5], img7=imgs[6], img8=imgs[7], img9=imgs[8], keypoints=kp)
    else:
      transformed = test_transform(image=imgs[0], img2=imgs[1], img3=imgs[2], img4=imgs[3], img5=imgs[4], img6=imgs[5], img7=imgs[6], img8=imgs[7], img9=imgs[8], keypoints=kp)
    img1 = transformed['image']
    img2 = transformed['img2']
    img3 = transformed['img3']
    img4 = transformed['img4']
    img5 = transformed['img5']
    img6 = transformed['img6']
    img7 = transformed['img7']
    img8 = transformed['img8']
    img9 = transformed['img9']
    xy_downscale = transformed['keypoints']
    xy_downscale = xy_downscale[0] if xy_downscale else (0, 0)

    return img1, img2, img3, img4, img5, img6, img7, img8, img9, xy_downscale

  def __getitem__(self, index):
    img1, img2, img3, img4, img5, img6, img7, img8, img9, xy_downscale = self.loader(index)
    window_imgs = np.concatenate((img1, img2, img3, img4, img5, img6, img7, img8, img9), axis=2)
    window_imgs = np.rollaxis(window_imgs, 2, 0)
    window_imgs = torch.from_numpy(window_imgs)

    centre_path = self.window_paths[index][self.window_size//2]
    event_probs = self.event_probs[index]

    return window_imgs, xy_downscale, centre_path, event_probs

  def __len__(self):
    return len(self.window_paths)


### Load Data

In [12]:
train_batch = 64
test_batch = 64
num_workers = 16

print("Loading Train...")
train_loader = DataLoader(TTN_Dataset(split="train"), batch_size=train_batch,shuffle=True,num_workers=num_workers,pin_memory=True,drop_last=True)
print("Loading Test...")
test_loader = DataLoader(TTN_Dataset(split="test"), batch_size=test_batch,shuffle=False,num_workers=num_workers,pin_memory=True)
print("Dataset total batches:")
print(f"Training : {len(train_loader)}")
print(f"Test : {len(test_loader)}")

Loading Train...
Loading Test...
Dataset total batches:
Training : 632
Test : 64


## Losses

In [6]:
# TODO This entire section probably needs a rewrite
def gaussian_1d(pos, mu, sigma):
  target = torch.exp(- (((pos - mu) / sigma) ** 2) / 2)
  return target

def ball_loss(pre_output,true_label):
  x_pred = pre_output[:, :TTN_width]
  y_pred = pre_output[:, TTN_width:]

  target_output = torch.zeros_like(pre_output)
  for i in range(pre_output.shape[0]):
    target_output_temp = torch.zeros((TTN_width+TTN_height), device=device)
    x,y = true_label[i]
    if (TTN_width > x > 0) and (TTN_height > y > 0):    
      x_axis = torch.arange(0, TTN_width, device=device)
      y_axis = torch.arange(0, TTN_height, device=device)

      target_output_temp[:TTN_width] = gaussian_1d(x_axis, x, sigma=sigma)
      target_output_temp[TTN_width:] = gaussian_1d(y_axis, y, sigma=sigma)
      target_output_temp[target_output_temp < 1e-2] = 0.
    target_output[i] = target_output_temp
    
  x_target = target_output[:, :TTN_width]
  y_target = target_output[:, TTN_width:]

  loss_x = - torch.mean(x_target * torch.log(x_pred + eps) + (1 - x_target) * torch.log(1 - x_pred + eps))
  loss_y = - torch.mean(y_target * torch.log(y_pred + eps) + (1 - y_target) * torch.log(1 - y_pred + eps))
  loss = loss_x+loss_y
  return loss

def event_loss(pre_output,true_label):
  weights = (1, 57, 40, 3)
  weights = torch.tensor(weights).to(device)
  weights = weights.view(-1,weights.shape[0])
  weights = weights / weights.sum()
  loss = -torch.mean(weights * (true_label * torch.log(pre_output + eps) + (1. - true_label) * torch.log(1 - pre_output + eps)))
  return loss


## TTNet Config

### Define Block Types

In [7]:
class ConvBlock(nn.Module):
  def __init__(self, in_channels, out_channels):
    super(ConvBlock, self).__init__() 
    self.block = nn.Sequential(
      nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),          
      nn.BatchNorm2d(out_channels, track_running_stats=False),
      nn.ReLU(inplace=True),
      nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True),
    )
    
  def forward(self, x):
    out = self.block(x)
    return out

class ConvBlock_without_Pooling(nn.Module):
  def __init__(self, in_channels, out_channels):
    super(ConvBlock_without_Pooling, self).__init__() 
    self.block = nn.Sequential(
      nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
      nn.BatchNorm2d(out_channels),
      nn.ReLU(),
    )

  def forward(self, x):
    out = self.block(x)
    return out

### Define Segments

In [8]:
class BallDetection(nn.Module):
  def __init__(self, frame_window, dropout_p):
    super(BallDetection, self).__init__()
    self.convBlocks = nn.Sequential(
      nn.Conv2d(in_channels = frame_window*3, out_channels=64, kernel_size=1, stride=1, padding=0, bias=False),
      nn.BatchNorm2d(64,track_running_stats=False),
      nn.ReLU(inplace=True),
      ConvBlock(in_channels=64, out_channels=64),
      ConvBlock(in_channels=64, out_channels=64),
      nn.Dropout2d(p=dropout_p),
      ConvBlock(in_channels=64, out_channels=128),
      ConvBlock(in_channels=128, out_channels=128),
      ConvBlock(in_channels=128, out_channels=256),
      ConvBlock(in_channels=256, out_channels=256),
    )
    self.FC = nn.Sequential(
      nn.Linear(in_features=2560, out_features=1792),
      nn.ReLU(inplace=True),
      nn.Dropout(p=dropout_p),
      nn.Linear(in_features=1792, out_features=896),
      nn.ReLU(inplace=True),
      nn.Dropout(p=dropout_p),
      nn.Linear(in_features=896, out_features=int(TTN_width+TTN_height)),
      nn.Sigmoid(),
    )
    self.dropout2d = nn.Dropout2d(p=dropout_p)

  def forward(self, x):
    block6_out = self.convBlocks(x)
    x = self.dropout2d(block6_out)
    x = x.contiguous().view(x.shape[0], -1)
    out = self.FC(x)
    return out, block6_out

class EventSpotting(nn.Module):
  def __init__(self, dropout_p):
    super(EventSpotting, self).__init__()
    self.convBlocks = nn.Sequential(
      nn.Conv2d(512, 64, kernel_size=1, stride=1, padding=0, bias=False),
      nn.BatchNorm2d(64),
      nn.ReLU(),
      nn.Dropout2d(p=dropout_p),
      ConvBlock_without_Pooling(in_channels=64, out_channels=64),
      nn.Dropout2d(p=dropout_p),
      ConvBlock_without_Pooling(in_channels=64, out_channels=64),
      nn.Dropout2d(p=dropout_p),
    )
    self.FC = nn.Sequential(           
      nn.Linear(in_features=640, out_features=512),
      nn.ReLU(),
      nn.Linear(in_features=512, out_features=event_num),
      nn.Sigmoid()
    )

  def forward(self, global_features, local_features):
    x = torch.cat((global_features, local_features), dim=1)
    x = self.convBlocks(x)
    x = x.contiguous().view(x.size(0), -1)
    out = self.FC(x)
    
    return out

### TTNet Model

In [9]:
class TTNet(nn.Module):
  def __init__(self, dropout_p, frame_window, threshold, tasks, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
    super(TTNet, self).__init__() 
    # Assign stages
    self.local_stage,self.event_spotting = None,None
    self.global_stage = BallDetection(frame_window, dropout_p)
    if "local" in tasks:
      self.local_stage = BallDetection(frame_window, dropout_p)
    if "event" in tasks:
      self.event_spotting = EventSpotting(dropout_p)

    self.threshold = threshold
    self.mean = torch.repeat_interleave(torch.tensor(mean).view(1, 3, 1, 1), repeats=9, dim=1)
    self.std = torch.repeat_interleave(torch.tensor(std).view(1, 3, 1, 1), repeats=9, dim=1)

  def forward(self, x):
    local_out, local_in, crop_params, event_out = None, None, None, None

    global_out, global_features = self.global_stage(x)
    if self.local_stage is not None:
      local_in, crop_params = self.crop_imgs(x, global_out)
      local_out, local_features = self.local_stage(local_in)
      if self.event_spotting is not None:
        event_out = self.event_spotting(global_features, local_features)
    
    return global_out, local_out, local_in, crop_params, event_out

  # TODO This section needs a rewrite
  def crop_imgs(self, x, global_xy):
    global_xy_copy = global_xy.detach().clone()
    global_xy_copy[global_xy_copy < self.threshold] = 0
    crop_params = []
    ball_detected = False

    global_output = torch.zeros_like(x)
    # original size
    original_input = F.interpolate(x, (data_height, data_width))
    
    for i in range(x.shape[0]):
      pos_x = global_xy_copy[i, :TTN_width]
      pos_y = global_xy_copy[i, TTN_width:]

      if (torch.sum(pos_x) == 0) or (torch.sum(pos_y) == 0):
        x_center = int(TTN_width / 2)
        y_center = int(TTN_height / 2)
      else:
        x_center = torch.argmax(pos_x)
        y_center = torch.argmax(pos_y)
        ball_detected = True

      # original size
      x_center = int(x_center * (data_width/TTN_width))
      y_center = int(y_center * (data_height/TTN_height))
      x_min = max(0, x_center - int(TTN_width / 2))
      y_min = max(0, y_center - int(TTN_height / 2))
      x_max = min(data_width, x_min + TTN_width)
      y_max = min(data_height, y_min + TTN_height)
      crop_width = x_max - x_min
      crop_height = y_max - y_min
      padding_x=padding_y=0
      if (crop_height != TTN_height) or (crop_width != TTN_width):
        padding_x = int((TTN_width - crop_width) / 2)
        padding_y = int((TTN_height - crop_height) / 2)
        global_output[i, :, padding_y:(padding_y + crop_height), padding_x:(padding_x + crop_width)] = original_input[i, :,y_min:y_max, x_min: x_max]
      else:
        global_output[i, :, :, :] = original_input[i, :, y_min:y_max, x_min: x_max]
      crop_params.append([ball_detected,x_min,y_min,x_max,y_max,padding_x,padding_y])


    return global_output,crop_params

  def norm(self,x):
    if not self.mean.is_cuda:
      self.mean = self.mean.cuda()
      self.std = self.std.cuda()

    return (x / 255. - self.mean) / self.std

def freeze_model(model, freeze_list):
  for layer_name, p in model.named_parameters():
    p.requires_grad = True
    for freeze_module in freeze_list:
      if freeze_module in layer_name:
        p.requires_grad = False
        break
  print("Frozen")
  return model

# TODO Rewrite this
def get_local_groundtruth(global_ball_pos_xyz,crop_params):
  
  local_ball_pos_xyz = []
  for i,para in enumerate(crop_params):
    ball_detected,x_min,y_min,x_max,y_max,padding_x,padding_y = para
    if ball_detected:
      ori_x = global_ball_pos_xyz[0][i].item()/TTN_width*data_width
      ori_y = global_ball_pos_xyz[1][i].item()/TTN_height*data_height
      local_x = max(ori_x - x_min + padding_x, -1)
      local_y = max(ori_y - y_min + padding_y, -1)

      if not (TTN_width>local_x>=0 and TTN_height>local_y>=0):
        local_x = local_y = -1
     
    else:
      local_x = local_y = -1
    local_ball_pos_xyz.append([local_x,local_y])

  return local_ball_pos_xyz

## Traning Phase 1

In [10]:
# Training config
model_saved = True
ph1_epochs = 30
global_weight = 1
ph1_train_loss_log = []
ph1_test_loss_log = []

In [11]:
# Create model
tasks = []
model = TTNet(dropout_p=0.5, frame_window=9, threshold=0.01, tasks=tasks).to(device)
optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate)
lr_scheduler = StepLR(optimizer, step_size=10, gamma=0.1)
scaler = GradScaler()

In [None]:
# Train the model
if (model_saved):
  savePath = f"{savePath_base}Phase1/TTNet_Phase1_22.pth"
  print(f"Loading model from path: {savePath}")
  checkpoint = torch.load(savePath)
  model.load_state_dict(checkpoint['model'])
  optimizer.load_state_dict(checkpoint['optimizer'])
  lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
  start_epoch = checkpoint['cur_epoch']
  ph1_train_loss_log = checkpoint['train_loss_log']
  ph1_test_loss_log = checkpoint['val_loss_log']
  print(f"Load phase 1 at epoch {start_epoch} succeed")
else:
  start_epoch = 0
  print("Phase 1: No model to load, start to train at epoch 0")

print("START TO TRAIN PHASE 1: Global Stage ...")
for epoch in range(start_epoch+1, ph1_epochs + 1):
  model.train()
  batch_num = len(train_loader)
  train_loss = 0
  train_loss_total = 0

  with tqdm(train_loader, unit="batch") as trepoch:
    for data_batch in trepoch:
      trepoch.set_description(f"Train Epoch {epoch}")
      # Read in train batch
      window_imgs, xy_downscale, _, _ = data_batch
      window_batch = torch.as_tensor(window_imgs,dtype=torch.float32).to(device)
      xy_downscale=torch.stack((xy_downscale[0],xy_downscale[1])).transpose(0,1)
      
      # Calculate train (global) loss
      with autocast():
        global_out, _, _, _, _ = model(window_batch)
        train_loss = ball_loss(global_out,xy_downscale) * global_weight

      optimizer.zero_grad(set_to_none=True)
      scaler.scale(train_loss).backward()
      loss = train_loss.detach().cpu().numpy()
      
      train_loss_total += loss
      scaler.step(optimizer)
      scaler.update()
      trepoch.set_postfix(loss=loss)

  # Log training losses
  tqdm.write(f"Train\t epoch: {epoch}/{ph1_epochs}\t loss: {train_loss_total/batch_num}")
  ph1_train_loss_log.append(train_loss_total/batch_num)

  model.eval()
  with torch.no_grad():     
    batch_num = len(test_loader)
    test_loss = 0
    test_loss_total = 0
    with tqdm(test_loader, unit="batch") as tepoch:
      for data_batch in tepoch:
        tepoch.set_description(f"Test Epoch {epoch}")
        # Read in test batch
        window_imgs ,xy_downscale, _,_ = data_batch
        window_batch = torch.as_tensor(window_imgs,dtype=torch.float32).to(device)      
        xy_downscale=torch.stack((xy_downscale[0],xy_downscale[1])).transpose(0,1)

        # Model forward step and calculate test (global) loss
        with autocast():
          global_out, _, _, _, _ = model(window_batch)
          test_loss = ball_loss(global_out,xy_downscale) * global_weight
        
        loss = test_loss.detach().cpu().numpy()
        test_loss_total += loss
        tepoch.set_postfix(loss=loss)

    # Log test losses
    ph1_test_loss_log.append(test_loss_total/batch_num)
    tqdm.write(f"Test\t epoch: {epoch}/{ph1_epochs}\t loss: {test_loss_total/batch_num}")
  
  tqdm.write("Saving model")
  state = {'model':model.state_dict(),'optimizer':optimizer.state_dict(),'lr_scheduler':lr_scheduler.state_dict(),'cur_epoch':epoch,'train_loss_log':ph1_train_loss_log,'val_loss_log':ph1_test_loss_log}
  savePath = f"{savePath_base}Phase1/TTNet_Phase1_{epoch}.pth"
  torch.save(state, savePath)
  model_saved = True

  # Plot the losses
  plt.clf()
  plt.figure(dpi=1200)
  plt.plot(range(1, epoch+1),ph1_train_loss_log,label='Train loss')
  plt.plot(range(1, epoch+1),ph1_test_loss_log,label='Test loss')
  plt.legend()
  plt.xlabel('Epoch')
  plt.ylabel('Loss')
  plt.title('Training Loss')
  plt.savefig(outputPath+"loss_ph1.png")
  plt.show()

In [None]:
def phase1_test():
  model.eval()
  batch_num = len(test_loader)
  dist = 0
  dist_x = 0
  dist_y = 0
  count = 0
  avg_dists = []
  x_dists = []
  y_dists = []
  with torch.no_grad(): 
    with tqdm(test_loader, unit="batch") as tepoch:
      for data_batch in tepoch:
        tepoch.set_description(f"Test")
        # Read in test batch
        window_imgs, xy_downscale, window_centre, _ = data_batch
        window_batch = torch.as_tensor(window_imgs,dtype=torch.float32).to(device)
        xy_downscale_batch=torch.stack((xy_downscale[0],xy_downscale[1])).transpose(0,1)

        with autocast():
          global_out, _, _, _, _ = model(window_batch)
        
        global_out_clone = global_out.clone().detach()
        for i in range(global_out.shape[0]):
          global_output_clone_x = global_out_clone[i,:TTN_width]
          global_output_clone_y = global_out_clone[i,TTN_width:TTN_width+TTN_height]
          global_output_x = torch.argmax(global_output_clone_x).item()
          global_output_y = torch.argmax(global_output_clone_y).item()

          pred_x = global_output_x*data_width/TTN_width
          pred_y = global_output_y*data_height/TTN_height

          pred_x = 0 if pred_x<0 else pred_x
          pred_y = 0 if pred_y<0 else pred_y
          pred = np.array([int(pred_x), int(pred_y)])

          ori_x = xy_downscale[0][i].item()*(data_width/TTN_width)
          ori_y = xy_downscale[1][i].item()*(data_height/TTN_height)
          ori = np.array([int(ori_x), int(ori_y)])

          dist += np.linalg.norm(ori-pred)
          dist_x += np.abs(ori_x - pred_x)
          dist_y += np.abs(ori_y - pred_y)
          count += 1
            
          centre_img = cv.imread(window_centre[i])
          cv.circle(centre_img, (int(pred_x),int(pred_y)), 8,  (0, 0, 255), 2)
          #out.write(centre_img)
    
    avg_dists.append(dist/count)
    x_dists.append(dist_x/count)
    y_dists.append(dist_y/count)
    
    print(f"Avg Dist = {dist/count}")
    print(f"X Dist = {dist_x/count}")
    print(f"Y Dist = {dist_y/count}")
  return avg_dists, x_dists, y_dists
    
global_weight = 1
for i in range(1, 31):
    saved_epoch = i
    tasks = []
    model = TTNet(dropout_p=0.5, frame_window=9, threshold=0.01, tasks=tasks).to(device)
    savePath = f"{savePath_base}Phase1/TTNet_Phase1_{i}.pth"
    checkpoint = torch.load(savePath)
    model.load_state_dict(checkpoint['model']) 
    ph1_train_loss_log = checkpoint['train_loss_log']
    ph1_test_loss_log = checkpoint['val_loss_log']
    print(f"Epoch {saved_epoch}")
    avg_dists, x_dists, y_dists = phase1_test()
    print(f"Train loss\t : {ph1_train_loss_log[saved_epoch-1]}")
    print(f"Test loss\t : {ph1_test_loss_log[saved_epoch-1]}")
    print()

## Training Phase 2

In [24]:
# Training config
ph2_saved = False
ph2_epochs = 30
local_weight = 1
event_weight = 2
ph2_train_loss_log = []
ph2_train_loc_log = []
ph2_train_event_log = []
ph2_test_loss_log = []
ph2_test_loc_log = []
ph2_test_event_log = []

In [None]:
freeze_list = ["global_stage"]
tasks = ["local", "event"]
model = TTNet(dropout_p=0.5, frame_window=9, threshold=0.01, tasks=tasks).to(device)
model = freeze_model(model,freeze_list)                           
train_params = [param for param in model.parameters() if param.requires_grad]
optimizer = torch.optim.AdamW(train_params,lr=learning_rate)
lr_scheduler = StepLR(optimizer, step_size=10, gamma=0.1)
scaler = GradScaler()

# TODO Rewrite this
def load_weights_local_stage(pretrained_dict):
  local_weights_dict = {}
  for layer_name, v in pretrained_dict.items():
    if 'global_stage' in layer_name:
      layer_name_parts = layer_name.split('.')
      layer_name_parts[1] = 'local_stage'
      local_name = '.'.join(layer_name_parts)
      local_weights_dict[local_name] = v

  return {**pretrained_dict, **local_weights_dict}

In [None]:
if ph2_saved:
  savePath = f"{savePath_base}Phase2/TTNet_Phase2_1.pth"
  print("Loading model from path: ")
  print(savePath)
  checkpoint = torch.load(savePath)
  model.load_state_dict(checkpoint['model'])
  optimizer.load_state_dict(checkpoint['optimizer'])
  lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
  start_epoch = checkpoint['cur_epoch']
  ph2_train_loss_log = checkpoint['train_loss_log']
  ph2_test_loss_log = checkpoint['val_loss_log']
  print(f"Load phase 2 at epoch {start_epoch} succeed")
else:
  checkpoint = torch.load(savePath_base+"Phase1/TTNet_Phase1_30.pth", map_location='cpu')
  pretrained_dict = checkpoint['model']
  model_state_dict = model.state_dict()
  pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_state_dict}
  pretrained_dict = load_weights_local_stage(pretrained_dict)
  model_state_dict.update(pretrained_dict)
  model.load_state_dict(model_state_dict, strict=False)
  model = model.to(device)

  start_epoch = 0
  print("Phase 2: No model to load, start to train at epoch 0")

print('START TO TRAIN PHASE 2: Local + Event Stage ...')
for epoch in range(start_epoch+1, ph2_epochs+1):
  model.train()
  batch_num = len(train_loader)
  train_loss = 0
  train_loss_total = 0
  tr_loc_loss_total = 0
  tr_event_loss_total = 0

  with tqdm(train_loader, unit="batch") as trepoch:
    for data_batch in trepoch:
      trepoch.set_description(f"Train Epoch {epoch}")
      # Read in train batch
      window_imgs, xy_downscale, _, event_probs = data_batch
      window_batch = torch.as_tensor(window_imgs,dtype=torch.float32).to(device, non_blocking=True)

      # Model forward step
      with autocast():
        _, local_out, _, crop_params, event_out = model(window_batch)

      # Get event probs and local xy predictions
      event_probs = torch.as_tensor(event_probs,dtype=torch.float32).to(device, non_blocking=True)
      local_ball_xy = get_local_groundtruth(xy_downscale,crop_params)
      local_ball_xy = torch.as_tensor(local_ball_xy,dtype=torch.float32).to(device)

      # Calucate train (local+event) loss
      with autocast():
        local_loss_train = ball_loss(local_out,local_ball_xy) * local_weight
        event_loss_train = event_loss(event_out,event_probs) * event_weight
        train_loss = local_loss_train + event_loss_train
      
      optimizer.zero_grad()
      scaler.scale(train_loss).backward()
      loss = train_loss.detach().cpu().numpy()
      loc = local_loss_train.detach().cpu().numpy()
      event = event_loss_train.detach().cpu().numpy()
      
      train_loss_total += loss
      tr_loc_loss_total += loc
      tr_event_loss_total += event

      scaler.step(optimizer)
      scaler.update()
      trepoch.set_postfix(loss=loss)

  # Log training losses
  tqdm.write(f"Train\t epoch: {epoch}/{ph2_epochs}\t loss: {train_loss_total/batch_num} Local : {tr_loc_loss_total/batch_num} Event {tr_event_loss_total/batch_num}")
  ph2_train_loss_log.append(train_loss_total/batch_num)
  ph2_train_loc_log.append(tr_loc_loss_total/batch_num)
  ph2_train_event_log.append(tr_event_loss_total/batch_num)

  model.eval()
  with torch.no_grad():   
    batch_num = len(test_loader)
    test_loss = 0
    test_loss_total = 0
    test_loc_loss_total = 0
    test_event_loss_total = 0
    with tqdm(test_loader, unit="batch") as tepoch:
      for data_batch in tepoch:
        tepoch.set_description(f"Test Epoch {epoch}")
        # Read in test batch
        window_imgs, xy_downscale, _, event_probs = data_batch
        window_batch = torch.as_tensor(window_imgs,dtype=torch.float32).to(device, non_blocking=True)

        # Model forward step
        with autocast():
          _, local_out, _, crop_params, event_out = model(window_batch)

        # Get event probs and local xy predictions
        event_probs = torch.as_tensor(event_probs,dtype=torch.float32).to(device, non_blocking=True)
        local_ball_xy = get_local_groundtruth(xy_downscale,crop_params)
        local_ball_xy = torch.as_tensor(local_ball_xy,dtype=torch.float32).to(device)

        # Calucate test (local+event) loss
        with autocast():
          local_loss_test = ball_loss(local_out,local_ball_xy) * local_weight
          event_loss_test = event_loss(event_out,event_probs) * event_weight
          test_loss = local_loss_test + event_loss_test

        loss = test_loss.detach().cpu().numpy()
        loc = local_loss_test.detach().cpu().numpy()
        event = event_loss_test.detach().cpu().numpy()

        test_loss_total += loss
        test_loc_loss_total += loc
        test_event_loss_total += event
        tepoch.set_postfix(loss=loss)

    ph2_test_loss_log.append(test_loss_total/batch_num)
    ph2_test_loc_log.append(test_loc_loss_total/batch_num)
    ph2_test_event_log.append(test_event_loss_total/batch_num)
    tqdm.write(f"Test\t epoch: {epoch}/{ph2_epochs}\t loss: {test_loss_total/batch_num} Local : {test_loc_loss_total/batch_num} Event {test_event_loss_total/batch_num}")

  tqdm.write("Saving model")
  savePath = f"{savePath_base}/Phase2/TTNet_Phase2_{epoch}.pth"
  state = {'model':model.state_dict(),'optimizer':optimizer.state_dict(),'lr_scheduler':lr_scheduler.state_dict(),'cur_epoch':epoch,'train_loss_log':ph2_train_loss_log,'val_loss_log':ph2_test_loss_log}
  torch.save(state, savePath) 
  ph2_saved = True

  # Plot the losses
  plt.clf()
  plt.plot(range(1, epoch+1),ph2_train_loss_log,label='Train loss')
  #plt.plot(range(epoch-1, epoch+len(ph2_train_loc_log)-1),ph2_train_loc_log,label='Train Local loss')
  #plt.plot(range(epoch-1, epoch+len(ph2_train_event_log)-1),ph2_train_event_log,label='Train Event loss')
  #plt.plot(range(epoch-1, epoch+len(ph2_val_loc_log)-1),ph2_val_loc_log,label='Validation Local loss')
  #plt.plot(range(epoch-1, epoch+len(ph2_val_event_log)-1),ph2_val_event_log,label='Validation Event loss')
  plt.plot(range(1, epoch+1),ph2_test_loss_log,label='Test loss')
  plt.legend()
  plt.xlabel('Epoch')
  plt.ylabel('Loss')
  plt.title('Training Loss')
  plt.savefig(outputPath+"loss_ph2.png")
  plt.show()

## Training Phase 3

In [None]:
# Training config
ph3_saved = False
ph3_epochs = 20
global_weight = 1
local_weight = 1
event_weight = 1
ph3_train_loss_log = []
ph3_tr_global_log = []
ph3_tr_local_log = []
ph3_tr_event_log = []

ph3_test_loss_log = []
ph3_test_global_log = []
ph3_test_local_log = []
ph3_test_event_log = []

In [None]:
tasks = ["local", "event"]
model = TTNet(dropout_p=0.5, frame_window=9, threshold=0.01, tasks=tasks).to(device)
train_params = [param for param in model.parameters() if param.requires_grad]
optimizer = torch.optim.AdamW(train_params,lr=learning_rate)
lr_scheduler = StepLR(optimizer, step_size=10, gamma=0.1)
scaler = GradScaler()

In [None]:
if ph3_saved:
  savePath = savePath_base+"Phase3/TTNet_Phase3_15.pth"
  print("Loading model from path: ")
  print(savePath)
  checkpoint = torch.load(savePath)
  model.load_state_dict(checkpoint['model'])
  optimizer.load_state_dict(checkpoint['optimizer'])
  lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
  start_epoch = checkpoint['cur_epoch']
  ph3_train_loss_log = checkpoint['train_loss_log']
  ph3_test_loss_log = checkpoint['val_loss_log']
  print(f"Load phase 3 at epoch {start_epoch} succeed")
else:
  checkpoint = torch.load(savePath_base+"Phase2/TTNet_Phase2_25.pth", map_location='cpu')
  pretrained_dict = checkpoint['model']
  model_state_dict = model.state_dict()
  pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_state_dict}
  model_state_dict.update(pretrained_dict)
  model.load_state_dict(model_state_dict,strict=False)
  model = model.to(device)

  start_epoch = 0
  print("Phase 3: No model to load, start to train at epoch 0")

print('START TO TRAIN PHASE 3: All Stages ...')

for epoch in range(start_epoch+1, ph3_epochs+1):
  model.train()
  batch_num = len(train_loader)
  train_loss = 0
  train_loss_total = 0
  tr_global_total = 0
  tr_local_total = 0
  tr_event_total = 0

  with tqdm(train_loader, unit="batch") as trepoch:
    for data_batch in trepoch:
      trepoch.set_description(f"Train Epoch {epoch}")
      # Read in train batch
      window_imgs, xy_downscale, _, event_probs = data_batch
      window_batch = torch.as_tensor(window_imgs,dtype=torch.float32).to(device, non_blocking=True)
      xy_downscale_batch = torch.stack((xy_downscale[0],xy_downscale[1])).transpose(0,1)

      # Model forward step
      with autocast():
        global_out, local_out, _, crop_params, event_out = model(window_batch)
      
      # Get event probs and local xy predictions
      event_probs = torch.as_tensor(event_probs,dtype=torch.float32).to(device, non_blocking=True)
      local_ball_xy = get_local_groundtruth(xy_downscale,crop_params)
      local_ball_xy = torch.as_tensor(local_ball_xy,dtype=torch.float32).to(device)

      # Calucate train (global+local+event) loss
      with autocast():
        global_loss_train = ball_loss(global_out,xy_downscale_batch) * global_weight
        local_loss_train = ball_loss(local_out,local_ball_xy) * local_weight
        event_loss_train = event_loss(event_out,event_probs) * event_weight
        train_loss = global_loss_train + local_loss_train + event_loss_train

      optimizer.zero_grad(set_to_none=True)
      scaler.scale(train_loss).backward()
      loss = train_loss.detach().cpu().numpy()
      glo = global_loss_train.detach().cpu().numpy()
      local = local_loss_train.detach().cpu().numpy()
      event = event_loss_train.detach().cpu().numpy()
        
      train_loss_total += loss
      tr_global_total += glo
      tr_local_total += local
      tr_event_total += event

      scaler.step(optimizer)
      scaler.update()
      trepoch.set_postfix(loss=loss)

    #train_loss.backward()
    #optimizer.step()
    #train_loss_total += train_loss.item()

  # Log trainig losses
  tqdm.write(f"Train\t epoch: {epoch}/{ph3_epochs}\t loss: {train_loss_total/batch_num} Global: {tr_global_total/batch_num} Local: {tr_local_total/batch_num} Event: {tr_event_total/batch_num}")
  ph3_train_loss_log.append(train_loss_total/batch_num)
  ph3_tr_global_log.append(tr_global_total/batch_num)
  ph3_tr_local_log.append(tr_local_total/batch_num)
  ph3_tr_event_log.append(tr_event_total/batch_num)


  model.eval()
  with torch.no_grad():   
    batch_num = len(test_loader)
    test_loss = 0
    test_loss_total = 0
    test_global_total = 0
    test_local_total = 0
    test_event_total = 0
    with tqdm(test_loader, unit="batch") as tepoch:
      for data_batch in tepoch:
        tepoch.set_description(f"Test Epoch {epoch}")
        # Read in val batch
        window_imgs, xy_downscale, _, event_probs = data_batch
        window_batch = torch.as_tensor(window_imgs,dtype=torch.float32).to(device, non_blocking=True)
        xy_downscale_batch=torch.stack((xy_downscale[0],xy_downscale[1])).transpose(0,1)

        # Model forward step
        with autocast():
          global_out, local_out, _, crop_params, event_out = model(window_batch)

        # Get event probs and local xy predictions
        event_probs = torch.as_tensor(event_probs,dtype=torch.float32).to(device, non_blocking=True)
        local_ball_xy = get_local_groundtruth(xy_downscale,crop_params)
        local_ball_xy = torch.as_tensor(local_ball_xy,dtype=torch.float32).to(device)

        # Calucate val (global+local+event) loss
        with autocast():
          global_loss_test = ball_loss(global_out,xy_downscale_batch) * global_weight
          local_loss_test = ball_loss(local_out,local_ball_xy) * local_weight
          event_loss_test = event_loss(event_out,event_probs) * event_weight
          test_loss = global_loss_test + local_loss_test + event_loss_test
        
        loss = test_loss.detach().cpu().numpy()
        glo = global_loss_test.detach().cpu().numpy()
        local = local_loss_test.detach().cpu().numpy()
        event = event_loss_test.detach().cpu().numpy()
        
        test_loss_total += loss
        test_global_total += glo
        test_local_total += local
        test_event_total += event
        tepoch.set_postfix(loss=loss)

    ph3_test_loss_log.append(test_loss_total/batch_num)
    ph3_test_global_log.append(test_global_total/batch_num)
    ph3_test_local_log.append(test_local_total/batch_num)
    ph3_test_event_log.append(test_event_total/batch_num)
    tqdm.write(f"Test\t epoch: {epoch}/{ph3_epochs}\t loss: {test_loss_total/batch_num} Global: {test_global_total/batch_num} Local: {test_local_total/batch_num} Event: {test_event_total/batch_num}")

  tqdm.write("Saving model")
  savePath = f"{savePath_base}/Phase3/TTNet_Phase3_{epoch}.pth"
  state = {'model':model.state_dict(),'optimizer':optimizer.state_dict(),'lr_scheduler':lr_scheduler.state_dict(),'cur_epoch':epoch,'train_loss_log':ph3_train_loss_log,'val_loss_log':ph3_test_loss_log}
  torch.save(state, savePath) 
  ph3_saved = True

  # Plot the losses    
  plt.clf()
  plt.figure(dpi=1200)
  plt.plot(range(1, epoch+1),ph3_train_loss_log,label='Train loss')
  plt.plot(range(1, epoch+1),ph3_test_loss_log,label='Test loss')
  plt.legend()
  plt.xlabel('Epoch')
  plt.ylabel('Loss')
  plt.title('Training Loss')
  plt.savefig(outputPath+"loss_ph3.png")
  plt.show()

## Testing

In [None]:
tasks = ["local", "event"]
model = TTNet(dropout_p=0.5, frame_window=9, threshold=0.01, tasks=tasks).to(device)
savePath = savePath_base+"Phase3/TTNet_Phase3_6.pth"
checkpoint = torch.load(savePath)
model.load_state_dict(checkpoint['model']) 

<All keys matched successfully>

In [None]:
transform = A.Compose([
  A.Resize(height=TTN_height, width=TTN_width, interpolation=1, always_apply=True, p=1)],
  additional_targets = {
    'img2': 'image',
    'img3': 'image',
    'img4': 'image',
    'img5': 'image',
    'img6': 'image',
    'img7': 'image',
    'img8': 'image',
    'img9': 'image',
  })

class demo_data_loader():
  def __init__(self, window_size=9):
    self.window_size = window_size
    self.window_paths = []
    self.jpeg_reader = TurboJPEG()
    self.img_paths = sorted(glob.glob('./tmp/*.jpg'))

    for frame_no, _ in enumerate(self.img_paths):
      if not (window_size-2 < frame_no <= len(self.img_paths)-1):
        continue

      # Get frame window image paths
      window_path = []
      middle_index = frame_no-self.window_size+(self.window_size+1)//2
      for window_frame in range(self.window_size):
        window_path.append(self.img_paths[(frame_no+1)-self.window_size+window_frame])
      self.window_paths.append(window_path)
    
  def loader(self, start_index):
    # Get images
    imgs = []
    for i in range(9):
      in_file = open(self.window_paths[start_index][i], 'rb')
      image = self.jpeg_reader.decode(in_file.read(), 0)
      imgs.append(image)
    
    # Apply transformation and get outputs
    transformed = transform(image=imgs[0], img2=imgs[1], img3=imgs[2], img4=imgs[3], img5=imgs[4], img6=imgs[5], img7=imgs[6], img8=imgs[7], img9=imgs[8])
    img1 = transformed['image']
    img2 = transformed['img2']
    img3 = transformed['img3']
    img4 = transformed['img4']
    img5 = transformed['img5']
    img6 = transformed['img6']
    img7 = transformed['img7']
    img8 = transformed['img8']
    img9 = transformed['img9']

    return img1, img2, img3, img4, img5, img6, img7, img8, img9


  def __getitem__(self, index):
    img1, img2, img3, img4, img5, img6, img7, img8, img9 = self.loader(index)
    window_imgs = np.concatenate((img1, img2, img3, img4, img5, img6, img7, img8, img9), axis=2)
    window_imgs = np.rollaxis(window_imgs, 2, 0)
    window_imgs = torch.from_numpy(window_imgs)
    centre_path = self.window_paths[index][self.window_size//2]

    return window_imgs, centre_path

  def __len__(self):
    return len(self.window_paths)

# Extracts all the frames from a video and saves them
def preprocess(video_path):
  cap = cv.VideoCapture(video_path)
  ret, frame = cap.read()
  count=0
  while(ret):
    cv.imwrite(f"tmp/{count:04d}.jpg", frame)
    ret, frame = cap.read()
    count += 1
    if count %200 == 0:
      print(f"Frame {count}")
  cap.release

# Predicts the ball location and events for a new video
def predict(video_path, preprocessing=True):
  if preprocessing:
    print("Preprocessing")
    for f in os.listdir("tmp"):
      os.remove(f"tmp/{f}")
    preprocess(video_path)
  print("Loading demo dataset")
  demo_dataset  = demo_data_loader()
  demo_loader = DataLoader(demo_dataset, batch_size=16,shuffle=False,num_workers=num_workers,pin_memory=True,drop_last=True)

  cap = cv.VideoCapture(video_path)
  fps = cap.get(cv.CAP_PROP_FPS)
  cap.release()
  fourcc = cv.VideoWriter_fourcc(*'MPEG')
  out = cv.VideoWriter('./Results/videos/out.avi', fourcc, fps, (1920,1080))

  print("Making Predictions")
  model.eval()
  with torch.no_grad():
    events = ["Flying", "Bounce", "Hit", "Out of image"]
    preds = []
    events = []

    for i,data_batch in enumerate(tqdm(demo_loader)):
      # Read in demo batch
      window_imgs, centre_path = data_batch
      window_batch = torch.as_tensor(window_imgs,dtype=torch.float32).to(device, non_blocking=True)
      # Model forward step
      global_out, local_out, local_in, crop_params, event_out = model(window_batch)

      local_out_clone = local_out.clone().detach()
      global_out_clone = global_out.clone().detach()
      event_out_clone = event_out.clone().detach()

      for out_index in range(local_out_clone.shape[0]):
        # Global bounding box
        _, x_min, y_min, x_max, y_max, _, _ = crop_params[out_index]

        # Ball prediction
        local_x = torch.argmax(local_out_clone[out_index,:TTN_width]).item()
        local_y = torch.argmax(local_out_clone[out_index,TTN_width:]).item()
        global_x = torch.argmax(global_out_clone[out_index,:TTN_width]).item()
        global_y = torch.argmax(global_out_clone[out_index,TTN_width:]).item()

        x_pred = global_x*data_width/TTN_width-TTN_width/2+local_x
        y_pred = global_y*data_height/TTN_height-TTN_height/2+local_y
        x_pred = 0 if x_pred<0 else int(x_pred)
        y_pred = 0 if y_pred<0 else int(y_pred)
        preds.append([x_pred, y_pred])

        # Event prediction
        event_preds = event_out_clone[out_index]
        events.append([event_preds[0].item(), event_preds[1].item(), event_preds[2].item(), event_preds[3].item()])
    
        # Label frame
        centre_img = cv.imread(centre_path[out_index])
        
        # Draw the prediction if the ball is in the frame
        if event_preds[3].item() < 0.6:
          cv.rectangle(centre_img, (x_min,y_min), (x_max,y_max), (0,0,255), 2)
          cv.circle(centre_img,(x_pred,y_pred), 8,  (0, 0, 255), 2)
        
        if event_preds[0].item() > 0.8:
          event_text = f"Event: Flying {event_preds[0].item():.2f}"
          cv.putText(centre_img, event_text, (50, 100), cv.FONT_HERSHEY_SIMPLEX, 1.0, (255, 0, 0), 2)
        if event_preds[1].item() > 0.5:
          event_text = f"Event: Bounce {event_preds[1].item():.2f}"
          cv.putText(centre_img, event_text, (50, 150), cv.FONT_HERSHEY_SIMPLEX, 1.0, (0, 255, 0), 2)
        if event_preds[2].item() > 0.5:
          event_text = f"Event: Hit {event_preds[2].item():.2f}"
          cv.putText(centre_img, event_text, (50, 200), cv.FONT_HERSHEY_SIMPLEX, 1.0, (0, 0, 255), 2)
        if event_preds[3].item() > 0.6:
          event_text = f"Event: Out of image {event_preds[3].item():.2f}"
          cv.putText(centre_img, event_text, (50, 250), cv.FONT_HERSHEY_SIMPLEX, 1.0, (255, 255, 0), 2)

        out.write(centre_img)
  out.release()
  return preds, events

# Set preprocessing to True the first time you predict a video
preds, events = predict("./Dataset/videos/test2.mp4", preprocessing=False)

Preprocessing
Frame 200
Frame 400
Frame 600
Frame 800
Frame 1000
Frame 1200
Frame 1400
Frame 1600
Loading demo dataset
Making Predictions


  0%|          | 0/111 [00:00<?, ?it/s]