Importing Required Modules

In [1]:
import numpy as np
import pyaudio
import sys
from scipy.io import wavfile
import time
import torch

In [2]:
import os
import sys
data_path = os.path.abspath(os.path.join(os.getcwd(), '..', 'data'))
sys.path.append(data_path)

Loading the model

In [3]:
RATE = 16000
CHUNK = 512  # Adjusted CHUNK size to match Silero VAD expectation for 16kHz
FORMAT = pyaudio.paInt16
CHANNELS = 1
STOP_MS = 1000
PRE_SPEECH_MS = 200
MAX_DURATION_SECONDS = 16  # Maximum duration for the W2v-BERT model
VAD_THRESHOLD = 0.7  # Threshold for speech detection
TEMP_OUTPUT_WAV = data_path+r"\audio\temp_output.wav"

In [4]:
try:
    print("Attempting to load Silero VAD model...")
    torch.hub.set_dir("../model/.torch_hub")  # Set a specific hub directory for debugging
    MODEL, _ = torch.hub.load(
        repo_or_dir="snakers4/silero-vad",
        source="github",
        model="silero_vad",  # Explicitly specify the model name
        onnx=False,
        trust_repo=True,
    )  # Trust the repository
    print("Silero VAD model loaded successfully.")
except Exception as e:
    print(f"Error loading Silero VAD model: {e}")
    print(
        "Please ensure you have an internet connection and that torch and torchaudio are installed correctly."
    )
    print("You can try installing them using: pip install torch torchaudio")
    print("If the issue persists, try clearing the torch hub cache:")
    print("import torch; print(torch.hub.get_dir())  # Find the cache directory")
    print("Then manually delete the 'snakers4_silero-vad_...' folder.")
    sys.exit(1)


Attempting to load Silero VAD model...


Using cache found in ../model/.torch_hub\snakers4_silero-vad_master


Silero VAD model loaded successfully.


In [5]:
import torch
from transformers import Wav2Vec2BertForSequenceClassification, AutoFeatureExtractor

In [6]:
MODEL_PATH = "pipecat-ai/smart-turn"

In [8]:
model = Wav2Vec2BertForSequenceClassification.from_pretrained(MODEL_PATH)
processor = AutoFeatureExtractor.from_pretrained(MODEL_PATH)

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [10]:
model = model.to(device)
model.eval()

Wav2Vec2BertForSequenceClassification(
  (wav2vec2_bert): Wav2Vec2BertModel(
    (feature_projection): Wav2Vec2BertFeatureProjection(
      (layer_norm): LayerNorm((160,), eps=1e-05, elementwise_affine=True)
      (projection): Linear(in_features=160, out_features=1024, bias=True)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): Wav2Vec2BertEncoder(
      (dropout): Dropout(p=0.0, inplace=False)
      (layers): ModuleList(
        (0-23): 24 x Wav2Vec2BertEncoderLayer(
          (ffn1_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (ffn1): Wav2Vec2BertFeedForward(
            (intermediate_dropout): Dropout(p=0.0, inplace=False)
            (intermediate_dense): Linear(in_features=1024, out_features=4096, bias=True)
            (intermediate_act_fn): SiLU()
            (output_dense): Linear(in_features=4096, out_features=1024, bias=True)
            (output_dropout): Dropout(p=0.0, inplace=False)
          )
          (self_attn_layer_n

Turn detection Function

In [11]:
def predict_endpoint(audio_array):
    """
    Predict whether an audio segment is complete (turn ended) or incomplete.

    Args:
        audio_array: Numpy array containing audio samples at 16kHz

    Returns:
        Dictionary containing prediction results:
        - prediction: 1 for complete, 0 for incomplete
        - probability: Probability of completion class
    """

    # Process audio
    inputs = processor(
        audio_array,
        sampling_rate=16000,
        padding="max_length",
        truncation=True,
        max_length=800,  # Maximum length as specified in training
        return_attention_mask=True,
        return_tensors="pt",
    )

    # Move inputs to device
    inputs = {k: v.to(device) for k, v in inputs.items()}

    # Run inference
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits

        # Get probabilities using softmax
        probabilities = torch.nn.functional.softmax(logits, dim=1)
        completion_prob = probabilities[0, 1].item()  # Probability of class 1 (Complete)

        # Make prediction (1 for Complete, 0 for Incomplete)
        prediction = 1 if completion_prob > 0.5 else 0

    return {
        "prediction": prediction,
        "probability": completion_prob,
    }

In [12]:
def process_speech_segment(audio_buffer, speech_start_time, speech_stop_time):
    if not audio_buffer:
        return

    # Find start and end indices for the segment
    start_time = speech_start_time - (PRE_SPEECH_MS / 1000)
    start_index = 0
    for i, (t, _) in enumerate(audio_buffer):
        if t >= start_time:
            start_index = i
            break

    end_index = len(audio_buffer) - 1

    # Extract the audio segment
    segment_audio_chunks = [chunk for _, chunk in audio_buffer[start_index: end_index + 1]]
    segment_audio = np.concatenate(segment_audio_chunks)

    # Remove (STOP_MS - 200)ms from the end of the segment
    samples_to_remove = int((STOP_MS - 200) / 1000 * RATE)
    segment_audio = segment_audio[:-samples_to_remove]

    # Limit maximum duration
    if len(segment_audio) / RATE > MAX_DURATION_SECONDS:
        segment_audio = segment_audio[: int(MAX_DURATION_SECONDS * RATE)]

    # No resampling needed as both recording and prediction use 16000 Hz
    segment_audio_resampled = segment_audio

    if len(segment_audio_resampled) > 0:
        # Save the audio for debugging purposes
        wavfile.write(TEMP_OUTPUT_WAV, RATE, (segment_audio_resampled * 32767).astype(np.int16))
        print(f"Processing speech segment of length {len(segment_audio) / RATE:.2f} seconds...")

        # Call the new predict_endpoint function with the audio data
        result = predict_endpoint(segment_audio_resampled)

        print("--------")
        print(f"Prediction: {'Complete' if result['prediction'] == 1 else 'Incomplete'}")
        print(f"Probability of complete: {result['probability']:.4f}")
    else:
        print("Captured empty audio segment, skipping prediction.")


Using real time Recording

In [18]:
p=pyaudio.PyAudio()

In [15]:
for i in range(p.get_device_count()):
    device_info = p.get_device_info_by_index(i)
    print(f"Device {i}: {device_info['name']}, Input Channels: {device_info['maxInputChannels']}")


Device 0: Microsoft Sound Mapper - Input, Input Channels: 2
Device 1: Microphone Array (Realtek(R) Au, Input Channels: 2
Device 2: AI Noise-cancelling Input (ASUS, Input Channels: 2
Device 3: Microsoft Sound Mapper - Output, Input Channels: 0
Device 4: Headphones (Realtek(R) Audio), Input Channels: 0
Device 5: Speakers (Realtek(R) Audio), Input Channels: 0
Device 6: AI Noise-cancelling Output (ASU, Input Channels: 0
Device 7: Primary Sound Capture Driver, Input Channels: 2
Device 8: Microphone Array (Realtek(R) Audio), Input Channels: 2
Device 9: AI Noise-cancelling Input (ASUS Utility), Input Channels: 2
Device 10: Primary Sound Driver, Input Channels: 0
Device 11: Headphones (Realtek(R) Audio), Input Channels: 0
Device 12: Speakers (Realtek(R) Audio), Input Channels: 0
Device 13: AI Noise-cancelling Output (ASUS Utility), Input Channels: 0
Device 14: Speakers (Realtek(R) Audio), Input Channels: 0
Device 15: AI Noise-cancelling Output (ASUS Utility), Input Channels: 0
Device 16: Headp

In [19]:
stream = p.open(
    format=FORMAT, channels=CHANNELS, rate=RATE, input=True, frames_per_buffer=CHUNK
)

print("Clearing audio buffer and listening for speech...")

audio_buffer = []
silence_frames = 0
speech_start_time = None
speech_triggered = False

try:
    while True:
        data = stream.read(CHUNK)
        audio_np = np.frombuffer(data, dtype=np.int16)
        audio_float32 = audio_np.astype(np.float32) / np.iinfo(np.int16).max

        # Use Silero VAD model directly to get probability
        speech_prob = MODEL(torch.from_numpy(audio_float32).unsqueeze(0), RATE).item()
        is_speech = speech_prob > VAD_THRESHOLD

        if is_speech:
            if not speech_triggered:
                silence_frames = 0
            speech_triggered = True
            if speech_start_time is None:
                speech_start_time = time.time()
            audio_buffer.append((time.time(), audio_float32))
        else:
            if speech_triggered:
                audio_buffer.append((time.time(), audio_float32))
                silence_frames += 1
                if silence_frames * (CHUNK / RATE) >= STOP_MS / 1000:
                    speech_triggered = False
                    speech_stop_time = time.time()

                    # Stop the stream before processing
                    stream.stop_stream()

                    process_speech_segment(audio_buffer, speech_start_time, speech_stop_time)
                    audio_buffer = []
                    speech_start_time = None

                    # Restart the stream after processing
                    stream.start_stream()
                    print("Listening for speech...")
            else:
                # Keep buffering some silence before potential speech starts
                audio_buffer.append((time.time(), audio_float32))
                # Keep the buffer size reasonable, assuming CHUNK is small
                max_buffer_time = (PRE_SPEECH_MS + STOP_MS) / 1000 + MAX_DURATION_SECONDS  # Some extra buffer
                while audio_buffer and audio_buffer[0][0] < time.time() - max_buffer_time:
                    audio_buffer.pop(0)

except KeyboardInterrupt:
    print("Stopping...")
finally:
    stream.stop_stream()
    stream.close()
    p.terminate()

Clearing audio buffer and listening for speech...
Processing speech segment of length 1.38 seconds...
--------
Prediction: Complete
Probability of complete: 0.9972
Listening for speech...
Processing speech segment of length 3.65 seconds...
Stopping...
