<a href="https://colab.research.google.com/github/abrichr/visual-contact-tracing/blob/master/Visual_Contact_Tracing.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## TODO

* Interpolate missing keypoints between frames
* Smooth keypoint positions across frames
  * Savgol or Kalman
  * Optical Flow: https://github.com/facebookresearch/DetectAndTrack/blob/d66734498a4331cd6fde87d8269499b8577a2842/lib/core/tracking_engine.py#L600
* Propagate keypoints through frames using optical flow and add to distance matrix: https://arxiv.org/pdf/1804.06208.pdf
* Require minimum number of consecutive contact frames
* Add bounding box area difference to cost?
* Add legend to output video

In [0]:
# install dependencies: (use cu101 because colab has CUDA 10.1)
!pip install -U torch==1.5 torchvision==0.6 -f https://download.pytorch.org/whl/cu101/torch_stable.html 
!pip install cython pyyaml==5.1 cython_bbox
!pip install -U 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI'
import torch, torchvision
print(torch.__version__, torch.cuda.is_available())
!gcc --version
# opencv is pre-installed on colab
# install detectron2:
!pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu101/index.html
# get configs
!git clone https://github.com/facebookresearch/detectron2

In [0]:
# May need to restart your runtime prior to this to let installation take effect

# Setup detectron2 logger
import detectron2
from detectron2.utils.logger import setup_logger
#setup_logger()

import numpy as np
import cv2
import random
from google.colab import files
from google.colab.patches import cv2_imshow

from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.data import MetadataCatalog

from scipy.spatial.distance import pdist, squareform
import requests
#import subprocess as sp

import imageio
from cython_bbox import bbox_overlaps
from skimage.color import rgba2rgb

import os
import sys
import time
from tqdm.notebook import tqdm
from detectron2.utils import visualizer
from detectron2.utils.visualizer import ColorMode
from detectron2.utils.video_visualizer import VideoVisualizer

Upload a video

In [0]:
if 0:
  uploaded = files.upload()
  for fn in uploaded.keys():
    print('User uploaded file "{name}" with length {length} bytes'.format(
      name=fn, length=len(uploaded[fn]))
    )

Or load from Google Drive

In [0]:

def download_file_from_google_drive(file_id, file_name):
  # download a file from the Google Drive link
  !rm -f ./cookie
  !curl -c ./cookie -s -L "https://drive.google.com/uc?export=download&id=$file_id" > /dev/null
  confirm_text = !awk '/download/ {print $NF}' ./cookie
  confirm_text = confirm_text[0]
  !curl -Lb ./cookie "https://drive.google.com/uc?export=download&confirm=$confirm_text&id=$file_id" -o $file_name
  with open(file_name, 'rb') as f:
    data = f.read()
    print('downloaded', len(data), 'bytes to', video_filename)


file_id = '0Bzf1l8WmTwu0eUluQ2h1NWZQRjQ'
video_filename = 'salsa_cpp_cam4.avi'
download_file_from_google_drive(file_id, video_filename)

Re-encode

In [0]:
RUN_CONFIGS = [
  # start_time_seconds, duration_seconds, start_infected_track
  (1, 10, 3),
  (10, 9.5, 2)
]
run_config = RUN_CONFIGS[1]
start_time_seconds, duration_seconds, start_infected_track = run_config


ffmpeg = imageio.plugins.ffmpeg                                                
try:                                                                           
    ffmpeg.download()                                                          
except:                                                                        
    pass                                                                       
ffmpeg_exe = ffmpeg.get_exe()

video_path = video_filename
video_filename_reenc = video_filename + '-reenc.avi'

cmd_parts = [
  ffmpeg_exe,
  '-i', video_path,
  '-vcodec', 'h264',
  '-acodec', 'aac',
  #'-c', 'copy',
  '-strict',
  '-2',
  '-ss', str(start_time_seconds),
  '-t', str(duration_seconds),
  '-y',
  '-loglevel', 'debug',
  '-an',
  video_filename_reenc
]
cmd = ' '.join(cmd_parts)
print('Running cmd:\n', cmd)                                
! $cmd                                

with open(video_filename_reenc, 'rb') as f:
  data = f.read()
  print('wrote', len(data), 'bytes')

In [0]:
HIDE_KEYPOINTS = False
DEFAULT_CONFIG = 'detectron2/configs/COCO-Keypoints/keypoint_rcnn_R_50_FPN_3x.yaml'
DEFAULT_CONF_THRESH = 0.7
DEFAULT_OPTS = [
  'MODEL.WEIGHTS',
  model_zoo.get_checkpoint_url("COCO-Keypoints/keypoint_rcnn_R_50_FPN_3x.yaml"),
]

# https://github.com/facebookresearch/detectron2/blob/04958b93e1232935e126c2fd9e6ccd3f57c3a8f3/detectron2/utils/visualizer.py#L32
KEYPOINT_THRESHOLD = 0.04
visualizer._KEYPOINT_THRESHOLD = KEYPOINT_THRESHOLD


default_args = [DEFAULT_CONFIG, DEFAULT_OPTS, DEFAULT_CONF_THRESH]


def setup_cfg(config=DEFAULT_CONFIG, opts=DEFAULT_OPTS, conf_thresh=DEFAULT_CONF_THRESH):
  # load config from file and arguments
  cfg = get_cfg()
  cfg.merge_from_file(config)
  cfg.merge_from_list(opts)
  # Set score_threshold for builtin models
  cfg.MODEL.RETINANET.SCORE_THRESH_TEST = conf_thresh
  cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = conf_thresh
  cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = conf_thresh
  cfg.freeze()
  return cfg


setup_logger(name="fvcore")
logger = setup_logger()
logger.info("Arguments: " + str(default_args))

cfg = setup_cfg()
predictor = DefaultPredictor(cfg)

video_input = video_filename_reenc
print('video_input:', video_input)
assert os.path.isfile(video_input)
video = cv2.VideoCapture(video_input)
print('video:', video)
width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
frames_per_second = video.get(cv2.CAP_PROP_FPS)
print('frames_per_second:', frames_per_second)
num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
print('num_frames:', num_frames)
basename = os.path.basename(video_input)

kp_output_fname = f'{video_filename_reenc}-kp.mp4'
print('kp_output_fname:', kp_output_fname)
kp_output_file = cv2.VideoWriter(
  filename=kp_output_fname,
  fourcc=cv2.VideoWriter_fourcc(*'mp4v'),
  fps=float(frames_per_second),
  frameSize=(width, height),
  isColor=True,
)

metadata = MetadataCatalog.get(
  cfg.DATASETS.TEST[0] if len(cfg.DATASETS.TEST) else "__unused"
)
cpu_device = torch.device("cpu")
instance_mode = ColorMode.IMAGE
video_visualizer = VideoVisualizer(metadata, instance_mode)

def process_predictions(frame, predictions):
  frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
  if "panoptic_seg" in predictions:
    panoptic_seg, segments_info = predictions["panoptic_seg"]
    vis_frame = video_visualizer.draw_panoptic_seg_predictions(
      frame, panoptic_seg.to(cpu_device), segments_info
    )
  elif "instances" in predictions:
    predictions = predictions["instances"].to(cpu_device)
    if HIDE_KEYPOINTS:
      predictions.remove('pred_keypoints')
    vis_frame = video_visualizer.draw_instance_predictions(frame, predictions)
  elif "sem_seg" in predictions:
    vis_frame = video_visualizer.draw_sem_seg(
      frame, predictions["sem_seg"].argmax(dim=0).to(cpu_device)
    )

  # Converts Matplotlib RGB format to OpenCV BGR format
  vis_frame = cv2.cvtColor(vis_frame.get_image(), cv2.COLOR_RGB2BGR)
  return vis_frame

def _frame_from_video(video):
  while video.isOpened():
    success, frame = video.read()
    if success:
      yield frame
    else:
      break

frame_gen = _frame_from_video(video)
frames = []
all_predictions = []
SHOW_NUM_FRAMES = 1
WRITE_KP_OUTPUT = False
for i, frame in enumerate(tqdm(frame_gen, total=num_frames)):

  frames.append(frame)
  start = time.time()
  predictions = predictor(frame)
  pred_times.append(time.time() - start)
  all_predictions.append(predictions)

  if i < SHOW_NUM_FRAMES or WRITE_KP_OUTPUT:
    vis_frame = process_predictions(frame, predictions)
    if i < SHOW_NUM_FRAMES:
      print('displaying frame', i)
      cv2_imshow(vis_frame)
    if WRITE_KP_OUTPUT:
      kp_output_file.write(vis_frame)

all_predictions = np.array(all_predictions)
#print('all_predictions.shape:', all_predictions.shape)

video.release()
kp_output_file.release()

if WRITE_KP_OUTPUT:
  with open(kp_output_fname, 'rb') as f:
    print('wrote', len(f.read()), 'bytes to', kp_output_fname)
  files.download(kp_output_fname)

In [0]:
DEFAULT_HIDE_KEYPOINTS = False

# https://github.com/facebookresearch/detectron2/issues/754#issuecomment-579463185
JOINT_NAMES = [
  "nose",
  "left_eye", "right_eye",
  "left_ear", "right_ear",
  "left_shoulder", "right_shoulder",
  "left_elbow", "right_elbow",
  "left_wrist", "right_wrist",
  "left_hip", "right_hip",
  "left_knee", "right_knee",
  "left_ankle", "right_ankle"
]

# TODO: smoothing e.g. kalman or savgol
# https://stackoverflow.com/a/52450682/95989
max_num_instances = 0
all_keypoints = []
all_boxes = []
for predictions in all_predictions:
  #predictions['instances'].get_fields().keys()
  #'pred_boxes', 'scores', 'pred_classes', 'pred_keypoints'
  instances = predictions['instances'].to(cpu_device)
  #import pdb; pdb.set_trace()
  keypoints = np.asarray(instances.pred_keypoints)
  boxes = np.asarray(instances.pred_boxes.tensor)
  #print('keypoints:', keypoints.shape)
  all_keypoints.append(keypoints)
  all_boxes.append(boxes)
  num_instances = keypoints.shape[0]
  max_num_instances = max(max_num_instances, num_instances)
print('max_num_instances:', max_num_instances)
print('len(all_keypoints):', len(all_keypoints))


# https://github.com/facebookresearch/DetectAndTrack/blob/d66734498a4331cd6fde87d8269499b8577a2842/lib/core/tracking_engine.py#L106
def compute_pairwise_iou(a, b):
  """
  a, b (np.ndarray) of shape Nx4 and Mx4.
  The output is NxM, for each combination of boxes.
  """

  C = 1 - bbox_overlaps(
    np.ascontiguousarray(a, dtype=np.float64),
    np.ascontiguousarray(b, dtype=np.float64),
  )
  return C


def compute_distance_matrix(prev_boxes, cur_boxes):
  # TODO: consider keypoint distance?
  # TODO: weigh cost further away in time more heavily
  return compute_pairwise_iou(prev_boxes, cur_boxes)


# https://github.com/facebookresearch/DetectAndTrack/blob/d66734498a4331cd6fde87d8269499b8577a2842/lib/core/tracking_engine.py#L184
def bipartite_matching_greedy(C, prev_tracks):
    """
    Computes the bipartite matching between the rows and columns, given the
    cost matrix, C.
    """
    C = C.copy()  # to avoid affecting the original matrix
    prev_ids = []
    cur_ids = []
    while (C == np.inf).sum() != C.size:
      #print('*' * 40)

      # Find the lowest cost element
      min_idx = C.argmin()
      i, j = np.unravel_index(min_idx, C.shape)
      min_val = C[i][j]
      #print('min_idx:', min_idx, 'min_val:', min_val, 'i:', i, 'j:', j)

      # Add to results
      #print('adding to results:')
      prev_ids.append(i)
      cur_ids.append(j)
      #print('prev_ids:', prev_ids)
      #print('cur_ids:', cur_ids)
      
      # Remove from cost matrix
      track = prev_tracks[i]
      #print('track:', track)
      track_idxs = [
        idx for idx in range(len(prev_tracks))
        if prev_tracks[idx] == track
      ]
      #print('track_idxs:', track_idxs)
      C[:, j] = np.inf
      for track_idx in track_idxs:
        #print('removing track_idx:', track_idx)
        C[track_idx, :] = np.inf
      #num_removed_costs = (C == np.inf).sum()
      #print('num_removed_costs:', num_infs)

    return prev_ids, cur_ids


def compute_matches(prev_boxes, cur_boxes, prev_tracks):
  assert len(prev_boxes) == len(prev_tracks)
  #matches = -np.ones((max_num_instances, ), dtype=np.int32)
  matches = -np.ones((len(cur_boxes), ), dtype=np.int32)
  if not prev_boxes.size:
    return matches
  C = compute_distance_matrix(prev_boxes, cur_boxes)
  prev_inds, next_inds = bipartite_matching_greedy(C, prev_tracks)
  #print('prev_inds:', prev_inds, len(prev_inds))
  #print('next_inds:', next_inds, len(next_inds))
  assert(len(prev_inds) == len(next_inds))
  for i in range(len(prev_inds)):
    #print('i:', i, 'next_inds[i]:', next_inds[i], 'prev_inds[i]', prev_inds[i])
    matches[next_inds[i]] = prev_inds[i]
    #print('matches:', matches)
  return matches

def get_frame_tracks(matches, prev_tracks, next_track_id):
  frame_tracks = []
  for i, m in enumerate(matches):
    #print('i:', i, 'm:', m, 'len(prev_tracks):', len(prev_tracks ))
    if m == -1 or m >= len(prev_tracks):  # didn't match to any
      frame_tracks.append(next_track_id[0])
      next_track_id[0] += 1
      if next_track_id[0] >= MAX_TRACK_IDS:
        # TODO: handle this
        print('Exceeded max track ids')
        next_track_id[0] %= MAX_TRACK_IDS
    else:
      frame_tracks.append(prev_tracks[m])
  #print('prev_tracks:\t', prev_tracks, 'len(prev_tracks):', len(prev_tracks))
  #print('frame_tracks:\t', frame_tracks, 'len(frame_tracks):', len(frame_tracks))
  return frame_tracks

def visualize_predictions(frame, predictions, hide_keypoints=DEFAULT_HIDE_KEYPOINTS):
    frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
    predictions = predictions["instances"].to(cpu_device)
    if hide_keypoints:
      predictions.remove('pred_keypoints')
    vis_frame = video_visualizer.draw_instance_predictions(frame, predictions)
    # Converts Matplotlib RGB format to OpenCV BGR format
    vis_frame = cv2.cvtColor(vis_frame.get_image(), cv2.COLOR_RGB2BGR)
    return vis_frame

# compute tracks, inspired by:
# https://github.com/facebookresearch/DetectAndTrack/blob/d66734498a4331cd6fde87d8269499b8577a2842/lib/core/tracking_engine.py#L272
MAX_TRACK_IDS = 999
all_prev_boxes = []
T = 60
video_tracks = []
next_track_id = [0]
for frame_id, (frame, predictions) in enumerate(tqdm(zip(frames, all_predictions), total=len(frames))):
  #print('\nframe_id:', frame_id)

  instances = predictions['instances'].to(cpu_device)  
  cur_boxes = np.asarray(instances.pred_boxes.tensor)
  prev_boxes = np.vstack(all_prev_boxes) if all_prev_boxes else np.array([])
  all_prev_tracks = video_tracks[
    max(0, frame_id - len(all_prev_boxes)) :
    max(0, frame_id)
  ]
  #print('len(all_prev_tracks):', len(all_prev_tracks))
  prev_tracks = np.hstack(all_prev_tracks) if all_prev_tracks else np.array([])
  #print('prev_boxes.shape:', prev_boxes.shape)
  #print('prev_tracks.shape:', prev_tracks.shape)

  matches = compute_matches(prev_boxes, cur_boxes, prev_tracks)
  #print('matches:\t', matches)
  # matches[i] contains the index of the box in the previous frames
  # corresponding to the box with index i in the current frame

  #print('prev_tracks:', prev_tracks)
  frame_tracks = get_frame_tracks(matches, prev_tracks, next_track_id)
  assert len(np.unique(frame_tracks)) == len(frame_tracks), (len(np.unique(frame_tracks)), len(frame_tracks))
  video_tracks.append(frame_tracks)
  all_prev_boxes.append(cur_boxes)
  if len(all_prev_boxes) > T:
    all_prev_boxes = all_prev_boxes[1:]

  SHOW_FRAME_ON_NEW_TRACK = False
  HAS_NEW_MATCH = any([match == -1 for match in matches])
  if frame_id < 3 or frame_id >= len(frames) - 3 or (SHOW_FRAME_ON_NEW_TRACK and HAS_NEW_MATCH):
    print('Visualizing frame_id:', frame_id)
    if HAS_NEW_MATCH:
      print('New match:')
    vis_frame = visualize_predictions(frame, predictions)
    for box, frame_track in zip(cur_boxes, frame_tracks):
      cv2.putText(vis_frame, str(frame_track), (int(box[0]-5), int(box[1]-5)), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0))
      cv2.putText(vis_frame, str(frame_track), (int(box[0]-4), int(box[1]-4)), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255))
    cv2_imshow(vis_frame)

# TODO: filter out large position jumps that immediately return after one frame

## Detect Infections

In [0]:
DISTANCE_THRESHOLD = 20
MIN_CONSECUTIVE_HIT_FRAMES = 1
WRITE_INF_OUTPUT = True

print('DISTANCE_THRESHOLD:', DISTANCE_THRESHOLD)
print('MIN_CONSECUTIVE_HIT_FRAMES:', MIN_CONSECUTIVE_HIT_FRAMES)

inf_output_fname = f'{video_filename_reenc}-inf.mp4'
print('inf_output_fname:', inf_output_fname)
inf_output_file = cv2.VideoWriter(
    filename=inf_output_fname,
    fourcc=cv2.VideoWriter_fourcc(*'mp4v'),
    fps=float(frames_per_second),
    frameSize=(width, height),
    isColor=True,
)

print(len(all_predictions), 'frames')
target_joint_names = [
  'left_wrist',
  'right_wrist'
]
infected_tracks = set(
  #[np.random.randint(0, max_num_instances)]
  [start_infected_track]
)
vis_times = []
inf_times = []
print('len(all_keypoints):', len(all_keypoints))
print('len(video_tracks):', len(video_tracks))
assert(len(all_keypoints) == len(video_tracks))
num_tracks = max(max(frame_tracks) for frame_tracks in video_tracks)
print('num_tracks:', num_tracks)
num_joints = len(target_joint_names)
dim = num_tracks * num_joints
consecutive_hits = np.zeros((dim, dim))
for frame_id, (keypoints, frame_tracks) in enumerate(
    tqdm(zip(all_keypoints, video_tracks), total=len(all_keypoints))
):
  #print('*' * 40)
  #print('frame_id:', frame_id)
  #print('infected_tracks:', infected_tracks)
  #print('frame_tracks:', frame_tracks)
  assert len(np.unique(frame_tracks)) == len(frame_tracks), (len(np.unique(frame_tracks)), len(frame_tracks))
  target_joint_vals = []
  target_joint_probs = []
  for keypoints_per_instance in keypoints:
    # https://github.com/facebookresearch/detectron2/blob/master/detectron2/utils/visualizer.py#L703
    # a tensor of shape (K, 3), where K is the number of keypoints
    # and the last dimension corresponds to (x, y, probability).

    for joint_name in target_joint_names:
      joint_idx = JOINT_NAMES.index(joint_name)
      joint_vals = keypoints_per_instance[joint_idx]
      x, y, prob = joint_vals
      # TODO: use confidence?
      #confident = prob >= _KEYPOINT_THRESHOLD
      # print(
      #   'joint_name:', joint_name,
      #   'joint_vals:', joint_vals,
      #   'prob:', prob
      # )
      target_joint_vals.append([x, y])
      target_joint_probs.append(prob)

  target_joint_vals = np.array(target_joint_vals)
  #print('target_joint_vals:', target_joint_vals.shape)

  start = time.time()
  distances = squareform(pdist(target_joint_vals))
  #print('distances.shape:', distances.shape)

  PRINT_DISTANCES = False
  if PRINT_DISTANCES:
    np.set_printoptions(
      threshold=sys.maxsize,
      formatter={'float': lambda x: "{0:0.0f}".format(x)}
    )
    print('distances:')
    print(distances)
    np.set_printoptions()
  
  hit_mask = distances < DISTANCE_THRESHOLD
  prob_mask = np.array(target_joint_probs) > KEYPOINT_THRESHOLD
  hit_mask[~prob_mask, :] = False
  hit_mask[:, ~prob_mask] = False
  #print('hit_mask.shape:', hit_mask.shape)
  frame_infected_tracks = set()
  frame_tracks = np.array(frame_tracks)
  hit_count_by_track_joint_idx_tup = {}
  joint_idx_by_track_joint_idx = {}
  for infected_track in infected_tracks:
    #print('-' * 20)
    #print('infected_track:', infected_track)
    infected_idx = np.where(frame_tracks == infected_track)[0]
    if not infected_idx.size:
      #print('track', infected_track, 'no longer in frame')
      continue
    infected_idx = infected_idx[0]
    #print('infected_idx:', infected_idx)
    
    for i_joint in range(len(target_joint_names)):
      #print('. ' * 20)
      #print('i_joint:', i_joint)
      infected_joint_idx = infected_idx * len(target_joint_names) + i_joint
      row = hit_mask[infected_joint_idx]
      #print('row:', [1 if v else 0 for v in row ])
      #print('row.shape:', row.shape)
      hit_joint_idxs = np.where(row)[0].tolist()
      #print('hit_joint_idxs:', hit_joint_idxs)

      for hit_joint_idx in hit_joint_idxs:
        hit_track_idx = hit_joint_idx // num_joints
        hit_track = frame_tracks[hit_track_idx]

        infected_track_joint_idx = infected_track * num_joints + i_joint
        hit_track_joint_idx = hit_track * num_joints + i_joint

        if infected_track_joint_idx == hit_track_joint_idx:
          continue

        joint_idx_by_track_joint_idx[infected_track_joint_idx] = infected_joint_idx
        joint_idx_by_track_joint_idx[hit_track_joint_idx] = hit_joint_idx

        hit_count = 1 + (
            consecutive_hits[infected_track_joint_idx][hit_track_joint_idx]
        )
        track_joint_idx_tup = (infected_track_joint_idx, hit_track_joint_idx)

        # TODO: understand why this happens
        #assert track_joint_idx_tup not in hit_count_by_track_joint_idx_tup, (
        #    track_joint_idx_tup, hit_count_by_track_joint_idx_tup
        #)
        if track_joint_idx_tup in hit_count_by_track_joint_idx_tup:
          print(
            '!!! WARNING',
            'track_joint_idx_tup was already in hit_count_by_track_joint_idx_tup:',
            track_joint_idx_tup, hit_count_by_track_joint_idx_tup
          )
        hit_count_by_track_joint_idx_tup.setdefault(track_joint_idx_tup, 0)
        hit_count_by_track_joint_idx_tup[track_joint_idx_tup] = hit_count
  
  # TODO: don't fill if missing joint_idx belongs to track that's not visible
  # in this frame, or if it's within e.g. MIN_CONSECUTIVE_DEAD_KP_FRAMES frames
  # (it might reappear)
  consecutive_hits.fill(0)
  for (a_idx, b_idx), hit_count in hit_count_by_track_joint_idx_tup.items():
    consecutive_hits[a_idx][b_idx] = hit_count
    consecutive_hits[b_idx][a_idx] = hit_count
  PRINT_CONSECUTIVE_HITS = False
  if PRINT_CONSECUTIVE_HITS:
    np.set_printoptions(
      threshold=sys.maxsize,
      formatter={'float': lambda x: "{0:0.0f}".format(x)}
    )
    print('consecutive_hits:')
    print(consecutive_hits)
    np.set_printoptions()


  hit_tracks = []
  all_hit_track_joint_idx_tups = []
  hit_track_joint_idx_tups = []
  for track_joint_idx_tup, hit_count in hit_count_by_track_joint_idx_tup.items():
    infected_track_joint_idx, hit_track_joint_idx = track_joint_idx_tup
    infected_track = infected_track_joint_idx // num_joints
    hit_track = hit_track_joint_idx // num_joints
    if infected_track == hit_track:
      continue

    all_hit_track_joint_idx_tups.append(track_joint_idx_tup)
    if hit_count < MIN_CONSECUTIVE_HIT_FRAMES:
      continue
    hit_track_joint_idx_tups.append(track_joint_idx_tup)
    hit_track = hit_track_joint_idx // num_joints
    hit_tracks.append(hit_track)

  #print('hit_tracks:', hit_tracks)
  frame_infected_tracks |= set(hit_tracks)
  inf_times.append(time.time() - start)

  start = time.time()
  frame = frames[frame_id]
  predictions = all_predictions[frame_id]
  instances = predictions['instances'].to(cpu_device)
  cur_boxes = np.asarray(instances.pred_boxes.tensor)
  vis_frame = visualize_predictions(frame, predictions, hide_keypoints=True)
  for box, frame_track in zip(cur_boxes, frame_tracks):
    cv2.putText(vis_frame, str(frame_track), (int(box[0]-5), int(box[1]-5)), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0))
    cv2.putText(vis_frame, str(frame_track), (int(box[0]-4), int(box[1]-4)), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255))

  b_channel, g_channel, r_channel = cv2.split(vis_frame)
  alpha_channel = np.ones(b_channel.shape, dtype=b_channel.dtype) * 64

  # highlight infected
  #print('frame_infected_tracks:', frame_infected_tracks)
  #print('infected_tracks:', infected_tracks)
  highlight_tracks = frame_infected_tracks | infected_tracks
  #print('highlight_tracks:', highlight_tracks)
  highlight_boxes = []
  for track in highlight_tracks:
    track_id = np.where(frame_tracks == track)[0]
    if not track_id.size:
      # track is not visible in this frame
      continue
    track_id = track_id[0]
    #print('track_id:', track_id)
    box = cur_boxes[track_id]
    #print('box:', box)
    highlight_boxes.append(box)
  #print('highlight_boxes:', highlight_boxes)
  for box in highlight_boxes:
    #print('box:', box, 'box[0]:', box[0])
    xA = int(box[0])
    yA = int(box[1])
    xB = int(box[2])
    yB = int(box[3])
    alpha_channel[yA:yB, xA:xB] = 128
    alpha_channel[yA, xA:xB] = 255
    alpha_channel[yB, xA:xB] = 255
    alpha_channel[yA:yB, xA] = 255
    alpha_channel[yA:yB, xB] = 255

    for channel in b_channel, g_channel, r_channel:
      channel[yA, xA:xB] = 0
      channel[yB, xA:xB] = 0
      channel[yA:yB, xA] = 0
      channel[yA:yB, xB] = 0

  # highlight hits
  HIGHLIGHT_AFTER_MIN_CONSECUTIVE_ONLY = False
  if HIGHLIGHT_AFTER_MIN_CONSECUTIVE_ONLY:
    highlight_track_joint_idx_tups = hit_track_joint_idx_tups
  else:
    highlight_track_joint_idx_tups = all_hit_track_joint_idx_tups

  r = 20
  for track_joint_idx_tup in highlight_track_joint_idx_tups:
    infected_track_joint_idx, hit_track_joint_idx = track_joint_idx_tup

    infected_joint_idx = joint_idx_by_track_joint_idx[infected_track_joint_idx]
    hit_joint_idx = joint_idx_by_track_joint_idx[hit_track_joint_idx]

    infected_joint_vals = target_joint_vals[infected_joint_idx]
    hit_joint_vals = target_joint_vals[hit_joint_idx]
    xA, yA = infected_joint_vals
    xB, yB = hit_joint_vals
    cA = int(min(xA, xB)) - r
    cB = int(max(xA, xB)) + r
    rA = int(min(yA, yB)) - r
    rB = int(max(yA, yB)) + r

    # TODO: rename, this is confusing
    infected_track = infected_track_joint_idx // num_joints
    hit_track = hit_track_joint_idx // num_joints
    sorted_tracks = sorted([infected_track, hit_track])
    label = f'{sorted_tracks[0]}:{sorted_tracks[1]}'

    alpha_channel[rA:rB, cA:cB] = 255
    for channel in b_channel, g_channel, r_channel:
      channel[rA, cA:cB] = 255
      channel[rB, cA:cB] = 255
      channel[rA:rB, cA] = 255
      channel[rA:rB, cB] = 255
      cv2.putText(channel, label, (cA-2, rA-2), cv2.FONT_HERSHEY_PLAIN, 1, (0, 0, 0))
      cv2.putText(channel, label, (cA-1, rA-1), cv2.FONT_HERSHEY_PLAIN, 1, (255, 255, 255))

    
  img_BGRA = cv2.merge((b_channel, g_channel, r_channel, alpha_channel))
  vis_times.append(time.time() - start)

  new_infected_tracks = frame_infected_tracks - infected_tracks
  if new_infected_tracks:
    print('new_infected_tracks:', new_infected_tracks)
  
  if highlight_track_joint_idx_tups:
    print('frame_id:', frame_id)
    cv2_imshow(img_BGRA)


  def bgra_2_bgr(bgra):
    # faster (?) but inverted alpha -> brightness
    b, g, r, a = cv2.split(bgra)
    bgr = cv2.merge((b, g, r))
    a = 255 - np.stack([a, a, a], axis=-1)
    bgr = cv2.addWeighted(np.float32(bgr) / 255, .5, np.float32(a) / 255, .5, 0)
    bgr = (bgr * 255).astype(np.uint8)
    '''
    # slower but matches cv2_imshow
    rgba = cv2.cvtColor(bgra, cv2.COLOR_BGRA2RGBA)
    rgb = rgba2rgb(rgba)
    rgb = (rgb * 255).astype(np.uint8)
    bgr = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR)
    '''
    return bgr

  if WRITE_INF_OUTPUT:
    img_BGR = bgra_2_bgr(img_BGRA)
    inf_output_file.write(img_BGR)

  infected_tracks |= frame_infected_tracks

  # if frame_id > 50:
  #   break

print('final infected_tracks:', infected_tracks)
print('inf_time:', sum(inf_times))
print('vis_time:', sum(vis_times))

inf_output_file.release()

if WRITE_INF_OUTPUT:
  with open(inf_output_fname, 'rb') as f:
    print('wrote', len(f.read()), 'bytes to', inf_output_fname)
  ! ls -alh $inf_output_fname
  # this produces "MessageError: TypeError: Failed to fetch"
  #files.download(inf_output_fname)  

In [0]:
  if WRITE_INF_OUTPUT:
    ! ls -alh $inf_output_fname
    files.download(inf_output_fname)

Limitations / Future Work
- 2D results in false positives (due to lack of depth information) and false negatives (due to occlusion)
  - Lack of depth can be mitigated with depth estimation: https://roxanneluo.github.io/Consistent-Video-Depth-Estimation/
  - Occlusion can be mitigated by using mutiple cameras: https://arxiv.org/pdf/2003.03972v2.pdf
- Only keypoints are considered, not semantic segmentation masks, which may contain more information about whether contact ocurred
- Doesn't track individuals between videos

In [0]:
print('frames_per_second:', frames_per_second)
print('num_frames:', num_frames)