In [None]:
import os
import sys
import logging
import traceback 
import hashlib
import pandas as pd
from dotenv import load_dotenv
from pprint import pprint

import rx
import rx.operators as ops
import diart.operators as dops
from diart import OnlineSpeakerDiarization, PipelineConfig 
from diart.sources import WebSocketAudioSource

from source import clair
from source import speech
from source.utils import Logger

In [None]:
# Load environment variables from .env file
load_dotenv()
HOST = os.environ.get("HOST")
PORT = int(os.environ.get("PORT"))
WHISPER_SIZE = os.environ.get("WHISPER_SIZE")
WHISPER_COMPRESS_RATIO_THRESHOLD = float(os.environ.get("WHISPER_COMPRESS_RATIO_THRESHOLD"))
WHISPER_NO_SPEECH_THRESHOLD = float(os.environ.get("WHISPER_NO_SPEECH_THRESHOLD"))
PIPELINE_MAX_SPEAKERS = int(os.environ.get("PIPELINE_MAX_SPEAKERS"))
PIPELINE_DURATION = float(os.environ.get("PIPELINE_DURATION"))
PIPELINE_STEP = float(os.environ.get("PIPELINE_STEP"))
PIPELINE_SAMPLE_RATE = int(os.environ.get("PIPELINE_SAMPLE_RATE"))
PIPELINE_TAU = float(os.environ.get("PIPELINE_TAU"))
PIPELINE_RHO = float(os.environ.get("PIPELINE_RHO"))
PIPELINE_DELTA = float(os.environ.get("PIPELINE_DELTA"))
PIPELINE_CHUNK_DURATION = float(os.environ.get("PIPELINE_CHUNK_DURATION"))

In [None]:
CLAIR_URL = os.environ.get("CLAIR_URL")

req = clair.activate_configuration(
    mode='ssrl', 
    language='EN', 
    keywords=['force', 'energy conservation', 'kinectic', 'potential'],
    host=CLAIR_URL
)
req.status_code, req.reason, req.text

In [None]:
# Pipeline params. haven't tinkered with them much. you can also set device=torch.device("cuda")
speech_config = PipelineConfig(
    duration=PIPELINE_DURATION,
    step=PIPELINE_STEP, # When lower is more accurate but slower
    latency="min",  # When higher is more accurate but slower
    tau_active=PIPELINE_TAU, # suggested by diart paper 
    rho_update=PIPELINE_RHO, # suggested by diart paper
    delta_new=PIPELINE_DELTA,  # suggested by diart paper
    device="cuda",
    max_speakers=PIPELINE_MAX_SPEAKERS,
)
pprint(speech_config.__dict__, indent=2)

# Split the stream into chunks of seconds for transcription
transcription_duration = 10 # seconds
# Apply models in batches for better efficiency
batch_size = int(transcription_duration // speech_config.step)

# Suppress whisper-timestamped warnings for a clean output
logging.getLogger("whisper_timestamped").setLevel(logging.ERROR)
# Set the whisper model size, you can also set device="cuda"
asr = speech.WhisperTranscriber(model=WHISPER_SIZE, device="cuda")
dia = OnlineSpeakerDiarization(speech_config)

# Instantiate a new dialogue
dialogue = []
group_id = hashlib.md5(pd.Timestamp.now().strftime('%Y%m%d%H%M%S').encode()).hexdigest()[:6]
last_processed_turn = {'last_turn': None}

In [None]:
# Set up audio sources
# source = MicrophoneAudioSource(config.sample_rate)
source = WebSocketAudioSource(speech_config.sample_rate, HOST, PORT)

# Chain of operations to test message helper for the stream of microphone audio
source.stream.pipe(
    # Format audio stream to sliding windows of 5s with a step of 500ms
    dops.rearrange_audio_stream(
        speech_config.duration, speech_config.step, speech_config.sample_rate
    ),
    # Wait until a batch is full. The output is a list of audio chunks
    ops.buffer_with_count(count=batch_size),
    # Obtain diarization prediction. The output is a list of pairs `(diarization, audio chunk)`
    ops.map(dia),
    # Concatenate 500ms predictions/chunks to form a single 2s chunk
    ops.map(speech.concat),
    # Modify ASR to handle empty chunks
    ops.map(lambda ann_wav: ('', '') if ann_wav[0].get_timeline().duration() == 0 else asr(*ann_wav)),
    # Modify transcription step to handle empty inputs
    ops.map(lambda speaker_caption: '' if speaker_caption == ('', '') else speech.message_transcription(speaker_caption)),
    # Buffering transcriptions until there is a new turn, adjusted to handle multiple turns
    ops.map(lambda text: clair.buffering_turn(text, dialogue, group_id, 
                                              turn_threshold=5,
                                              silence_threshold=3, 
                                              last_processed_turn=last_processed_turn,
                                              verbose=True)),
    # Use flat_map to handle each turn in the list individually
    ops.flat_map(lambda turns: rx.from_iterable(turns)),
    # Filter out empty turns based on the 'text' content
    ops.filter(lambda turn: 'text' not in turn or turn['text'].strip() != ''),
    # Send to API and get response
    ops.map(lambda turn: clair.send_to_api_and_get_response(**turn, dialogue=dialogue, host=CLAIR_URL, verbose=True))
).subscribe(
    on_next=lambda output: (print(output), source.send(output)) if output else None,
    on_error=lambda _: traceback.print_exc()  # print stacktrace if error
)

# Save the original stdout
original_stdout = sys.stdout
try:
    # Initialize the logger
    sys.stdout = Logger(f'logs/{group_id}.txt', original_stdout)
    print(f"Listening... {group_id}")
    # Stream of data
    source.read()  
finally:
    print("Stopped listening")
    # Ensure the log file is properly closed
    sys.stdout.log.close()
    # Restore stdout to its original state
    sys.stdout = original_stdout

    print(dialogue)
    # Export transcribed dialogue to a csv file
    for msg in dialogue:
        msg['group'] = group_id
        msg['text'] = f'"{msg["text"]}"' if '"' not in msg['text'] else f"{msg['text']}"
    pd.DataFrame(dialogue)[['group', 'username', 'timestamp', 'text']]\
        .to_csv(f'logs/{group_id}.csv', index=False, sep="|")