Skip to content

Commit

Permalink
Add render module
Browse files Browse the repository at this point in the history
  • Loading branch information
Ichunjo committed Nov 27, 2021
1 parent bc57e98 commit a334b2c
Showing 1 changed file with 370 additions and 0 deletions.
370 changes: 370 additions & 0 deletions vardautomation/render.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,370 @@
"""Node rendering helpers"""

__all__ = [
'clip_async_render',
'WaveHeader', 'audio_async_render'
]

import struct
from enum import Enum, IntEnum
from typing import (BinaryIO, Callable, Dict, List, Optional, TextIO, Tuple,
Union, overload)

import numpy as np
import vapoursynth as vs
from rich.progress import (BarColumn, Progress, ProgressColumn, Task,
TextColumn, TimeRemainingColumn)
from rich.text import Text

from .utils import Properties


class FPSColumn(ProgressColumn):
def render(self, task: Task) -> Text:
return Text(f"{task.speed or 0:.02f} fps")


def get_render_progress() -> Progress:
return Progress(
TextColumn("{task.description}"),
BarColumn(),
TextColumn("{task.completed}/{task.total}"),
TextColumn("{task.percentage:>3.02f}%"),
FPSColumn(),
TimeRemainingColumn(),
)


RenderCallback = Callable[[int, vs.VideoFrame], None]


@overload
def clip_async_render(clip: vs.VideoNode, # type: ignore [misc]
outfile: Optional[BinaryIO] = None,
timecodes: None = ...,
progress: Optional[str] = "Rendering clip...",
callback: Union[RenderCallback, List[RenderCallback], None] = None) -> None:
...


@overload
def clip_async_render(clip: vs.VideoNode,
outfile: Optional[BinaryIO] = None,
timecodes: TextIO = ...,
progress: Optional[str] = "Rendering clip...",
callback: Union[RenderCallback, List[RenderCallback], None] = None) -> List[float]:
...


def clip_async_render(clip: vs.VideoNode, # noqa: C901
outfile: Optional[BinaryIO] = None,
timecodes: Optional[TextIO] = None,
progress: Optional[str] = "Rendering clip...",
callback: Union[RenderCallback, List[RenderCallback], None] = None) -> Union[None, List[float]]:
"""
Render a clip by requesting frames asynchronously using clip.frames,
providing for callback with frame number and frame object.
This is mostly a re-implementation of VideoNode.output, but a little bit slower since it's pure python.
You only really need this when you want to render a clip while operating on each frame in order
or you want timecodes without using vspipe.
Original function borrowed from lvsfunc.render.clip_async_render.
:param clip: Clip to render.
:param outfile: Y4MPEG render output BinaryIO handle. If None, no Y4M output is performed.
Use ``sys.stdout.buffer`` for stdout. (Default: None)
:param timecodes: Timecode v2 file TextIO handle. If None, timecodes will not be written.
:param progress: String to use for render progress display.
If empty or ``None``, no progress display.
:param callback: Single or list of callbacks to be preformed. The callbacks are called
when each sequential frame is output, not when each frame is done.
:return: List of timecodes from rendered clip.
"""
cbl = [] if callback is None else callback if isinstance(callback, list) else [callback]

if progress:
p = get_render_progress()
task = p.add_task(progress, total=clip.num_frames)
p.start()

def _progress_cb(n: int, f: vs.VideoFrame) -> None:
p.update(task, advance=1)

cbl.append(_progress_cb)

if outfile:
if clip.format is None:
raise ValueError("clip_async_render: 'Cannot render a variable format clip to y4m!'")
if clip.format.color_family not in (vs.YUV, vs.GRAY):
raise ValueError("clip_async_render: 'Can only render YUV and GRAY clips to y4m!'")
if clip.format.color_family == vs.GRAY:
y4mformat = "mono"
else:
try:
formats: Dict[Tuple[int, int], str] = {
(1, 1): "420",
(1, 0): "422",
(0, 0): "444",
(2, 2): "410",
(2, 0): "411",
(0, 1): "440",
}
y4mformat = formats[(clip.format.subsampling_w, clip.format.subsampling_h)]
except KeyError as key_err:
raise ValueError("clip_async_render: 'What have you done'") from key_err

y4mformat = f"{y4mformat}p{clip.format.bits_per_sample}" if clip.format.bits_per_sample > 8 else y4mformat
header = f"YUV4MPEG2 C{y4mformat} W{clip.width} H{clip.height} F{clip.fps.numerator}:{clip.fps.denominator} Ip A0:0\n"
outfile.write(header.encode("utf-8"))

if timecodes:
timecodes.write("# timestamp format v2\n")

tc_list = [0.0]

for n, f in enumerate(clip.frames(close=True)):
for cb in cbl:
cb(n, f)
if timecodes:
_write_timecodes(f, timecodes, tc_list)
if outfile:
_finish_frame_video(f, outfile)
if progress:
p.stop() # type: ignore

return tc_list if timecodes else None


def _finish_frame_video(frame: vs.VideoFrame, outfile: BinaryIO) -> None:
outfile.write("FRAME\n".encode("utf-8"))
for plane in frame: # type: ignore [attr-defined]
outfile.write(plane)


def _write_timecodes(frame: vs.VideoFrame, timecodes: TextIO, tc_list: List[float]) -> None:
tc = tc_list[-1] + Properties.get_prop(frame, '_DurationNum', int) / Properties.get_prop(frame, '_DurationDen', int)
tc_list.append(tc)
timecodes.write(f"{round(tc * 1000):d}\n")


class WaveFormat(IntEnum):
"""
WAVE form wFormatTag IDs
Complete list is in mmreg.h in Windows 10 SDK.
"""
PCM = 0x0001
IEEE_FLOAT = 0x0003
EXTENSIBLE = 0xFFFE


class WaveHeader(IntEnum):
"""
Wave headers
"""
WAVE = 0
WAVE64 = 1
AUTO = 2


WAVE_RIFF_TAG = b'RIFF'
WAVE_WAVE_TAG = b'WAVE'
WAVE_FMT_TAG = b'fmt '
WAVE_DATA_TAG = b'data'

WAVE64_RIFF_UUID = (0x72, 0x69, 0x66, 0x66, 0x2E, 0x91, 0xCF, 0x11, 0xA5, 0xD6, 0x28, 0xDB, 0x04, 0xC1, 0x00, 0x00)
WAVE64_WAVE_UUID = (0x77, 0x61, 0x76, 0x65, 0xF3, 0xAC, 0xD3, 0x11, 0x8C, 0xD1, 0x00, 0xC0, 0x4F, 0x8E, 0xDB, 0x8A)
WAVE64_FMT_UUID = (0x66, 0x6D, 0x74, 0x20, 0xF3, 0xAC, 0xD3, 0x11, 0x8C, 0xD1, 0x00, 0xC0, 0x4F, 0x8E, 0xDB, 0x8A)
WAVE64_DATA_UUID = (0x64, 0x61, 0x74, 0x61, 0xF3, 0xAC, 0xD3, 0x11, 0x8C, 0xD1, 0x00, 0xC0, 0x4F, 0x8E, 0xDB, 0x8A)
WAVE_FMT_EXTENSIBLE_SUBFORMAT = (
(WaveFormat.PCM, 0x00, 0x00, 0x00, 0x00, 0x00, 0x10, 0x00, 0x80, 0x00, 0x00, 0xAA, 0x00, 0x38, 0x9B, 0x71),
(WaveFormat.IEEE_FLOAT, 0x00, 0x00, 0x00, 0x00, 0x00, 0x10, 0x00, 0x80, 0x00, 0x00, 0xAA, 0x00, 0x38, 0x9B, 0x71)
)


def audio_async_render(audio: vs.AudioNode,
outfile: BinaryIO,
header: WaveHeader = WaveHeader.AUTO,
progress: Optional[str] = "Rendering audio...") -> None:
"""
Render an audio by requesting frames asynchronously using audio.frames.
Implementation-like of VideoNode.output for an AudioNode that isn't in the Cython side yet.
:param audio: Audio to render.
:param outfile: Render output BinaryIO handle.
:param header: Kind of Wave header.
WaveHeader.AUTO adds a Wave64 header if the audio
* Has more than 2 channels
* Has a bitdepth > 16
* Has more than 44100 samples
:param progress: String to use for render progress display.
If empty or ``None``, no progress display.
"""
if progress:
p = get_render_progress()
task = p.add_task(progress, total=audio.num_frames)
p.start()

bytes_per_output_sample = (audio.bits_per_sample + 7) // 8
block_align = audio.num_channels * bytes_per_output_sample
bytes_per_second = audio.sample_rate * block_align
data_size = audio.num_samples * block_align

if header == WaveHeader.AUTO:
conditions = (audio.num_channels > 2, audio.bits_per_sample > 16, audio.num_samples > 44100)
header_func, use_w64 = (_w64_header, WaveHeader.WAVE64) if any(conditions) else (_wav_header, WaveHeader.WAVE)
else:
use_w64 = header
header_func = (_wav_header, _w64_header)[header]

outfile.write(header_func(audio, bytes_per_second, block_align, data_size))

for f in audio.frames(close=True):
if progress:
p.update(task, advance=1) # type: ignore
_finish_frame_audio(f, outfile, audio.bits_per_sample == 24)
# Determine file size and place the value at the correct position
# at the beginning of the file
size = outfile.tell()
if use_w64:
outfile.seek(16)
outfile.write(struct.pack('<Q', size))
else:
outfile.seek(4)
outfile.write(struct.pack('<I', size - 8))
if progress:
p.stop() # type: ignore


def _wav_header(audio: vs.AudioNode, bps: int, block_align: int, data_size: int) -> bytes:
header = WAVE_RIFF_TAG
# Add 4 bytes for the length later
header += b'\x00\x00\x00\x00'
header += WAVE_WAVE_TAG

header += WAVE_FMT_TAG
format_tag = WaveFormat.IEEE_FLOAT if audio.sample_type == vs.FLOAT else WaveFormat.PCM

fmt_chunk_data = struct.pack(
'<HHIIHH', format_tag, audio.num_channels, audio.sample_rate,
bps, block_align, audio.bits_per_sample
)
header += struct.pack('<I', len(fmt_chunk_data))
header += fmt_chunk_data

if len(header) + data_size > 0xFFFFFFFE:
raise ValueError('Data exceeds wave file size limit')

header += WAVE_DATA_TAG
header += struct.pack('<I', data_size)
return header


def _w64_header(audio: vs.AudioNode, bps: int, block_align: int, data_size: int) -> bytes:
# RIFF-GUID
header = bytes(WAVE64_RIFF_UUID)
# Add 8 bytes for the length later
header += b'\x00\x00\x00\x00\x00\x00\x00\x00'
# WAVE-GUID
header += bytes(WAVE64_WAVE_UUID)
# FMT-GUID
fmt_guid = bytes(WAVE64_FMT_UUID)
header += fmt_guid

# We only support WAVEFORMATEXTENSIBLE for WAVE64 header
format_tag = WaveFormat.EXTENSIBLE

# cb_size should be 22 for WAVEFORMATEXTENSIBLE with PCM
cb_size = 22
fmt_chunk_data = struct.pack(
'<HHIIHHHHI', format_tag, audio.num_channels, audio.sample_rate,
bps, block_align, audio.bits_per_sample, cb_size,
audio.bits_per_sample, # valid bit per sample
audio.channel_layout
)
# Add the subformat GUID, first 2 bytes have format type, 1 being PCM
fmt_chunk_data += bytes(WAVE_FMT_EXTENSIBLE_SUBFORMAT[audio.sample_type])

# Add the FMT size
# Length of the FMT-GUID + length of FMT data and 8 for the bytes themself
header += struct.pack('<Q', len(fmt_guid) + 8 + len(fmt_chunk_data))
header += fmt_chunk_data

# DATA-GUID
data_uuid = bytes(WAVE64_DATA_UUID)
header += data_uuid
header += struct.pack('<Q', data_size + len(data_uuid) + 8)
return header


def _finish_frame_audio(frame: vs.AudioFrame, outfile: BinaryIO, _24bit: bool) -> None:
# For some reason f[i] is faster than list(f) or just passing f to stack
data = np.stack([frame[i] for i in range(frame.num_channels)], axis=1) # type: ignore

if _24bit:
if data.ndim == 1:
# Convert to a 2D array with a single column
data.shape += (1, )
# Data values are stored in 32 bits so we must convert them to 24 bits
# Then by shifting first 0 bits, then 8, then 16, the resulting output is 24 bit little-endian.
data = ((data // 2 ** 8).reshape(data.shape + (1, )) >> np.array([0, 8, 16])) # type: ignore [attr-defined]
outfile.write(data.ravel().astype(np.uint8).tobytes())
else:
outfile.write(data.ravel().view(np.int8).tobytes()) # type: ignore


class SceneChangeMode(Enum):
WWXD = 0
SCXVID = 1
WWXD_SCXVID_UNION = 2
WWXD_SCXVID_INTERSECTION = 3


def find_scene_changes(clip: vs.VideoNode, mode: SceneChangeMode = SceneChangeMode.WWXD) -> List[int]: # noqa: C901
"""
Generate a list of scene changes (keyframes).
Dependencies:
* vapoursynth-wwxd
* vapoursynth-scxvid (Optional: scxvid mode)
:param clip: Clip to search for scene changes. Will be rendered in its entirety.
:param mode: Scene change detection mode:
* WWXD: Use wwxd
* SCXVID: Use scxvid
* WWXD_SCXVID_UNION: Union of wwxd and sxcvid (must be detected by at least one)
* WWXD_SCXVID_INTERSECTION: Intersection of wwxd and scxvid (must be detected by both)
:return: List of scene changes.
"""
frames: List[int] = []
clip = clip.resize.Bilinear(640, 360, format=vs.YUV420P8)

if mode in (SceneChangeMode.WWXD, SceneChangeMode.WWXD_SCXVID_UNION, SceneChangeMode.WWXD_SCXVID_INTERSECTION):
clip = clip.wwxd.WWXD()
if mode in (SceneChangeMode.SCXVID, SceneChangeMode.WWXD_SCXVID_UNION, SceneChangeMode.WWXD_SCXVID_INTERSECTION):
clip = clip.scxvid.Scxvid()

def _cb(n: int, f: vs.VideoFrame) -> None:
if mode == SceneChangeMode.WWXD:
if Properties.get_prop(f, "Scenechange", int) == 1:
frames.append(n)
elif mode == SceneChangeMode.SCXVID:
if Properties.get_prop(f, "_SceneChangePrev", int) == 1:
frames.append(n)
elif mode == SceneChangeMode.WWXD_SCXVID_UNION:
if Properties.get_prop(f, "Scenechange", int) == 1 or Properties.get_prop(f, "_SceneChangePrev", int) == 1:
frames.append(n)
elif mode == SceneChangeMode.WWXD_SCXVID_INTERSECTION:
if Properties.get_prop(f, "Scenechange", int) == 1 and Properties.get_prop(f, "_SceneChangePrev", int) == 1:
frames.append(n)

clip_async_render(clip, progress="Detecting scene changes...", callback=_cb)

return sorted(frames)

0 comments on commit a334b2c

Please sign in to comment.