In [60]:
import tensorflow as tf

# show tflite version
print(tf.__version__)


2.8.0


In [71]:
interpreter = tf.lite.Interpreter(model_path="tflite/repnet.tflite")
input_details = interpreter.get_input_details()

interpreter.resize_tensor_input(
    input_details[0]['index'], (1, 64, 112, 112, 3), strict=True)

interpreter.allocate_tensors()

In [72]:
i = interpreter.get_input_details()[0]

print(i['shape_signature'])

[ -1  64 112 112   3]


In [63]:
o = interpreter.get_output_details()

# name_to_index = {'raw_scores': 32, 'within_period_scores': 1, 'period_scores': 512}

name_to_index = {}

for t in o:
    # get the dimension of the final axis
    dim = t['shape_signature'][-1]

    if dim == 32:
        name_to_index['raw_scores'] = t['index']
    elif dim == 1:
        name_to_index['within_period_scores'] = t['index']
    elif dim == 512:
        name_to_index['period_scores'] = t['index']
    else:
        raise ValueError("Unknown dimension")


print(name_to_index)
# raw_scores: [1 64 32]
# within_period_scores: [1 64 1]
# period_scores: [1 61 512]

{'raw_scores': 545, 'period_scores': 302, 'within_period_scores': 579}


In [73]:
def get_counts(interpreter, frames, strides, batch_size,
               threshold,
               within_period_threshold):
  """Pass frames through model and conver period predictions to count."""
  seq_len = len(frames)
  raw_scores_list = []
  scores = []
  within_period_scores_list = []

  model_num_frames = 64
  model_image_size = 112

  # frames = model_preprocess(frames)
  imgs = frames
  imgs = tf.cast(imgs, tf.float32)
  imgs -= 127.5
  imgs /= 127.5
  imgs = tf.image.resize(imgs, (model_image_size, model_image_size))
  frames = imgs

  for stride in strides:
    num_batches = int(np.ceil(seq_len/model_num_frames/stride/batch_size))
    raw_scores_per_stride = []
    within_period_score_stride = []
    for batch_idx in range(num_batches):
      idxes = tf.range(batch_idx*batch_size*model_num_frames*stride,
                       (batch_idx+1)*batch_size*model_num_frames*stride,
                       stride)
      idxes = tf.clip_by_value(idxes, 0, seq_len-1)
      curr_frames = tf.gather(frames, idxes)
      curr_frames = tf.reshape(
          curr_frames,
          [batch_size, model_num_frames, model_image_size, model_image_size, 3])

    #   interpreter.get_input_details()[0]['value'] = curr_frames
      interpreter.set_tensor(i['index'], curr_frames)
      interpreter.invoke()

      raw_scores = interpreter.get_tensor(name_to_index['raw_scores'])
      within_period_scores = interpreter.get_tensor(name_to_index['within_period_scores'])

      raw_scores_per_stride.append(np.reshape(raw_scores,
                                              [-1, model_num_frames//2]))
      within_period_score_stride.append(np.reshape(within_period_scores,
                                                   [-1, 1]))
    raw_scores_per_stride = np.concatenate(raw_scores_per_stride, axis=0)
    raw_scores_list.append(raw_scores_per_stride)
    within_period_score_stride = np.concatenate(
        within_period_score_stride, axis=0)
    pred_score, within_period_score_stride = get_score(
        raw_scores_per_stride, within_period_score_stride)
    scores.append(pred_score)
    within_period_scores_list.append(within_period_score_stride)

  # Stride chooser
  argmax_strides = np.argmax(scores)
  chosen_stride = strides[argmax_strides]
  raw_scores = np.repeat(
      raw_scores_list[argmax_strides], chosen_stride, axis=0)[:seq_len]
  within_period = np.repeat(
      within_period_scores_list[argmax_strides], chosen_stride,
      axis=0)[:seq_len]
  within_period_binary = np.asarray(within_period > within_period_threshold)

  # Count each frame. More noisy but adapts to changes in speed.
  pred_score = tf.reduce_mean(within_period)
  per_frame_periods = tf.argmax(raw_scores, axis=-1) + 1
  per_frame_counts = tf.where(
      tf.math.less(per_frame_periods, 3),
      0.0,
      tf.math.divide(1.0,
                      tf.cast(chosen_stride * per_frame_periods, tf.float32)),
  )

  per_frame_counts *= np.asarray(within_period_binary)

  pred_period = seq_len/np.sum(per_frame_counts)

  if pred_score < threshold:
    print('No repetitions detected in video as score '
          '%0.2f is less than threshold %0.2f.'%(pred_score, threshold))
    per_frame_counts = np.asarray(len(per_frame_counts) * [0.])

  return (pred_period, pred_score, within_period,
          per_frame_counts, chosen_stride)

In [74]:
def get_score(period_score, within_period_score):
  """Combine the period and periodicity scores."""
  within_period_score = tf.nn.sigmoid(within_period_score)[:, 0]
  per_frame_periods = tf.argmax(period_score, axis=-1) + 1
  pred_period_conf = tf.reduce_max(
      tf.nn.softmax(period_score, axis=-1), axis=-1)
  pred_period_conf = tf.where(
      tf.math.less(per_frame_periods, 3), 0.0, pred_period_conf)
  within_period_score *= pred_period_conf
  within_period_score = np.sqrt(within_period_score)
  pred_score = tf.reduce_mean(within_period_score)
  return pred_score, within_period_score

In [75]:
# FPS while recording video from webcam.
WEBCAM_FPS = 16#@param {type:"integer"}

# Time in seconds to record video on webcam. 
RECORDING_TIME_IN_SECONDS = 8. #@param {type:"number"}

# Threshold to consider periodicity in entire video.
THRESHOLD = 0.2#@param {type:"number"}

# Threshold to consider periodicity for individual frames in video.
WITHIN_PERIOD_THRESHOLD = 0.5#@param {type:"number"}

# Use this setting for better results when it is 
# known action is repeating at constant speed.
CONSTANT_SPEED = False#@param {type:"boolean"}

# Use median filtering in time to ignore noisy frames.
MEDIAN_FILTER = True#@param {type:"boolean"}

# Use this setting for better results when it is 
# known the entire video is periodic/reapeating and
# has no aperiodic frames.
FULLY_PERIODIC = False#@param {type:"boolean"}

# Plot score in visualization video.
PLOT_SCORE = False#@param {type:"boolean"}

# Visualization video's FPS.
IZ_FPS = 30#@param {type:"integer"}

In [76]:
# import cv2 
# import numpy as np


# def read_video(video_filename, width=224, height=224):
#   """Read video from file."""
#   cap = cv2.VideoCapture(video_filename)
#   fps = cap.get(cv2.CAP_PROP_FPS)
#   frames = []
#   if cap.isOpened():
#     while True:
#       success, frame_bgr = cap.read()
#       if not success:
#         break
#       frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
#       frame_rgb = cv2.resize(frame_rgb, (width, height))
#       frames.append(frame_rgb)
#   frames = np.asarray(frames)
#   return frames, fps

In [77]:
# imgs, vid_fps = read_video("data/hummingbird.gif")

In [78]:
# np.save("data/hummingbird.npy", imgs)
import numpy as np
imgs = np.load("data/hummingbird.npy")

In [80]:
(pred_period, pred_score, within_period,
 per_frame_counts, chosen_stride) = get_counts(
     interpreter,
     imgs,
     strides=[1,2,3,4],
     batch_size=1,
     threshold=THRESHOLD,
     within_period_threshold=WITHIN_PERIOD_THRESHOLD)

Current frame shape (1, 64, 112, 112, 3)
Current frame shape (1, 64, 112, 112, 3)
Current frame shape (1, 64, 112, 112, 3)
Current frame shape (1, 64, 112, 112, 3)
Current frame shape (1, 64, 112, 112, 3)
Current frame shape (1, 64, 112, 112, 3)
Current frame shape (1, 64, 112, 112, 3)
Current frame shape (1, 64, 112, 112, 3)
Current frame shape (1, 64, 112, 112, 3)
Current frame shape (1, 64, 112, 112, 3)
Current frame shape (1, 64, 112, 112, 3)
Current frame shape (1, 64, 112, 112, 3)


In [83]:
print('Predicted period: %0.2f'%pred_period)
print('Predicted score: %0.2f'%pred_score)
print('Chosen stride: %d'%chosen_stride)
# print('Per frame counts: %s'%per_frame_counts)
print(f'Predicted counts: {sum(per_frame_counts)}')

Predicted period: 25.80
Predicted score: 0.88
Chosen stride: 1
Predicted counts: 10.076630592346191
