In [1]:
# set variables
videoFilePath = ""
saveDirPath = ""
audioFilePath = saveDirPath + ""
srtFilePath = saveDirPath + ""
languages = ""
use_denoising = True

model_name = "openai/whisper-large-v3"

In [None]:
# extract audio from video
import subprocess

command = (
    'ffmpeg -y -i "{}" -ab 160k -ac 2 -ar 44100 -vn "{}"'
    .format(videoFilePath, audioFilePath)
)

return_code = subprocess.call(command, shell=True)

if return_code == 0:
    print(f"Successfully converted audio, please check \"{audioFilePath}\"")
else:
    print(f"Error: ffmpeg command failed with return code {return_code}")


In [None]:
import torch
import torchaudio
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from denoiser import pretrained
from denoiser.dsp import convert_audio

def reduce_noise(audio, denoiser_model, device='cuda'):
    audio = audio.to(device)

    with torch.no_grad(), torch.amp.autocast(device_type=device, dtype=torch.float16):
        denoised_audio = denoiser_model(audio)

    denoised_audio = denoised_audio.squeeze()
        
    return denoised_audio

def load_audio(file_path):
    speech_array, sampling_rate = torchaudio.load(file_path)
    if sampling_rate != 16000:
        resampler = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=16000)
        speech_array = resampler(speech_array)
    if speech_array.shape[0] > 1:
        speech_array = torch.mean(speech_array, dim=0, keepdim=False)
    return speech_array

if use_denoising:
    denoiser_model = pretrained.dns64().cuda()
    speech = load_audio(audioFilePath)
    speech = speech.unsqueeze(0)
    speech = reduce_noise(speech, denoiser_model)

    del denoiser_model
    torch.cuda.empty_cache()

In [4]:
# set CUDA
device = 0 if torch.cuda.is_available() else -1

# load whisper & processor
processor = AutoProcessor.from_pretrained(model_name)
model = AutoModelForSpeechSeq2Seq.from_pretrained(model_name)

In [None]:
# init pipeline
asr_pipeline = pipeline(
    "automatic-speech-recognition",
    model=model,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    chunk_length_s=30,
    batch_size=1,
    device=device,
    return_timestamps=True,
    generate_kwargs={
        "language": languages,
        "task": "transcribe",
        "num_beams": 10,  # default 5
        #"length_penalty": 1.0,
        #"early_stopping": True,
        #"no_repeat_ngram_size": 3,  
        #"repetition_penalty": 1.2  
    }
)

In [None]:
if str(type(speech)) != "<class 'numpy.ndarray'>":
    speech = speech.cpu().numpy()

result = asr_pipeline(speech)
transcript = result.get('chunks', [])

del asr_pipeline
torch.cuda.empty_cache()

In [None]:
# display audio

from IPython.display import Audio, display
import torch
import numpy as np

def play_audio(audio, sr=16000):
    if isinstance(audio, torch.Tensor):
        audio_np = audio.cpu().squeeze().numpy()
    elif isinstance(audio, np.ndarray):
        audio_np = audio
    else:
        raise TypeError("Audio data must be in torch.Tensor or numpy.ndarray format.")
    
    display(Audio(audio_np, rate=sr))


play_audio(speech)

In [None]:
def create_srt(transcript, srt_path):
    with open(srt_path, 'w', encoding='utf-8') as srt_file:
        for idx, segment in enumerate(transcript, start=1):
            # Verify that the 'timestamp' key exists and the value is valid
            if 'timestamp' not in segment or segment['timestamp'] is None:
                print(f"Warning: Segment {idx} missing 'timestamp' key or 'timestamp' is None. Skipping.")
                continue

            start_time, end_time = segment['timestamp']

            # Verify that the start and end times are not None
            if start_time is None or end_time is None:
                print(f"Warning: Segment {idx} has None start or end time. Skipping.")
                continue

            text = segment['text'].strip()

            def format_timestamp(seconds):
                if seconds is None:
                    return "00:00:00,000"
                millis = int((seconds - int(seconds)) * 1000)
                hrs = int(seconds // 3600)
                mins = int((seconds % 3600) // 60)
                secs = int(seconds % 60)
                return f"{hrs:02}:{mins:02}:{secs:02},{millis:03}"

            srt_file.write(f"{idx}\n")
            srt_file.write(f"{format_timestamp(start_time)} --> {format_timestamp(end_time)}\n")
            srt_file.write(f"{text}\n\n")


# create .srt file
create_srt(transcript, srtFilePath)

print(f"Subtitle file created at \"{srtFilePath}\"")