In [None]:
import os
import torch
from torch.utils.data import Dataset


class fourteen_commands(Dataset):
    def __init__(
        self, base_dir, test_list_path=None, return_label_name=False, train=True
    ):
        self.filepaths = []
        self.labels = []
        self.label_to_index = {"no": 0, "yes": 1, "up": 2, "down" : 3, "left" : 4, "right" : 5, "on" : 6, "off" : 7, "stop" : 8, "go" : 9}
        self.index_to_label = {0: "no", 1: "yes", 2: "up", 3: "down", 4: "left", 5: "right", 6: "on", 7: "off", 8: "stop", 9: "go"}
        self.return_label_name = return_label_name

        included_files = set()

        # Load filenames from testing_list.txt if provided
        if test_list_path and os.path.exists(test_list_path):
            with open(test_list_path, "r") as f:
                for line in f:
                    
                    file_name = (
                            os.path.splitext(os.path.basename(line.strip()))[0] + ".pt"
                        )
                    included_files.add(file_name)

        for label in ["no", "yes", "up", "down", "left", "right", "on", "off", "stop", "go"]:
            label_path = os.path.join(base_dir, label)
            if not os.path.isdir(label_path):
                continue
            for file in os.listdir(label_path):
                if file.endswith(".pt"):
                    if (train and file not in included_files) or (
                        not train and file in included_files
                    ):
                        self.filepaths.append(os.path.join(label_path, file))
                        self.labels.append(self.label_to_index[label])

    def __len__(self):
        return len(self.filepaths)

    def __getitem__(self, idx):
        spectrogram = torch.load(self.filepaths[idx])
        label_index = self.labels[idx]
        return spectrogram, (
            self.index_to_label[label_index] if self.return_label_name else label_index
        )