In [None]:
import os, glob
import numpy as np
import faiss
import torch
from torch.utils.data import DataLoader

from datasets import *
from models import *

In [5]:
audio_files = glob.glob(
    "../dataset/acrcloud/songkey/**/*.*", recursive=True
)
audio_files = [f for f in audio_files if f.endswith(".mp3")]
print(f"Total audio number: {len(audio_files)}")

Total audio number: 150


In [6]:
query_files = glob.glob(
    "../dataset/acrcloud/parts2songkey/**/*.*", recursive=True
)
query_files = [f for f in query_files if f.endswith(".wav")]
print(f"Total audio number: {len(query_files)}")

Total audio number: 351


In [None]:
query_dataset_raw = MelSpecDataset(query_files)
doc_dataset_raw = MelSpecDataset(audio_files)

query_data = preprocess_and_cache(query_dataset_raw, "../dataset/dataset_cache/acrcloud_query_model2.pt")
doc_data = preprocess_and_cache(doc_dataset_raw, "../dataset/dataset_cache/acrcloud_doc_model2.pt")

No cache file found. Starting dataset preprocessing...
Checkpoint reached: processed 100 out of 351 samples.
Checkpoint reached: processed 200 out of 351 samples.
Checkpoint reached: processed 300 out of 351 samples.
Processing complete. Cached dataset saved to '../dataset/dataset_cache/acrcloud_query_model2.pt'.
No cache file found. Starting dataset preprocessing...
Checkpoint reached: processed 100 out of 150 samples.
Processing complete. Cached dataset saved to '../dataset/dataset_cache/acrcloud_doc_model2.pt'.


In [6]:
query_data = torch.load("../dataset/dataset_cache/acrcloud_query_model2.pt")
doc_data = torch.load("../dataset/dataset_cache/acrcloud_doc_model2.pt")

In [7]:
query_dataset = CachedDataset(query_data)
doc_dataset = CachedDataset(doc_data)

query_loader = DataLoader(
    query_dataset, batch_size=1, shuffle=False,
)
doc_loader = DataLoader(
    doc_dataset, batch_size=1, shuffle=False
)

In [None]:
def extract_base_name(file_path):
    base = os.path.basename(file_path)
    base_no_ext, _ = os.path.splitext(base)
    parts = base_no_ext.split('.')
    return parts[-1] if parts else base_no_ext

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DimensionMaskedResNet().to(device)
for epoch in range(1, 32):
    model.load_state_dict(torch.load(f"./model_cache/model3_epoch{epoch}.pth"))
    model.eval()

    query_filenames = []
    query_embs = []
    with torch.no_grad():
        for query_batch in query_loader:
            query_feat, query_path = query_batch
            query_feat = query_feat.to(device)
            query_emb = model(query_feat)
            query_embs.append(query_emb.cpu().numpy())
            query_filenames.append(query_path[0])

    doc_filenames = []
    doc_embs = []
    with torch.no_grad():
        for doc_batch in doc_loader:
            doc_feat, doc_path = doc_batch
            doc_feat = doc_feat.to(device)
            doc_emb = model(doc_feat)
            doc_embs.append(doc_emb.cpu().numpy())
            doc_filenames.append(doc_path[0])

    query_embs = np.concatenate(query_embs, axis=0).astype(np.float32)
    doc_embs = np.concatenate(doc_embs, axis=0).astype(np.float32)

    dim = doc_embs.shape[1]
    index = faiss.IndexFlatIP(dim)
    index.add(doc_embs)
    print(f"Indexed {index.ntotal} document embeddings")

    k = 10
    D, I = index.search(query_embs, k)

    TP = 0
    num_queries = query_embs.shape[0]
    for i in range(num_queries):
        query_base = extract_base_name(query_filenames[i])
        retrieved_bases = [extract_base_name(doc_filenames[idx]) for idx in I[i]]
        # print(f"Query {i} (base: {query_base}) Top-{k} Neighbors: {retrieved_bases}")
        if query_base in retrieved_bases:
            TP += 1

    print(f"Epoch num: {epoch}, Model accuracy is {TP / num_queries}")

Indexed 150 document embeddings
Epoch num: 1, Model accuracy is 0.1908831908831909
Indexed 150 document embeddings
Epoch num: 2, Model accuracy is 0.1623931623931624
Indexed 150 document embeddings
Epoch num: 3, Model accuracy is 0.24216524216524216
Indexed 150 document embeddings
Epoch num: 4, Model accuracy is 0.19943019943019943
Indexed 150 document embeddings
Epoch num: 5, Model accuracy is 0.21082621082621084
Indexed 150 document embeddings
Epoch num: 6, Model accuracy is 0.14814814814814814
Indexed 150 document embeddings
Epoch num: 7, Model accuracy is 0.2336182336182336
Indexed 150 document embeddings
Epoch num: 8, Model accuracy is 0.2849002849002849
Indexed 150 document embeddings
Epoch num: 9, Model accuracy is 0.33903133903133903
Indexed 150 document embeddings
Epoch num: 10, Model accuracy is 0.2621082621082621
Indexed 150 document embeddings
Epoch num: 11, Model accuracy is 0.18518518518518517
Indexed 150 document embeddings
Epoch num: 12, Model accuracy is 0.347578347578

In [None]:
dim = doc_embs.shape[1]
index = faiss.IndexFlatIP(dim)
index.add(doc_embs)
print(f"Indexed {index.ntotal} document embeddings")

k = 10
D, I = index.search(query_embs, k)

TP = 0
num_queries = query_embs.shape[0]
for i in range(num_queries):
    query_base = extract_base_name(query_filenames[i])
    retrieved_bases = [extract_base_name(doc_filenames[idx]) for idx in I[i]]
    # print(f"Query {i} (base: {query_base}) Top-{k} Neighbors: {retrieved_bases}")
    if query_base in retrieved_bases:
        TP += 1

print(f"Model accuracy is {TP / num_queries}")

Indexed 150 document embeddings
Model accuracy is 0.24786324786324787
