In [2]:
import glob, os, math, faiss, soundfile as sf
from pathlib import Path
from typing import List
import numpy as np
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor
from torch.utils.data import DataLoader

from models import *
from datasets import *

In [None]:
def split_audio(
    y: np.ndarray,
    sample_rate: int,
    segment_sec: int = 30,
    overlap_sec: int = 5,
) -> List[np.ndarray]:
    """Return list of fixed-length PCM segments (zero-padded last one)."""
    seg, hop = segment_sec * sample_rate, (segment_sec - overlap_sec) * sample_rate
    n_segs = max(1, math.ceil((len(y) - seg) / hop) + 1)
    segments = []
    for i in range(n_segs):
        s, e = i * hop, i * hop + seg
        seg_wav = y[s:e]
        if len(seg_wav) < seg:
            seg_wav = np.pad(seg_wav, (0, seg - len(seg_wav)))
        segments.append(seg_wav)
    return segments


def process_one(file_path: Path, out_dir: Path, sr=40_000, seg_dur=30, ovlp=5) -> int:
    wav, _ = read_audio(str(file_path), target_fs=sr, mono=True, normalize=False)
    y = wav.squeeze(0).cpu().numpy()
    segments = split_audio(y, sample_rate=sr, segment_sec=seg_dur, overlap_sec=ovlp)
    base = file_path.stem
    for i, seg in enumerate(segments, 1):
        out = out_dir / f"{base}.seg{i}.wav"
        if not out.exists():
            sf.write(out, seg, sr, subtype="PCM_16")
    return len(segments)


def batch_split(
    files: List[str],
    output_dir: str = "data/musan_segments",
    sr: int = 40_000,
    seg_sec: int = 30,
    overlap: int = 5,
    workers: int = 4,
):
    out_dir = Path(output_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    with ThreadPoolExecutor(max_workers=workers) as pool:
        tasks = [
            pool.submit(process_one, Path(f), out_dir, sr, seg_sec, overlap)
            for f in files
        ]
        total = 0
        for fut in tqdm(tasks, desc="splitting"):
            total += fut.result()

    print(f"Done. {total} segments written to {out_dir}")


test_musan_full = glob.glob("data/musan/music/**/*.wav", recursive=True)
batch_split(test_musan_full, workers=os.cpu_count() // 2)
test_musan_segment = glob.glob("data/musan_segments/*.wav", recursive=True)

In [None]:
def extract_base_name(p: str) -> str:
    stem = Path(p).stem
    parts = stem.split(".")
    return parts[-2] if len(parts) >= 2 else stem


class MelQueryDataset(Dataset):
    def __init__(
        self,
        file_paths: List[str],
        seg_sec: int = 5,
        n_query: int = 3,
        sample_rate: int = 40_000,
        window_size: int = 2_560,
        overlap_ratio: float = 0.5,
        n_mels: int = 256,
    ):
        self.paths = file_paths
        self.window_size = window_size
        self.overlap_ratio = overlap_ratio
        self.n_mels = n_mels
        self.sample_rate = sample_rate
        self.seg_len = seg_sec * self.sample_rate
        self.n_query = n_query

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        path = self.paths[idx]
        wav, _ = read_audio(path, target_fs=self.sample_rate, mono=True)
        wav = wav[0]
        total = wav.size(0)

        specs = []
        for _ in range(self.n_query):
            start = random.randint(0, max(0, total - self.seg_len))
            seg = wav[start : start + self.seg_len]
            spec = audio_to_melspec(
                seg,
                window_size=self.window_size,
                overlap_ratio=self.overlap_ratio,
                fs=self.sample_rate,
                n_mels=self.n_mels,
            )
            specs.append(spec)

        return specs, extract_base_name(path)


class MelDocDataset(Dataset):
    def __init__(
        self,
        file_paths: List[str],
        sample_rate: int = 40_000,
        window_size: int = 2_560,
        overlap_ratio: float = 0.5,
        n_mels: int = 256,
    ):
        self.paths = file_paths
        self.window_size = window_size
        self.overlap_ratio = overlap_ratio
        self.n_mels = n_mels
        self.sample_rate = sample_rate

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        path = self.paths[idx]
        wav, _ = read_audio(path, target_fs=self.sample_rate, mono=True)
        spec = audio_to_melspec(
            wav[0],
            window_size=self.window_size,
            overlap_ratio=self.overlap_ratio,
            fs=self.sample_rate,
            n_mels=self.n_mels,
        )
        return spec, extract_base_name(path)

In [None]:
query_raw = MelQueryDataset(test_musan_segment)
doc_raw = MelDocDataset(test_musan_segment)

query_cache = preprocess_and_cache_lazy(query_raw, "data/dataset_cache/musan/query")
doc_cache = preprocess_and_cache_lazy(doc_raw, "data/dataset_cache/musan/doc")

In [3]:
query_ds = LazyCachedDataset("data/dataset_cache/musan/query")
doc_ds = LazyCachedDataset("data/dataset_cache/musan/doc")

query_loader = DataLoader(
    query_ds, batch_size=8, shuffle=False, collate_fn=lambda b: list(zip(*b))
)
doc_loader = DataLoader(
    doc_ds, batch_size=8, shuffle=False, collate_fn=lambda b: list(zip(*b))
)

In [None]:
def _stack_to_device(xlist, device):
    if isinstance(xlist[0], list):
        xlist = [t for sub in xlist for t in sub]
    return torch.stack(xlist).to(device, non_blocking=True)


def embed_loader(model, loader, device, is_query=False):
    embs, names = [], []
    with torch.no_grad():
        for feats, fn in tqdm(loader, desc="embed_query" if is_query else "embed_doc"):
            feats = _stack_to_device(feats, device)
            out = model(feats).cpu().numpy()
            embs.append(out)

            if is_query:
                seg_per_song = feats.shape[0] // len(fn)
                for n in fn:
                    names.extend([n] * seg_per_song)
            else:
                names.extend(fn)

    return np.concatenate(embs).astype("float32"), names


def process_test(model, epoch: int, top_k: int, device: torch.device):
    ckpt = torch.load(f"outputs/checkpoints/epoch{epoch}.pth", map_location="cpu")
    model.load_state_dict(ckpt["model"], strict=True)
    model.eval()

    q_emb, q_name = embed_loader(model, query_loader, device, is_query=True)
    d_emb, d_name = embed_loader(model, doc_loader, device, is_query=False)

    faiss.normalize_L2(q_emb)
    faiss.normalize_L2(d_emb)

    index = faiss.IndexFlatIP(d_emb.shape[1])
    index.add(d_emb)
    _, I = index.search(q_emb, top_k * 2)

    hit = 0
    for qi, neigh in enumerate(I):
        target = q_name[qi]
        seen, kept = set(), 0
        for j in neigh:
            f = d_name[j]
            if f not in seen:
                seen.add(f)
                kept += 1
                if f == target:
                    hit += 1
                    break
                if kept >= top_k:
                    break

    acc = hit / len(q_name)
    print(f"Epoch {epoch:02d} | recall@{top_k}: {acc:.4%}")

In [15]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ResidualSTANet().to(device)
with torch.no_grad():
    dummy = torch.randn(1, 1, 256, 256, device=device)
    _ = model(dummy)
process_test(model, epoch=10, top_k=1, device=device)

embed_query: 100%|██████████| 83/83 [00:04<00:00, 16.84it/s]
embed_doc: 100%|██████████| 792/792 [00:42<00:00, 18.62it/s]


Epoch 10 | recall@1: 92.6768%


In [None]:
for i in range(16, 19):
    process_test(model, epoch=i, top_k=1, device=device)