# Integrating Speaker Diarization with a Speech Recognition System

This notebook demonstrates a cutting-edge approach to audio analysis by integrating WhisperX's powerful transcription capabilities with NVIDIA NeMo's speaker diarization. The result is a comprehensive system that not only transcribes speech but also identifies individual speakers, answering the crucial question of "who said what" in audio recordings.

### Key Features

1. **Transcription**: Utilizing WhisperX for highly accurate speech-to-text conversion.

2. **Speaker Diarization**: Implementing NVIDIA's NeMo MSDD model to distinguish between speakers.

3. **Music-Speech Separation**: Employing Demucs to isolate speech from background music.

4. **Timestamp Alignment**: Using Wav2Vec2 to precisely sync transcriptions with the original audio.

5. **Punctuation-Based Refinement**: Enhancing speaker attribution by analyzing sentence structure.

### Workflow Overview

1. **Environment Setup**: Install and configure necessary libraries and tools.
2. **Audio Preprocessing**: Separate speech from music using Demucs.
3. **Transcription**: Generate text from speech using WhisperX.
4. **Timestamp Alignment**: Use Wav2Vec2 to precisely sync transcriptions with the original audio.
5. **Diarization**: Apply NeMo's MSDD model to identify distinct speakers.
6. **Alignment and Mapping**: Connect speaker identities with transcribed text.
7. **Post-processing**: Refine results using punctuation analysis and finalize output.
8. **Finalizing**: Generate the results in desired formats and map the speaker IDs with desired names.

### Applications

This integrated approach has wide-ranging applications in:

- Automated meeting transcription
- Media content analysis
- Legal and medical transcription services
- Accessibility enhancement for audio/video content


# 1. Environment Setup

## Core Components

1. **WhisperX**
   - Enhanced version of OpenAI's Whisper
   - Provides advanced speech recognition capabilities

2. **NVIDIA NeMo (ASR Toolkit)**
   - Backbone for speaker diarization
   - Offers state-of-the-art automatic speech recognition

3. **Demucs**
   - Specializes in music source separation
   - Critical for isolating speech from background audio

## Supporting Libraries

4. **Audio Processing Suite**
   - dora-search: Audio search and analysis
   - lameenc: MP3 encoding capabilities
   - openunmix: Music source separation

5. **deepmultilingualpunctuation**
   - Adds intelligent punctuation to transcriptions
   - Enhances readability of generated text

6. **Utility Tools**
   - wget: Efficient file downloading
   - pydub: Audio file manipulation within Python

## Synergy in Action

This carefully curated set of tools forms a comprehensive ecosystem for advanced audio analysis:

- **Pre-processing**: Clean and prepare audio with Demucs and pydub
- **Transcription**: Convert speech to text using WhisperX
- **Diarization**: Identify speakers with NeMo
- **Post-processing**: Refine output with punctuation and formatting

By integrating these components, we can create a robust pipeline capable of converting a raw audio input to structured, speaker-attributed transcriptions.

In [1]:
!pip install -q git+https://github.com/m-bain/whisperX.git@78dcfaab51005aa703ee21375f81ed31bc248560
!pip install -q --no-build-isolation nemo_toolkit[asr]==1.22.0
!pip install -q --no-deps git+https://github.com/facebookresearch/demucs#egg=demucs
!pip install -q dora-search "lameenc>=1.2" openunmix
!pip install -q deepmultilingualpunctuation
!pip install -q wget pydub

  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.5/1.5 MB[0m [31m6.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m208.7/208.7 kB[0m [31m8.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m32.9/32.9 MB[0m [31m33.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m192.3/192.3 MB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.6/3.6 MB[0m [31m54.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.8/6.8 MB[0m [31m87.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.2/43.2 kB[0m [31m4.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m808.5/808.5 kB[0m [31m59.1 MB/s[0m eta 

In [2]:
import os
import wget
from omegaconf import OmegaConf
import json
import shutil
from faster_whisper import WhisperModel
import whisperx
import torch
from pydub import AudioSegment
from nemo.collections.asr.models.msdd_models import NeuralDiarizer
from deepmultilingualpunctuation import PunctuationModel
import re
import logging
import nltk
from whisperx.alignment import DEFAULT_ALIGN_MODELS_HF, DEFAULT_ALIGN_MODELS_TORCH
from whisperx.utils import LANGUAGES, TO_LANGUAGE_CODE

  torchaudio.set_audio_backend("soundfile")
[NeMo W 2024-07-22 21:11:26 transformer_bpe_models:59] Could not import NeMo NLP collection which is required for speech translation model.


In [3]:
# Name of the audio file
audio_path = "/content/audio.mp3"

# Whether to enable music removal from speech, helps increase diarization quality but uses alot of ram
enable_stemming = True

# (choose from 'tiny.en', 'tiny', 'base.en', 'base', 'small.en', 'small', 'medium.en', 'medium', 'large-v1', 'large-v2', 'large-v3', 'large')
whisper_model_name = "large-v2"

# replaces numerical digits with their pronounciation, increases diarization accuracy
suppress_numerals = True

batch_size = 8

language = 'en'  # autodetect language

device = "cuda" if torch.cuda.is_available() else "cpu"

# Helper functions for Audio Processing and Transcription

## Configuration and Setup

#### `create_config`
- Creates configuration for NeMo diarization model
- Sets up input/output paths and diarization parameters
- Configures VAD and speaker embedding models

#### `process_language_arg`
- Validates and processes language argument
- Converts language names to codes
- Handles English-only model cases


## Word and Speaker Mapping

#### `get_word_ts_anchor`
- Calculates word anchor timestamp

#### `get_words_speaker_mapping`
- Maps words to speakers based on timestamps

#### `get_realigned_ws_mapping_with_punctuation`
- Realigns word-speaker mapping considering punctuation and speaker changes

#### `get_sentences_speaker_mapping`
- Groups words into sentences with speaker assignments


## Transcript Formatting

#### `get_speaker_aware_transcript`
- Writes speaker-aware transcript, organized by turns

#### `format_timestamp`
- Formats milliseconds into readable timestamp string

#### `write_srt`
- Outputs transcript in SRT format


## Utility Functions

#### `find_numeral_symbol_tokens`
- Identifies tokens with numerals or currency symbols

#### `filter_missing_timestamps`
- Fills in missing timestamps in word-level transcriptions

#### `cleanup`
- Removes specified files or directories


## Transcription Functions

#### `transcribe`
- Transcribes audio using Faster Whisper (non-batched)

#### `transcribe_batched`
- Transcribes audio using Whisper with batching for better performance

In [4]:
punct_model_langs = [
    "en"
]
wav2vec2_langs = list(DEFAULT_ALIGN_MODELS_TORCH.keys()) + list(
    DEFAULT_ALIGN_MODELS_HF.keys()
)

whisper_langs = sorted(LANGUAGES.keys()) + sorted(
    [k.title() for k in TO_LANGUAGE_CODE.keys()]
)


def create_config(output_dir):
    DOMAIN_TYPE = "telephonic"  # Can be meeting, telephonic, or general based on domain type of the audio file
    CONFIG_FILE_NAME = f"diar_infer_{DOMAIN_TYPE}.yaml"
    CONFIG_URL = f"https://raw.githubusercontent.com/NVIDIA/NeMo/main/examples/speaker_tasks/diarization/conf/inference/{CONFIG_FILE_NAME}"
    MODEL_CONFIG = os.path.join(output_dir, CONFIG_FILE_NAME)
    if not os.path.exists(MODEL_CONFIG):
        MODEL_CONFIG = wget.download(CONFIG_URL, output_dir)

    config = OmegaConf.load(MODEL_CONFIG)

    data_dir = os.path.join(output_dir, "data")
    os.makedirs(data_dir, exist_ok=True)

    meta = {
        "audio_filepath": os.path.join(output_dir, "mono_file.wav"),
        "offset": 0,
        "duration": None,
        "label": "infer",
        "text": "-",
        "rttm_filepath": None,
        "uem_filepath": None,
    }
    with open(os.path.join(data_dir, "input_manifest.json"), "w") as fp:
        json.dump(meta, fp)
        fp.write("\n")

    pretrained_vad = "vad_multilingual_marblenet"
    pretrained_speaker_model = "titanet_large"
    config.num_workers = 0  # Workaround for multiprocessing hanging with ipython issue
    config.diarizer.manifest_filepath = os.path.join(data_dir, "input_manifest.json")
    config.diarizer.out_dir = (
        output_dir  # Directory to store intermediate files and prediction outputs
    )

    config.diarizer.speaker_embeddings.model_path = pretrained_speaker_model
    config.diarizer.oracle_vad = (
        False  # compute VAD provided with model_path to vad config
    )
    config.diarizer.clustering.parameters.oracle_num_speakers = False

    # Here, we use our in-house pretrained NeMo VAD model
    config.diarizer.vad.model_path = pretrained_vad
    config.diarizer.vad.parameters.onset = 0.8
    config.diarizer.vad.parameters.offset = 0.6
    config.diarizer.vad.parameters.pad_offset = -0.05
    config.diarizer.msdd_model.model_path = (
        "diar_msdd_telephonic"  # Telephonic speaker diarization model
    )

    return config


def get_word_ts_anchor(s, e, option="start"):
    if option == "end":
        return e
    elif option == "mid":
        return (s + e) / 2
    return s


def get_words_speaker_mapping(wrd_ts, spk_ts, word_anchor_option="start"):
    s, e, sp = spk_ts[0]
    wrd_pos, turn_idx = 0, 0
    wrd_spk_mapping = []
    for wrd_dict in wrd_ts:
        ws, we, wrd = (
            int(wrd_dict["start"] * 1000),
            int(wrd_dict["end"] * 1000),
            wrd_dict["word"],
        )
        wrd_pos = get_word_ts_anchor(ws, we, word_anchor_option)
        while wrd_pos > float(e):
            turn_idx += 1
            turn_idx = min(turn_idx, len(spk_ts) - 1)
            s, e, sp = spk_ts[turn_idx]
            if turn_idx == len(spk_ts) - 1:
                e = get_word_ts_anchor(ws, we, option="end")
        wrd_spk_mapping.append(
            {"word": wrd, "start_time": ws, "end_time": we, "speaker": sp}
        )
    return wrd_spk_mapping


sentence_ending_punctuations = ".?!"


def get_first_word_idx_of_sentence(word_idx, word_list, speaker_list, max_words):
    is_word_sentence_end = (
        lambda x: x >= 0 and word_list[x][-1] in sentence_ending_punctuations
    )
    left_idx = word_idx
    while (
        left_idx > 0
        and word_idx - left_idx < max_words
        and speaker_list[left_idx - 1] == speaker_list[left_idx]
        and not is_word_sentence_end(left_idx - 1)
    ):
        left_idx -= 1

    return left_idx if left_idx == 0 or is_word_sentence_end(left_idx - 1) else -1


def get_last_word_idx_of_sentence(word_idx, word_list, max_words):
    is_word_sentence_end = (
        lambda x: x >= 0 and word_list[x][-1] in sentence_ending_punctuations
    )
    right_idx = word_idx
    while (
        right_idx < len(word_list)
        and right_idx - word_idx < max_words
        and not is_word_sentence_end(right_idx)
    ):
        right_idx += 1

    return (
        right_idx
        if right_idx == len(word_list) - 1 or is_word_sentence_end(right_idx)
        else -1
    )


def get_realigned_ws_mapping_with_punctuation(
    word_speaker_mapping, max_words_in_sentence=50
):
    is_word_sentence_end = (
        lambda x: x >= 0
        and word_speaker_mapping[x]["word"][-1] in sentence_ending_punctuations
    )
    wsp_len = len(word_speaker_mapping)

    words_list, speaker_list = [], []
    for k, line_dict in enumerate(word_speaker_mapping):
        word, speaker = line_dict["word"], line_dict["speaker"]
        words_list.append(word)
        speaker_list.append(speaker)

    k = 0
    while k < len(word_speaker_mapping):
        line_dict = word_speaker_mapping[k]
        if (
            k < wsp_len - 1
            and speaker_list[k] != speaker_list[k + 1]
            and not is_word_sentence_end(k)
        ):
            left_idx = get_first_word_idx_of_sentence(
                k, words_list, speaker_list, max_words_in_sentence
            )
            right_idx = (
                get_last_word_idx_of_sentence(
                    k, words_list, max_words_in_sentence - k + left_idx - 1
                )
                if left_idx > -1
                else -1
            )
            if min(left_idx, right_idx) == -1:
                k += 1
                continue

            spk_labels = speaker_list[left_idx : right_idx + 1]
            mod_speaker = max(set(spk_labels), key=spk_labels.count)
            if spk_labels.count(mod_speaker) < len(spk_labels) // 2:
                k += 1
                continue

            speaker_list[left_idx : right_idx + 1] = [mod_speaker] * (
                right_idx - left_idx + 1
            )
            k = right_idx

        k += 1

    k, realigned_list = 0, []
    while k < len(word_speaker_mapping):
        line_dict = word_speaker_mapping[k].copy()
        line_dict["speaker"] = speaker_list[k]
        realigned_list.append(line_dict)
        k += 1

    return realigned_list


def get_sentences_speaker_mapping(word_speaker_mapping, spk_ts):
    sentence_checker = nltk.tokenize.PunktSentenceTokenizer().text_contains_sentbreak
    s, e, spk = spk_ts[0]
    prev_spk = spk

    snts = []
    snt = {"speaker": f"Speaker {spk}", "start_time": s, "end_time": e, "text": ""}

    for wrd_dict in word_speaker_mapping:
        wrd, spk = wrd_dict["word"], wrd_dict["speaker"]
        s, e = wrd_dict["start_time"], wrd_dict["end_time"]
        if spk != prev_spk or sentence_checker(snt["text"] + " " + wrd):
            snts.append(snt)
            snt = {
                "speaker": f"Speaker {spk}",
                "start_time": s,
                "end_time": e,
                "text": "",
            }
        else:
            snt["end_time"] = e
        snt["text"] += wrd + " "
        prev_spk = spk

    snts.append(snt)
    return snts


def get_speaker_aware_transcript(sentences_speaker_mapping, f):
    previous_speaker = sentences_speaker_mapping[0]["speaker"]
    f.write(f"{previous_speaker}: ")

    for sentence_dict in sentences_speaker_mapping:
        speaker = sentence_dict["speaker"]
        sentence = sentence_dict["text"]

        # If this speaker doesn't match the previous one, start a new paragraph
        if speaker != previous_speaker:
            f.write(f"\n\n{speaker}: ")
            previous_speaker = speaker

        # No matter what, write the current sentence
        f.write(sentence + " ")


def format_timestamp(
    milliseconds: float, always_include_hours: bool = False, decimal_marker: str = "."
):
    assert milliseconds >= 0, "non-negative timestamp expected"

    hours = milliseconds // 3_600_000
    milliseconds -= hours * 3_600_000

    minutes = milliseconds // 60_000
    milliseconds -= minutes * 60_000

    seconds = milliseconds // 1_000
    milliseconds -= seconds * 1_000

    hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
    return (
        f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
    )


def write_srt(transcript, file):
    """
    Write a transcript to a file in SRT format.

    """
    for i, segment in enumerate(transcript, start=1):
        # write srt lines
        print(
            f"{i}\n"
            f"{format_timestamp(segment['start_time'], always_include_hours=True, decimal_marker=',')} --> "
            f"{format_timestamp(segment['end_time'], always_include_hours=True, decimal_marker=',')}\n"
            f"{segment['speaker']}: {segment['text'].strip().replace('-->', '->')}\n",
            file=file,
            flush=True,
        )


def find_numeral_symbol_tokens(tokenizer):
    numeral_symbol_tokens = [
        -1,
    ]
    for token, token_id in tokenizer.get_vocab().items():
        has_numeral_symbol = any(c in "0123456789%$£" for c in token)
        if has_numeral_symbol:
            numeral_symbol_tokens.append(token_id)
    return numeral_symbol_tokens


def _get_next_start_timestamp(word_timestamps, current_word_index, final_timestamp):
    # if current word is the last word
    if current_word_index == len(word_timestamps) - 1:
        return word_timestamps[current_word_index]["start"]

    next_word_index = current_word_index + 1
    while current_word_index < len(word_timestamps) - 1:
        if word_timestamps[next_word_index].get("start") is None:
            # if next word doesn't have a start timestamp
            # merge it with the current word and delete it
            word_timestamps[current_word_index]["word"] += (
                " " + word_timestamps[next_word_index]["word"]
            )

            word_timestamps[next_word_index]["word"] = None
            next_word_index += 1
            if next_word_index == len(word_timestamps):
                return final_timestamp

        else:
            return word_timestamps[next_word_index]["start"]


def filter_missing_timestamps(
    word_timestamps, initial_timestamp=0, final_timestamp=None
):
    # handle the first and last word
    if word_timestamps[0].get("start") is None:
        word_timestamps[0]["start"] = (
            initial_timestamp if initial_timestamp is not None else 0
        )
        word_timestamps[0]["end"] = _get_next_start_timestamp(
            word_timestamps, 0, final_timestamp
        )

    result = [
        word_timestamps[0],
    ]

    for i, ws in enumerate(word_timestamps[1:], start=1):
        # if ws doesn't have a start and end
        # use the previous end as start and next start as end
        if ws.get("start") is None and ws.get("word") is not None:
            ws["start"] = word_timestamps[i - 1]["end"]
            ws["end"] = _get_next_start_timestamp(word_timestamps, i, final_timestamp)

        if ws["word"] is not None:
            result.append(ws)
    return result


def cleanup(path: str):
    """path could either be relative or absolute."""
    # check if file or directory exists
    if os.path.isfile(path) or os.path.islink(path):
        # remove file
        os.remove(path)
    elif os.path.isdir(path):
        # remove directory and all its content
        shutil.rmtree(path)
    else:
        raise ValueError("Path {} is not a file or dir.".format(path))


def process_language_arg(language: str, model_name: str):
    """
    Process the language argument to make sure it's valid and convert language names to language codes.
    """
    if language is not None:
        language = language.lower()
    if language not in LANGUAGES:
        if language in TO_LANGUAGE_CODE:
            language = TO_LANGUAGE_CODE[language]
        else:
            raise ValueError(f"Unsupported language: {language}")

    if model_name.endswith(".en") and language != "en":
        if language is not None:
            logging.warning(
                f"{model_name} is an English-only model but received '{language}'; using English instead."
            )
        language = "en"
    return language


def transcribe(
    audio_file: str,
    language: str,
    model_name: str,
    compute_dtype: str,
    suppress_numerals: bool,
    device: str,
):
    from faster_whisper import WhisperModel
    from helpers import find_numeral_symbol_tokens, wav2vec2_langs

    # Faster Whisper non-batched
    # Run on GPU with FP16
    whisper_model = WhisperModel(model_name, device=device, compute_type=compute_dtype)

    # or run on GPU with INT8
    # model = WhisperModel(model_size, device="cuda", compute_type="int8_float16")
    # or run on CPU with INT8
    # model = WhisperModel(model_size, device="cpu", compute_type="int8")

    if suppress_numerals:
        numeral_symbol_tokens = find_numeral_symbol_tokens(whisper_model.hf_tokenizer)
    else:
        numeral_symbol_tokens = None

    if language is not None and language in wav2vec2_langs:
        word_timestamps = False
    else:
        word_timestamps = True

    segments, info = whisper_model.transcribe(
        audio_file,
        language=language,
        beam_size=5,
        word_timestamps=word_timestamps,  # TODO: disable this if the language is supported by wav2vec2
        suppress_tokens=numeral_symbol_tokens,
        vad_filter=True,
    )
    whisper_results = []
    for segment in segments:
        whisper_results.append(segment._asdict())
    # clear gpu vram
    del whisper_model
    torch.cuda.empty_cache()
    return whisper_results, language


def transcribe_batched(
    audio_file: str,
    language: str,
    batch_size: int,
    model_name: str,
    compute_dtype: str,
    suppress_numerals: bool,
    device: str,
):
    import whisperx

    # Faster Whisper batched
    whisper_model = whisperx.load_model(
        model_name,
        device,
        compute_type=compute_dtype,
        asr_options={"suppress_numerals": suppress_numerals},
    )
    audio = whisperx.load_audio(audio_file)
    result = whisper_model.transcribe(audio, language=language, batch_size=batch_size)
    del whisper_model
    torch.cuda.empty_cache()
    return result["segments"], result["language"]

# 2. Audio Preprocessing: Isolating Speech with Demucs

Many real-world audio recordings contain a mix of speech and background music, posing challenges for accurate transcription and speaker identification. To address this, we employ a powerful preprocessing step using Demucs.

## Introducing Demucs

Demucs is a state-of-the-art deep learning model designed for music source separation. In our context, it serves as a critical tool for isolating speech from complex audio environments.

### Key Features:
- Separates vocal tracks from instrumental backgrounds
- Utilizes advanced neural network architecture
- Trained on diverse audio mixtures

## Benefits for Diarization and Transcription

By implementing Demucs in our preprocessing pipeline, we achieve several advantages:

1. **Enhanced Speech Clarity**: Removes musical interference, making speaker voices more distinct.
2. **Improved Diarization Accuracy**: Allows for more precise speaker attribution by focusing on clean vocal signals.
3. **Better Transcription Results**: Provides cleaner input for speech recognition models like Whisper.

## The Process

1. Input mixed audio file
2. Apply Demucs to separate vocal and instrumental tracks
3. Extract the isolated vocal track
4. Feed clean speech to subsequent diarization and transcription steps

## Impact on Workflow

Integrating Demucs significantly enhances the robustness of our audio analysis pipeline, especially for:
- Podcast transcriptions with background music
- Interview recordings in musical settings
- Any audio content where speech and music coexist

By ensuring that our diarization system receives optimized input, we set the stage for more accurate and reliable results throughout the entire analysis process.

In [5]:
if enable_stemming:
    # Isolate vocals from the rest of the audio

    return_code = os.system(
        f'python3 -m demucs.separate -n htdemucs --two-stems=vocals "{audio_path}" -o "temp_outputs"'
    )

    if return_code != 0:
        logging.warning("Source splitting failed, using original audio file.")
        vocal_target = audio_path
    else:
        vocal_target = os.path.join(
            "temp_outputs",
            "htdemucs",
            os.path.splitext(os.path.basename(audio_path))[0],
            "vocals.wav",
        )
else:
    vocal_target = audio_path

# 3. Using WhisperX for Audio Transcription

In this section, we explore how to utilize WhisperX for precise audio transcription. WhisperX excels at converting spoken words into written text, accommodating a vast array of languages and dialects.

The transcription process entails running the audio file through WhisperX, which produces text segments along with corresponding timestamps for when each segment was spoken. This transcription forms the basis for speaker identification and subsequent analysis by supplying the essential text content.


In [6]:
compute_type = "float16"
# or run on GPU with INT8
# compute_type = "int8_float16"
# or run on CPU with INT8
# compute_type = "int8"

if batch_size != 0:
    whisper_results, language = transcribe_batched(
        vocal_target,
        language,
        batch_size,
        whisper_model_name,
        compute_type,
        suppress_numerals,
        device,
    )
else:
    whisper_results, language = transcribe(
        vocal_target,
        language,
        whisper_model_name,
        compute_type,
        suppress_numerals,
        device,
    )

config.json:   0%|          | 0.00/2.80k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.20M [00:00<?, ?B/s]

vocabulary.txt:   0%|          | 0.00/460k [00:00<?, ?B/s]

model.bin:   0%|          | 0.00/3.09G [00:00<?, ?B/s]

No language specified, language will be first be detected for each audio file (increases inference time).


100%|█████████████████████████████████████| 16.9M/16.9M [00:01<00:00, 16.9MiB/s]
INFO:pytorch_lightning.utilities.migration.utils:Lightning automatically upgraded your loaded checkpoint from v1.5.4 to v2.0.7. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint --file ../root/.cache/torch/whisperx-vad-segmentation.bin`


Model was trained with pyannote.audio 0.0.1, yours is 3.1.1. Bad things might happen unless you revert pyannote.audio to 0.x.
Model was trained with torch 1.10.0+cu102, yours is 2.3.1+cu121. Bad things might happen unless you revert torch to 1.x.
Suppressing numeral and symbol tokens


# 4. Aligning Transcriptions with Original Audio Using Wav2Vec2

In this section, we leverage Wav2Vec2, a large-scale neural network, to learn speech representations that aid in various speech processing tasks, such as speech recognition and alignment.

The process begins by loading the Wav2Vec2 alignment model. This model is then used to align transcription segments with the original audio signal in the `vocal_target` file. Specifically, it identifies the precise timestamps in the audio where each segment was spoken and aligns the text accordingly.

By integrating the outputs from Whisper and Wav2Vec2, the result is a fully aligned transcription of the speech in the `vocal_target` file. This aligned transcription is crucial for subsequent tasks like speaker diarization, sentiment analysis, and language identification.

If a Wav2Vec2 model is unavailable for the specified language, the word timestamps generated by Whisper will be utilized instead.


In [7]:
if language in wav2vec2_langs:
    device = "cuda"
    alignment_model, metadata = whisperx.load_align_model(
        language_code=language, device=device
    )
    result_aligned = whisperx.align(
        whisper_results, alignment_model, metadata, vocal_target, device
    )
    word_timestamps = filter_missing_timestamps(
        result_aligned["word_segments"],
        initial_timestamp=whisper_results[0].get("start"),
        final_timestamp=whisper_results[-1].get("end"),
    )

    # clear gpu vram
    del alignment_model
    torch.cuda.empty_cache()
else:
    assert batch_size == 0, (  # TODO: add a better check for word timestamps existence
        f"Unsupported language: {language}, use --batch_size to 0"
        " to generate word timestamps using whisper directly and fix this error."
    )
    word_timestamps = []
    for segment in whisper_results:
        for word in segment["words"]:
            word_timestamps.append({"word": word[2], "start": word[0], "end": word[1]})

Downloading: "https://download.pytorch.org/torchaudio/models/wav2vec2_fairseq_base_ls960_asr_ls960.pth" to /root/.cache/torch/hub/checkpoints/wav2vec2_fairseq_base_ls960_asr_ls960.pth
100%|██████████| 360M/360M [00:02<00:00, 129MB/s]


In [10]:
with open("transcription.txt", "w", encoding="utf-8") as txt:
    for result in whisper_results:
        txt.write("{" + "\n" + "  text:" + result["text"] + "\n" + "\n" + "  start: " + str(result["start"]) + "\n" + "  end: " + str(result["end"]) + "\n" + "}" + "\n")

# 5. Using NeMo's MSDD Model for Speaker Diarization

In this section, we use the NVIDIA NeMo MSDD (Multi-scale Diarization Decoder) model to carry out speaker diarization on an audio signal. Speaker diarization involves segmenting the audio into different parts based on the speaker at any given time.


In [8]:
# Convert audio to mono for NeMo combatibility
sound = AudioSegment.from_file(vocal_target).set_channels(1)
ROOT = os.getcwd()
temp_path = os.path.join(ROOT, "temp_outputs")
os.makedirs(temp_path, exist_ok=True)
sound.export(os.path.join(temp_path, "mono_file.wav"), format="wav")

<_io.BufferedRandom name='/content/temp_outputs/mono_file.wav'>

In [9]:
# Initialize NeMo MSDD diarization model
msdd_model = NeuralDiarizer(cfg=create_config(temp_path)).to("cuda")
msdd_model.diarize()

del msdd_model
torch.cuda.empty_cache()

[NeMo I 2024-07-22 21:14:26 msdd_models:1092] Loading pretrained diar_msdd_telephonic model from NGC
[NeMo I 2024-07-22 21:14:26 cloud:68] Downloading from: https://api.ngc.nvidia.com/v2/models/nvidia/nemo/diar_msdd_telephonic/versions/1.0.1/files/diar_msdd_telephonic.nemo to /root/.cache/torch/NeMo/NeMo_1.22.0/diar_msdd_telephonic/3c3697a0a46f945574fa407149975a13/diar_msdd_telephonic.nemo
[NeMo I 2024-07-22 21:14:28 common:913] Instantiating model from pre-trained checkpoint


[NeMo W 2024-07-22 21:14:29 modelPT:161] If you intend to do training or fine-tuning, please call the ModelPT.setup_training_data() method and provide a valid configuration file to setup the train data loader.
    Train config : 
    manifest_filepath: null
    emb_dir: null
    sample_rate: 16000
    num_spks: 2
    soft_label_thres: 0.5
    labels: null
    batch_size: 15
    emb_batch_size: 0
    shuffle: true
    
[NeMo W 2024-07-22 21:14:29 modelPT:168] If you intend to do validation, please call the ModelPT.setup_validation_data() or ModelPT.setup_multiple_validation_data() method and provide a valid configuration file to setup the validation data loader(s). 
    Validation config : 
    manifest_filepath: null
    emb_dir: null
    sample_rate: 16000
    num_spks: 2
    soft_label_thres: 0.5
    labels: null
    batch_size: 15
    emb_batch_size: 0
    shuffle: false
    
[NeMo W 2024-07-22 21:14:29 modelPT:174] Please call the ModelPT.setup_test_data() or ModelPT.setup_multiple

[NeMo I 2024-07-22 21:14:29 features:289] PADDING: 16
[NeMo I 2024-07-22 21:14:29 features:289] PADDING: 16
[NeMo I 2024-07-22 21:14:30 save_restore_connector:249] Model EncDecDiarLabelModel was successfully restored from /root/.cache/torch/NeMo/NeMo_1.22.0/diar_msdd_telephonic/3c3697a0a46f945574fa407149975a13/diar_msdd_telephonic.nemo.
[NeMo I 2024-07-22 21:14:30 features:289] PADDING: 16
[NeMo I 2024-07-22 21:14:31 clustering_diarizer:127] Loading pretrained vad_multilingual_marblenet model from NGC
[NeMo I 2024-07-22 21:14:31 cloud:68] Downloading from: https://api.ngc.nvidia.com/v2/models/nvidia/nemo/vad_multilingual_marblenet/versions/1.10.0/files/vad_multilingual_marblenet.nemo to /root/.cache/torch/NeMo/NeMo_1.22.0/vad_multilingual_marblenet/670f425c7f186060b7a7268ba6dfacb2/vad_multilingual_marblenet.nemo
[NeMo I 2024-07-22 21:14:31 common:913] Instantiating model from pre-trained checkpoint


[NeMo W 2024-07-22 21:14:31 modelPT:161] If you intend to do training or fine-tuning, please call the ModelPT.setup_training_data() method and provide a valid configuration file to setup the train data loader.
    Train config : 
    manifest_filepath: /manifests/ami_train_0.63.json,/manifests/freesound_background_train.json,/manifests/freesound_laughter_train.json,/manifests/fisher_2004_background.json,/manifests/fisher_2004_speech_sampled.json,/manifests/google_train_manifest.json,/manifests/icsi_all_0.63.json,/manifests/musan_freesound_train.json,/manifests/musan_music_train.json,/manifests/musan_soundbible_train.json,/manifests/mandarin_train_sample.json,/manifests/german_train_sample.json,/manifests/spanish_train_sample.json,/manifests/french_train_sample.json,/manifests/russian_train_sample.json
    sample_rate: 16000
    labels:
    - background
    - speech
    batch_size: 256
    shuffle: true
    is_tarred: false
    tarred_audio_filepaths: null
    tarred_shard_strategy: sca

[NeMo I 2024-07-22 21:14:31 features:289] PADDING: 16
[NeMo I 2024-07-22 21:14:31 save_restore_connector:249] Model EncDecClassificationModel was successfully restored from /root/.cache/torch/NeMo/NeMo_1.22.0/vad_multilingual_marblenet/670f425c7f186060b7a7268ba6dfacb2/vad_multilingual_marblenet.nemo.
[NeMo I 2024-07-22 21:14:31 msdd_models:864] Multiscale Weights: [1, 1, 1, 1, 1]
[NeMo I 2024-07-22 21:14:31 msdd_models:865] Clustering Parameters: {
        "oracle_num_speakers": false,
        "max_num_speakers": 8,
        "enhanced_count_thres": 80,
        "max_rp_threshold": 0.25,
        "sparse_search_volume": 30,
        "maj_vote_spk_count": false,
        "chunk_cluster_count": 50,
        "embeddings_per_chunk": 10000
    }
[NeMo I 2024-07-22 21:14:31 speaker_utils:93] Number of files to diarize: 1
[NeMo I 2024-07-22 21:14:31 clustering_diarizer:309] Split long audio file to avoid CUDA memory issue


splitting manifest: 100%|██████████| 1/1 [00:16<00:00, 16.38s/it]

[NeMo I 2024-07-22 21:14:48 classification_models:273] Perform streaming frame-level VAD
[NeMo I 2024-07-22 21:14:48 collections:445] Filtered duration for loading collection is  0.00 hours.
[NeMo I 2024-07-22 21:14:48 collections:446] Dataset loaded with 12 items, total duration of  0.17 hours.
[NeMo I 2024-07-22 21:14:48 collections:448] # 12 files loaded accounting to # 1 labels



vad: 100%|██████████| 12/12 [00:04<00:00,  2.87it/s]

[NeMo I 2024-07-22 21:14:52 clustering_diarizer:250] Generating predictions with overlapping input segments



                                                               

[NeMo I 2024-07-22 21:15:00 clustering_diarizer:262] Converting frame level prediction to speech/no-speech segment in start and end times format.


creating speech segments: 100%|██████████| 1/1 [00:00<00:00,  1.84it/s]

[NeMo I 2024-07-22 21:15:00 clustering_diarizer:287] Subsegmentation for embedding extraction: scale0, /content/temp_outputs/speaker_outputs/subsegments_scale0.json
[NeMo I 2024-07-22 21:15:00 clustering_diarizer:343] Extracting embeddings for Diarization
[NeMo I 2024-07-22 21:15:00 collections:445] Filtered duration for loading collection is  0.00 hours.
[NeMo I 2024-07-22 21:15:00 collections:446] Dataset loaded with 565 items, total duration of  0.21 hours.
[NeMo I 2024-07-22 21:15:00 collections:448] # 565 files loaded accounting to # 1 labels



[1/5] extract embeddings: 100%|██████████| 9/9 [00:02<00:00,  4.28it/s]


[NeMo I 2024-07-22 21:15:02 clustering_diarizer:389] Saved embedding files to /content/temp_outputs/speaker_outputs/embeddings
[NeMo I 2024-07-22 21:15:02 clustering_diarizer:287] Subsegmentation for embedding extraction: scale1, /content/temp_outputs/speaker_outputs/subsegments_scale1.json
[NeMo I 2024-07-22 21:15:03 clustering_diarizer:343] Extracting embeddings for Diarization
[NeMo I 2024-07-22 21:15:03 collections:445] Filtered duration for loading collection is  0.00 hours.
[NeMo I 2024-07-22 21:15:03 collections:446] Dataset loaded with 683 items, total duration of  0.21 hours.
[NeMo I 2024-07-22 21:15:03 collections:448] # 683 files loaded accounting to # 1 labels


[2/5] extract embeddings: 100%|██████████| 11/11 [00:01<00:00,  5.58it/s]

[NeMo I 2024-07-22 21:15:05 clustering_diarizer:389] Saved embedding files to /content/temp_outputs/speaker_outputs/embeddings
[NeMo I 2024-07-22 21:15:05 clustering_diarizer:287] Subsegmentation for embedding extraction: scale2, /content/temp_outputs/speaker_outputs/subsegments_scale2.json
[NeMo I 2024-07-22 21:15:05 clustering_diarizer:343] Extracting embeddings for Diarization
[NeMo I 2024-07-22 21:15:05 collections:445] Filtered duration for loading collection is  0.00 hours.
[NeMo I 2024-07-22 21:15:05 collections:446] Dataset loaded with 867 items, total duration of  0.22 hours.





[NeMo I 2024-07-22 21:15:05 collections:448] # 867 files loaded accounting to # 1 labels


[3/5] extract embeddings: 100%|██████████| 14/14 [00:02<00:00,  6.03it/s]


[NeMo I 2024-07-22 21:15:07 clustering_diarizer:389] Saved embedding files to /content/temp_outputs/speaker_outputs/embeddings
[NeMo I 2024-07-22 21:15:07 clustering_diarizer:287] Subsegmentation for embedding extraction: scale3, /content/temp_outputs/speaker_outputs/subsegments_scale3.json
[NeMo I 2024-07-22 21:15:07 clustering_diarizer:343] Extracting embeddings for Diarization
[NeMo I 2024-07-22 21:15:07 collections:445] Filtered duration for loading collection is  0.00 hours.
[NeMo I 2024-07-22 21:15:07 collections:446] Dataset loaded with 1166 items, total duration of  0.23 hours.
[NeMo I 2024-07-22 21:15:07 collections:448] # 1166 files loaded accounting to # 1 labels


[4/5] extract embeddings: 100%|██████████| 19/19 [00:03<00:00,  5.11it/s]


[NeMo I 2024-07-22 21:15:11 clustering_diarizer:389] Saved embedding files to /content/temp_outputs/speaker_outputs/embeddings
[NeMo I 2024-07-22 21:15:11 clustering_diarizer:287] Subsegmentation for embedding extraction: scale4, /content/temp_outputs/speaker_outputs/subsegments_scale4.json
[NeMo I 2024-07-22 21:15:11 clustering_diarizer:343] Extracting embeddings for Diarization
[NeMo I 2024-07-22 21:15:11 collections:445] Filtered duration for loading collection is  0.00 hours.
[NeMo I 2024-07-22 21:15:11 collections:446] Dataset loaded with 1798 items, total duration of  0.24 hours.
[NeMo I 2024-07-22 21:15:11 collections:448] # 1798 files loaded accounting to # 1 labels


[5/5] extract embeddings: 100%|██████████| 29/29 [00:03<00:00,  7.41it/s]


[NeMo I 2024-07-22 21:15:15 clustering_diarizer:389] Saved embedding files to /content/temp_outputs/speaker_outputs/embeddings


clustering: 100%|██████████| 1/1 [00:01<00:00,  1.40s/it]

[NeMo I 2024-07-22 21:15:16 clustering_diarizer:464] Outputs are saved in /content/temp_outputs directory



[NeMo W 2024-07-22 21:15:16 der:185] Check if each ground truth RTTMs were present in the provided manifest file. Skipping calculation of Diariazation Error Rate


[NeMo I 2024-07-22 21:15:17 msdd_models:960] Loading embedding pickle file of scale:0 at /content/temp_outputs/speaker_outputs/embeddings/subsegments_scale0_embeddings.pkl
[NeMo I 2024-07-22 21:15:17 msdd_models:960] Loading embedding pickle file of scale:1 at /content/temp_outputs/speaker_outputs/embeddings/subsegments_scale1_embeddings.pkl
[NeMo I 2024-07-22 21:15:17 msdd_models:960] Loading embedding pickle file of scale:2 at /content/temp_outputs/speaker_outputs/embeddings/subsegments_scale2_embeddings.pkl
[NeMo I 2024-07-22 21:15:17 msdd_models:960] Loading embedding pickle file of scale:3 at /content/temp_outputs/speaker_outputs/embeddings/subsegments_scale3_embeddings.pkl
[NeMo I 2024-07-22 21:15:17 msdd_models:960] Loading embedding pickle file of scale:4 at /content/temp_outputs/speaker_outputs/embeddings/subsegments_scale4_embeddings.pkl
[NeMo I 2024-07-22 21:15:17 msdd_models:938] Loading cluster label file from /content/temp_outputs/speaker_outputs/subsegments_scale4_cluste

100%|██████████| 1/1 [00:00<00:00,  8.28it/s]

[NeMo I 2024-07-22 21:15:17 msdd_models:1403]      [Threshold: 0.7000] [use_clus_as_main=False] [diar_window=50]
[NeMo I 2024-07-22 21:15:17 speaker_utils:93] Number of files to diarize: 1
[NeMo I 2024-07-22 21:15:17 speaker_utils:93] Number of files to diarize: 1



[NeMo W 2024-07-22 21:15:17 der:185] Check if each ground truth RTTMs were present in the provided manifest file. Skipping calculation of Diariazation Error Rate


[NeMo I 2024-07-22 21:15:17 speaker_utils:93] Number of files to diarize: 1


[NeMo W 2024-07-22 21:15:17 der:185] Check if each ground truth RTTMs were present in the provided manifest file. Skipping calculation of Diariazation Error Rate


[NeMo I 2024-07-22 21:15:17 speaker_utils:93] Number of files to diarize: 1


[NeMo W 2024-07-22 21:15:17 der:185] Check if each ground truth RTTMs were present in the provided manifest file. Skipping calculation of Diariazation Error Rate


[NeMo I 2024-07-22 21:15:17 msdd_models:1431]   
    


# 6. Mapping Speakers to Sentences According to Timestamps

This section details how the code reads speaker labels and their corresponding timestamps from the output file generated by the NeMo MSDD model. It then employs the `get_words_speaker_mapping` function to associate each word in the transcription with its respective speaker, based on the timestamp data.

This mapping process ensures accurate attribution of words to speakers, providing a detailed representation of who spoke what and when. The resulting `wsm` (Word-Speaker Mapping) variable contains a list of dictionaries, each representing a word along with its associated speaker, start time, and end time.

By mapping speakers to sentences according to timestamps, the code sets the stage for further analysis and processing of the diarized transcription.


In [11]:
# Reading timestamps <> Speaker Labels mapping

speaker_ts = []
with open(os.path.join(temp_path, "pred_rttms", "mono_file.rttm"), "r") as f:
    lines = f.readlines()
    for line in lines:
        line_list = line.split(" ")
        s = int(float(line_list[5]) * 1000)
        e = s + int(float(line_list[8]) * 1000)
        speaker_ts.append([s, e, int(line_list[11].split("_")[-1])])

wsm = get_words_speaker_mapping(word_timestamps, speaker_ts, "start")

# 7. Enhancing Speaker Attribution with Punctuation-Based Realignment

In this section, we introduce a method to resolve ambiguities in speaker labels, especially when a sentence is split between two speakers. By utilizing punctuation marks, the code identifies the dominant speaker for each sentence in the transcription.

Consider the following scenario:


```
Speaker A: It's got to come from somewhere else. Yeah, that one's also fun because you know the lows are
Speaker B: going to suck, right? So it's actually it hits you on both sides.
```

If a sentence is divided between two speakers, the code determines the most frequent speaker label for each word within the sentence and assigns that label to the entire sentence. This method enhances the accuracy of speaker diarization, particularly in cases where the Whisper model might overlook subtle utterances like "hmm" and "yeah," which the NeMo Diarization Model captures, potentially leading to inconsistencies.

The code also addresses scenarios where one speaker is delivering a lengthy monologue while others make occasional remarks in the background. In such instances, it disregards the background comments and attributes the entire monologue to the speaker who predominates.

By realigning speech segments according to punctuation, this approach ensures a more precise and reliable enhancement of speaker attribution in the transcription.

In [12]:
if language in punct_model_langs:
    # restoring punctuation in the transcript to help realign the sentences
    punct_model = PunctuationModel(model="kredor/punctuate-all")

    words_list = list(map(lambda x: x["word"], wsm))

    labled_words = punct_model.predict(words_list)

    ending_puncts = ".?!"
    model_puncts = ".,;:!?"

    # We don't want to punctuate U.S.A. with a period. Right?
    is_acronym = lambda x: re.fullmatch(r"\b(?:[a-zA-Z]\.){2,}", x)

    for word_dict, labeled_tuple in zip(wsm, labled_words):
        word = word_dict["word"]
        if (
            word
            and labeled_tuple[1] in ending_puncts
            and (word[-1] not in model_puncts or is_acronym(word))
        ):
            word += labeled_tuple[1]
            if word.endswith(".."):
                word = word.rstrip(".")
            word_dict["word"] = word

else:
    logging.warning(
        f"Punctuation restoration is not available for {language} language. Using the original punctuation."
    )

wsm = get_realigned_ws_mapping_with_punctuation(wsm)
ssm = get_sentences_speaker_mapping(wsm, speaker_ts)

config.json:   0%|          | 0.00/914 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.11G [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/447 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.1M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

# 8. Finalizing the Diarization Process

In this final section, we complete the diarization process by performing essential cleanup tasks, exporting results, and replacing speaker IDs with their corresponding names. The main steps are as follows:

1. **Saving the Speaker-Aware Transcript**: The `get_speaker_aware_transcript` function generates a transcript that includes both the text and the associated speaker information. This transcript is saved to a file with the same name as the input audio file but with a ".txt" extension.

2. **Exporting Diarization Results in SRT Format**: The `write_srt` function exports the diarization results in the SubRip Text (SRT) format, which is commonly used for subtitles. This format includes speaker labels and precise timestamps for each utterance. The SRT file is saved with the same name as the input audio file but with a ".srt" extension.

3. **Cleaning Up Temporary Files**: The `cleanup` function removes any temporary files or directories created during the diarization process. This step ensures a tidy working environment, freeing up storage space and maintaining system efficiency.

4. **Mapping Speaker IDs to Names**: The code reads the previously saved speaker-aware transcript file and replaces generic speaker IDs (e.g., "Speaker 0", "Speaker 1", "Speaker 2") with the actual names of the speakers.

By completing these final steps, the diarization process is wrapped up, and the results are ready for further analysis, post-processing, or integration with other tools and workflows. The final outputs, including the speaker-aware transcript, SRT file, and the transcript with mapped speaker names, offer valuable insights into the audio recording, supporting various applications such as content analysis, speaker identification, and subtitle creation.


In [13]:
with open(f"{os.path.splitext(audio_path)[0]}.txt", "w", encoding="utf-8-sig") as f:
    get_speaker_aware_transcript(ssm, f)

with open(f"{os.path.splitext(audio_path)[0]}.srt", "w", encoding="utf-8-sig") as srt:
    write_srt(ssm, srt)

cleanup(temp_path)

In [14]:
# Open the file
with open(f"{os.path.splitext(audio_path)[0]}.txt", 'r') as f:
    text = f.read()

# Replace the speaker IDs with names
text = text.replace('Speaker 0','Lex Fridman')
text = text.replace('Speaker 1','Andrew Ng')

# Write the file to disk
with open(audio_path[:-4] + '-with-speakers-names.txt', 'w') as f:
    f.write(text)