In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import numpy as np
import pandas as pd
import torch
import torchaudio
from transformers import Wav2Vec2Model, Wav2Vec2FeatureExtractor
from sklearn.neighbors import NearestNeighbors
import os
from tqdm import tqdm
import multiprocessing
# Configuration
AUDIO_DIR =  "/content/drive/My Drive/fma_small"
METADATA_PATH = '/content/drive/My Drive/fma_metadata/tracks.csv'
EMBEDDING_DIM = 768  # Wav2Vec base model
CHECKPOINT_DIR = "/content/drive/My Drive/"
EMBEDDING_DIM = 768
SAMPLE_RATE = 16000
CHUNK_SIZE = 5

# 1. Load FMA Metadata
def load_fma_metadata():
    # Load the tracks.csv with proper multi-level header
    tracks = pd.read_csv(METADATA_PATH, index_col=0, header=[0, 1])

    tracks.columns = ['_'.join(col) for col in tracks.columns]

    small = tracks[tracks['set_subset'] == 'small']
    print(small.shape)

    metadata = small[[
        'track_title',
        'artist_name',
        'album_title',
        'album_id'
    ]]

    return metadata

metadata = load_fma_metadata()


(8000, 52)


In [None]:
# 2. Initialize Wav2Vec
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-base")
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base").to(device)
model.share_memory()

In [4]:
import json
import numpy as np
import torch
import torchaudio
import os
from tqdm import tqdm

def get_audio_path(track_id):
    tid_str = '{:06d}'.format(track_id)
    return os.path.join(AUDIO_DIR, tid_str[:3], tid_str + '.mp3')

def audio_to_embedding(audio_path):
    try:
        waveform, sr = torchaudio.load(audio_path)
        if sr != SAMPLE_RATE:
            resampler = torchaudio.transforms.Resample(sr, SAMPLE_RATE)
            waveform = resampler(waveform)

        # Ensure mono audio by taking mean if multiple channels exist
        if waveform.dim() > 1 and waveform.shape[0] > 1:
            waveform = waveform.mean(dim=0, keepdim=True)

        # Ensure we have exactly CHUNK_SIZE seconds
        target_samples = SAMPLE_RATE * CHUNK_SIZE
        if waveform.shape[1] < target_samples:
            padding = target_samples - waveform.shape[1]
            waveform = torch.nn.functional.pad(waveform, (0, padding))
        else:
            waveform = waveform[:, :target_samples]

        # Normalize audio
        waveform = (waveform - waveform.mean()) / (waveform.std() + 1e-8)

        inputs = feature_extractor(
            waveform.squeeze().numpy(),
            sampling_rate=SAMPLE_RATE,
            return_tensors="pt",
            padding=True,

        ).to(device)

        with torch.no_grad():
            outputs = model(**inputs)

        # Ensure consistent embedding size
        embedding = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()
        if embedding.ndim == 0:  # handle scalar case
            embedding = np.array([embedding])
        return embedding
    except Exception as e:
        print(f"Error processing {audio_path}: {str(e)}")
        return None

def build_embedding_database():
    # Initialize with empty array of correct shape if file doesn't exist
    embeddings = []
    try:
        existing_embeddings = np.load('/content/drive/My Drive/embeddings.npy', allow_pickle=True)
        if existing_embeddings.size > 0:
            embeddings = list(existing_embeddings)
    except (FileNotFoundError, ValueError):
        pass

    valid_tracks = []
    try:
        with open("valid_tracks.txt", "r") as f:
            valid_tracks = [int(line.strip()) for line in f]
    except FileNotFoundError:
        pass

    cnt = 0
    for track_id, row in tqdm(metadata.iterrows(), total=len(metadata)):
        if track_id in valid_tracks:  # Skip already processed tracks
            continue

        audio_path = get_audio_path(track_id)

        if not os.path.exists(audio_path):
            print(f"Missing audio file: {audio_path}")
            continue

        emb = audio_to_embedding(audio_path)
        if emb is not None:
            embeddings.append(emb)
            valid_tracks.append(track_id)

        cnt += 1
        if cnt % 5 == 0:
            np.save('/content/drive/My Drive/embeddings.npy', np.vstack(embeddings))
            with open("valid_tracks.txt", "w") as f:
                f.write("\n".join(map(str, valid_tracks)))

    if embeddings:
        np.save('/content/drive/My Drive/embeddings.npy', np.vstack(embeddings))
        with open("valid_tracks.txt", "w") as f:
            f.write("\n".join(map(str, valid_tracks)))

    return np.vstack(embeddings) if embeddings else np.array([]), valid_tracks

In [121]:
print("Building embedding database")
embeddings, track_ids = build_embedding_database()
print(f"Generated {len(embeddings)} embeddings")

nn_model = NearestNeighbors(n_neighbors=5, metric='cosine')
nn_model.fit(embeddings)


Building embedding database...


 56%|█████▌    | 4471/8000 [2:44:06<1:29:23,  1.52s/it]

Error processing /content/drive/My Drive/fma_small/099/099134.mp3: Failed to open the input "/content/drive/My Drive/fma_small/099/099134.mp3" (Invalid argument).
Exception raised from get_input_format_context at /__w/audio/audio/pytorch/audio/src/libtorio/ffmpeg/stream_reader/stream_reader.cpp:42 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7decb8d6c1b6 in /usr/local/lib/python3.11/dist-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x64 (0x7decb8d15a76 in /usr/local/lib/python3.11/dist-packages/torch/lib/libc10.so)
frame #2: <unknown function> + 0x42034 (0x7decfc669034 in /usr/local/lib/python3.11/dist-packages/torio/lib/libtorio_ffmpeg4.so)
frame #3: torio::io::StreamingMediaDecoder::StreamingMediaDecoder(std::string const&, std::optional<std::string> const&, std::optional<std::map<std::string, std::string, std::less<std::string>, std::allocator<std

 61%|██████▏   | 4904/8000 [2:59:50<1:15:26,  1.46s/it]

Error processing /content/drive/My Drive/fma_small/108/108925.mp3: Failed to open the input "/content/drive/My Drive/fma_small/108/108925.mp3" (Invalid argument).
Exception raised from get_input_format_context at /__w/audio/audio/pytorch/audio/src/libtorio/ffmpeg/stream_reader/stream_reader.cpp:42 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7decb8d6c1b6 in /usr/local/lib/python3.11/dist-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x64 (0x7decb8d15a76 in /usr/local/lib/python3.11/dist-packages/torch/lib/libc10.so)
frame #2: <unknown function> + 0x42034 (0x7decfc669034 in /usr/local/lib/python3.11/dist-packages/torio/lib/libtorio_ffmpeg4.so)
frame #3: torio::io::StreamingMediaDecoder::StreamingMediaDecoder(std::string const&, std::optional<std::string> const&, std::optional<std::map<std::string, std::string, std::less<std::string>, std::allocator<std

 87%|████████▋ | 6966/8000 [4:15:48<26:22,  1.53s/it]

Error processing /content/drive/My Drive/fma_small/133/133297.mp3: Failed to open the input "/content/drive/My Drive/fma_small/133/133297.mp3" (Invalid argument).
Exception raised from get_input_format_context at /__w/audio/audio/pytorch/audio/src/libtorio/ffmpeg/stream_reader/stream_reader.cpp:42 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7decb8d6c1b6 in /usr/local/lib/python3.11/dist-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x64 (0x7decb8d15a76 in /usr/local/lib/python3.11/dist-packages/torch/lib/libc10.so)
frame #2: <unknown function> + 0x42034 (0x7decfc669034 in /usr/local/lib/python3.11/dist-packages/torio/lib/libtorio_ffmpeg4.so)
frame #3: torio::io::StreamingMediaDecoder::StreamingMediaDecoder(std::string const&, std::optional<std::string> const&, std::optional<std::map<std::string, std::string, std::less<std::string>, std::allocator<std

100%|██████████| 8000/8000 [4:54:15<00:00,  2.21s/it]


Generated 8027 embeddings


In [58]:
import faiss

class MusicSimilaritySearchANNPQ:
    def __init__(self, metadata, embeddings, track_ids, audio_dir=AUDIO_DIR, n_neighbors=5):
        self.metadata = metadata
        self.audio_dir = audio_dir
        self.dim = embeddings.shape[1]
        self.track_ids = track_ids  # Keep list of track IDs aligned to embeddings

        # Normalize embeddings
        faiss.normalize_L2(embeddings)
        self.embeddings = embeddings

        # Build FAISS index
        self.index = faiss.IndexPQ(self.dim, 8, 8)
        self.index.train(embeddings)
        self.index.add(embeddings)

    def get_similar_tracks(self, query_embedding, k=5):
        query_embedding = query_embedding.astype('float32')
        faiss.normalize_L2(query_embedding.reshape(1, -1))

        distances, indices = self.index.search(query_embedding.reshape(1, -1), k)
        similar_tracks = []
        for i in range(k):
            idx = indices[0][i]
            if idx >= len(self.track_ids):
                continue  # Avoid out-of-bounds

            track_id = self.track_ids[idx]
            if track_id not in self.metadata.index:
                continue  # Skip if metadata is missing

            row = self.metadata.loc[track_id]
            track_info = {
                'track_id': track_id,
                'title': row['track_title'],
                'artist': row['artist_name'],
                'album': row['album_title'],
                'distance': 1 - distances[0][i],  # similarity = 1 - distance
                'audio_path': get_audio_path(track_id)
            }
            similar_tracks.append(track_info)
        return similar_tracks

    def search_by_audio(self, input_mp3_path, k=5):
        query_embedding = audio_to_embedding(input_mp3_path)
        if query_embedding is None:
            return []
        return self.get_similar_tracks(query_embedding, k)

def load_embeddings():
    embedding_file = os.path.join(CHECKPOINT_DIR, 'embeddings.npy')
    valid_tracks_file = os.path.join(CHECKPOINT_DIR, 'valid_tracks.txt')
    print(embedding_file, valid_tracks_file)

    print("Loading pre-computed embeddings...")
    embeddings = np.load(embedding_file).astype('float32')  # Ensure float32
    with open(valid_tracks_file, 'r') as f:
        valid_tracks = [int(line.strip()) for line in f]
    valid_metadata = metadata[metadata.index.isin(valid_tracks)]
    return valid_metadata, embeddings


In [59]:
valid_metadata, embeddings = load_embeddings()

with open(os.path.join(CHECKPOINT_DIR, 'valid_tracks.txt'), 'r') as f:
    track_ids = [int(line.strip()) for line in f]

valid_metadata = metadata[metadata.index.isin(track_ids)]

music_search = MusicSimilaritySearchANNPQ(valid_metadata, embeddings, track_ids)


/content/drive/My Drive/embeddings.npy /content/drive/My Drive/valid_tracks.txt
Loading pre-computed embeddings...


In [60]:
input_mp3 = "006331.mp3"

if os.path.exists(input_mp3):
    print(f"\nFinding similar tracks for: {input_mp3}")
    similar_tracks = music_search.search_by_audio(input_mp3)

    for i, track in enumerate(similar_tracks, 1):
        print(f"\nSimilar track #{i}:")
        print(f"Title: {track['title']}")
        print(f"Artist: {track['artist']}")
        print(f"Album: {track['album']}")
        print(f"Similarity: {1 - track['distance']:.3f}")
        print(f"Audio path: {track['audio_path']}")
else:
    print(f"Input file not found: {input_mp3}")



Finding similar tracks for: 006331.mp3

Similar track #1:
Title: Crossover
Artist: EPMD
Album: Live at ATP 2008
Similarity: 0.071
Audio path: /content/drive/My Drive/fma_small/006/006439.mp3

Similar track #2:
Title: one last try
Artist: Chicken Jones
Album: Sour Soul 12
Similarity: 0.079
Audio path: /content/drive/My Drive/fma_small/043/043867.mp3

Similar track #3:
Title: Ciganka Je Malena
Artist: Gogofski
Album: Live at the 2016 Golden Festival
Similarity: 0.085
Audio path: /content/drive/My Drive/fma_small/132/132139.mp3

Similar track #4:
Title: Success
Artist: K. Sparks
Album: Diagnosis: Success
Similarity: 0.085
Audio path: /content/drive/My Drive/fma_small/054/054365.mp3

Similar track #5:
Title: Enyi Wana Damu  Sharmila and Black Star Orchestra
Artist: The Sounds of Taraab
Album: Live at WFMU on Rob Weisberg's Show on 4/7/2007
Similarity: 0.085
Audio path: /content/drive/My Drive/fma_small/004/004236.mp3


In [61]:
import time
from sklearn.metrics.pairwise import cosine_similarity

def evaluate_model(music_search, valid_metadata, embeddings, valid_track_ids, k=5, sim_threshold=0.85):
    top_k_hits = 0
    reciprocal_ranks = []
    total_queries = 0
    total_latency = 0
    total_fp = 0
    total_sim_scores = []
    simulated_clicks = 0
    session_lengths = []

    for i, query_id in enumerate(valid_track_ids):
        query_path = get_audio_path(query_id)
        if not os.path.exists(query_path):
            continue

        start_time = time.time()
        query_embedding = audio_to_embedding(query_path)
        if query_embedding is None:
            continue

        results = music_search.get_similar_tracks(query_embedding, k)
        latency = time.time() - start_time
        total_latency += latency

        total_queries += 1
        correct_found = False
        session_length = 0

        for rank, result in enumerate(results, 1):
            retrieved_id = result['track_id']
            sim = 1 - result['distance']
            total_sim_scores.append(sim)
            session_length += 1

            if retrieved_id == query_id:
                top_k_hits += 1
                reciprocal_ranks.append(1 / rank)
                correct_found = True

            elif sim > sim_threshold:
                simulated_clicks += 1
            else:
                total_fp += 1

        if not correct_found:
            reciprocal_ranks.append(0)

        session_lengths.append(session_length if correct_found else 1)

    top_k_accuracy = top_k_hits / total_queries
    mrr = sum(reciprocal_ranks) / total_queries
    avg_cos_sim = sum(total_sim_scores) / len(total_sim_scores)
    false_positive_rate = total_fp / (total_queries * k)

    avg_latency = total_latency / total_queries
    avg_session_length = sum(session_lengths) / total_queries

    return {
        "Top-5 Accuracy": top_k_accuracy,
        "Mean Reciprocal Rank": mrr,
        "Average Cosine Similarity": avg_cos_sim,
        "False Positive Rate": false_positive_rate,
        "Search Latency (s)": avg_latency,
        "Simulated Session Length": avg_session_length
    }



In [62]:
track_ids = track_ids[:50]
metrics = evaluate_model(music_search, valid_metadata, embeddings, track_ids)
print("\n--- Evaluation Results ---")
for metric, value in metrics.items():
    print(f"{metric}: {value:.4f}")


--- Evaluation Results ---
Top-5 Accuracy: 0.0200
Mean Reciprocal Rank: 0.0200
Average Cosine Similarity: 0.1472
False Positive Rate: 0.9960
Search Latency (s): 1.6852
Simulated Conversion Rate: 0.0000
Simulated Session Length: 1.0800


In [63]:
from sklearn.neighbors import NearestNeighbors

class MusicSimilaritySearchKDTree:
    def __init__(self, metadata, embeddings, track_ids, audio_dir=AUDIO_DIR, n_neighbors=5):
        self.metadata = metadata
        self.audio_dir = audio_dir
        self.embeddings = embeddings
        self.track_ids = track_ids
        self.nn_model = NearestNeighbors(n_neighbors=n_neighbors, algorithm='kd_tree', metric='euclidean')
        self.nn_model.fit(embeddings)

    def get_similar_tracks(self, query_embedding, k=5):
        query_embedding = query_embedding.astype('float32').reshape(1, -1)
        distances, indices = self.nn_model.kneighbors(query_embedding, n_neighbors=k)
        similar_tracks = []
        for i in range(k):
            idx = indices[0][i]
            if idx >= len(self.track_ids):
                continue

            track_id = self.track_ids[idx]
            if track_id not in self.metadata.index:
                continue

            row = self.metadata.loc[track_id]
            track_info = {
                'track_id': track_id,
                'title': row['track_title'],
                'artist': row['artist_name'],
                'album': row['album_title'],
                'distance': distances[0][i],
                'audio_path': get_audio_path(track_id)
            }
            similar_tracks.append(track_info)
        return similar_tracks

    def search_by_audio(self, input_mp3_path, k=5):
        query_embedding = audio_to_embedding(input_mp3_path)
        if query_embedding is None:
            return []
        return self.get_similar_tracks(query_embedding, k)


In [64]:
valid_metadata, embeddings = load_embeddings()

with open(os.path.join(CHECKPOINT_DIR, 'valid_tracks.txt'), 'r') as f:
    track_ids = [int(line.strip()) for line in f]

valid_metadata = metadata[metadata.index.isin(track_ids)]

music_search = MusicSimilaritySearchKDTree(valid_metadata, embeddings, track_ids)


/content/drive/My Drive/embeddings.npy /content/drive/My Drive/valid_tracks.txt
Loading pre-computed embeddings...


In [65]:
# Run evaluation
track_ids = track_ids[:50]
metrics = evaluate_model(music_search, valid_metadata, embeddings, track_ids)
print("\n--- Evaluation Results ---")
for metric, value in metrics.items():
    print(f"{metric}: {value:.4f}")


--- Evaluation Results ---
Top-5 Accuracy: 0.0200
Mean Reciprocal Rank: 0.0100
Average Cosine Similarity: -1.3704
False Positive Rate: 0.9080
Search Latency (s): 1.6754
Simulated Conversion Rate: 0.0840
Simulated Session Length: 1.0800
