In [13]:
import os
import math
import tempfile
import fire
from tqdm import tqdm
import numpy as np
import cv2
from jax import jit
from onnxruntime import (
    GraphOptimizationLevel, InferenceSession,
    SessionOptions, get_available_providers
)

In [14]:
import time

In [24]:
CURRENT_DIR = os.path.realpath(os.path.dirname(''))
NUM_THREADS = min(1, os.cpu_count())

In [33]:
def create_model_for_provider(
    model_path: str,
    provider: str = 'auto',
    num_threads=None
) -> InferenceSession:
    if provider == 'auto':
        if 'CUDAExecutionProvider' in get_available_providers():
            provider = 'CUDAExecutionProvider'
        else:
            provider = 'CPUExecutionProvider'
        print('model provider', provider)
    assert provider in get_available_providers(), \
        f"provider {provider} not found, {get_available_providers()}"
    
    # Few properties that might have an impact on performances (provided by MS)
    options = SessionOptions()
    if num_threads is not None:
        options.intra_op_num_threads = num_threads
    else:
        options.intra_op_num_threads = int(os.environ.get('NUM_THREADS', NUM_THREADS))
    options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
    # Load the model as a graph and prepare the CPU backend
    session = InferenceSession(model_path, options, providers=[provider])
    session.disable_fallback()
    return session


def get_video(cap):
    c = 0
    while True:
        if cap.grab():
            flag, frame = cap.retrieve()
            if not flag:
                continue
            else:
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                src = np.expand_dims(np.transpose(frame, (2, 0, 1)), 0).astype(np.float32) / 255.0
                yield src
            c += 1
            if c >= 10:
                break
        else:
            break


@jit
def compute_border(fgr, pha, border):
    output_img_border = fgr * pha + (1 - pha) * border
    output_img_border = np.clip(output_img_border, 0.0, 1.0)
    output_img_border = (output_img_border * 255.0).astype('uint8')
    ipha = ((pha > 0.2) * 255).astype('uint8')
    return output_img_border, ipha

@jit
def compute_border_2(output_img_border, img_dilation_filter, green):
    return (output_img_border * img_dilation_filter) + (1 - img_dilation_filter) * green * 255


@jit
def compute_without_border(fgr, pha, green):
    output_img = fgr * pha + (1 - pha) * green
    output_img = np.clip(output_img, 0.0, 1.0)
    output_img = (output_img * 255.0)
    return output_img

def write_frame(fgr, pha, border, green, use_border):
    if use_border:
        output_img_border, ipha = compute_border(fgr, pha, border)
        output_img_border = np.array(output_img_border)
        ipha = np.array(ipha)

        dilation = cv2.dilate(np.array(ipha[0]), np.ones((5, 5)), iterations=1)
        img_dilation = np.expand_dims(np.expand_dims(dilation, 0), -1)
        img_dilation_filter = img_dilation.astype('float32') / 255.0

        output_img_border = np.array(
            compute_border_2(
                output_img_border, img_dilation_filter, green
            )
        ).astype('uint8')

        output_img = output_img_border
    else:
        output_img = np.array(compute_without_border(fgr, pha, green)).astype('uint8')
    oi = cv2.cvtColor(output_img[0], cv2.COLOR_RGB2BGR)
    return oi


def generate_result(cap, all_frames, sess, model_path, downsample):
    pbar = tqdm(
        get_video(cap=cap),
        total=math.ceil(all_frames)
    )
    rec = [ np.zeros([1, 1, 1, 1], dtype=np.float32) ] * 4  # Must match dtype of the model.
    downsample_ratio = np.array([downsample], dtype=np.float32)  # dtype always FP32
    for src in pbar:
        batch_inputs = src
        if 'fp16' in model_path:
            batch_inputs = batch_inputs.astype('float16')
            rec = [x.astype('float16') for x in rec]
        elif 'fp32' in model_path:
            batch_inputs = batch_inputs.astype('float32')
            rec = [x.astype('float32') for x in rec]
        fgr, pha, *rec = sess.run([], {
            'src': batch_inputs,
            'r1i': rec[0][-1:], 'r2i': rec[1][-1:], 'r3i': rec[2][-1:], 'r4i': rec[3][-1:],
            'downsample_ratio': downsample_ratio
        })
        fgr = np.transpose(fgr, [0, 2, 3, 1])
        pha = np.transpose(pha, [0, 2, 3, 1])
        yield fgr, pha


def convert(
    input_file,
    output_file,
    model_path=os.path.join(CURRENT_DIR, 'rvm_mobilenetv3_int8.onnx'),
    downsample=0.5,
    green_color=[0, 255, 0],
    use_border=False,
    border_color=[255, 255, 255],
    num_threads=None
):
    assert os.path.exists(input_file), 'Input file not found'
    assert os.path.exists(model_path), 'Model not found'
    ss = time.time()

    sess = create_model_for_provider(model_path, num_threads=num_threads)

    green = np.array(green_color).reshape([1, 1, 3]) / 255.
    border = np.array(border_color).reshape([1, 1, 3]) / 255.

    cap = cv2.VideoCapture(input_file)
    all_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    fps = float(cap.get(cv2.CAP_PROP_FPS))
    width  = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

    fourcc = cv2.VideoWriter_fourcc(*'mp4v')

    out = cv2.VideoWriter('tmp.mp4', fourcc, fps, (width, height))

    for fgr, pha in generate_result(cap, all_frames, sess, model_path, downsample):
        o = write_frame(fgr, pha, border, green, use_border)
        # out.write(o)

    out.release()
    cap.release()
    # combine_audio(f.name, input_file, output_file)
    print(time.time() - ss)

In [36]:
for t in (1, 2, 4, 6, 8, 10, 12):
    print(t)
    convert('../IMG_5022.MOV', './out.mp4', '../rvm/rvm_mobilenetv3_int8.onnx', num_threads=t)
    convert('../IMG_5022.MOV', './out.mp4', '../rvm/rvm_mobilenetv3_int8.onnx', num_threads=t)

  0%|          | 0/365 [00:00<?, ?it/s]

1
model provider CPUExecutionProvider


  3%|▎         | 10/365 [00:21<12:28,  2.11s/it]


21.515475749969482
model provider CPUExecutionProvider


  3%|▎         | 10/365 [00:21<12:28,  2.11s/it]


21.49342131614685
2
model provider CPUExecutionProvider


  3%|▎         | 10/365 [00:13<08:04,  1.36s/it]


14.09043288230896
model provider CPUExecutionProvider


  3%|▎         | 10/365 [00:13<07:54,  1.34s/it]


13.81650686264038
4
model provider CPUExecutionProvider


  3%|▎         | 10/365 [00:09<05:53,  1.00it/s]


10.384968042373657
model provider CPUExecutionProvider


  3%|▎         | 10/365 [00:10<05:55,  1.00s/it]


10.437133073806763
6
model provider CPUExecutionProvider


  3%|▎         | 10/365 [00:16<09:46,  1.65s/it]


16.953156232833862
model provider CPUExecutionProvider


  3%|▎         | 10/365 [00:16<09:40,  1.64s/it]


16.772194147109985
8
model provider CPUExecutionProvider


  3%|▎         | 10/365 [00:22<13:13,  2.24s/it]


22.76658844947815
model provider CPUExecutionProvider


  3%|▎         | 10/365 [00:22<13:34,  2.30s/it]


23.38491153717041
10
model provider CPUExecutionProvider


  3%|▎         | 10/365 [00:29<17:26,  2.95s/it]


29.896848917007446
model provider CPUExecutionProvider


  3%|▎         | 10/365 [00:32<19:01,  3.22s/it]


32.61180758476257
12
model provider CPUExecutionProvider


  3%|▎         | 10/365 [00:34<20:37,  3.49s/it]


35.331199407577515
model provider CPUExecutionProvider


  3%|▎         | 10/365 [00:35<20:58,  3.54s/it]

35.890734910964966





In [37]:
for t in (1, 2, 4, 6, 8, 10, 12):
    print(t)
    convert('../IMG_5022.MOV', './out.mp4', '../rvm_mobilenetv3_fp32.onnx', num_threads=t)
    convert('../IMG_5022.MOV', './out.mp4', '../rvm_mobilenetv3_fp32.onnx', num_threads=t)

  0%|          | 0/365 [00:00<?, ?it/s]

1
model provider CPUExecutionProvider


  3%|▎         | 10/365 [00:15<09:07,  1.54s/it]
  0%|          | 0/365 [00:00<?, ?it/s]

15.5608389377594
model provider CPUExecutionProvider


  3%|▎         | 10/365 [00:15<09:22,  1.58s/it]
  0%|          | 0/365 [00:00<?, ?it/s]

15.982490062713623
2
model provider CPUExecutionProvider


  3%|▎         | 10/365 [00:09<05:32,  1.07it/s]
  0%|          | 0/365 [00:00<?, ?it/s]

9.50753927230835
model provider CPUExecutionProvider


  3%|▎         | 10/365 [00:09<05:32,  1.07it/s]
  0%|          | 0/365 [00:00<?, ?it/s]

9.486950159072876
4
model provider CPUExecutionProvider


  3%|▎         | 10/365 [00:06<03:52,  1.52it/s]
  0%|          | 0/365 [00:00<?, ?it/s]

6.698566913604736
model provider CPUExecutionProvider


  3%|▎         | 10/365 [00:06<03:47,  1.56it/s]
  0%|          | 0/365 [00:00<?, ?it/s]

6.531675577163696
6
model provider CPUExecutionProvider


  3%|▎         | 10/365 [00:10<06:05,  1.03s/it]
  0%|          | 0/365 [00:00<?, ?it/s]

10.430461645126343
model provider CPUExecutionProvider


  3%|▎         | 10/365 [00:11<06:33,  1.11s/it]
  0%|          | 0/365 [00:00<?, ?it/s]

11.219563961029053
8
model provider CPUExecutionProvider


  3%|▎         | 10/365 [00:14<08:34,  1.45s/it]
  0%|          | 0/365 [00:00<?, ?it/s]

14.65671420097351
model provider CPUExecutionProvider


  3%|▎         | 10/365 [00:14<08:19,  1.41s/it]
  0%|          | 0/365 [00:00<?, ?it/s]

14.221046924591064
10
model provider CPUExecutionProvider


  3%|▎         | 10/365 [00:18<11:10,  1.89s/it]
  0%|          | 0/365 [00:00<?, ?it/s]

19.016494750976562
model provider CPUExecutionProvider


  3%|▎         | 10/365 [00:19<11:23,  1.93s/it]
  0%|          | 0/365 [00:00<?, ?it/s]

19.396165370941162
12
model provider CPUExecutionProvider


  3%|▎         | 10/365 [00:22<13:06,  2.22s/it]
  0%|          | 0/365 [00:00<?, ?it/s]

22.29765295982361
model provider CPUExecutionProvider


  3%|▎         | 10/365 [00:21<12:42,  2.15s/it]

21.615721940994263



