# Online Rolling-Buffer/Chunk-Based Diarization & STT Pipeline
Adapted from Juanma Coria's <a href='https://betterprogramming.pub/color-your-captions-streamlining-live-transcriptions-with-diart-and-openais-whisper-6203350234ef'>Blog Post</a> "Color Your Captions: Streamlining Live Transcriptions with "Diart" and OpenAI's Whisper". 


The pipeline incorporates both Diart: an "<a href='github.com/juanmc2005/diart/blob/main/paper.pdf'>Overlap-Aware Low-Latency Online Speaker Diarization Based on End-to-End Local Segmentation</a>", and Whisper from OpenAI. 


In [1]:
# Imports:

# general
import os
import sys
import logging
import traceback
import numpy as np

# diarization module
from diart import SpeakerDiarizationConfig, SpeakerDiarization
from diart.sources import AppleDeviceAudioSource, FileAudioSource, MicrophoneAudioSource

# asr module
import whisper_timestamped as whisper
from pyannote.core import Segment
from contextlib import contextmanager

# chain operators
import diart.operators as dops
import rich
import rx.operators as ops

Importing the dtw module. When using in academic works please cite:
  T. Giorgino. Computing and Visualizing Dynamic Time Warping Alignments in R: The dtw Package.
  J. Stat. Soft., doi:10.18637/jss.v031.i07.



## Diarization Module

- Streaming using buffered chunks presents a unique challenge, but also brings with it unique benefits over offline diarization/transcription pipelines. Offline models can suffer from long processing times, particularly on longer conversations. Diart can handle continuous audio streaming with low and constant memory costs, and also improves performance as the stream goes on (due to continual learning through improved speaker centroids). However, DER can be quite high at the onset of streaming!!! We have to keep this in mind as we develop our downstream pipeline

<strong>Diart Methods:</strong>
* Segmentation
    - end-to-end speaker segmentation model used to produce local speaker activity probabilites for each frame. 
    - Tau_active threshold controls min prob to be tagged in chunk

* Incremental Clustering
    - Segmentation-Driven Speaker Embedding: using modified x-vector based TDNN-based architecture with a statistical pooling layer that weighs frames based on speaker activity probs.
    - Constrained Incremental Clustering: ensures no two local speakers are assigned the same global speaker & handles overlapping speech.
    - Detection of New Speakers and Centroid Updates: Based on Delta_new threshold for new speakers, and centroids are only updated if the active duratiuon exceeds Rho_update.

* Latency Adjustment: initial buffer must be 5s, later buffers can be set to a min value of 500 ms for heightened responsiveness (λ is Latency!).
    - When longer latency (λ) is permitted, several positions of the rolling buffer can be combined in an ensemble-like manner (ideally improving accuracy!!!)

In [2]:
# Suppressing whisper-timestamped warnings for clean output
logging.getLogger("whisper_timestamped").setLevel(logging.ERROR)

In [3]:
 # set configurations for the speakerdiarization module
config = SpeakerDiarizationConfig( #Parameters from Juanma - FINE-TUNE THEM!
    duration=5, # Chunk duration in s - default: 5
    step=0.5, # Sliding window step in s - default -0.5
    latency="min", # System latency in s 
    tau_active=0.5, 
    rho_update=0.1,
    delta_new=0.57
)

# construct diarizer pipeline & set source
diarizer = SpeakerDiarization(config)
source = FileAudioSource("/Users/gael/Desktop/WorkFiles/ToyProjects/diarized_stt/data/3.wav", sample_rate=config.sample_rate) #ADJUST! - Does not work.
# For active listening demo set the source to: 
# source = MicrophoneAudioSource(sample_rate=config.sample_rate)

Lightning automatically upgraded your loaded checkpoint from v1.5.4 to v2.2.5. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../../../../.cache/torch/pyannote/models--pyannote--segmentation/snapshots/660b9e20307a2b0cdb400d0f80aadc04a701fc54/pytorch_model.bin`


Model was trained with pyannote.audio 0.0.1, yours is 3.2.0. Bad things might happen unless you revert pyannote.audio to 0.x.
Model was trained with torch 1.10.0+cu102, yours is 2.3.1. Bad things might happen unless you revert torch to 1.x.


Lightning automatically upgraded your loaded checkpoint from v1.2.7 to v2.2.5. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../../../../.cache/torch/pyannote/models--pyannote--embedding/snapshots/4db4899737a38b2d618bbd74350915aa10293cb2/pytorch_model.bin`
Lightning automatically upgraded your loaded checkpoint from v1.2.7 to v2.2.5. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../../../../.cache/torch/pyannote/models--pyannote--embedding/snapshots/4db4899737a38b2d618bbd74350915aa10293cb2/pytorch_model.bin`


Model was trained with pyannote.audio 0.0.1, yours is 3.2.0. Bad things might happen unless you revert pyannote.audio to 0.x.
Model was trained with torch 1.8.1+cu102, yours is 2.3.1. Bad things might happen unless you revert torch to 1.x.
Model was trained with pyannote.audio 0.0.1, yours is 3.2.0. Bad things might happen unless you revert pyannote.audio to 0.x.
Model was trained with torch 1.8.1+cu102, yours is 2.3.1. Bad things might happen unless you revert torch to 1.x.


### Note: Tuning Parameters - From Coria Et. Al.'s 2021  <a href='github.com/juanmc2005/diart/blob/main/paper.pdf'>Paper</a>

-  Tau_active: tunable speaker activity probability threshold (0-1). (Speakers whose activity exceeds Tau_active during a chunk constitute the set of local speakers, their activity probabilities are then passed downstream to the incremental clustering step--thus handling the overlapping-speech problem fromt the start as opposed to in post-processing).

- Rho_update: tunable parameter controlling the rate at which speaker embeddings are updated (seconds). (Similar to/represents the learning rate)

- Delta_new: tunable parameter controlling the threshold for considering new speakers in a chunk (0-2). Lower values will make the system more sensitive to different voices (Mathematically represents the minimum distance required between embeddings for the clusters to be considered seperable)

<a href='github.com/juanmc2005/diart/blob/main/src/diart/console/tune.py'>diart.tune</a> - can be used to automatically tune these parameters.

## Whisper/ASR Module

In [4]:
# Class definition
@contextmanager
def suppress_stdout():
    # Aux function to suppress (extremely long) whisper logs (https://thesmithfam.org/blog/2012/10/25/temporarily-suppress-console-output-in-python/)
    with open(os.devnull, "w") as devnull: # open the devnull file (discards all data written to it!)
        old_stdout = sys.stdout # save current std output (to restore later)
        sys.stdout = devnull # redirect std output to devnull (discarding any logs)
        try: 
            yield # temporarily exit context manager, allowing wrapped code to execute with the suppressed std output
        finally: # ensure original stdout is restored even if an exception occures (don't wanna play with that!!)
            sys.stdout = old_stdout 

class WhisperTranscriber:
    def __init__(self, model="small", device=None):
        self.model = whisper.load_model(model, device=device)
        self._buffer = ""

    def transcribe(self, waveform):
        """Transcribe audio using Whisper"""
        # Pad/trim audio to fit 30 seconds as required by Whisper
        audio = waveform.data.astype("float32").reshape(-1)
        audio = whisper.pad_or_trim(audio)

        # Transcribe the given audio while suppressing logs
        with suppress_stdout():
            transcription = whisper.transcribe(
                self.model,
                audio,
                # We use past transcriptions to condition the model
                initial_prompt=self._buffer,
                verbose=True  # to avoid progress bar
            )

        return transcription

    def identify_speakers(self, transcription, diarization, time_shift):
        """Iterate over transcription segments to assign speakers"""
        speaker_captions = []
        for segment in transcription["segments"]:

            # Crop diarization to the segment timestamps
            start = time_shift + segment["words"][0]["start"]
            end = time_shift + segment["words"][-1]["end"]
            dia = diarization.crop(Segment(start, end))

            # Assign a speaker to the segment based on diarization
            speakers = dia.labels()
            num_speakers = len(speakers)
            if num_speakers == 0:
                # No speakers were detected
                caption = (-1, segment["text"])
            elif num_speakers == 1:
                # Only one speaker is active in this segment
                spk_id = int(speakers[0].split("speaker")[1])
                caption = (spk_id, segment["text"])
            else:
                # Multiple speakers, select the one that speaks the most
                max_speaker = int(np.argmax([
                    dia.label_duration(spk) for spk in speakers
                ]))
                caption = (max_speaker, segment["text"])
            speaker_captions.append(caption)

        return speaker_captions

    def __call__(self, diarization, waveform):
        # Step 1: Transcribe
        transcription = self.transcribe(waveform)
        # Update transcription buffer
        self._buffer += transcription["text"]
        # The audio may not be the beginning of the conversation
        time_shift = waveform.sliding_window.start
        # Step 2: Assign speakers
        speaker_transcriptions = self.identify_speakers(transcription, diarization, time_shift)
        return speaker_transcriptions

In [5]:
asr = WhisperTranscriber

## Combining Both Modules

In [6]:
import numpy as np
from pyannote.core import Annotation, SlidingWindowFeature, SlidingWindow

def concat(chunks, collar=0.05):
    """
    Concatenate predictions and audio
    given a list of `(diarization, waveform)` pairs
    and merge contiguous single-speaker regions
    with pauses shorter than `collar` seconds.
    """
    first_annotation = chunks[0][0]
    first_waveform = chunks[0][1]
    annotation = Annotation(uri=first_annotation.uri)
    data = []
    for ann, wav in chunks:
        annotation.update(ann)
        data.append(wav.data)
    annotation = annotation.support(collar)
    window = SlidingWindow(
        first_waveform.sliding_window.duration,
        first_waveform.sliding_window.step,
        first_waveform.sliding_window.start,
    )
    data = np.concatenate(data, axis=0)
    return annotation, SlidingWindowFeature(data, window)

def colorize_transcription(transcription):
    """
    Unify a speaker-aware transcription represented as
    a list of `(speaker: int, text: str)` pairs
    into a single text colored by speakers.
    """
    colors = 2 * [
        "bright_red", "bright_blue", "bright_green", "orange3", "deep_pink1",
        "yellow2", "magenta", "cyan", "bright_magenta", "dodger_blue2"
    ]
    result = []
    for speaker, text in transcription:
        if speaker == -1:
            # No speakerfound for this text, use default terminal color
            result.append(text)
        else:
            result.append(f"[{colors[speaker]}]{text}")
    return "\n".join(result)

In [7]:
# Split the stream into 2s chunks for transcription
transcription_duration = 2
# Apply models in batches for better efficiency
batch_size = int(transcription_duration // config.step)

# Chain of operations to apply on the stream of microphone audio
source.stream.pipe(
    # Format audio stream to sliding windows of 5s with a step of 500ms
    dops.rearrange_audio_stream(
        config.duration, config.step, config.sample_rate
    ),
    # Wait until a batch is full
    # The output is a list of audio chunks
    ops.buffer_with_count(count=batch_size),
    # Obtain diarization prediction
    # The output is a list of pairs `(diarization, audio chunk)`
    ops.map(diarizer),
    # Concatenate 500ms predictions/chunks to form a single 2s chunk
    ops.map(concat),
    # Ignore this chunk if it does not contain speech
    ops.filter(lambda ann_wav: ann_wav[0].get_timeline().duration() > 0),
    # Obtain speaker-aware transcriptions
    # The output is a list of pairs `(speaker: int, caption: str)`
    ops.starmap(asr),
    # Color transcriptions according to the speaker
    # The output is plain text with color references for rich
    ops.map(colorize_transcription),
).subscribe(
    on_next=rich.print,  # print colored text
    on_error=lambda _: traceback.print_exc()  # print stacktrace if error
)

<rx.disposable.disposable.Disposable at 0x2a047dc70>

In [8]:
print('listening...')
source.read()

listening...


Traceback (most recent call last):
  File "/Users/gael/miniconda3/envs/ml_env/lib/python3.8/site-packages/rx/core/operators/map.py", line 37, in on_next
    result = _mapper(value)
  File "/Users/gael/miniconda3/envs/ml_env/lib/python3.8/site-packages/rx/operators/__init__.py", line 2662, in <lambda>
    return pipe(map(lambda values: cast(Mapper, mapper)(*values)))
  File "/var/folders/15/wd23przs6r1gtq0g6ks4kvkh0000gn/T/ipykernel_2807/3041333293.py", line 15, in __init__
    self.model = whisper.load_model(model, device=device)
  File "/Users/gael/miniconda3/envs/ml_env/lib/python3.8/site-packages/whisper_timestamped/transcribe.py", line 2441, in load_model
    extension = os.path.splitext(name)[-1] if os.path.isfile(name) else None
  File "/Users/gael/miniconda3/envs/ml_env/lib/python3.8/genericpath.py", line 30, in isfile
    st = os.stat(path)
TypeError: stat: path should be string, bytes, os.PathLike or integer, not Annotation
