In [1]:
import os
import time
import numpy as np
import random
import subprocess
import logging
# from your_yolo_module import run_yolo_world  # Replace with your actual YOLO-World function
from IPython.display import Audio

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')

In [2]:
# 1. call the main fucntion
# 2. request_new_waveform("sample_name") returns the spike encoded input as a 1d tensor
# which is sent to predict(model, input)
# 3. predict(model, input) calls the AIfES C code through a subprocess

In [3]:
# This entire cell will be moved to a new file
from snntorch import spikegen
import torchaudio
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F

from pathlib import Path
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split

# === Label Encoding ===
all_labels =  ["bed", "bird", "cat", "dog", "house", "marvin", "tree", "mask", "frame", "unknown", "silence"]
label_encoder = LabelEncoder()
label_encoder.fit(all_labels) # encode labels as indices

# # === DATASET ===
class KeywordSpottingDataset(Dataset):
    def __init__(self, samples, transform=None):
        self.samples = samples
        self.transform = transform

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index):
        path, label_idx = self.samples[index]
        waveform, sample_rate = torchaudio.load(path)
        label_name = path.parent.name  # Get the class name from the folder
        if self.transform:
            waveform = self.transform(waveform)
        return waveform, sample_rate, label_name
    
def load_test_dataset_from_txt(txt_file):
    test_samples = []
    with open(txt_file, "r") as f:
        for line in f:
            # path_str, label_idx_str = line.strip().split(",")
            parts = line.strip().split(",")
            if len(parts) < 2:
                print(f"Skipping line: {line.strip()}")
                continue
            path_str, label_idx_str = parts
            test_samples.append((Path(path_str), int(label_idx_str)))
    return test_samples

def preprocess(raw_waveform):
    mel_transform = torchaudio.transforms.MelSpectrogram(
        sample_rate=16000,
        n_fft=400,
        hop_length=160,
        n_mels=64
    )
    amplitude_to_db = torchaudio.transforms.AmplitudeToDB()
    target_length = 16000

    # Pad or trim waveform
    if raw_waveform.size(1) < target_length:
        pad_size = target_length - raw_waveform.size(1)
        raw_waveform = F.pad(raw_waveform, (0, pad_size))
    else:
        raw_waveform = raw_waveform[:, :target_length]

    # Mel spectrogram and normalization
    mel_spec = mel_transform(raw_waveform).squeeze(0)
    mel_spec = amplitude_to_db(mel_spec)
    mel_spec = (mel_spec - mel_spec.mean()) / (mel_spec.std() + 1e-5)
    mel_spec = mel_spec.unsqueeze(0)  # Shape: [1, 64, time]
    encoded_input = spikegen.rate(mel_spec, num_steps=35)
    
    return encoded_input

def request_new_waveform(sample_name, play_audio=False):
    # picks a sample waveform from the KWS test set until microphone setup is done.
    # the sample to be picked from the dataset can be specified by the sample_name as we afre emulating the microphone.
    sample_name = sample_name.lower()
    # Convert label name to index
    if sample_name not in all_labels:
        raise ValueError(f"Label '{sample_name}' is not in the list of known labels.")
    label_idx = label_encoder.transform([sample_name])[0]

    test_samples = load_test_dataset_from_txt("test_dataset_list.txt")
    # test_dataset = KeywordSpottingDataset(test_samples)
    label_idx = label_encoder.transform([sample_name])[0]
    matching_samples = [sample for sample in test_samples if sample[1] == label_idx]
    
    if not matching_samples:
        raise RuntimeError(f"No test samples found for class '{sample_name}'")
    
    selected_path,_ = random.choice(matching_samples)
    raw_waveform, sample_rate = torchaudio.load(selected_path)
    
    if play_audio:
        display(Audio(raw_waveform.numpy().squeeze(), rate=sample_rate))

    encoded_sample = preprocess(raw_waveform)
        
    return encoded_sample


In [4]:
# Testing the function
sample = request_new_waveform("cat", True)
print(sample.shape)

Skipping line: ../../datasets/kws_dataset/mask/mask_ED
Skipping line: ../../datasets/kws_dataset/mask/mask_ED


torch.Size([35, 1, 64, 101])


In [5]:
def activate_led():
    # Placeholder for activating an LED to indicate listening to user
    print("LED activated: System is listening...")

In [14]:
def predict(encoded_input, bin_filename=None):
    if encoded_input is None:
        print("Input waveform required!")
        return None
    if bin_filename is None:
        bin_filename = "aifes_kws_input.bin"
    flattened_input = encoded_input.flatten().numpy().astype(np.float32)
    flattened_input.tofile(bin_filename)
    
    print("Starting inference")
    flag = False
    try:
        result = subprocess.run(["./main", bin_filename], capture_output=True, text=True, check=True)
        output = result.stdout
        for line in output.splitlines():
            if "prediction class index" in line:
                flag = True
                class_index = int(line.strip().split("index")[1])
                class_name = label_encoder.inverse_transform([class_index])
                print(f"predicted class index: {class_index}, class name: {class_name}") 
                break
        if not flag:
            print("No prediction found.") 
            return None
        print("Inference complete")        
        return class_name[0]
        
    except subprocess.CalledProcessError as e:
        print(f"Inference failed: {e.stderr}")
        return None


In [15]:
# Testing the code
sample_name = "marvin"
sample = request_new_waveform(sample_name, False)
sample_idx = label_encoder.transform([sample_name])
print("idx: ", sample_idx)
bin_filename = f"input_waveform_{sample_name}{sample_idx[0]}.bin"
predicted_class = predict(sample)
print(predicted_class)


Skipping line: ../../datasets/kws_dataset/mask/mask_ED
Skipping line: ../../datasets/kws_dataset/mask/mask_ED
idx:  [6]
Starting inference
predicted class index: 6, class name: ['marvin']
Inference complete
marvin


In [16]:
def trigger_vision_model(action, object):
    # calls YOLO-World script
    pass

In [17]:
#  This is the ideal case. No edge cases are dealt here
# Just a vanilla integration model


def main():
    wakeword = "Marvin"
    object_classes =["cat", "car", "dog", "bed", "bird", "house", "tree"]
    logging.info("System initialized. Listening for wake word...")
    wakeword_input = request_new_waveform(wakeword)
    
    predicted_keyword = predict(wakeword_input)
    if predicted_keyword == wakeword :  # this looks reduntant now, but in real case request_new_waveform() will give a random word. so it must be validated
        logging.info("Wake word 'Marvin' detected.")        
        action_word_input = request_new_waveform("action") # hardcoded now, later the sample name must be removed as arg to retrieve a random waveform
        predicted_keyword = predict(action_word_input)
        
        if predicted_keyword in ["frame", "mask"]:
            action_word = predicted_keyword
            object_input = request_new_waveform("object")
            predicted_keyword = predict(object_input)
            if predicted_keyword in object_classes:
                object_word = predicted_keyword
                logging.info(f"Command recognized: {action_word} {object_word}")
                logging.info("Activating Vision pipeline")
                trigger_vision_model(action_word, object_word)
            else:
                logging.warning("Unrecognized object. Please try again.")
        else:
            logging.warning("Unrecognized action. Please try again.")
    else:
        logging.info("No wake word detected.")
            

In [18]:
if __name__ == "__main__":
    main()

2025-07-04 11:54:07,850 - System initialized. Listening for wake word...


Skipping line: ../../datasets/kws_dataset/mask/mask_ED
Skipping line: ../../datasets/kws_dataset/mask/mask_ED
Starting inference


2025-07-04 11:54:09,330 - No wake word detected.


predicted class index: 6, class name: ['marvin']
Inference complete
