<a href="https://colab.research.google.com/github/Jerry-at-GH/fansub-utils/blob/main/auto_timing.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 自动打轴

- 输入的译稿需要符合如下格式（如无需中文，`\N`后留空即可）：
    ```
    日日日日？\N中中？
    日日日\N中中中（注释以左括号开头
    （注释以左括号开头
    日日 日日日\N中中 中中中
    ```
- 运行结束后，会自动下载输出文件 `fa+vad.ass`
- 设计目标是尽量贴紧实际说话时间，未做前后留白

In [None]:
# @title  {"run":"auto","display-mode":"form"}
VIDEO_PATH = "" # @param {"type":"string","placeholder":"如果文件在 Drive 内：/content/drive/MyDrive/..."}
TRANSCRIPT_PATH = "" # @param {"type":"string","placeholder":"如果文件在 Drive 内：/content/drive/MyDrive/..."}

## 请运行下列所有代码

### Preparation

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!apt update
!apt install libc++1

In [None]:
!git clone https://github.com/NVIDIA/NeMo --depth=1
!uv pip install pysubs2 "nemo_toolkit[asr]" "audio-separator[gpu]"
!uv pip install -U --force-reinstall -v git+https://github.com/TEN-framework/ten-vad.git
!uv pip install "numpy<2.0.0" scipy --force-reinstall

In [None]:
!ffmpeg -i $VIDEO_PATH -ac 1 -ar 16000 mono_16k.wav

### VAD

In [None]:
import numpy as np
import pysubs2
from scipy.signal import resample


AUDIO = "mono_16k.wav"
VOCALS_AUDIO = "vocals_mono_16k.wav"
WORK_DIR = "vad"
FRAME_HOP_SEC = 0.01
!rm -rf $WORK_DIR
!mkdir $WORK_DIR


def zeroify_period(signal, start_sec, end_sec):
    start_idx = int(start_sec / FRAME_HOP_SEC) if start_sec >= 0 else int((len(signal) + start_sec / FRAME_HOP_SEC))
    end_idx = int(end_sec / FRAME_HOP_SEC) if end_sec >= 0 else int((len(signal) + end_sec / FRAME_HOP_SEC))
    signal[start_idx : end_idx + 1] = 0


def bias_toward_half(x, k=0.4):
    return (x**k) / (x**k + (1 - x) ** k)


def pred_to_vad(pred, onset, offset, min_on_sec=0.0, min_off_sec=0.0):
    onset, offset = onset, offset
    min_on_frames = round(min_on_sec / FRAME_HOP_SEC)
    min_off_frames = round(min_off_sec / FRAME_HOP_SEC)

    # hysteresis threshold
    vad_labels = np.zeros_like(pred, dtype=bool)
    in_speech = False
    for i, val in enumerate(pred):
        if not in_speech:
            if val >= onset:
                in_speech = True
                vad_labels[i] = True
        else:
            if val < offset:
                in_speech = False
            else:
                vad_labels[i] = True

    # utility to get contiguous (start, end) from boolean array
    def get_segments(labels):
        segs = []
        start = None
        for i, val in enumerate(labels):
            if val and start is None:
                start = i
            elif not val and start is not None:
                segs.append((start, i - 1))
                start = None
        if start is not None:
            segs.append((start, len(labels) - 1))
        return segs

    # remove short speech segments
    segments = get_segments(vad_labels)
    filtered = []
    for start, end in segments:
        if end - start >= min_on_frames:
            filtered.append((start, end))

    # merge short non-speech gaps
    merged = []
    for seg in filtered:
        if not merged:
            merged.append(seg)
        else:
            prev_start, prev_end = merged[-1]
            if seg[0] - prev_end < min_off_frames:
                merged[-1] = (prev_start, seg[1])  # merge
            else:
                merged.append(seg)

    vad = np.zeros_like(vad_labels, dtype=bool)
    for s, e in merged:
        vad[s : e + 1] = True

    return merged, vad


def to_srt(merged, srt_file):
    subs = pysubs2.SSAFile()
    for start, end in merged:
        subs.append(pysubs2.SSAEvent(start=start * 1000 * FRAME_HOP_SEC, end=end * 1000 * FRAME_HOP_SEC, text="???"))
    subs.save(srt_file)


# ten, 10ms/frame
from ten_vad import TenVad
import scipy.io.wavfile as Wavfile

sr, data = Wavfile.read(AUDIO)
hop_size = 160  # 16000Hz/(1s/10ms)
threshold = 0.5
ten_vad_instance = TenVad(hop_size, threshold)
num_frames = data.shape[0] // hop_size
ten = np.array([ten_vad_instance.process(data[i * hop_size : (i + 1) * hop_size])[0] for i in range(num_frames)])[1:]
# np.save(f"{WORK_DIR}/ten.npy", ten)


# heuristic
from audio_separator.separator import Separator
import librosa
import shutil


separator = Separator()
separator.load_model(model_filename="mel_band_roformer_vocals_fv4_gabox.ckpt")
output_files = separator.separate(globals().get('VIDEO_PATH', None), {"Vocals": "vocals", "Instrumental": "inst"})
shutil.move("vocals.wav", f"{WORK_DIR}/vocals.wav")
shutil.move("inst.wav", f"{WORK_DIR}/inst.wav")
!ffmpeg -i $WORK_DIR/vocals.wav -ac 1 -ar 16000 $WORK_DIR/vocals_mono_16k.wav

y, sr = librosa.load(f"{WORK_DIR}/{VOCALS_AUDIO}")

rms = librosa.feature.rms(y=y)[0]
rms_downsampled = np.clip(resample(rms, len(ten)), 0, rms.max())

rolloff = librosa.feature.spectral_rolloff(y=y + 0.1, sr=sr)[0]
rolloff_downsampled = np.clip(resample(rolloff, len(ten)), 0, rolloff.max())

heuristic = rms_downsampled * rolloff_downsampled
heuristic = (heuristic - heuristic.min()) / (heuristic.max() - heuristic.min())
# np.save(f"{WORK_DIR}/heuristic.npy", heuristic)


# ensemble
pred = np.mean(
    np.stack(
        [
            ten,
            bias_toward_half(np.clip((heuristic - 0.005), 0, 0.1) / 0.1, 3),
        ],
        axis=1,
    ),
    axis=1,
)
pred[pred < 0] = 0
merged, vad = pred_to_vad(pred, 0.4, 0.5, min_on_sec=0.03)
to_srt(merged, "vad.srt")

### Forced alignment

In [None]:
import json
import re
import pysubs2


AUDIO = r"mono_16k.wav"
WORK_DIR = "nfa"
!rm -rf $WORK_DIR
!mkdir $WORK_DIR


with open(TRANSCRIPT_PATH, "r", encoding="utf-8") as f:
    lines = f.readlines()
    lines = [l.strip() for l in lines]
WORD_CHARS = "A-Za-z0-9\u3040-\u309f\u30a0-\u30ff\u3400-\u4dbf\u4e00-\u9fff\uf900-\ufaff"
lines = [
    {
        "raw": l.strip(),
        "processed": re.sub(rf"[^{WORD_CHARS}]+", " ", l.split(r"\N")[0]).strip()
        if "(" not in l and "（" not in l and l.strip() not in ["OP", "ED"]
        else "",
    }
    for l in lines
    if len(l.strip()) > 0
]
lines = [{**l, "match": re.sub(rf"[^{WORD_CHARS}]+", "", l["processed"])} for l in lines]


manifest_filepath = f"{WORK_DIR}/manifest.json"
manifest_data = {
    "audio_filepath": AUDIO,
    "text": "|".join([l["processed"] for l in lines if l["processed"]]),
}
with open(manifest_filepath, "w") as f:
    line = json.dumps(manifest_data)
    f.write(line + "\n")


!python /content/NeMo/tools/nemo_forced_aligner/align.py \
  pretrained_name="nvidia/parakeet-tdt_ctc-0.6b-ja" \
  manifest_filepath=$manifest_filepath \
  output_dir=$WORK_DIR/nfa_output/ \
  additional_segment_grouping_separator="|" \
  ass_file_config.vertical_alignment="bottom" \
  ass_file_config.text_already_spoken_rgb="[66,245,212]" \
  ass_file_config.text_being_spoken_rgb="[242,222,44]" \
  ass_file_config.text_not_yet_spoken_rgb="[223,242,239]"


ctm_filepath = f"{WORK_DIR}/nfa_output/ctm/tokens/{AUDIO.split('/')[-1].split('.')[0]}.ctm"
with open(ctm_filepath, "r") as f:
    ctm_lines = f.read().split("\n")
    seg = [
        list(map(float, c.split(" ")[2:4]))
        + [c.split(" ")[4].replace("<b>", "^"), re.sub(rf"[^{WORD_CHARS}]+", "", c.split(" ")[4].replace("<b>", ""))]
        for c in ctm_lines
        if c
    ]
start_idx = 0
line_iter = iter([line for line in lines if line["match"]])
try:
    current_line = next(line_iter)
    match = current_line["match"]
    for idx, segment in enumerate(seg):
        if start_idx <= idx + 2:
            match = match.lstrip(segment[3])
            if match == "":
                start_idx = next(i for i, s in enumerate(seg) if i >= start_idx and s[3])
                start = seg[start_idx][0] * 1000
                if idx + 1 < len(seg):
                    end = (seg[idx + 1][0] + seg[idx + 1][1]) * 1000
                else:
                    end = (seg[idx][0] + seg[idx][1]) * 1000
                current_line["debug_ass"] = pysubs2.SSAEvent(
                    start=start,
                    end=end,
                    text="".join([f"{{\\kf{round(seg_item[1] * 100)}}}{seg_item[2]}" for seg_item in seg[start_idx : idx + 2]]),
                    name=current_line["name"] if "name" in current_line else "",
                )
                long_seg = [
                    [(i / (idx + 1 - start_idx)) ** 1.5, s[1] - 0.8]
                    for i, s in enumerate(seg[start_idx : idx + 2])
                    if s[1] > 0.8
                ]
                if long_seg:
                    avg_position = sum(x[0] * (x[1] ** 2) for x in long_seg) / sum(x[1] ** 2 for x in long_seg)
                    total_delta = sum(x[1] for x in long_seg) * 1000
                    start = start + (1 - avg_position) * total_delta
                    end = end - avg_position * total_delta
                current_line["res_ass"] = pysubs2.SSAEvent(
                    start=start, end=end, text=current_line["raw"], name=current_line["name"] if "name" in current_line else ""
                )
                start_idx = idx + 2
                current_line = next(line_iter)
                match = current_line["match"]
except StopIteration:
    pass
subs = pysubs2.SSAFile()
subs.styles["Default"] = pysubs2.SSAStyle(
    fontname="IPAexGothic",
    fontsize=40,
    primarycolor=pysubs2.Color(r=255, g=255, b=255, a=0),
    secondarycolor=pysubs2.Color(r=0, g=213, b=255, a=0),
    outlinecolor=pysubs2.Color(r=0, g=0, b=0, a=0),
    backcolor=pysubs2.Color(r=74, g=74, b=74, a=0),
    bold=True,
    alignment=pysubs2.Alignment.BOTTOM_CENTER,
    shadow=0,
)
subs.styles["Top"] = subs.styles["Default"].copy()
subs.styles["Top"].alignment = pysubs2.Alignment.TOP_CENTER
subs.info["PlayResX"] = "1920"
subs.info["PlayResY"] = "1080"
subs.events = [l["debug_ass"] for l in lines if l["processed"]]
subs.save(f"{WORK_DIR}/debug.ass")
subs.events = []
for idx, l in enumerate(lines):
    if l["processed"]:
        subs.events.append(l["res_ass"])
    else:
        prev_event = None
        next_event = None
        prev_event = next((lines[i]["res_ass"] for i in range(idx - 1, -1, -1) if lines[i]["processed"]), None)
        next_event = next((lines[i]["res_ass"] for i in range(idx + 1, len(lines)) if lines[i]["processed"]), None)
        if prev_event and next_event:
            x = (prev_event.end + next_event.start) / 2.0
            start = x - 500
            end = x + 500
        elif next_event:
            start = next_event.start - 1000
            end = next_event.start
        elif prev_event:
            start = prev_event.end
            end = prev_event.end + 1000
        else:
            start = 0
            end = 1000
        subs.events.append(pysubs2.SSAEvent(start=start, end=end, text=l["raw"], style="Top"))
subs.save("fa.ass")

### VAD-assisted refinement

In [None]:
import os
import pysubs2

SUBS = r"fa.ass"
VAD = r"vad.srt"

# all in milliseconds
THRESHOLD_GAP = 40  # min gap before/after a line to be considered a true gap (for start_with_gap/end_with_gap)
THRESHOLD_SAME_TIME = 40  # min difference between two timestamps to be considered different
WINDOW_START_WITH_GAP = [-1800, 1800]
WINDOW_END_WITH_GAP = [-1800, 1800]
WINDOW_ADJOINT = [-800, 800]


def IS_DIALOGUE(e):
    return e.style == "Default"


def has_overlap(start1, end1, start2, end2):
    return end1 >= start2 and end2 >= start1


subs = pysubs2.load(SUBS)
subs.sort()
subs_dialog = [e for e in subs.events if IS_DIALOGUE(e)]
subs_other = [e for e in subs.events if not IS_DIALOGUE(e)]

vad = pysubs2.load(VAD)
vad.sort()

# THRESHOLD_GAP before start does not overlap any other event
start_with_gap = [
    (i, e.start)
    for i, e in enumerate(subs_dialog)
    if not any(
        j != i and has_overlap(subs_dialog[j].start, subs_dialog[j].end, e.start - THRESHOLD_GAP, e.start)
        for j in range(len(subs_dialog))
    )
]
count = 0
for i, this_start in start_with_gap:
    candidates = [
        v.start for v in vad.events if this_start + WINDOW_START_WITH_GAP[0] < v.start < (this_start + WINDOW_START_WITH_GAP[1])
    ]
    if candidates:
        match = min(candidates, key=lambda x: abs(x - this_start))
        subs_dialog[i].start = match
        count += 1
print(f"start_with_gap: adjusted {count}/{len(start_with_gap)} boundaries")


# THRESHOLD_GAP after end does not overlap any other event
end_with_gap = [
    (i, e.end)
    for i, e in enumerate(subs_dialog)
    if not any(
        j != i and has_overlap(subs_dialog[j].start, subs_dialog[j].end, e.end, e.end + THRESHOLD_GAP)
        for j in range(len(subs_dialog))
    )
]
count = 0
for i, this_end in end_with_gap:
    candidates = [
        v.end for v in vad.events if (this_end + WINDOW_END_WITH_GAP[0]) < v.end < (this_end + WINDOW_END_WITH_GAP[1])
    ]
    if candidates:
        match = min(candidates, key=lambda x: abs(x - this_end))
        subs_dialog[i].end = match
        count += 1
print(f"end_with_gap: adjusted {count}/{len(end_with_gap)} boundaries")


# event i ends just at event j's start
adjoint = [
    (i, j, e_i.end)
    for i, e_i in enumerate(subs_dialog)
    for j, e_j in enumerate(subs_dialog)
    if i != j and abs(e_i.end - e_j.start) < THRESHOLD_SAME_TIME
]
vad_gaps = [(vad.events[i].end, vad.events[i + 1].start) for i in range(len(vad.events) - 1)]
count = 0
for i, j, this_boundary in adjoint:
    max_gap_duration = 0
    nearest_gap = None
    min_distance = float("inf")
    for gap_start, gap_end in vad_gaps:
        if has_overlap(gap_start, gap_end, this_boundary + WINDOW_ADJOINT[0], this_boundary + WINDOW_ADJOINT[1]):
            distance = abs(gap_start - this_boundary)
            if distance < min_distance:
                min_distance = distance
                nearest_gap = (gap_start, gap_end)
    if nearest_gap:
        subs_dialog[i].end = nearest_gap[0]
        subs_dialog[j].start = nearest_gap[1]
        max_gap_duration = nearest_gap[1] - nearest_gap[0]
    if max_gap_duration > 0:
        count += 1
print(f"adjoint: adjusted {count}/{len(adjoint)} boundaries")

for i, e in enumerate(subs_dialog):
    if i - 1 >= 0 and abs(e.start - subs_dialog[i - 1].start) < THRESHOLD_SAME_TIME:
        e.start = subs_dialog[i - 1].end
    if i - 1 >= 0 and abs(e.end - subs_dialog[i - 1].end) < THRESHOLD_SAME_TIME:
        subs_dialog[i - 1].end = e.start
for e in subs_dialog:
    if e.start >= e.end:
        e.effect = "!"
        e.start, e.end = e.end, e.start
for e in subs_dialog:
    if e.duration <= 400:
        e.effect = "!"
        e.start -= 200
        e.end += 200
subs.events = subs_dialog + subs_other

subs.save(os.path.join(os.path.dirname(SUBS), f"{'.'.join(os.path.basename(SUBS).split('.')[:-1])}+vad.ass"))

### Download the result

In [None]:
from google.colab import files

files.download("/content/fa+vad.ass")