<a href="https://colab.research.google.com/github/Willyzw/SuperPointPretrainedNetwork/blob/master/superpoint_handson.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Superpoint Handson
This demo

In [1]:
! git clone https://github.com/Willyzw/SuperPointPretrainedNetwork
% cd /content/SuperPointPretrainedNetwork

Cloning into 'SuperPointPretrainedNetwork'...
remote: Enumerating objects: 81, done.[K
remote: Total 81 (delta 0), reused 0 (delta 0), pack-reused 81[K
Unpacking objects: 100% (81/81), done.
/content/SuperPointPretrainedNetwork


In [1]:
import time
import numpy as np
import cv2
import PIL.Image as pil
from IPython.display import Video
from demo_superpoint import SuperPointFrontend, PointTracker, VideoStreamer, myjet

In [2]:
# This class helps load input images from different sources.
vs = VideoStreamer("assets/nyu_snippet.mp4", camid=0, height=480, width=640, skip=1, img_glob='*.png')

print('==> Loading pre-trained network.')
# This class runs the SuperPoint network and processes its outputs.
fe = SuperPointFrontend(weights_path='superpoint_v1.pth',
                        nms_dist=4,
                        conf_thresh=0.015,
                        nn_thresh=0.7,
                        cuda=False)
print('==> Successfully loaded pre-trained network.')

# This class helps merge consecutive point matches into tracks.
tracker = PointTracker(5, nn_thresh=fe.nn_thresh)


==> Processing Video Input.
==> Loading pre-trained network.
==> Successfully loaded pre-trained network.


In [3]:
vs.reset()
print('==> Running Demo.')

video = cv2.VideoWriter("demo.mp4", cv2.VideoWriter_fourcc(*'XVID'), 3.0, (640*3,480))
while True:
  start = time.time()

  # Get a new image.
  img, status = vs.next_frame()
  if status is False:
    break

  # Get points and descriptors.
  start1 = time.time()
  pts, desc, heatmap = fe.run(img)
  end1 = time.time()

  # Add points and descriptors to the tracker.
  tracker.update(pts, desc)

  # Get tracks for points which were match successfully across all frames.
  tracks = tracker.get_tracks(2)

  # Primary output - Show point tracks overlayed on top of input image.
  out1 = (np.dstack((img, img, img)) * 255.).astype('uint8')
  tracks[:, 1] /= float(fe.nn_thresh) # Normalize track scores to [0,1].
  tracker.draw_tracks(out1, tracks)

  # Extra output -- Show current point detections.
  out2 = (np.dstack((img, img, img)) * 255.).astype('uint8')
  for pt in pts.T:
    pt1 = (int(round(pt[0])), int(round(pt[1])))
    cv2.circle(out2, pt1, 1, (0, 255, 0), -1, lineType=16)
  cv2.putText(out2, 'Raw Point Detections', (4, 12), cv2.FONT_HERSHEY_DUPLEX, 0.4, (255, 255, 255), lineType=16)

  # Extra output -- Show the point confidence heatmap.
  if heatmap is not None:
    min_conf = 0.001
    heatmap[heatmap < min_conf] = min_conf
    heatmap = -np.log(heatmap)
    heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min() + .00001)
    out3 = myjet[np.round(np.clip(heatmap*10, 0, 9)).astype('int'), :]
    out3 = (out3*255).astype('uint8')
  else:
    out3 = np.zeros_like(out2)
  cv2.putText(out3, 'Raw Point Confidences', (4, 12), cv2.FONT_HERSHEY_DUPLEX, 0.4, (255, 255, 255), lineType=16)

  # Compute runtime
  end = time.time()
  net_t = (1./ float(end1 - start))
  total_t = (1./ float(end - start))

  # Print and show result image
  print('Processed image %d (net+post_process: %.2f FPS, total: %.2f FPS).'\
        % (vs.i, net_t, total_t))
  out = np.hstack((out1, out2, out3))
  video.write(out)

==> Running Demo.
Processed image 1 (net+post_process: 3.15 FPS, total: 3.06 FPS).
Processed image 2 (net+post_process: 2.66 FPS, total: 2.41 FPS).
Processed image 3 (net+post_process: 3.63 FPS, total: 3.15 FPS).
Processed image 4 (net+post_process: 3.72 FPS, total: 3.22 FPS).
Processed image 5 (net+post_process: 3.97 FPS, total: 3.36 FPS).
Processed image 6 (net+post_process: 3.83 FPS, total: 3.28 FPS).
Processed image 7 (net+post_process: 3.67 FPS, total: 3.16 FPS).
Processed image 8 (net+post_process: 3.94 FPS, total: 3.36 FPS).
Processed image 9 (net+post_process: 3.73 FPS, total: 3.19 FPS).
Processed image 10 (net+post_process: 3.39 FPS, total: 2.94 FPS).
Processed image 11 (net+post_process: 3.96 FPS, total: 3.37 FPS).
Processed image 12 (net+post_process: 3.86 FPS, total: 3.29 FPS).
Processed image 13 (net+post_process: 3.84 FPS, total: 3.29 FPS).
Processed image 14 (net+post_process: 3.88 FPS, total: 3.30 FPS).
Processed image 15 (net+post_process: 3.80 FPS, total: 3.26 FPS).
P

In [4]:
Video("demo.mp4")