Skip to content

Commit

Permalink
Overhaul
Browse files Browse the repository at this point in the history
Fix #6
Fix #7
  • Loading branch information
HolyWu committed Jun 2, 2024
1 parent 600bbd7 commit 4a04c46
Show file tree
Hide file tree
Showing 10 changed files with 588 additions and 188 deletions.
16 changes: 8 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@ Real-World Blind Super-Resolution via Feature Matching with Implicit High-Resolu


## Dependencies
- [NumPy](https://numpy.org/install)
- [PyTorch](https://pytorch.org/get-started) 1.13
- [VapourSynth](http://www.vapoursynth.com/) R55+
- [PyTorch](https://pytorch.org/get-started/) 2.4.0.dev or later
- [VapourSynth](http://www.vapoursynth.com/) R66 or later

`trt` requires additional runtime libraries:
- [CUDA Toolkit](https://developer.nvidia.com/cuda-toolkit) 11.7
- [cuDNN](https://developer.nvidia.com/cudnn) 8.6
- [TensorRT](https://developer.nvidia.com/tensorrt) 8.5.3.1
`trt` requires additional Python packages:
- [TensorRT](https://developer.nvidia.com/tensorrt/) 10.0.1
- [Torch-TensorRT](https://pytorch.org/TensorRT/) 2.4.0.dev

For ease of installation on Windows, you can download the 7z file on [Releases](https://github.com/HolyWu/vs-femasr/releases) which contains required runtime libraries and Python wheel file. Either add the unzipped directory to your system `PATH` or copy the DLL files to a directory which is already in your system `PATH`. Finally pip install the Python wheel file.
To install TensorRT, run `pip install tensorrt==10.0.1 tensorrt-cu12_bindings==10.0.1 tensorrt-cu12_libs==10.0.1 --extra-index-url https://pypi.nvidia.com`

To install Torch-TensorRT, Windows users can pip install the whl file on [Releases](https://github.com/HolyWu/vs-femasr/releases). Linux users can run `pip install --pre torch_tensorrt --index-url https://download.pytorch.org/whl/nightly/cu124` (requires PyTorch nightly build).


## Installation
Expand Down
28 changes: 17 additions & 11 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,32 @@ version = "1.2.0"
description = "FeMaSR function for VapourSynth"
readme = "README.md"
requires-python = ">=3.10"
license = {file = "LICENSE"}
authors = [{name = "HolyWu", email = "holywu@gmail.com"}]
keywords = ["FeMaSR", "VapourSynth"]
license-files = { paths = ["LICENSE"] }
authors = [
{ name = "HolyWu", email = "holywu@gmail.com" },
]
keywords = [
"FeMaSR",
"PyTorch",
"TensorRT",
"VapourSynth",
]
classifiers = [
"Environment :: GPU :: NVIDIA CUDA",
"License :: CC0 1.0 Universal (CC0 1.0) Public Domain Dedication",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3.10",
"Topic :: Multimedia :: Video"
"Programming Language :: Python :: 3",
"Topic :: Multimedia :: Video",
]
dependencies = [
"numpy",
"requests",
"tensorrt>=8.5.3.1",
"timm",
"torch>=1.13.0",
"torch-tensorrt-fx-only>=1.3.0",
"torch>=2.4.0.dev",
"tqdm",
"VapourSynth>=55"
"VapourSynth>=66",
]

[project.urls]
"Homepage" = "https://github.com/HolyWu/vs-femasr"
"Bug Tracker" = "https://github.com/HolyWu/vs-femasr/issues"
Homepage = "https://github.com/HolyWu/vs-femasr"
Issues = "https://github.com/HolyWu/vs-femasr/issues"
215 changes: 97 additions & 118 deletions vsfemasr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,65 +2,72 @@

import math
import os
import warnings
from enum import IntEnum
from threading import Lock

import numpy as np
import tensorrt
import torch
import torch.nn.functional as F
import vapoursynth as vs
from functorch.compile import memory_efficient_fusion
from torch_tensorrt.fx import LowerSetting
from torch_tensorrt.fx.lower import Lowerer
from torch_tensorrt.fx.utils import LowerPrecision

from .femasr_arch import FeMaSRNet, VectorQuantizer
from .femasr_arch import FeMaSRNet

__version__ = "1.2.0"

os.environ["CUDA_MODULE_LOADING"] = "LAZY"

package_dir = os.path.dirname(os.path.realpath(__file__))
warnings.filterwarnings("ignore", "At pre-dispatch tracing")
warnings.filterwarnings("ignore", "Attempted to insert a get_attr Node with no underlying reference")
warnings.filterwarnings("ignore", "Node _run_on_acc_0_engine target _run_on_acc_0_engine _run_on_acc_0_engine of")
warnings.filterwarnings("ignore", "The given NumPy array is not writable")

model_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")


class FeMaSRModel(IntEnum):
FeMaSR_SRX2_model_g = 0
FeMaSR_SRX4_model_g = 1


@torch.inference_mode()
def femasr(
clip: vs.VideoNode,
device_index: int | None = None,
device_index: int = 0,
num_streams: int = 1,
nvfuser: bool = False,
cuda_graphs: bool = False,
model: FeMaSRModel = FeMaSRModel.FeMaSR_SRX2_model_g,
trt: bool = False,
trt_min_subgraph_size: int = 1,
trt_max_workspace_size: int = 1 << 25,
trt_cache_path: str = package_dir,
model: int = 0,
trt_debug: bool = False,
trt_workspace_size: int = 0,
trt_max_aux_streams: int | None = None,
trt_optimization_level: int | None = None,
trt_cache_dir: str = model_dir,
) -> vs.VideoNode:
"""Real-World Blind Super-Resolution via Feature Matching with Implicit High-Resolution Priors
:param clip: Clip to process. Only RGBH and RGBS formats are supported.
RGBH performs inference in FP16 mode while RGBS performs inference in FP32 mode.
:param device_index: Device ordinal of the GPU.
:param num_streams: Number of CUDA streams to enqueue the kernels.
:param nvfuser: Enable fusion through nvFuser. Not allowed in TensorRT. (experimental)
:param cuda_graphs: Use CUDA Graphs to remove CPU overhead associated with launching CUDA kernels
sequentially. Not allowed in TensorRT.
:param trt: Use TensorRT for high-performance inference.
:param trt_min_subgraph_size: Minimum node size in a subgraph after splitting. Subgraphs with smaller size will
fall back to CUDA. Try trt_min_subgraph_size=5 if TensorRT engine cannot be built
due to insufficient VRAM, but the performance will degrade.
:param trt_max_workspace_size: Maximum workspace size for TensorRT engine.
:param trt_cache_path: Path for TensorRT engine file. Engine will be cached when it's built for the first
time. Note each engine is created for specific settings such as model path/name,
precision, workspace etc, and specific GPUs and it's not portable.
:param model: Model to use.
0 = FeMaSR_SRX2_model_g
1 = FeMaSR_SRX4_model_g
:param trt: Use TensorRT for high-performance inference.
:param trt_debug: Print out verbose debugging information.
:param trt_workspace_size: Size constraints of workspace memory pool.
:param trt_max_aux_streams: Maximum number of auxiliary streams per inference stream that TRT is allowed to use
to run kernels in parallel if the network contains ops that can run in parallel,
with the cost of more memory usage. Set this to 0 for optimal memory usage.
(default = using heuristics)
:param trt_optimization_level: Builder optimization level. Higher level allows TensorRT to spend more building time
for more optimization options. Valid values include integers from 0 to the maximum
optimization level, which is currently 5. (default is 3)
:param trt_cache_dir: Directory for TensorRT engine file. Engine will be cached when it's built for the
first time. Note each engine is created for specific settings such as model
path/name, precision, workspace etc, and specific GPUs and it's not portable.
"""
if not isinstance(clip, vs.VideoNode):
raise vs.Error("femasr: this is not a clip")

if clip.format.id not in (vs.RGBH, vs.RGBS):
if clip.format.id not in [vs.RGBH, vs.RGBS]:
raise vs.Error("femasr: only RGBH and RGBS formats are supported")

if not torch.cuda.is_available():
Expand All @@ -69,20 +76,10 @@ def femasr(
if num_streams < 1:
raise vs.Error("femasr: num_streams must be at least 1")

if num_streams > vs.core.num_threads:
raise vs.Error("femasr: setting num_streams greater than `core.num_threads` is useless")

if trt:
if nvfuser:
raise vs.Error("femasr: nvfuser and trt are mutually exclusive")

if cuda_graphs:
raise vs.Error("femasr: cuda_graphs and trt are mutually exclusive")

if model not in range(2):
raise vs.Error("femasr: model must be 0 or 1")
if model not in FeMaSRModel:
raise vs.Error("femasr: model must be one of the members in FeMaSRModel")

if os.path.getsize(os.path.join(package_dir, "FeMaSR_SRX2_model_g.pth")) == 0:
if os.path.getsize(os.path.join(model_dir, "FeMaSR_SRX2_model_g.pth")) == 0:
raise vs.Error("femasr: model files have not been downloaded. run 'python -m vsfemasr' first")

torch.set_float32_matmul_precision("high")
Expand All @@ -96,87 +93,74 @@ def femasr(
stream_lock = [Lock() for _ in range(num_streams)]

match model:
case 0:
model_name = "FeMaSR_SRX2_model_g.pth"
case FeMaSRModel.FeMaSR_SRX2_model_g:
scale = 2
modulo = 32
case 1:
model_name = "FeMaSR_SRX4_model_g.pth"
downscale = 4
case FeMaSRModel.FeMaSR_SRX4_model_g:
scale = 4
modulo = 16

model_path = os.path.join(package_dir, model_name)

module = FeMaSRNet(codebook_params=[[32, 1024, 512]], LQ_stage=True, scale_factor=scale)
module.load_state_dict(torch.load(model_path, map_location="cpu")["params"], strict=False)
module.eval().to(device, memory_format=torch.channels_last)
downscale = 2

w = clip.width
h = clip.height
pad_w = math.ceil(w / modulo) * modulo
pad_h = math.ceil(h / modulo) * modulo
padding = (0, pad_w - w, 0, pad_h - h)

model_name = f"{FeMaSRModel(model).name}.pth"
state_dict = torch.load(os.path.join(model_dir, model_name), map_location=device, mmap=True)["params"]

module = FeMaSRNet(
input_resolution=(pad_h // downscale, pad_w // downscale),
codebook_params=[[32, 1024, 512]],
LQ_stage=True,
scale_factor=scale,
)
module.load_state_dict(state_dict, strict=False)
module.eval().to(device)
if fp16:
module.half()

pad_w = math.ceil(clip.width / modulo) * modulo
pad_h = math.ceil(clip.height / modulo) * modulo

if nvfuser:
module = memory_efficient_fusion(module)

if cuda_graphs:
graph: list[torch.cuda.CUDAGraph] = []
static_input: list[torch.Tensor] = []
static_output: list[torch.Tensor] = []

for i in range(num_streams):
static_input.append(
torch.zeros((1, 3, pad_h, pad_w), dtype=dtype, device=device).to(memory_format=torch.channels_last)
)
if trt:
import tensorrt
import torch_tensorrt

torch.cuda.synchronize(device=device)
stream[i].wait_stream(torch.cuda.current_stream(device=device))
with torch.cuda.stream(stream[i]):
module(static_input[i])
torch.cuda.current_stream(device=device).wait_stream(stream[i])
torch.cuda.synchronize(device=device)

graph.append(torch.cuda.CUDAGraph())
with torch.cuda.graph(graph[i], stream=stream[i]):
static_output.append(module(static_input[i]))
elif trt:
device_name = torch.cuda.get_device_name(device)
trt_version = tensorrt.__version__
dimensions = f"{pad_w}x{pad_h}"
precision = "fp16" if fp16 else "fp32"
trt_engine_path = os.path.join(
os.path.realpath(trt_cache_path),
os.path.realpath(trt_cache_dir),
(
f"{model_name}"
+ f"_{device_name}"
+ f"_trt-{trt_version}"
+ f"_{dimensions}"
+ f"_{precision}"
+ f"_workspace-{trt_max_workspace_size}"
+ ".pt"
+ f"_{pad_w}x{pad_h}"
+ f"_{'fp16' if fp16 else 'fp32'}"
+ f"_{torch.cuda.get_device_name(device)}"
+ f"_trt-{tensorrt.__version__}"
+ (f"_workspace-{trt_workspace_size}" if trt_workspace_size > 0 else "")
+ (f"_aux-{trt_max_aux_streams}" if trt_max_aux_streams is not None else "")
+ (f"_level-{trt_optimization_level}" if trt_optimization_level is not None else "")
+ ".ep"
),
)

if not os.path.isfile(trt_engine_path):
lower_setting = LowerSetting(
lower_precision=LowerPrecision.FP16 if fp16 else LowerPrecision.FP32,
min_acc_module_size=trt_min_subgraph_size,
leaf_module_list={VectorQuantizer},
max_workspace_size=trt_max_workspace_size,
dynamic_batch=False,
tactic_sources=1 << int(tensorrt.TacticSource.EDGE_MASK_CONVOLUTIONS)
| 1 << int(tensorrt.TacticSource.JIT_CONVOLUTIONS),
)
lowerer = Lowerer.create(lower_setting=lower_setting)
module = lowerer(
inputs = [torch.zeros((1, 3, pad_h, pad_w), dtype=dtype, device=device)]

module = torch_tensorrt.compile(
module,
[torch.zeros((1, 3, pad_h, pad_w), dtype=dtype, device=device).to(memory_format=torch.channels_last)],
ir="dynamo",
inputs=inputs,
enabled_precisions={dtype},
debug=trt_debug,
workspace_size=trt_workspace_size,
min_block_size=1,
max_aux_streams=trt_max_aux_streams,
optimization_level=trt_optimization_level,
truncate_double=True,
device=device,
)
torch.save(module, trt_engine_path)

del module
torch.cuda.empty_cache()
module = [torch.load(trt_engine_path) for _ in range(num_streams)]
torch_tensorrt.save(module, trt_engine_path, inputs=inputs)

module = [torch.export.load(trt_engine_path).module() for _ in range(num_streams)]

index = -1
index_lock = Lock()
Expand All @@ -190,22 +174,14 @@ def inference(n: int, f: list[vs.VideoFrame]) -> vs.VideoFrame:

with stream_lock[local_index], torch.cuda.stream(stream[local_index]):
img = frame_to_tensor(f[0], device)
img = F.pad(img, padding, "replicate")

h, w = img.shape[2:]
img = F.pad(img, (0, pad_w - w, 0, pad_h - h), "reflect")

if cuda_graphs:
static_input[local_index].copy_(img)
graph[local_index].replay()
output = static_output[local_index]
elif trt:
if trt:
output = module[local_index](img)
else:
output = module(img)

output = output[:, :, : h * scale, : w * scale]

return tensor_to_frame(output, f[1].copy())
return tensor_to_frame(output[:, :, : h * scale, : w * scale], f[1].copy())

new_clip = clip.std.BlankClip(width=clip.width * scale, height=clip.height * scale, keep=True)
return new_clip.std.FrameEval(
Expand All @@ -214,12 +190,15 @@ def inference(n: int, f: list[vs.VideoFrame]) -> vs.VideoFrame:


def frame_to_tensor(frame: vs.VideoFrame, device: torch.device) -> torch.Tensor:
array = np.stack([np.asarray(frame[plane]) for plane in range(frame.format.num_planes)])
return torch.from_numpy(array).unsqueeze(0).to(device, memory_format=torch.channels_last).clamp(0.0, 1.0)
return (
torch.stack([torch.from_numpy(np.asarray(frame[plane])).to(device) for plane in range(frame.format.num_planes)])
.unsqueeze(0)
.clamp(0.0, 1.0)
)


def tensor_to_frame(tensor: torch.Tensor, frame: vs.VideoFrame) -> vs.VideoFrame:
array = tensor.squeeze(0).detach().cpu().numpy()
for plane in range(frame.format.num_planes):
np.copyto(np.asarray(frame[plane]), array[plane, :, :])
np.copyto(np.asarray(frame[plane]), array[plane])
return frame
2 changes: 1 addition & 1 deletion vsfemasr/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
def download_model(url: str) -> None:
filename = url.split("/")[-1]
r = requests.get(url, stream=True)
with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), filename), "wb") as f:
with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), "models", filename), "wb") as f:
with tqdm(
unit="B",
unit_scale=True,
Expand Down
Loading

0 comments on commit 4a04c46

Please sign in to comment.