# <center>Video Dubbing Full Pipeline</center>

In [1]:
import os 

import torch 
import torchaudio
import numpy as np
from moviepy import VideoFileClip, AudioFileClip

In [2]:
def extract_audio_from_mp4(video_path: str, target_sr: int = 16000, temp_dir='./temp-audios', delete_file=True) -> tuple[np.ndarray, int]:
    video = VideoFileClip(video_path)
    audio: AudioFileClip = video.audio
    
    if not os.path.exists(temp_dir):
        os.makedirs(temp_dir)

    temp_audio_path = temp_dir + "/temp_audio.wav"
    audio.write_audiofile(temp_audio_path, codec='pcm_s16le', fps=target_sr)
    
    audio_data, sr = torchaudio.load(temp_audio_path)
    
    if sr != target_sr:
        audio_data = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sr)(audio_data)
    
    if audio_data.shape[0] > 1:
        audio_data = audio_data.mean(dim=0)

    if delete_file:
        os.remove(temp_audio_path)
    
    return audio_data.numpy(), sr

In [3]:
video_path = "/home/maksim/Repos/video-dubbing/test-videos/videoplayback.mp4"

audio, sr = extract_audio_from_mp4(video_path)

MoviePy - Writing audio in ./temp-audios/temp_audio.wav


                                                                      

MoviePy - Done.


In [4]:
from dataclasses import dataclass, field
import numpy as np


@dataclass
class DubbingSegment:
    start: int = field(repr=True) # Начало сегмента (индекс в исходном аудио)
    end: int = field(repr=True) # Конец сегмента
    audio: np.ndarray = field(repr=False) # Аудио сегмент

    transcription: str = field(repr=True, default=None)
    translation: str = field(repr=True, default=None)

    tts_wav: np.ndarray = field(repr=False, default=None)


@dataclass
class ProcessingContext:
    original_audio: np.ndarray = field(repr=False)
    sample_rate: int = field(repr=True)
    temp_dir: str = field(repr=True)

    segments: list[DubbingSegment] = field(default_factory=list, repr=False)

    speech_audio: np.ndarray = field(repr=False, default=None)

    timestamps_mapping: dict[tuple[int, int], DubbingSegment] = field(repr=False, default=None) # соответствие временных отрезков между original_audio и speech_audio

In [5]:
from abc import ABC, abstractmethod


class BaseProcessor(ABC):
    @abstractmethod
    def __call__(self, context: ProcessingContext) -> ProcessingContext:
        pass


class VADProcessor(BaseProcessor):
    pass


class ASRProcessor(BaseProcessor):
    pass


class MTProcessor(BaseProcessor):
    pass


class TTSProcessor(BaseProcessor):
    pass

In [6]:
import torch


class SileroVADProcessor(VADProcessor):
    def __init__(self, model_path: str = "", threshold: float = 0.5,  
                  min_silence_duration_ms=1000, 
                  min_speech_duration_ms=1000):
        """
        Для локальной загрузки модели, нужно сначала её скачать: git clone <silerovad repo>
        А затем передать в качестве параметра model_path путь до корня склонированного репозитория.
        """
        if model_path:
            self.silerovad, utils = torch.hub.load(repo_or_dir=model_path,
                              model='silero_vad',
                              force_reload=True,
                              source='local')
        else:
            self.silerovad, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',
                                model='silero_vad',
                                force_reload=True, 
                                source='github') 

        (self.get_speech_timestamps, _, _, _, _) = utils

        self.threshold = threshold
        self.min_silence_duration_ms = min_silence_duration_ms
        self.min_speech_duration_ms = min_speech_duration_ms


    def __call__(self, context: ProcessingContext) -> ProcessingContext:
        speech_audio = []
        new2old = {}
    
        speech_timestamps = self.get_speech_timestamps(context.original_audio, 
                                                  self.silerovad, 
                                                  threshold=self.threshold, 
                                                  sampling_rate=context.sample_rate,
                                                  min_silence_duration_ms=self.min_silence_duration_ms, 
                                                  min_speech_duration_ms=self.min_speech_duration_ms)
    
        for ts in speech_timestamps:
            start = ts['start']
            end = ts['end']
            context.segments.append(DubbingSegment(start=start, 
                                                 end=end, 
                                                 audio=context.original_audio[start:end]))


        start_idx = 0

        for segment in context.segments:
            new2old[(start_idx, segment.end - segment.start + start_idx)] = segment

            start_idx = segment.end - segment.start + start_idx + 1

            speech_audio.extend(segment.audio.tolist())

        context.speech_audio = np.array(speech_audio)
        context.timestamps_mapping = new2old
        
        return context 

In [8]:
context = ProcessingContext(original_audio=audio, sample_rate=sr, temp_dir="./temp-dir")

In [9]:
vad_path = "/home/maksim/Models/SileroVAD/snakers4-silero-vad"

vad_pipe = SileroVADProcessor(vad_path)

context = vad_pipe(context=context)

In [10]:
from faster_whisper import WhisperModel, BatchedInferencePipeline

class FasterWhisperProcessor(ASRProcessor):
    sr = 16_000

    def __init__(self, model_size_or_path: str = "tiny.en", device: str = "cpu", compute_type: str = "int8", batch_size=8):
        self.model = model=WhisperModel(model_size_or_path=model_size_or_path,
                                                                  device=device, 
                                                                  compute_type=compute_type)
        self.batch_size = batch_size
        self.last_segment: DubbingSegment | None = None


    def _put_word_in_segment(self, word: str, start: int, end: int, context: ProcessingContext):
        word_processed: bool = False

        for ts_interval in context.timestamps_mapping.keys():
            if start >= ts_interval[0] and end <= ts_interval[1]:
                segment = context.timestamps_mapping[ts_interval]
                if segment.transcription:
                    segment.transcription += word
                else:
                    segment.transcription = word

                word_processed = True
                self.last_segment = segment

                break 

        if not(word_processed):
            self.last_segment.transcription += word
        
        return context

    
    def __call__(self, context: ProcessingContext) -> ProcessingContext:
        speech_audio = context.speech_audio

        if context.sample_rate != self.sr:
            speech_audio = torchaudio.transforms.Resample(orig_freq=context.sample_rate,
                                                           new_freq=self.sr)(torch.tensor(speech_audio)).numpy()

        whisper_segments, _ = self.model.transcribe(context.speech_audio, word_timestamps=True)
        
        for segment in whisper_segments:
            for word in segment.words:
                context = self._put_word_in_segment(word=word.word,
                                           start=int(word.start * self.sr), 
                                           end=int(word.end * self.sr), 
                                           context=context)

        return context

In [11]:
whisper_path = "/home/maksim/Models/FasterWhisper/tiny-en"

asr_pipe = FasterWhisperProcessor(whisper_path)

context = asr_pipe(context=context)

In [12]:
context.segments[0].transcription

" What you're doing right now at this very moment is killing you."

In [13]:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline

class HelsinkiEnRuProcessor(MTProcessor):
    def __init__(self, model_path: str | None = None, device: str = 'cpu'):
        model = None
        tokenizer = None 
        if model_path:
            model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
            tokenizer = AutoTokenizer.from_pretrained(model_path)
        
        else:
            model_hf_name = "Helsinki-NLP/opus-mt-en-ru"
            model = AutoModelForSeq2SeqLM.from_pretrained(model_hf_name)
            tokenizer = AutoTokenizer.from_pretrained(model_hf_name)
        
        self.pipe = pipeline(
            task="translation", 
            model=model, 
            tokenizer=tokenizer,
            device=device)


    def _process_sample(self, text_en: str) -> str:
        return self.pipe(text_en)[0]['translation_text']


    def __call__(self, context: ProcessingContext) -> ProcessingContext:
        for segment in context.segments:
            segment.translation = self._process_sample(segment.transcription)
        
        return context

In [14]:
mt_path = "/home/maksim/Models/OpusEnRu"

mt_pipe = HelsinkiEnRuProcessor(model_path=mt_path)

context = mt_pipe(context=context)

Device set to use cpu


In [15]:
print(context.segments[0].translation)

То, что ты сейчас делаешь, убивает тебя.


# TTS

In [16]:
import soundfile as sf

In [17]:
from TTS.api import TTS

class XTTSProcessor(TTSProcessor):
    output_sample_rate = 24_000

    def __init__(self, target_spk: str | None = None, model_path: str | None = None, device: str = 'cpu', temp_dir="./temp-dir/tts/"):
        self.target_spk = target_spk

        if model_path:
            self.model = TTS(model_path=model_path, config_path=f'{model_path}/config.json').to(device)
        else:
            self.model = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(device)

        self.temp_dir = temp_dir


    def _process_sample(self, text_ru: str, speaker_wav: str) -> np.ndarray:
        return self.model.tts(text=text_ru, speaker_wav=speaker_wav, language='ru')


    def __call__(self, context: ProcessingContext) -> ProcessingContext:
        os.makedirs(self.temp_dir, exist_ok=True)

        for i, segment in enumerate(context.segments):
            if self.target_spk is None:
                audio_path = self.temp_dir + f"{i}.wav"

                sf.write(audio_path, segment.audio, 16_000)

                segment.tts_wav = self._process_sample(segment.translation, audio_path)

                os.remove(audio_path)
            else:
                segment.tts_wav = self._process_sample(segment.translation, self.target_spk)

        return context

In [20]:
class SileroTTSProcessor(TTSProcessor):
    output_sample_rate = 48_000

    def __init__(self, model_path: str | None = None, device: str = 'cpu', speaker: str = "xenia"):
        if model_path:
            self.silero_tts, _ = torch.hub.load(repo_or_dir=model_path,
                                     model='silero_tts',
                                     language='ru',
                                     speaker='v4_ru',
                                     source='local')

        else:
            self.silero_tts, _ = torch.hub.load(repo_or_dir='snakers4/silero-models',
                                     model='silero_tts',
                                     language='ru',
                                     speaker='v4_ru',
                                     source='github')
        
        self.silero_tts.to(device)

        self.speaker = speaker


    def _process_sample(self, text_ru: str, speaker: str) -> np.ndarray:
        return self.silero_tts.apply_tts(text=text_ru,
                        speaker=speaker,
                        sample_rate=self.output_sample_rate).numpy()


    def __call__(self, context: ProcessingContext) -> ProcessingContext:
        for segment in context.segments:
            segment.tts_wav = self._process_sample(segment.translation, self.speaker)
        return context

In [21]:
silero_tts_path = "/home/maksim/Models/SileroModels"

tts_pipe = SileroTTSProcessor(silero_tts_path)

context = tts_pipe(context)

In [32]:
import ffmpeg


def merge_segments_with_alignment(context: ProcessingContext, tts_sr: int) -> np.ndarray:
    orig_audio_len = len(context.original_audio)
    output_audio = np.array([0.0] * orig_audio_len)

    for segment in context.segments:
        segment_len = segment.end - segment.start
        tts_wav = torchaudio.transforms.Resample(orig_freq=tts_sr, new_freq=context.sample_rate)(torch.tensor(segment.tts_wav)).numpy()

        if len(tts_wav) < segment_len:
            output_audio[segment.start:segment.start+len(tts_wav)] = tts_wav
        else:
            output_audio[segment.start:segment.end] = tts_wav[:segment_len]

    return output_audio


def merge_audio_video(audio: np.ndarray, sr: int, video_path: str, output_path: str):
    audio_path = "./temp-audios" + "/output.wav"

    sf.write(audio_path, audio, 16_000)

    video = ffmpeg.input(video_path).video
    audio = ffmpeg.input(audio_path).audio

    ffmpeg.output(audio, video, output_path, vcodec="copy", acodec="aac").run()

    os.remove(audio_path)

In [33]:
outp_audio = merge_segments_with_alignment(context, tts_pipe.output_sample_rate)

merge_audio_video(outp_audio, 16_000, video_path, "./outp-4.mp4")

ffmpeg version 6.1.1-3ubuntu5 Copyright (c) 2000-2023 the FFmpeg developers
  built with gcc 13 (Ubuntu 13.2.0-23ubuntu3)
  configuration: --prefix=/usr --extra-version=3ubuntu5 --toolchain=hardened --libdir=/usr/lib/x86_64-linux-gnu --incdir=/usr/include/x86_64-linux-gnu --arch=amd64 --enable-gpl --disable-stripping --disable-omx --enable-gnutls --enable-libaom --enable-libass --enable-libbs2b --enable-libcaca --enable-libcdio --enable-libcodec2 --enable-libdav1d --enable-libflite --enable-libfontconfig --enable-libfreetype --enable-libfribidi --enable-libglslang --enable-libgme --enable-libgsm --enable-libharfbuzz --enable-libmp3lame --enable-libmysofa --enable-libopenjpeg --enable-libopenmpt --enable-libopus --enable-librubberband --enable-libshine --enable-libsnappy --enable-libsoxr --enable-libspeex --enable-libtheora --enable-libtwolame --enable-libvidstab --enable-libvorbis --enable-libvpx --enable-libwebp --enable-libx265 --enable-libxml2 --enable-libxvid --enable-libzimg --ena

# Объединение в единый пайплайн

In [34]:
@dataclass
class ProcessorConfig:
    stage: str = field(repr=True)
    model: type = field(repr=True)
    params: dict = field(repr=True)


cpu_config = {
    "pipeline": [
        ProcessorConfig(
            stage="vad",
            model=SileroVADProcessor,
            params={
                "model_path": "/home/maksim/Models/SileroVAD/snakers4-silero-vad", 
                "threshold": 0.5,  
                "min_silence_duration_ms": 1000, 
                "min_speech_duration_ms": 1000
            }
        ),
        ProcessorConfig(
            stage="asr",
            model=FasterWhisperProcessor,
            params={
                "model_size_or_path": "/home/maksim/Models/FasterWhisper/tiny-en", 
                "device": "cpu", 
                "compute_type": "int8"
            }
        ), 
        ProcessorConfig(
            stage="mt",
            model=HelsinkiEnRuProcessor,
            params={
                "model_path": "/home/maksim/Models/OpusEnRu", 
                "device": 'cpu'
            }
        ),  
        ProcessorConfig(
            stage="tts",
            model=SileroTTSProcessor,
            params={
                "model_path": "/home/maksim/Models/SileroModels", 
                "device": 'cpu', 
                "speaker": "xenia"
            }
        )
    ],
    "temp-dir": "./video-dubbing-temp-dir"}

In [38]:
import ffmpeg


class VideoDubber:
    def __init__(self, config: dict):
        self.processors = []

        for processor in config["pipeline"]:
            self.processors.append(processor.model(**processor.params))
        
        self.temp_dir = config["temp-dir"]
    

    def _extract_audio_from_mp4(self, 
                                video_path: str, 
                                target_sr: int = 16000) -> tuple[np.ndarray, int]:
        
        video = VideoFileClip(video_path)
        audio: AudioFileClip = video.audio

        if not os.path.exists(self.temp_dir):
            os.makedirs(self.temp_dir)

        temp_audio_path = self.temp_dir + "original_audio.wav"
        audio.write_audiofile(temp_audio_path, codec='pcm_s16le', fps=target_sr)
    
        audio_data, sr = torchaudio.load(temp_audio_path)
    
        if sr != target_sr:
            audio_data = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sr)(audio_data)
    
        if audio_data.shape[0] > 1:
            audio_data = audio_data.mean(dim=0)

        os.remove(temp_audio_path)
    
        return audio_data.numpy(), sr


    def _merge_audio_video(self, audio: np.ndarray, sr: int, video_path: str, output_path: str):
        audio_path = self.temp_dir + "/output.wav"

        sf.write(audio_path, audio, 16_000)

        video = ffmpeg.input(video_path).video
        audio = ffmpeg.input(audio_path).audio

        ffmpeg.output(audio, video, output_path, vcodec="copy", acodec="aac").run()

        os.remove(audio_path)
    

    def _merge_segments_with_alignment(self, context: ProcessingContext) -> np.ndarray:
        orig_audio_len = len(context.original_audio)
        output_audio = np.array([0.0] * orig_audio_len)
    
        tts_sr = self.processors[-1].output_sample_rate

        for segment in context.segments:
            segment_len = segment.end - segment.start
            tts_wav = torchaudio.transforms.Resample(orig_freq=tts_sr, new_freq=context.sample_rate)(torch.tensor(segment.tts_wav)).numpy()

            if len(tts_wav) < segment_len:
                output_audio[segment.start:segment.start+len(tts_wav)] = tts_wav
            else:
                output_audio[segment.start:segment.end] = tts_wav[:segment_len]

        return output_audio


    def __call__(self, input_video_path: str, output_video_path: str):
        os.makedirs(self.temp_dir, exist_ok=True)

        audio, sr = self._extract_audio_from_mp4(input_video_path, target_sr=16_000)

        context = ProcessingContext(original_audio=audio, sample_rate=sr, temp_dir=self.temp_dir)

        for processor in self.processors:
            context = processor(context)
        
        output_audio = self._merge_segments_with_alignment(context=context)

        self._merge_audio_video(output_audio, sr, input_video_path, output_video_path)
        
        os.rmdir(self.temp_dir)

In [39]:
dubber = VideoDubber(config=cpu_config)

Device set to use cpu


In [40]:
dubber(video_path, "./outp.mp4")

MoviePy - Writing audio in ./video-dubbing-temp-diroriginal_audio.wav


                                                                      

MoviePy - Done.


ffmpeg version 6.1.1-3ubuntu5 Copyright (c) 2000-2023 the FFmpeg developers
  built with gcc 13 (Ubuntu 13.2.0-23ubuntu3)
  configuration: --prefix=/usr --extra-version=3ubuntu5 --toolchain=hardened --libdir=/usr/lib/x86_64-linux-gnu --incdir=/usr/include/x86_64-linux-gnu --arch=amd64 --enable-gpl --disable-stripping --disable-omx --enable-gnutls --enable-libaom --enable-libass --enable-libbs2b --enable-libcaca --enable-libcdio --enable-libcodec2 --enable-libdav1d --enable-libflite --enable-libfontconfig --enable-libfreetype --enable-libfribidi --enable-libglslang --enable-libgme --enable-libgsm --enable-libharfbuzz --enable-libmp3lame --enable-libmysofa --enable-libopenjpeg --enable-libopenmpt --enable-libopus --enable-librubberband --enable-libshine --enable-libsnappy --enable-libsoxr --enable-libspeex --enable-libtheora --enable-libtwolame --enable-libvidstab --enable-libvorbis --enable-libvpx --enable-libwebp --enable-libx265 --enable-libxml2 --enable-libxvid --enable-libzimg --ena