In [1]:
import os
import numpy as np
import torch
from torch.utils.data import Dataset 
import re

class BinauralCueDataset(Dataset):
    def __init__(self, npz_dir, audio_ids=range(1, 101)):
        self.dir = npz_dir
        pattern = re.compile(r'main_audio_(\d+)_azi(\d+)\.npz')
        self.files = []
        for f in os.listdir(npz_dir):
            if f.endswith('.npz'):
                match = pattern.match(f)
                if match and int(match.group(1)) in audio_ids:
                    self.files.append(f)
        self.files.sort()

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        path = os.path.join(self.dir, self.files[idx])
        data = np.load(path)
        itd = data["itd"].astype(np.float32)
        ild = data["ild"].astype(np.float32)
        ic = data["ic"].astype(np.float32)

        cue = np.stack([itd, ild, ic], axis=0)  # [3, filters, frames]

        # 提取 azimuth label
        azimuth = int(re.search(r'azi(\d+)', self.files[idx]).group(1))
        label = azimuth // 5  # 共72类（0-71）

        return cue, label


: 