In [2]:
import os
import random
import torch
from torch.utils.data import Dataset
from typing import List, Optional

class three_class_commands(Dataset):
    def __init__(
        self, base_dir, test_list_path=None, return_label_name=False, train=True,
        target_count_train=2100, target_count_test=250, silence_list=None
    ):
        self.filepaths = []
        self.labels = []
        self.return_label_name = return_label_name

        self.known_labels = [
            "no", "yes", "up", "down", "left", "right", "on", "off", "stop", "go"
        ]

        self.label_to_index = {
            "known": 0,
            "unknown": 1,
            "silence": 2,
        }
        self.index_to_label = {idx: label for label, idx in self.label_to_index.items()}

        included_files = set()

        if test_list_path and os.path.exists(test_list_path):
            with open(test_list_path, "r") as f:
                for line in f:
                    file_name = os.path.splitext(os.path.basename(line.strip()))[0] + ".pt"
                    included_files.add(file_name)

        all_folders = os.listdir(base_dir)

        # Temporary storage
        unknown_samples = []

        for label in all_folders:
            label_path = os.path.join(base_dir, label)

            if not os.path.isdir(label_path):
                continue

            if label == "silence":
                continue  # Silence handled separately

            is_known = label in self.known_labels

            for file in os.listdir(label_path):
                if file.endswith(".pt"):
                    file_base = os.path.splitext(file)[0] + ".pt"
                    is_included = file_base in included_files

                    if (train and not is_included) or (not train and is_included):
                        if is_known:
                            self.filepaths.append(os.path.join(label_path, file))
                            self.labels.append(self.label_to_index["known"])
                        else:
                            unknown_samples.append(os.path.join(label_path, file))

        # Handle unknown undersampling
        if train:
            if unknown_samples:
                undersampled_unknown = random.sample(unknown_samples, k=target_count_train)
                self.filepaths.extend(undersampled_unknown)
                self.labels.extend([self.label_to_index["unknown"]] * target_count_train)
        else:
            if unknown_samples:
                undersampled_unknown = random.sample(unknown_samples, k=target_count_test)
                self.filepaths.extend(undersampled_unknown)
                self.labels.extend([self.label_to_index["unknown"]] * target_count_test)

        # Handle silence separately
        if silence_list is not None:
            if train:
                oversampled_silence = random.choices(silence_list, k=target_count_train)
                self.filepaths.extend(oversampled_silence)
                self.labels.extend([self.label_to_index["silence"]] * target_count_train)
            else:
                oversampled_silence = random.choices(silence_list, k=target_count_test)
                self.filepaths.extend(oversampled_silence)
                self.labels.extend([self.label_to_index["silence"]] * target_count_test)

    def __len__(self):
        return len(self.filepaths)

    def __getitem__(self, idx):
        spectrogram = torch.load(self.filepaths[idx])
        label_index = self.labels[idx]
        return spectrogram, (
            self.index_to_label[label_index] if self.return_label_name else label_index
        )
