In [None]:
import os
import time
import random
import subprocess
import logging
from your_yolo_module import run_yolo_world  # Replace with your actual YOLO-World function

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')

In [None]:
# 1. call the main fucntion
# 2. request_new_waveform("sample_name") returns the spike encoded input as a 1d tensor
# which is sent to predict(model, input)
# 3. predict(model, input) calls the AIfES C code through a subprocess

In [None]:
# This entire cell will be moved to a new file
import torchaudio
from torch.utils.data import DataLoader, Dataset

from pathlib import Path
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split

# === Label Encoding ===
all_labels =  ["bed", "bird", "cat", "dog", "house", "marvin", "tree", "mask", "frame", "unknown", "silence"]
label_encoder = LabelEncoder()
label_encoder.fit(all_labels) # encode labels as indices

# # === DATASET ===
class KeywordSpottingDataset(Dataset):
    def __init__(self, samples, transform=None):
        self.samples = samples
        self.transform = transform

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index):
        path, label_idx = self.samples[index]
        waveform, sample_rate = torchaudio.load(path)
        label_name = path.parent.name  # Get the class name from the folder
        if self.transform:
            waveform = self.transform(waveform)
        return waveform, sample_rate, label_name
    
def load_test_dataset_from_txt(txt_file):
    test_samples = []
    with open(txt_file, "r") as f:
        for line in f:
            # path_str, label_idx_str = line.strip().split(",")
            parts = line.strip().split(",")
            if len(parts) < 2:
                print(f"Skipping line: {line.strip()}")
                continue
            path_str, label_idx_str = parts
            test_samples.append((Path(path_str), int(label_idx_str)))
    return test_samples

def request_new_waveform(sample_name):
    # picks a sample waveform from the KWS test set until microphone setup is done.
    # the sample to be picked from the dataset can be specified by the sample_name as we afre emulating the microphone.
    
# Convert label name to index
    if sample_name not in all_labels:
        raise ValueError(f"Label '{sample_name}' is not in the list of known labels.")
    label_idx = label_encoder.transform([sample_name])[0]

    test_samples = load_test_dataset_from_txt("test_dataset_list.txt")
    test_dataset = KeywordSpottingDataset(test_samples)
    # select sample with the index of sample_name using the labelencoder if needed
    return waveform


In [None]:
def activate_led():
    # Placeholder for activating an LED to indicate listening to user
    print("LED activated: System is listening...")

In [None]:
def predict(input_waveform):
    if input_waveform is None:
        print("Input waveform required!")
        return None
    input_filename = "input_waveform.bin"
    input_waveform.tofile(input_filename)
    
    print("Starting inference")
    flag = False
    try:
        result = subprocess.run(["./main", input_filename], capture_output=True, text=True, check=True)
        output = result.stdout
        for line in output.splitlines():
            if "prediction class index" in line:
                flag = True
                class_index = int(line.strip().split("index")[1])
                class_name = label_encoder.inverse_transform([class_index])
                print(f"predicted class index: {class_index}, class name: {class_name}") 
        if not flag:
            print("No prediction found.") 
        print("Inference complete")        
        return class_name
        
    except subprocess.CalledProcessError as e:
        print(f"Inference failed: {e.stderr}")
        return None


In [None]:
def trigger_vision_model(action, object):
    # calls YOLO-World script
    pass

In [None]:
#  This is the ideal case. No edge cases are dealt here
# Just a vanilla integration model


def main(args):
    wakeword = "Marvin"
    object_classes =["cat", "car", "dog", "bed", "bird", "house", "tree"]
    logging.info("System initialized. Listening for wake word...")
    wakeword_input = request_new_waveform(wakeword)
    
    predicted_keyword = predict(wakeword_input)
    if predicted_keyword == wakeword :  # this looks reduntant now, but in real case request_new_waveform() will give a random word. so it must be validated
        logging.info("Wake word 'Marvin' detected.")        
        action_word_input = request_new_waveform("action") # hardcoded now, later the sample name must be removed as arg to retrieve a random waveform
        predicted_keyword = predict(action_word_input)
        
        if predicted_keyword in ["frame", "mask"]:
            action_word = predicted_keyword
            object_input = request_new_waveform("object")
            predicted_keyword = predict(object_input)
            if predicted_keyword in object_classes:
                object_word = predicted_keyword
                logging.info(f"Command recognized: {action_word} {object_word}")
                logging.info("Activating Vision pipeline")
                trigger_vision_model(action_word, object_word)
            else:
                logging.warning("Unrecognized object. Please try again.")
        else:
            logging.warning("Unrecognized action. Please try again.")
    else:
        logging.info("No wake word detected.")
            

In [None]:
if __name__ == "__main__":
    main()