# Analytics-decoder

Note: This code serves as a non-functional boilerplate for developing other methods for decoding desirable analytics from speech interfaces using the code implemented in the `source/` folder. Again, this notebook does not yet work a mocked usecase example, but it is a starting point for developing a more functional version.

In [None]:
import os
import logging
import traceback 
from pprint import pprint
from dotenv import load_dotenv
from diart import OnlineSpeakerDiarization, PipelineConfig 
from diart.sources import WebSocketAudioSource
import diart.operators as dops
import rx.operators as ops

from source import analytics
from source import speech

In [None]:
# Load environment variables from .env file
load_dotenv()
HOST = os.environ.get("HOST")
PORT = int(os.environ.get("PORT"))
WHISPER_SIZE = os.environ.get("WHISPER_SIZE")
WHISPER_COMPRESS_RATIO_THRESHOLD = float(os.environ.get("WHISPER_COMPRESS_RATIO_THRESHOLD"))
WHISPER_NO_SPEECH_THRESHOLD = float(os.environ.get("WHISPER_NO_SPEECH_THRESHOLD"))
PIPELINE_MAX_SPEAKERS = int(os.environ.get("PIPELINE_MAX_SPEAKERS"))
PIPELINE_DURATION = float(os.environ.get("PIPELINE_DURATION"))
PIPELINE_STEP = float(os.environ.get("PIPELINE_STEP"))
PIPELINE_SAMPLE_RATE = int(os.environ.get("PIPELINE_SAMPLE_RATE"))
PIPELINE_TAU = float(os.environ.get("PIPELINE_TAU"))
PIPELINE_RHO = float(os.environ.get("PIPELINE_RHO"))
PIPELINE_DELTA = float(os.environ.get("PIPELINE_DELTA"))
PIPELINE_CHUNK_DURATION = float(os.environ.get("PIPELINE_CHUNK_DURATION"))

In [None]:
# Pipeline params. haven't tinkered with them much. you can also set device=torch.device("cuda")
speech_config = PipelineConfig(
    duration=5,
    step=0.5, # When lower is more accurate but slower
    latency="min",  # When higher is more accurate but slower
    tau_active=0.555, # suggested by diart paper 
    rho_update=0.422, # suggested by diart paper
    delta_new=1.517,  # suggested by diart paper
    device="cuda",
    max_speakers=2,
)
pprint(speech_config.__dict__, indent=2)

# Split the stream into chunks of seconds for transcription
transcription_duration = 10 # seconds
# Apply models in batches for better efficiency
batch_size = int(transcription_duration // speech_config.step)

# Suppress whisper-timestamped warnings for a clean output
logging.getLogger("whisper_timestamped").setLevel(logging.ERROR)
# Set the whisper model size, you can also set device="cuda"
asr = speech.WhisperTranscriber(model=WHISPER_SIZE, device="cuda")
dia = OnlineSpeakerDiarization(speech_config)

# Set up audio sources
# source = MicrophoneAudioSource(config.sample_rate)
source = WebSocketAudioSource(speech_config.sample_rate, "localhost", 5000)

# Instantiate a new dialogue
dialogue_state = analytics.DialogueState()

# Chain of operations to test message helper for the stream of microphone audio
source.stream.pipe(
    # Format audio stream to sliding windows of 5s with a step of 500ms
    dops.rearrange_audio_stream(
        speech_config.duration, speech_config.step, speech_config.sample_rate
    ),
    # Wait until a batch is full. The output is a list of audio chunks
    ops.buffer_with_count(count=batch_size),
    # Obtain diarization prediction. The output is a list of pairs `(diarization, audio chunk)`
    ops.map(dia),
    # Concatenate 500ms predictions/chunks to form a single 2s chunk
    ops.map(speech.concat),
    # Ignore this chunk if it does not contain speech
    ops.filter(lambda ann_wav: ann_wav[0].get_timeline().duration() > 0),
    # Obtain speaker-aware transcriptions. The output is a list of pairs `(speaker: int, caption: str)`
    ops.starmap(asr),
    # Transcriptions
    ops.map(speech.message_transcription),
    # Buffering transcriptions until there is a new turn
    ops.map(lambda text: analytics.buffering_turn(text, dialogue_state.turns_list, verbose=True)),
    # Filter out None inputs
    ops.filter(lambda turn: turn is not None), 
    # Send to API and get response
    ops.map(lambda turn: analytics.compute_turn(**turn, dialogue_state=dialogue_state, verbose=True))
).subscribe(
    on_next=print,
    on_error=lambda _: traceback.print_exc()  # print stacktrace if error
)

In [None]:
# After running this cell, start the client_mic.py script to listen to a remote audio stream
print("Listening...")
source.read()