In [1]:
# Wake Word Detection Data Generation Pipeline
# This notebook generates training data (activates, negatives, backgrounds) for wake-word detection.

# --- Imports ---
import os
import glob
import json
import logging
import random
import requests
from pathlib import Path
from typing import List, Dict, Tuple
from dataclasses import dataclass, asdict
from concurrent.futures import ThreadPoolExecutor, as_completed
from pydub import AudioSegment
from dotenv import load_dotenv
import time

In [2]:
# --- Load Environment Variables ---
load_dotenv()  # Load variables from .env

API_KEY = os.getenv("ELEVEN_API_KEY")
BASE_FOLDER = os.getenv("BASE_FOLDER")
WAKE_WORD = os.getenv("WAKE_WORD", "Hey Jerry")
BASE_URL = os.getenv("ELEVEN_API_URL", "https://api.elevenlabs.io/v1/text-to-speech")  # Default if not set

if not API_KEY:
    raise ValueError("ELEVEN_API_KEY not found. Please add it to your .env file.")
if not BASE_FOLDER:
    raise ValueError("BASE_FOLDER not found. Please add it to your .env file.")
if not BASE_URL:
    raise ValueError("ELEVEN_API_URL not found. Please add it to your .env file.")

# --- Configuration ---
@dataclass
class PipelineConfig:
    api_key: str = API_KEY
    base_folder: str = BASE_FOLDER
    wake_word: str = WAKE_WORD
    base_url: str = BASE_URL
    voice_pool_size: int = 18
    negatives_per_phrase: int = 4
    audio_sample_rate: int = 44100
    audio_channels: int = 1
    background_duration_ms: int = 10000
    stability: float = 0.75
    similarity_boost: float = 0.75
    max_workers: int = 4
    retry_attempts: int = 3
    retry_delay: float = 1.0

Initialize Pipeline Class

In [3]:
# ----------------- Pipeline -----------------
class WakeWordDataPipeline:
    def __init__(self, config: PipelineConfig):
        self.config = config
        self.logger = self._setup_logging()
        self.voice_ids = self._load_voice_ids()
        self.negative_phrases = self._load_negative_phrases()
        self.folders = self._create_directory_structure()
    
    def _setup_logging(self):
        logging.basicConfig(
            level=logging.INFO,
            format='%(asctime)s - %(levelname)s - %(message)s',
            handlers=[logging.StreamHandler()]
        )
        return logging.getLogger(__name__)
    
    def _load_voice_ids(self) -> List[str]:
        return [
            "pNInz6obpgDQGcFmaJgB", "Xb7hH8MSUJpSbSDYk0k2", "ErXwobaYiN019PkySvjV",
            # "9BWtsMINqrJLrRacOk9x", "VR6AewLTigWG4xSOukaG", "wViXBPUzp2ZZixB1xQuM",
            # "pqHfZKP75CvOlQylNhV4", "N2lVS1w4EtoT3dr4eOWO", "IKne3meq5aSn9XLyUdCD",
            # "XB0fDUnXU5powFXDhCwa", "iP95p4xoKVk53GoZ742B", "2EiwWnXFnvU5JabPnv8n",
            # "AZnzlk1XvdvUeBnXmlld", "ThT5KcBeYPX3keUQqHPh", "LcfcDJNUP1GQjkzn1xUU",
            # "cjVigY5qzO86Huf0OWal", "jsCqWAovK2LkecY7zXl4", "Yko7PKHZNXotIFUBG7I9"
        ]
    
    def _load_negative_phrases(self) -> List[str]:
        return [
            "Terry", "Harry", "Gary", "Barry", "Larry", "Mary"
            # "Siri", "Google", "Alexa", "Jarvis", "Cortana",
            # "lights", "time", "engine", "help", "joke", "morning",
            # "alarm", "door", "messages", "music", "where", "protocol",
            # "computer", "phone", "radio", "weather", "calendar", "email",
            # "kitchen", "bedroom", "living room", "bathroom", "garage",
            # "turn on", "turn off", "play", "stop", "pause", "resume"
        ]
    
    def _create_directory_structure(self) -> Dict[str, str]:
        base_path = Path(self.config.base_folder)
        folders = {
            'base': str(base_path),
            'activates': str(base_path / 'activates'),
            'negatives': str(base_path / 'negatives'),
            'backgrounds': str(base_path / 'backgrounds'),
            'processed': str(base_path / 'processed'),
            'metadata': str(base_path / 'metadata')
        }
        for folder_path in folders.values():
            os.makedirs(folder_path, exist_ok=True)
        return folders

    # --- API call with retries ---
    def _make_api_request(self, voice_id: str, text: str, output_path: str) -> Tuple[bool, str]:
        url = f"{self.config.base_url}/{voice_id}/stream"  # use env variable now
        headers = {"xi-api-key": self.config.api_key, "Content-Type": "application/json"}
        payload = {"text": text, "voice_settings": {"stability": self.config.stability, "similarity_boost": self.config.similarity_boost}}
        
        for attempt in range(self.config.retry_attempts):
            try:
                response = requests.post(url, headers=headers, json=payload, timeout=30)
                if response.status_code == 200:
                    with open(output_path, 'wb') as f:
                        f.write(response.content)
                    return True, f"Success: {os.path.basename(output_path)}"
                else:
                    if attempt == self.config.retry_attempts - 1:
                        return False, f"API Error {response.status_code}: {response.text}"
                    time.sleep(self.config.retry_delay * (2 ** attempt))
            except requests.exceptions.RequestException as e:
                if attempt == self.config.retry_attempts - 1:
                    return False, f"Request failed: {str(e)}"
                time.sleep(self.config.retry_delay * (2 ** attempt))
        return False, "Max retries exceeded"

    # --- Step functions ---
    def generate_positive_samples(self) -> int:
        self.logger.info("Generating positive samples...")
        tasks = [(voice_id, self.config.wake_word, os.path.join(self.folders['activates'], f"pos_{self.config.wake_word.replace(' ', '_').lower()}_{voice_id}.mp3")) for voice_id in self.voice_ids]
        success_count = 0
        with ThreadPoolExecutor(max_workers=self.config.max_workers) as executor:
            future_to_task = {executor.submit(self._make_api_request, v, t, p): (v, t, p) for v,t,p in tasks}
            for future in as_completed(future_to_task):
                _, _, path = future_to_task[future]
                success, message = future.result()
                if success: success_count += 1
                self.logger.info(message)
        return success_count

    def generate_negative_samples(self) -> int:
        self.logger.info("Generating negative samples...")
        tasks = []
        for phrase in self.negative_phrases:
            for voice_id in random.sample(self.voice_ids, min(self.config.negatives_per_phrase, len(self.voice_ids))):
                clean_phrase = phrase.replace(" ", "_").replace("'", "").lower()
                tasks.append((voice_id, phrase, os.path.join(self.folders['negatives'], f"neg_{clean_phrase}_{voice_id}.mp3")))
        success_count = 0
        with ThreadPoolExecutor(max_workers=self.config.max_workers) as executor:
            future_to_task = {executor.submit(self._make_api_request, v, t, p): (v, t, p) for v,t,p in tasks}
            for future in as_completed(future_to_task):
                _, _, path = future_to_task[future]
                success, message = future.result()
                if success: success_count += 1
                self.logger.info(message)
        return success_count

    def convert_to_wav(self, folder_path: str) -> int:
        self.logger.info(f"Converting audio in {folder_path} to WAV...")
        audio_extensions = ['*.mp3', '*.m4a', '*.mp4', '*.flac', '*.ogg']
        audio_files = [f for ext in audio_extensions for f in glob.glob(os.path.join(folder_path, ext))]
        count = 0
        for file_path in audio_files:
            try:
                audio = AudioSegment.from_file(file_path)
                audio = audio.set_channels(self.config.audio_channels).set_frame_rate(self.config.audio_sample_rate)
                wav_path = os.path.splitext(file_path)[0] + '.wav'
                audio.export(wav_path, format='wav')
                os.remove(file_path)
                count += 1
                self.logger.info(f"Converted: {os.path.basename(file_path)}")
            except Exception as e:
                self.logger.error(f"Error converting {file_path}: {str(e)}")
        return count

    def standardize_filenames(self, folder_path: str, prefix: str) -> int:
        self.logger.info(f"Standardizing filenames in {folder_path}...")
        wav_files = sorted(glob.glob(os.path.join(folder_path, '*.wav')))
        for i, file_path in enumerate(wav_files, 1):
            new_path = os.path.join(folder_path, f"{prefix}-{i:04d}.wav")
            os.rename(file_path, new_path)
        return len(wav_files)

    def process_background_audio(self) -> int:
        self.logger.info("Processing background audio...")
        bg_files = glob.glob(os.path.join(self.folders['backgrounds'], '*.wav'))
        for file_path in bg_files:
            audio = AudioSegment.from_wav(file_path).set_channels(self.config.audio_channels).set_frame_rate(self.config.audio_sample_rate)
            if len(audio) < self.config.background_duration_ms:
                audio += AudioSegment.silent(duration=self.config.background_duration_ms - len(audio))
            audio = audio[:self.config.background_duration_ms]
            audio.export(file_path, format='wav')
        return len(bg_files)

    def generate_metadata(self):
        metadata = {'config': asdict(self.config), 'folders': self.folders}
        metadata_path = os.path.join(self.folders['metadata'], 'dataset_info.json')
        with open(metadata_path, 'w') as f:
            json.dump(metadata, f, indent=2)
        self.logger.info(f"Metadata saved at {metadata_path}")

    def run_full_pipeline(self):
        self.logger.info("Starting full pipeline...")
        self.generate_positive_samples()
        self.generate_negative_samples()
        self.convert_to_wav(self.folders['activates'])
        self.convert_to_wav(self.folders['negatives'])
        self.convert_to_wav(self.folders['backgrounds'])
        self.standardize_filenames(self.folders['activates'], 'pos')
        self.standardize_filenames(self.folders['negatives'], 'neg')
        self.standardize_filenames(self.folders['backgrounds'], 'bg')
        self.process_background_audio()
        self.generate_metadata()
        self.logger.info("Pipeline finished!")

Generate Positive Samples

In [4]:
# ----------------- Usage in Notebook -----------------
load_dotenv()
config = PipelineConfig(
    api_key=API_KEY,
    base_folder=BASE_FOLDER,
    wake_word=WAKE_WORD,
    base_url=BASE_URL
)

pipeline = WakeWordDataPipeline(config)
pipeline.run_full_pipeline()

2025-09-13 01:00:35,126 - INFO - Starting full pipeline...
2025-09-13 01:00:35,127 - INFO - Generating positive samples...
2025-09-13 01:00:36,497 - INFO - Success: pos_hey_jerry_ErXwobaYiN019PkySvjV.mp3
2025-09-13 01:00:36,800 - INFO - Success: pos_hey_jerry_Xb7hH8MSUJpSbSDYk0k2.mp3
2025-09-13 01:00:38,193 - INFO - Success: pos_hey_jerry_pNInz6obpgDQGcFmaJgB.mp3
2025-09-13 01:00:38,193 - INFO - Generating negative samples...
2025-09-13 01:00:39,603 - INFO - Success: neg_terry_pNInz6obpgDQGcFmaJgB.mp3
2025-09-13 01:00:39,808 - INFO - Success: neg_terry_Xb7hH8MSUJpSbSDYk0k2.mp3
2025-09-13 01:00:41,044 - INFO - Success: neg_harry_Xb7hH8MSUJpSbSDYk0k2.mp3
2025-09-13 01:00:41,066 - INFO - Success: neg_harry_pNInz6obpgDQGcFmaJgB.mp3
2025-09-13 01:00:42,379 - INFO - Success: neg_gary_ErXwobaYiN019PkySvjV.mp3
2025-09-13 01:00:42,502 - INFO - Success: neg_gary_Xb7hH8MSUJpSbSDYk0k2.mp3
2025-09-13 01:00:43,118 - INFO - API Error 429: {"detail":{"status":"too_many_concurrent_requests","message":"