In [12]:
import whisperx
import gc
import os
from dotenv import load_dotenv

load_dotenv()

token = os.environ.get("HUGGINGFACE_ACCESS_TOKEN")

device = "cuda"
audio_file = "001.mp3"
batch_size = 16
compute_type = "float16"

model = whisperx.load_model("large-v2", device, compute_type=compute_type)

# save model to local path (optional)
# model_dir = "/path/"
# model = whisperx.load_model("large-v2", device, compute_type=compute_type, download_root=model_dir)

audio = whisperx.load_audio(audio_file)

Lightning automatically upgraded your loaded checkpoint from v1.5.4 to v2.0.7. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint --file ../../.cache/torch/whisperx-vad-segmentation.bin`


No language specified, language will be first be detected for each audio file (increases inference time).
Model was trained with pyannote.audio 0.0.1, yours is 3.3.1. Bad things might happen unless you revert pyannote.audio to 0.x.
Model was trained with torch 1.10.0+cu102, yours is 2.3.1+cu121. Bad things might happen unless you revert torch to 1.x.


In [13]:
result = model.transcribe(audio, batch_size=batch_size)
# print(result["segments"])

model_a, metadata = whisperx.load_align_model(
    language_code=result["language"], device=device
)
result = whisperx.align(
    result["segments"], model_a, metadata, audio, device, return_char_alignments=False
)

# print(result["segments"])

diarize_model = whisperx.DiarizationPipeline(use_auth_token=token, device=device)

diarize_segments = diarize_model(audio)
diarize_model(audio, min_speakers=2, max_speakers=2)

result = whisperx.assign_word_speakers(diarize_segments, result)

print(diarize_segments)

Detected language: en (1.00) in first 30s of audio...
                               segment label     speaker        start  \
0    [ 00:00:06.780 -->  00:00:10.189]     A  SPEAKER_00     6.780969   
1    [ 00:00:10.392 -->  00:00:10.965]     B  SPEAKER_00    10.392219   
2    [ 00:00:10.965 -->  00:00:11.286]     C  SPEAKER_03    10.965969   
3    [ 00:00:12.484 -->  00:00:16.197]     D  SPEAKER_03    12.484719   
4    [ 00:00:16.787 -->  00:00:18.374]     E  SPEAKER_03    16.787844   
..                                 ...   ...         ...          ...   
822  [ 00:55:18.330 -->  00:55:22.887]   AEQ  SPEAKER_03  3318.330969   
823  [ 00:55:23.747 -->  00:55:24.979]   AER  SPEAKER_03  3323.747844   
824  [ 00:55:27.713 -->  00:55:28.607]   AES  SPEAKER_03  3327.713469   
825  [ 00:55:30.902 -->  00:55:31.274]   AET  SPEAKER_03  3330.902844   
826  [ 00:55:36.167 -->  00:55:36.724]   AEU  SPEAKER_03  3336.167844   

             end  intersection        union  
0      10.189719  -3311

In [27]:
from datetime import timedelta

segments = result["segments"]

f = open("transcription.txt", "x", encoding="utf-8")
for segment in segments:
    start = segment["start"]
    end = segment["end"]
    speaker = segment["speaker"] if segment.get("speaker") else "UNKNOWN"
    text = segment["text"]
    line = (
        f"[{timedelta(seconds=start)} - {timedelta(seconds=end)}] {speaker} - {text}\n"
    )
    f.write(line)
f.close()