<a href="https://colab.research.google.com/github/Shrinkhal01/-Real-time-Object-Detection-using-TensorFlow-and-OpenCV/blob/main/colabs/optical_flow_track_assist.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Demo for annotate a point track with optical flow

This notebook illustrates how we use optical flow to facilitate human annotation on point tracking. Note that it is very hard to annotate a point track extensively along a whole video sequence. However we find dense optical flow estimation these days are fast and accurate. In this demo, we utilize [RAFT](https://pytorch.org/vision/stable/auto_examples/plot_optical_flow.html) to compute the dense optical flow for us.

We then ask the annotater to select a point in the starting frame and the corresponding point location in the ending frame. A dynamic programming algorithm is used to optimize the estimated tracks given starting and ending point location. Note that the algorithm here differs from what we use in the original annotation system (dijkstra algorithm).

The dynamic programming algorithm here requires large matrix computation. Hence running on GPU will be a lot faster.

In [None]:
!pip install mediapy mako flow_vis

In [None]:
# @title Imports {form-width: "25%"}

import copy
import io
import flow_vis
import functools
import gc
import IPython
import mediapy as media
import numpy as np
from PIL import Image
from google.colab import html
import base64
from mako import template
import torch
import torchvision
from tqdm import tqdm

In [None]:
# If you can, run this example on a GPU, it will be a lot faster.
device = "cuda" if torch.cuda.is_available() else "cpu"

torch.set_grad_enabled(False)

In [None]:
# @title Load an Exemplar Video {form-width: "25%"}

# !wget -P tapnet/examplar_videos https://storage.googleapis.com/dm-tapnet/horsejump-high.mp4
video = media.read_video('/content/jet.mp4')
video = media.resize_video(video, (480, 768))
height, width = video.shape[1:3]
media.show_video(video, fps=10)

In [None]:
# @title Predict Optical Flows with RAFT {form-width: "25%"}

from torchvision.models.optical_flow import raft_large
from torchvision.models.optical_flow import Raft_Large_Weights

model = raft_large(weights=Raft_Large_Weights.DEFAULT, progress=False).to(device)
model = model.eval()

optical_flows = []
for i in tqdm(range(video.shape[0] - 1)):
  image1 = video[i].astype(np.float32) / 127.5 - 1.0
  image1 = image1.transpose(2, 0, 1)[None]
  image2 = video[i + 1].astype(np.float32) / 127.5 - 1.0
  image2 = image2.transpose(2, 0, 1)[None]
  flow = model(torch.tensor(image1).to(device), torch.tensor(image2).to(device))
  flow = flow[-1][0].cpu().numpy()
  flow = flow.transpose(1, 2, 0)
  optical_flows.append(flow)
optical_flows = np.stack(optical_flows)

# Release Memory after Prediction
del model
gc.collect()
torch.cuda.empty_cache()

print(optical_flows.shape)
print(np.abs(optical_flows).max())

In [None]:
# @title Visualize Optical Flows {form-width: "25%"}

flow_viz = []
for i in range(optical_flows.shape[0]):
  flow_viz.append(flow_vis.flow_to_color(optical_flows[i]))
flow_viz = np.stack(flow_viz)

media.show_video(flow_viz, fps=10)

In [None]:
# @title HTML Template {form-width: "25%"}

class Img(html.Element):
  def __init__(self, src=None, show=False):
    super(Img, self).__init__('img')
    if src is not None:
      self.src = src
    self.set_attribute('style', ('display:block;' if show else 'display:none;')+'margin:0px;')

  @property
  def src(self):
    return self.get_property('src')

  @src.setter
  def src(self, value):
    content = self._to_jpeg(value)
    url = 'data:image/jpeg;base64,' + base64.b64encode(content).decode('utf-8')
    self.set_property('src', url)

  def _to_jpeg(self, np_image):
    img = Image.fromarray(np_image)
    buf = io.BytesIO()
    img.save(buf, format="JPEG")
    return buf.getvalue()

In [None]:
# @title Dynamic Programming Algorithm {form-width: "25%"}

def interpolate(flows, frame1, click1, frame2, click2, radius=20):
  x1, y1 = click2idx(click1)
  x2, y2 = click2idx(click2)

  window = 2 * radius + 1
  x, y = np.meshgrid(np.arange(-radius, radius + 1), np.arange(-radius, radius + 1))
  offset_cost = np.stack([x, y], axis=-1)
  offset_cost = torch.tensor(offset_cost).to(device)

  num_frames, height, width = flows.shape[0:3]

  forward_i = np.zeros((num_frames + 1, height, width), dtype=np.int32)
  forward_j = np.zeros((num_frames + 1, height, width), dtype=np.int32)

  forward_cost = torch.ones((height, width)).to(device) * 1e10
  forward_cost[y1, x1] = 0

  for t in range(frame1, frame2):
    cost_pad = torch.nn.functional.pad(forward_cost, (radius, radius, radius, radius), 'constant', value=1e10)
    cost_unfold = cost_pad.unfold(0, window, 1).unfold(1, window, 1)
    del cost_pad
    gc.collect()
    torch.cuda.empty_cache()

    flow_cuda = torch.tensor(flows[t]).to(device)
    flow_pad = torch.nn.functional.pad(flow_cuda, (0, 0, radius, radius, radius, radius), 'constant', value=1e10)
    flow_unfold = flow_pad.unfold(0, window, 1).unfold(1, window, 1).permute(0, 1, 3, 4, 2)
    del flow_cuda, flow_pad
    gc.collect()
    torch.cuda.empty_cache()

    cost = cost_unfold + torch.abs(-offset_cost[None, None] - flow_unfold).sum(axis=-1)
    cost = cost.reshape(height, width, -1)
    forward_cost, argmin_indices = torch.min(cost, axis=-1)
    del cost
    gc.collect()
    torch.cuda.empty_cache()

    argmin_indices = argmin_indices.cpu().numpy()
    forward_i_min, forward_j_min = argmin_indices // (window), argmin_indices % (window)
    forward_i[t] = forward_i_min + np.arange(height)[:, None] - radius
    forward_j[t] = forward_j_min + np.arange(width)[None] - radius

  last_cost = torch.ones((height, width)).to(device) * 1e10
  last_cost[y2, x2] = 0
  forward_cost += last_cost
  min_cost = torch.min(forward_cost).cpu().numpy()

  argmin_indices = torch.argmin(forward_cost).item()
  min_i, min_j = argmin_indices // width, argmin_indices % width
  min_ij = [(min_j, min_i)]

  for t in range(frame2 - 1, frame1 - 1, -1):
    min_i, min_j = forward_i[t, min_i, min_j], forward_j[t, min_i, min_j]
    min_ij.insert(0, (min_j, min_i))

  del forward_cost
  gc.collect()
  torch.cuda.empty_cache()

  return np.stack(min_ij), min_cost

In [None]:
# @title Reset the Annotated Trajectories {form-width: "25%"}

clicks=[None for i in range(video.shape[0])]

In [None]:
# @title Start Annotation {form-width: "25%"}

def mouse_position(event, frame_id):
  x = event['clientX']
  y = event['clientY']
  clicks[frame_id]=[x, y]
  print('\r', 'Please re-run this cell ...', end='')

def click2idx(click):
  x, y = click
  x = int(round(x))
  y = int(round(y))
  return x, y

cur_pos = None
frames2 = []
all_pos = np.zeros([video.shape[0], 2], dtype=int)
last_click = None
for i in range(video.shape[0]):
  if clicks[i] and last_click:
    all_pos[last_click[0]:i+1, :], forward_cost = interpolate(optical_flows, last_click[0], last_click[1], i, clicks[i])

  if clicks[i]:
    cur_pos = copy.copy(clicks[i])
    last_click = (i, clicks[i])
  if cur_pos:
    x, y = click2idx(cur_pos)

    y = min(max(y, 0), height - 1)
    x = min(max(x, 0), width - 1)
    all_pos[i,0] = x
    all_pos[i,1] = y
    if i < optical_flows.shape[0]:
      cur_pos[0] += optical_flows[i, y, x, 0]
      cur_pos[1] += optical_flows[i, y, x, 1]

for i in range(video.shape[0]):
  fr = np.copy(video[i])
  x, y = all_pos[i] - 5
  fr[y-2:y+3,x-2:x+3,0] = 255 if clicks[i] else 0
  fr[y-2:y+3,x-2:x+3,1] = 0 if clicks[i] else 255
  fr[y-2:y+3,x-2:x+3,2] = 0 if clicks[i] else 255
  frames2.append(fr)

imgs=[]
img_ids="["
for i in range(len(frames2)):
  img = Img(src=frames2[i], show=i==0)
  img.add_event_listener('click', functools.partial(mouse_position, frame_id=i))
  imgs.append(img)
  img_ids += "\"" + str(img._guid) + "\","
img_ids += "]"

MAKO_TEMPLATE="""
<input type="range" min="0" max="${num_frames-1}" value="0" class="slider" id="myRange">
<script>
img_ids=${img_ids}
slider=document.getElementById("myRange");
cur_frame=0
slider.oninput = function() {
  idx = this.value;
  for (var i = 0; i<${num_frames}; i++){
    document.getElementById(img_ids[i]).style.display="none"
  }
  document.getElementById(img_ids[idx]).style.display="block"
}
</script>
"""
viz_tpl = template.Template(MAKO_TEMPLATE, strict_undefined=True)
script = viz_tpl.render(num_frames=len(frames2),img_ids=img_ids)

display(IPython.display.HTML(" ".join([img._repr_html_() for img in imgs])+script))

################################################################################
# Instructions:
#
# 1) click anywhere on the first frame to get a point to track.
# 2) re-run this cell to see where it goes
# 3) click a point on any other frame, and the demo will find the shortest path.
################################################################################