In [1]:
import torch
import torchaudio
from whisperx.vads import SileroCustom
from whisperx.audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
default_vad_options = {
    "chunk_size": 30,  # needed by silero since binarization happens before merge_chunks
    "vad_onset": 0.500,
    "vad_offset": 0.363,
    "vad_onnx": True,
    "silero_merge_cutoff": 0.1
}

vad_model = SileroCustom(
    **default_vad_options
)

>>Performing voice activity detection using Silero...


Using cache found in /home/ubuntu/.cache/torch/hub/snakers4_silero-vad_master


In [3]:
def load_audio(audio_path, target_sr=16000):
    audio, sr = torchaudio.load(audio_path)
    if sr != target_sr:
        audio = torchaudio.functional.resample(audio, sr, target_sr)
        sr = target_sr
    return audio, sr

def get_vad_segments(waveform, sr):
    vad_input = {'waveform': waveform.numpy(), 'sample_rate': sr}
    vad_segments = vad_model(vad_input)

    segments = []
    for segment in vad_segments:
        seg_start = int(segment.start * sr)
        seg_end = int(segment.end * sr)
        segments.append(waveform[seg_start:seg_end])
    return segments

In [4]:
vad_segments = []

waveform, sr = load_audio('/home/ubuntu/v2v-voice-library/data/fisher/audios/006/fe_03_00600.wav')
for i in range(waveform.shape[0]):
    segments = get_vad_segments(waveform[i], sr)
    vad_segments.extend(segments)


In [5]:
torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float16

model = AutoModelForSpeechSeq2Seq.from_pretrained("openai/whisper-large-v3-turbo", torch_dtype=torch_dtype, device_map="auto")
processor = AutoProcessor.from_pretrained("openai/whisper-large-v3-turbo")

In [6]:
def preprocess_audio(audio_data):
    """Preprocess audio data for the model."""
    # The HF WhisperFeatureExtractor uses 80 mel bins by default
    # Access it from the feature_extractor's config
    
    n_mels = 128  if 'v3' in model.name_or_path else 80 # Default value for Whisper models
    if hasattr(processor, "feature_extractor") and hasattr(processor.feature_extractor, "config"):
        n_mels = getattr(processor.feature_extractor.config, "num_mel_bins", 80)
    
    features = log_mel_spectrogram(
        audio_data,
        n_mels=n_mels,
        padding=N_SAMPLES - audio_data.shape[0] if audio_data.shape[0] < N_SAMPLES else 0,
    )
    # Convert features to match model's dtype
    return features.to(device=model.device, dtype=torch_dtype)

In [7]:
generate_kwargs = {
    "max_new_tokens": 100,
    "num_beams": 5,
    "condition_on_prev_tokens": False,
    "compression_ratio_threshold": 1.35,  # zlib compression ratio threshold (in token space)
    "temperature": 0.0,
    "logprob_threshold": -1.0,
    "no_speech_threshold": 0.6,
    "return_timestamps": False,
    "return_dict_in_generate": True,
}

In [8]:
input_features = []
for segment in vad_segments:
    input_features.append(preprocess_audio(segment))

In [9]:
import time
input_features = torch.stack(input_features)

start_time = time.time()
batched_responses = []
for batch in torch.split(input_features, 16):
    output = model.generate(batch, **generate_kwargs)
    batched_responses.append(output)
    torch.cuda.empty_cache()
end_time = time.time()
print(f"Time taken: {end_time - start_time} seconds")


Due to a bug fix in https://github.com/huggingface/transformers/pull/28687 transcription using a multilingual Whisper will default to language detection followed by transcription instead of translation to English.This might be a breaking change for your use case. If you want to instead always translate your audio to English, make sure to pass `language='en'`.
Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
From v4.47 onwards, when a model cache is to be returned, `generate` will return a `Cache` instance instead by default (as opposed to the legacy tuple of tuples format). If

Time taken: 77.0274965763092 seconds


In [10]:
text = output.sequences[0].tolist()
text = processor.decode(text, skip_special_tokens=False)
print(text)

<|startoftranscript|><|ko|><|transcribe|><|notimestamps|> MBC 뉴스 박진주입니다.<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|end

In [28]:
print(output.sequences[0].shape)

torch.Size([53])
