In [1]:
import json
from pathlib import Path

secrets_file = Path(".") / "secrets.json"
with open(secrets_file) as f:
    secrets = json.load(f)

In [None]:
import whisperx
import gc 
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"
audio_file = "sample/two-mates-having-a-chat.m4a"
batch_size = 16 # reduce if low on GPU mem
compute_type = "float16" if device == "cuda" else "int8" # change to "int8" if low on GPU mem (may reduce accuracy)

# 1. Transcribe with original whisper (batched)
model = whisperx.load_model("large-v2", device, compute_type=compute_type)

# save model to local path (optional)
# model_dir = "/path/"
# model = whisperx.load_model("large-v2", device, compute_type=compute_type, download_root=model_dir)

audio = whisperx.load_audio(audio_file)
result = model.transcribe(audio, batch_size=batch_size)
print(result["segments"]) # before alignment

# delete model if low on GPU resources
# import gc; gc.collect(); torch.cuda.empty_cache(); del model

# 2. Align whisper output
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False)

print(result["segments"]) # after alignment

# delete model if low on GPU resources
# import gc; gc.collect(); torch.cuda.empty_cache(); del model_a

# 3. Assign speaker labels
diarize_model = whisperx.DiarizationPipeline(use_auth_token=secrets.get("hf_token"), device=device)
# add min/max number of speakers if known
diarize_segments = diarize_model(audio)
# diarize_model(audio, min_speakers=min_speakers, max_speakers=max_speakers)
result = whisperx.assign_word_speakers(diarize_segments, result)


In [None]:
for segment in result["segments"]:
    print(segment["speaker"], segment["text"])

In [3]:
import ffmpeg

# Create a directory to save the audio files
output_dir = Path("output_audio")
output_dir.mkdir(exist_ok=True)

# Iterate through the segments and create audio files for each speaker
for i, segment in enumerate(result["segments"]):
    start_time = segment["start"]
    end_time = segment["end"]
    speaker = segment["speaker"].replace("_", "")
    output_file = output_dir / f"{speaker}_{i}.wav"
    
    # Skip if output_file exists
    if not output_file.exists():
        # Extract the audio segment for the speaker
        (
            ffmpeg
            .input(audio_file, ss=start_time, to=end_time)
            .output(str(output_file), format='wav', loglevel="quiet")
            .run(overwrite_output=True)
        )

In [8]:
from pathlib import Path
import ffmpeg

# Create lists of files for each speaker
speaker_files = {}
for file in output_dir.glob("*.wav"):
    speaker, index = file.stem.split('_')
    if speaker not in speaker_files:
        speaker_files[speaker] = []
    speaker_files[speaker].append((int(index), file))

# Sort the files for each speaker by the index
for speaker in speaker_files:
    speaker_files[speaker].sort()

# Merge the files for each speaker into a single audio file
for speaker, files in speaker_files.items():
    # Use the first file as the initial input to concatenate with
    concat = ffmpeg.input(str(files[0][1]))
    for _, file in files[1:]:
        segment = ffmpeg.input(str(file))
        concat = ffmpeg.concat(concat, segment, v=0, a=1)
    output_file = output_dir / f"{speaker}_merged.wav"
    try:
        concat.output(str(output_file), format='wav', loglevel="quiet").run(overwrite_output=True)
        for _, file in files:
            file.unlink()
    except ffmpeg.Error as e:
        print(f"Error occurred: {e.stderr}")