Skip to content

Commit

Permalink
scripts/vsmlrt.py: added support for rife v2 implementation
Browse files Browse the repository at this point in the history
(experimental) rife v2 models can be downloaded on https://github.com/AmusementClub/vs-mlrt/releases/tag/external-models ("rife_v2_v{version}.7z"). It leverages onnx's shape tensor to reduce memory transaction from cpu to gpu by 36.4%. It also handles padding internally so explicit padding is not required.
  • Loading branch information
WolframRhodium committed Feb 27, 2023
1 parent 4d1625e commit bd0ff98
Showing 1 changed file with 20 additions and 9 deletions.
29 changes: 20 additions & 9 deletions scripts/vsmlrt.py
@@ -1,4 +1,4 @@
__version__ = "3.15.12"
__version__ = "3.15.13"

__all__ = [
"Backend", "BackendV2",
Expand Down Expand Up @@ -828,6 +828,25 @@ def RIFEMerge(
multiple = int(multiple_frac.numerator)
scale = float(Fraction(scale))

# use v2 implementation by default
network_path = os.path.join(
models_path,
"rife_v2",
f"rife_v{model // 10}.{model % 10}{'_ensemble' if ensemble else ''}.onnx"
)
if os.path.exists(network_path) and scale == 1.0:
clips = [clipa, clipb, mask]
multiple = 1 # v2 implements internal padding
else:
# v2 onnx not found, try v1
network_path = os.path.join(
models_path,
"rife",
f"rife_v{model // 10}.{model % 10}{'_ensemble' if ensemble else ''}.onnx"
)

clips = [clipa, clipb, mask, *get_rife_input(clipa)]

(tile_w, tile_h), (overlap_w, overlap_h) = calc_tilesize(
tiles=tiles, tilesize=tilesize,
width=clip.width, height=clip.height,
Expand All @@ -845,14 +864,6 @@ def RIFEMerge(
trt_opt_shapes=(tile_w, tile_h)
)

network_path = os.path.join(
models_path,
"rife",
f"rife_v{model // 10}.{model % 10}{'_ensemble' if ensemble else ''}.onnx"
)

clips = [clipa, clipb, mask, *get_rife_input(clipa)]

if scale == 1.0:
return inference_with_fallback(
clips=clips, network_path=network_path,
Expand Down

0 comments on commit bd0ff98

Please sign in to comment.