In [1]:
##imports
import logging
import os
import sys
import traceback
from contextlib import contextmanager

In [2]:
##spicier imports
import diart.operators as dops
import numpy as np
import rich
import rx.operators as ops
import whisper_timestamped as whisper
from diart import OnlineSpeakerDiarization, PipelineConfig
from diart.sources import MicrophoneAudioSource
from pyannote.core import Annotation, SlidingWindowFeature, SlidingWindow, Segment

  from .autonotebook import tqdm as notebook_tqdm
The torchaudio backend is switched to 'soundfile'. Note that 'sox_io' is not supported on Windows.
The torchaudio backend is switched to 'soundfile'. Note that 'sox_io' is not supported on Windows.


Importing the dtw module. When using in academic works please cite:
  T. Giorgino. Computing and Visualizing Dynamic Time Warping Alignments in R: The dtw Package.
  J. Stat. Soft., doi:10.18637/jss.v031.i07.



In [3]:
##helper function to concatenate transcriptions and diarizations
def concat(chunks, collar=0.05):
    """
    Concatenate predictions and audio
    given a list of `(diarization, waveform)` pairs
    and merge contiguous single-speaker regions
    with pauses shorter than `collar` seconds.
    """
    first_annotation = chunks[0][0]
    first_waveform = chunks[0][1]
    annotation = Annotation(uri=first_annotation.uri)
    data = []
    for ann, wav in chunks:
        annotation.update(ann)
        data.append(wav.data)
    annotation = annotation.support(collar)
    window = SlidingWindow(
        first_waveform.sliding_window.duration,
        first_waveform.sliding_window.step,
        first_waveform.sliding_window.start,
    )
    data = np.concatenate(data, axis=0)
    return annotation, SlidingWindowFeature(data, window)



In [4]:
##helper function to make different speakers appear as messages
def message_transcription(transcription):
        
    result = []
    for speaker, text in transcription:
        if speaker == -1:
            # No speakerfound for this text, use default terminal color
            result.append(text)
        else:
            result.append("Speaker"+str(speaker)+": "+text)
    return "\n".join(result)


In [6]:
@contextmanager
def suppress_stdout():
    # Auxiliary function to suppress Whisper logs (it is quite verbose)
    with open(os.devnull, "w") as devnull:
        old_stdout = sys.stdout
        sys.stdout = devnull
        try:
            yield
        finally:
            sys.stdout = old_stdout


class WhisperTranscriber:
    def __init__(self, model="small", device=None):
        self.model = whisper.load_model(model, device=device)
        self._buffer = ""

    def transcribe(self, waveform):
        """Transcribe audio using Whisper"""
        # Pad/trim audio to fit 30 seconds as required by Whisper
        audio = waveform.data.astype("float32").reshape(-1)
        audio = whisper.pad_or_trim(audio)

        # Transcribe the given audio while suppressing logs
        with suppress_stdout():
            transcription = whisper.transcribe(
                self.model,
                audio,
                # We use past transcriptions to condition the model
                initial_prompt=self._buffer,
                verbose=True,  # to avoid progress bar
                ##decode_options=options
            )

        return transcription
    
    def identify_speakers(self, transcription, diarization, time_shift):
        """Iterate over transcription segments to assign speakers"""
        speaker_captions = []
        for segment in transcription["segments"]:

            # Crop diarization to the segment timestamps
            start = time_shift + segment["words"][0]["start"]
            end = time_shift + segment["words"][-1]["end"]
            dia = diarization.crop(Segment(start, end))

            # Assign a speaker to the segment based on diarization
            speakers = dia.labels()
            num_speakers = len(speakers)
            if num_speakers == 0:
                # No speakers were detected
                caption = (-1, segment["text"])
            elif num_speakers == 1:
                # Only one speaker is active in this segment
                spk_id = int(speakers[0].split("speaker")[1])
                caption = (spk_id, segment["text"])
            else:
                # Multiple speakers, select the one that speaks the most
                max_speaker = int(np.argmax([
                    dia.label_duration(spk) for spk in speakers
                ]))
                caption = (max_speaker, segment["text"])
            speaker_captions.append(caption)

        return speaker_captions

    def __call__(self, diarization, waveform):
        # Step 1: Transcribe
        transcription = self.transcribe(waveform)
        # Update transcription buffer
        self._buffer += transcription["text"]
        # The audio may not be the beginning of the conversation
        time_shift = waveform.sliding_window.start
        # Step 2: Assign speakers
        speaker_transcriptions = self.identify_speakers(transcription, diarization, time_shift)
        return speaker_transcriptions


In [7]:
# Suppress whisper-timestamped warnings for a clean output
logging.getLogger("whisper_timestamped").setLevel(logging.ERROR)


In [8]:
# Pipeline params. haven't tinkered with them much. you can also set device=torch.device("cuda")
config = PipelineConfig(
    duration=5,
    step=0.5,
    latency="min",
    tau_active=0.5,
    rho_update=0.1,
    delta_new=0.57
)

In [22]:
## set up sources for both modules
dia = OnlineSpeakerDiarization(config)
source = MicrophoneAudioSource(config.sample_rate)

In [23]:
# Set the whisper model size, you can also set device="cuda"
asr = WhisperTranscriber(model="base")

In [11]:
# Split the stream into 2s chunks for transcription
transcription_duration = 2
# Apply models in batches for better efficiency
batch_size = int(transcription_duration // config.step)

In [24]:
# Chain of operations to test message helper for the stream of microphone audio
source.stream.pipe(
    # Format audio stream to sliding windows of 5s with a step of 500ms
    dops.rearrange_audio_stream(
        config.duration, config.step, config.sample_rate
    ),
    # Wait until a batch is full
    # The output is a list of audio chunks
    ops.buffer_with_count(count=batch_size),
    # Obtain diarization prediction
    # The output is a list of pairs `(diarization, audio chunk)`
    ops.map(dia),
    # Concatenate 500ms predictions/chunks to form a single 2s chunk
    ops.map(concat),
    # Ignore this chunk if it does not contain speech
    ops.filter(lambda ann_wav: ann_wav[0].get_timeline().duration() > 0),
    # Obtain speaker-aware transcriptions
    # The output is a list of pairs `(speaker: int, caption: str)`
    ops.starmap(asr),
    # Color transcriptions according to the speaker
    # The output is plain text with color references for rich
    ops.map(message_transcription),
).subscribe(
    on_next=rich.print,  # print colored text
    on_error=lambda _: traceback.print_exc()  # print stacktrace if error
)

<rx.disposable.disposable.Disposable at 0x17e0f978610>

In [25]:
## Make the magic happen
print("Listening...")
source.read()

Listening...


Traceback (most recent call last):
  File "C:\Users\jedwards23\Anaconda3\envs\diart\lib\site-packages\diart\sources.py", line 173, in read
    self.stream.on_next(self._queue.get_nowait())
  File "C:\Users\jedwards23\Anaconda3\envs\diart\lib\site-packages\rx\subject\subject.py", line 55, in on_next
    super().on_next(value)
  File "C:\Users\jedwards23\Anaconda3\envs\diart\lib\site-packages\rx\core\observer\observer.py", line 26, in on_next
    self._on_next_core(value)
  File "C:\Users\jedwards23\Anaconda3\envs\diart\lib\site-packages\rx\subject\subject.py", line 62, in _on_next_core
    observer.on_next(value)
  File "C:\Users\jedwards23\Anaconda3\envs\diart\lib\site-packages\rx\core\observer\autodetachobserver.py", line 26, in on_next
    self._on_next(value)
  File "C:\Users\jedwards23\Anaconda3\envs\diart\lib\site-packages\rx\core\operators\map.py", line 41, in on_next
    obv.on_next(result)
  File "C:\Users\jedwards23\Anaconda3\envs\diart\lib\site-packages\rx\core\observer\autod