In [1]:
import librosa
import os, glob, random
import numpy as np
import soundfile as sf
import faiss
import torch
from torch.utils.data import DataLoader, Dataset

from models import *
from datasets import *

In [4]:
test_dataset = CachedDataset(cache_file="data/dataset_cache/model2/cached_test_model2.pt")
test_loader = DataLoader(
    test_dataset,
    batch_size=64,
    shuffle=False,
    num_workers=8,
    pin_memory=True,
    collate_fn=collate_fn_dif_length,
)

In [6]:
def process_test(model, epoch, top_k, device):
    model.load_state_dict(torch.load(f"outputs/ResNet+Attention/model3_epoch{epoch}.pth"))
    model.eval()

    test_doc_embs = []
    test_query_embs = []

    with torch.no_grad():
        for test_batch in test_loader:
            query_feat_batch, doc_feat_batch = test_batch
            query_feat_batch = query_feat_batch.to(device)
            doc_feat_batch = doc_feat_batch.to(device)

            query_emb = model(query_feat_batch)
            doc_emb = model(doc_feat_batch)

            test_query_embs.append(query_emb.cpu().numpy())
            test_doc_embs.append(doc_emb.cpu().numpy())

    test_query_embs = np.concatenate(test_query_embs, axis=0).astype(np.float32)
    test_doc_embs = np.concatenate(test_doc_embs, axis=0).astype(np.float32)

    dim = test_doc_embs.shape[1]
    index = faiss.IndexFlatIP(dim)
    index.add(test_doc_embs)

    D, I = index.search(test_query_embs, top_k)

    TP = 0
    num_queries = test_query_embs.shape[0]
    for i in range(num_queries):
        # print(f"Query {i} Top-{top_k} Neighbors: {I[i]}")
        if i in I[i]:
            TP += 1

    print(f"Epoch {epoch}: Model accuracy is {TP / num_queries}")

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DimensionMaskedResNet().to(device)

for epoch in range(4, 32):
    process_test(model, epoch, 1, device)

## Test with Musan

In [None]:
test_musan_full = glob.glob(
    "../dataset/musan/music/**/*.*",
    recursive=True,
)
test_musan_full = [f for f in test_musan_full if f.endswith(".wav")]
print(f"Total musan music files: {len(test_musan_full)}")

Total musan music files: 660


In [None]:
def split_audio(file_path, sample_rate, segment_duration, overlap):
    audio, sr = librosa.load(file_path, sr=sample_rate)
    total_samples = len(audio)
    seg_samples = segment_duration * sample_rate
    step_samples = (segment_duration - overlap) * sample_rate

    segments = []
    for start in range(0, total_samples, step_samples):
        end = start + seg_samples
        segment = audio[start:end]
        if len(segment) < seg_samples:
            pad_length = seg_samples - len(segment)
            segment = np.concatenate([segment, np.zeros(pad_length)])
        segments.append(segment)
        if end >= total_samples:
            break
    return segments

output_dir = "../dataset/musan_segments"
os.makedirs(output_dir, exist_ok=True)

for file_path in test_musan_full:
    segments = split_audio(file_path, FS, segment_duration=30, overlap=5)
    base_name = os.path.splitext(os.path.basename(file_path))[0]
    for i, seg in enumerate(segments):
        output_file = os.path.join(output_dir, f"{base_name}.seg{i+1}.wav")
        sf.write(output_file, seg, FS)
    print(f"Processed {file_path} into {len(segments)} segments.")

In [None]:
test_musan_segment = glob.glob("../dataset/musan_segments/*.*")
test_musan_segment = [f for f in test_musan_segment if f.endswith(".wav")]
print(f"Total musan segment files: {len(test_musan_segment)}")

Total musan segment files: 9898


In [None]:
def extract_base_name(file_path):
    base = os.path.basename(file_path)
    base_no_ext, _ = os.path.splitext(base)
    parts = base_no_ext.split(".")
    return parts[-2] if len(parts) >= 2 else base_no_ext

In [None]:
class Model2QueryDataset(Dataset):
    def __init__(
        self,
        file_paths,
        segment_seconds=SEGMENT_TIME,
        sample_rate=FS,
        num_queries=NUM_AUG_QUERIES,
    ):
        self.file_paths = file_paths
        self.sample_rate = sample_rate
        self.num_queries = num_queries
        self.segment_samples = int(segment_seconds * sample_rate)

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        file = self.file_paths[idx]
        channels, _ = read(file)
        waveform = channels[0]
        total_len = waveform.shape[0]

        queries = []
        for _ in range(self.num_queries):
            if total_len > self.segment_samples:
                start = random.randint(0, total_len - self.segment_samples)
            else:
                start = 0
            end = start + self.segment_samples
            query_wave = waveform[start:end]
            query_spec = transform_to_spectrogram_mel(query_wave)
            queries.append(query_spec)

        filename = extract_base_name(file)
        return queries, filename

In [None]:
query_dataset_raw = Model2QueryDataset(test_musan_full)
# Doc Dataset can just use the one from ACRCloud part
doc_dataset_raw = Model2TestDataset(test_musan_segment)

query_data = preprocess_and_cache(
    query_dataset_raw, "../dataset/dataset_cache/musan/musan_query_model2.pt"
)
doc_data = preprocess_and_cache(
    doc_dataset_raw, "../dataset/dataset_cache/musan/musan_doc_model2.pt"
)

In [None]:
query_data = torch.load("../dataset/dataset_cache/musan/musan_query_model2.pt")
doc_data = torch.load("../dataset/dataset_cache/musan/musan_doc_model2.pt")

In [None]:
query_dataset = CachedDataset(query_data)
doc_dataset = CachedDataset(doc_data)

query_loader = DataLoader(
    query_dataset,
    batch_size=1,
    shuffle=False,
)
doc_loader = DataLoader(doc_dataset, batch_size=1, shuffle=False)

In [None]:
def process_test(model, epoch, top_k):
    model.load_state_dict(torch.load(f"./model_cache/model3_epoch{epoch}.pth"))
    # model.load_state_dict(torch.load(f"./model_cache/ResNet18.pth"))
    model.eval()

    query_filenames = []
    query_embs = []
    with torch.no_grad():
        for query_batch in query_loader:
            query_feats, query_paths = query_batch
            for i in range(NUM_AUG_QUERIES):
                query_feat = query_feats[i].to(device)
                query_emb = model(query_feat)
                query_embs.append(query_emb.cpu().numpy())
                query_filenames.append(query_paths[0])

    doc_filenames = []
    doc_embs = []
    with torch.no_grad():
        for doc_batch in doc_loader:
            doc_feat, doc_path = doc_batch
            doc_feat = doc_feat.to(device)
            doc_emb = model(doc_feat)
            doc_embs.append(doc_emb.cpu().numpy())
            doc_filenames.append(doc_path[0])

    query_embs = np.concatenate(query_embs, axis=0).astype(np.float32)
    doc_embs = np.concatenate(doc_embs, axis=0).astype(np.float32)

    dim = doc_embs.shape[1]
    index = faiss.IndexFlatIP(dim)
    index.add(doc_embs)

    # D, I = index.search(query_embs, top_k)

    # TP = 0
    # num_queries = query_embs.shape[0]
    # for i in range(num_queries):
    #     query_base = query_filenames[i]
    #     retrieved_bases = [doc_filenames[idx] for idx in I[i]]
    #     # print(f"Query {i} (base: {query_base}) Top-{k} Neighbors: {retrieved_bases}")
    #     if query_base in retrieved_bases:
    #         TP += 1

    # print(f"Epoch {epoch}: Model accuracy is {TP / num_queries}")

    D, I = index.search(query_embs, top_k * 2)

    TP = 0
    num_queries = query_embs.shape[0]
    for i in range(num_queries):
        query_filename = query_filenames[i]
        distinct_retrieved = []
        for idx in I[i]:
            candidate = doc_filenames[idx]
            if candidate not in distinct_retrieved:
                distinct_retrieved.append(candidate)
            if len(distinct_retrieved) >= top_k:
                break
        if query_filename in distinct_retrieved:
            TP += 1

    print(f"Epoch {epoch}: Model accuracy is {TP / num_queries}")

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = ResNet18Model(embed_dim=EMBED_DIM, hidden_size=HIDDEN_LAYER).to(device)
model = DimensionMaskedCNN().to(device)

for epoch in range(1, 32):
    process_test(model, 7, 10)
    break

Epoch 7: Model accuracy is 0.9904040404040404
