#### Demo Video:

https://www.youtube.com/watch?v=A-GUBdgqG-U

#### Model Weights:

https://drive.google.com/file/d/1SpzsQcu1HqlhECYaFfR7HLXJ7c2fzUtu/view?usp=sharing

In [None]:
from repnet import *
from utils import *

from keras.models import load_model

import gradio as gr

repnet_model = get_repnet_model('./repnet_ckpt')
classification_model = load_model('classification_model.h5')

WEBCAM_FPS = 16
RECORDING_TIME_IN_SECONDS = 8.
THRESHOLD = 0.2
WITHIN_PERIOD_THRESHOLD = 0.5
CONSTANT_SPEED = False
MEDIAN_FILTER = True
FULLY_PERIODIC = False
PLOT_SCORE = False
VIZ_FPS = 30

def count(path_to_video):
    imgs, vid_fps = read_video(path_to_video)
    (pred_period, pred_score, within_period, per_frame_counts, chosen_stride) = get_counts(
        repnet_model,
        imgs,
        strides=[1,2,3,4],
        batch_size=20,
        threshold=THRESHOLD,
        within_period_threshold=WITHIN_PERIOD_THRESHOLD,
        constant_speed=CONSTANT_SPEED,
        median_filter=MEDIAN_FILTER,
        fully_periodic=FULLY_PERIODIC)

    viz_reps(
        imgs,
        per_frame_counts,
        alpha=.5,
        interval=1_000 / VIZ_FPS,
        tmp_path='output.mp4'
    )
    return 'output.mp4'

def classify(path_to_video):
    frames, optical_flow_frames = frame_extraction(path_to_video)
    frames = np.array(frames)
    optical_flow_frames = np.array(optical_flow_frames)
    
    frames = np.expand_dims(frames, axis = 0)
    optical_flow_frames = np.expand_dims(optical_flow_frames, axis = 0)
    
    pred = classification_model.predict([frames, optical_flow_frames])
    
    top_indices = pred[0].argsort()
    
    top_classes = [CLASSES_INDEX_REVERSE[i] for i in top_indices]
    top_scores =  [pred[0][i] for i in top_indices]
    
    results = {top_classes[i]: top_scores[i] for i in range(len(top_indices))}
    
    return results

def run(url_input, video_input):
    if url_input != '':
        video_path = download_youtube_video(url_input)
    elif video_input is not None:
        video_path = video_input

    repnet_res = count(video_path)
    class_res = classify(video_path)
    return repnet_res, class_res

demo = gr.Interface(fn=run, inputs=[gr.Textbox(), gr.Video()], outputs=[gr.Video(), gr.Label(num_top_classes=5)])
demo.launch()