In [1]:
import pickle
import random
from pathlib import Path
import json
import music_tag
import numpy as np
import os
from utils.music_utils import *

from pymilvus import connections, utility
from pymilvus import Collection, DataType, FieldSchema, CollectionSchema

In [2]:
DATASET = Path('MegaSet')
# count the mp3 and pkl files in DATASET

count_mp3 = 0   
count_pkl = 0
for root, dirs, files in os.walk(DATASET):
    for file in files:
        if file.endswith('.mp3'):
            count_mp3 += 1
        if file.endswith('.pkl'):
            count_pkl += 1
print(f"founds {count_mp3} mp3 files and {count_pkl} pkl files")

founds 16428 mp3 files and 16428 pkl files


In [3]:
pkl_files = list(DATASET.glob('**/*.pkl'))

types87 = []
types512 = []

for pkl_path in pkl_files:
    with pkl_path.open('rb') as f:
        try:
            content = pickle.load(f)
            embed = content.get('embedding_512')
            if embed is not None:
                types87.append(type(embed))
                types512.append(type(embed))
        except Exception as e:
            print(f"Error reading file {pkl_path}: {e}")

unique_types87 = set(types87)
unique_types512 = set(types512)

print(f"Unique types 87: {unique_types87}")
print(f"Unique types 512: {unique_types512}")

Unique types 87: {<class 'numpy.ndarray'>}
Unique types 512: {<class 'numpy.ndarray'>}


In [4]:
print_info(pick_random_mp3(DATASET))

filename: 05-Ligature Marks.mp3
filepath: MegaSet/Meshuggah/Immutable/05-Ligature Marks.mp3
folder: MegaSet/Meshuggah/Immutable
filesize: 11.62
title: Ligature Marks
artist: Meshuggah
album: Immutable
year: 2022
tracknumber: 5
genre: Metal
predictions_87: [6.1419181e-04 8.6471241e-04 2.6310179e-03 4.6148770e-03 2.8302471e-04
 3.1838280e-01 1.6303260e-02 2.6449401e-02 8.3537111e-03 5.3867460e-03
 9.6747337e-04 9.1023845e-05 1.9526081e-03 1.3612040e-03 2.1957252e-04
 1.7139756e-03 1.9567600e-03 1.5385854e-02 1.7087556e-03 2.3831488e-04
 1.3356119e-03 3.8597547e-04 1.2208153e-03 5.8145472e-03 7.9370597e-03
 1.3090110e-04 2.3401172e-04 1.0940991e-03 1.5364673e-03 3.2151589e-04
 1.0012333e-02 3.1285489e-03 5.3860265e-04 9.3508430e-02 7.4049999e-04
 1.0770928e-03 1.2291442e-03 1.1018844e-04 8.5543200e-02 3.4213746e-03
 2.3808875e-03 2.6843611e-03 6.2229228e-04 6.0239106e-02 8.3378209e-03
 5.5150371e-02 1.3578821e-03 8.1786246e-04 4.5463437e-04 6.9514581e-04
 4.3259062e-02 9.1353416e-02 3.767

In [5]:
########## Milvus Client ##########

In [6]:
from dotenv import load_dotenv
import os
load_dotenv()

URI = os.getenv("MILVUS_URI")
TOKEN = os.getenv("MILVUS_TOKEN")

In [7]:
# connect to milvus
connections.connect("default",
                    uri=URI,
                    token=TOKEN)
print(f"Connecting to DB: {URI}")

Connecting to DB: https://in03-efa63c0579a14a1.api.gcp-us-west1.zillizcloud.com


In [8]:
########## predictions_87 ##########

In [9]:
print(utility.list_collections())

['predictions_87', 'embeddings_512']


In [10]:
# Check if the collection exists
check_collection_name = 'predictions_87'

check_collection = utility.has_collection(check_collection_name )
if check_collection:
    drop_result = utility.drop_collection(check_collection_name )

print(utility.list_collections())

['embeddings_512']


In [11]:
# create a collection for prediction_87
collection_name = "predictions_87"
dimension = 87

# Define the schema
schema = CollectionSchema(
    fields=[
        FieldSchema(name="id",dtype=DataType.INT64,is_primary=True,auto_id=False,max_length=100),
        FieldSchema(name="path", dtype=DataType.VARCHAR, max_length=280),
        FieldSchema(name="title", dtype=DataType.VARCHAR, max_length=220),
        FieldSchema(name="artist", dtype=DataType.VARCHAR, max_length=120),
        FieldSchema(name="album", dtype=DataType.VARCHAR, max_length=240),
        FieldSchema(name="predictions", dtype=DataType.FLOAT_VECTOR, dim=dimension),
        FieldSchema(name="top_5_genres", dtype=DataType.ARRAY, element_type=DataType.VARCHAR, max_capacity=5, max_length=100),
        ],
    enable_dynamic_field=True,  # Optional, defaults to 'False'.
)

print(f"Creating example collection: {collection_name}")
collection = Collection(collection_name, schema)
print(f"Schema: {schema}")
print("Success!")

Creating example collection: predictions_87
Schema: {'auto_id': False, 'description': '', 'fields': [{'name': 'id', 'description': '', 'type': <DataType.INT64: 5>, 'is_primary': True, 'auto_id': False}, {'name': 'path', 'description': '', 'type': <DataType.VARCHAR: 21>, 'params': {'max_length': 280}}, {'name': 'title', 'description': '', 'type': <DataType.VARCHAR: 21>, 'params': {'max_length': 220}}, {'name': 'artist', 'description': '', 'type': <DataType.VARCHAR: 21>, 'params': {'max_length': 120}}, {'name': 'album', 'description': '', 'type': <DataType.VARCHAR: 21>, 'params': {'max_length': 240}}, {'name': 'predictions', 'description': '', 'type': <DataType.FLOAT_VECTOR: 101>, 'params': {'dim': 87}}, {'name': 'top_5_genres', 'description': '', 'type': <DataType.ARRAY: 22>, 'params': {'max_length': 100, 'max_capacity': 5}, 'element_type': <DataType.VARCHAR: 21>}], 'enable_dynamic_field': True}
Success!


In [12]:
import logging
from multiprocessing import Pool

# Set up logging
logging.basicConfig(filename='app.log', filemode='w', format='%(name)s - %(levelname)s - %(message)s')


def process_file(args):
    i, path = args
    try:
        with open(path, "rb") as f:
            data = pickle.load(f)
            filepath = data.get("filepath")
            title = data.get("title")
            album = data.get("album")
            artist = data.get("artist")
            predictions = data.get("predictions_87")
            top_5_genres = data.get("top_5_genres")

        if isinstance(predictions, np.ndarray):
            return (i, filepath, title, album, artist, predictions, top_5_genres)
    except Exception as e:
        logging.error(f"Error processing file {path}: {str(e)}")
        return None

In [13]:
# Use a multiprocessing pool to process the files in parallel
with Pool() as p:
    results = p.map(process_file, enumerate(pkl_files))

# Filter out None results
results = [r for r in results if r is not None]

# Batch size
batch_size = 18

fails = []

# Insert the data into the collection
for i in range(0, len(results), batch_size):
    id_batch, path_batch, title_batch, album_batch, artist_batch, predictions_batch, top_5_genres_batch = zip(*results[i : i + batch_size])
    documents = [
        {
            "id": id,
            "path": path,
            "title": title,
            "album": album,
            "artist": artist,
            "predictions": predictions.tolist(),
            "top_5_genres": top_5_genres
        }
        for id, path, title, album, artist, predictions, top_5_genres in zip(
            id_batch, path_batch, title_batch, album_batch, artist_batch, predictions_batch, top_5_genres_batch,
        )
    ]
    try:
        collection.insert(documents)
        print(f"Inserted {i + batch_size} records into collection {collection.name}.")
    except Exception as e:
        fails.append((i + batch_size, documents))
        print(f"Error inserting batch {i + batch_size} into collection {collection.name}. Error: {str(e)}")

# 3m3s

Inserted 18 records into collection predictions_87.
Inserted 36 records into collection predictions_87.
Inserted 54 records into collection predictions_87.
Inserted 72 records into collection predictions_87.
Inserted 90 records into collection predictions_87.
Inserted 108 records into collection predictions_87.
Inserted 126 records into collection predictions_87.
Inserted 144 records into collection predictions_87.
Inserted 162 records into collection predictions_87.
Inserted 180 records into collection predictions_87.
Inserted 198 records into collection predictions_87.
Inserted 216 records into collection predictions_87.
Inserted 234 records into collection predictions_87.
Inserted 252 records into collection predictions_87.
Inserted 270 records into collection predictions_87.
Inserted 288 records into collection predictions_87.
Inserted 306 records into collection predictions_87.
Inserted 324 records into collection predictions_87.
Inserted 342 records into collection predictions_87

In [14]:
print(f"Failed batches: {fails}")

Failed batches: []


In [15]:
index = {
    "index_type": "IVF_FLAT",
    "metric_type": "L2",
    "params": {"nlist": 128},
}
collection.create_index(field_name="predictions", index_params=index)

Status(code=0, message=)

In [16]:
collection.load()

In [17]:
random_song = pick_random_mp3(DATASET).with_suffix(".pkl")
print(f"Random mp3: {random_song}")
print(get_top_5_genres(random_song, "utils/mtg_jamendo_genre.json"))

with open(random_song, "rb") as f:
    file_info = pickle.load(f)
    query_embed = file_info.get("predictions_87")
    
search_params = {"metric_type": "L2", "params": {"nprobe": 10}}
result = collection.search([query_embed], "predictions", search_params, limit=10,offset=1, output_fields=["path"])
print('\n')
for element in result[0]:
    print(element)

Random mp3: MegaSet/Damian Marley/1996 - Mr_ Marley/01 - Trouble.pkl
['reggae', 'electronic', 'hiphop', 'lounge', 'dub']


id: 2393, distance: 0.013562529347836971, entity: {'path': "MegaSet/Amy Whinehouse/Amy Winehouse - 2003 - Frank/06. Moody's Mood for Love - Teo Licks.mp3"}
id: 14626, distance: 0.015989627689123154, entity: {'path': 'MegaSet/Billy ze Kick/Verdure & libido/06 - dany le rouge.mp3'}
id: 10741, distance: 0.01756531372666359, entity: {'path': 'MegaSet/Nova Tunes/Nova Tunes 16/10-dj_vadim_feat._katherin_de_boer-black_is_the_night.mp3'}
id: 15151, distance: 0.01766939088702202, entity: {'path': 'MegaSet/Martin Jondo/Martin Jondo - Echo And Smoke [2006]/11 martin jondo - jah gringo.mp3'}
id: 13764, distance: 0.017970960587263107, entity: {'path': 'MegaSet/Mano Negra/(1994) casa babylone/12_Drives Me Crazy.mp3'}
id: 11886, distance: 0.01810826174914837, entity: {'path': 'MegaSet/Damian Marley/1996 - Mr_ Marley/08 - Searching (So Much Bubble).mp3'}
id: 1915, distance: 0.0189

In [18]:
########## embedding_512 ##########

In [19]:
print(utility.list_collections())

['embeddings_512', 'predictions_87']


In [20]:
# Check if the collection exists
check_collection_name = 'embeddings_512'

check_collection = utility.has_collection(check_collection_name )
if check_collection:
    drop_result = utility.drop_collection(check_collection_name )

print(utility.list_collections())

['predictions_87']


In [21]:
# create a collection for embeddings_512
collection_name = "embeddings_512"
dimension = 512

# Define the schema
schema = CollectionSchema(
    fields=[
        FieldSchema(name="id",dtype=DataType.INT64,is_primary=True,auto_id=False,max_length=100),
        FieldSchema(name="path", dtype=DataType.VARCHAR, max_length=280),
        FieldSchema(name="title", dtype=DataType.VARCHAR, max_length=220),
        FieldSchema(name="album", dtype=DataType.VARCHAR, max_length=240),
        FieldSchema(name="artist", dtype=DataType.VARCHAR, max_length=120),
        FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=dimension),
        FieldSchema(name="top_5_genres",dtype=DataType.ARRAY,element_type=DataType.VARCHAR,max_capacity=5,max_length=100),
    ],
    enable_dynamic_field=True,  # Optional, defaults to 'False'.
)

print(f"Creating example collection: {collection_name}")
collection = Collection(collection_name, schema)
print(f"Schema: {schema}")
print("Success!")

Creating example collection: embeddings_512
Schema: {'auto_id': False, 'description': '', 'fields': [{'name': 'id', 'description': '', 'type': <DataType.INT64: 5>, 'is_primary': True, 'auto_id': False}, {'name': 'path', 'description': '', 'type': <DataType.VARCHAR: 21>, 'params': {'max_length': 280}}, {'name': 'title', 'description': '', 'type': <DataType.VARCHAR: 21>, 'params': {'max_length': 220}}, {'name': 'album', 'description': '', 'type': <DataType.VARCHAR: 21>, 'params': {'max_length': 240}}, {'name': 'artist', 'description': '', 'type': <DataType.VARCHAR: 21>, 'params': {'max_length': 120}}, {'name': 'embedding', 'description': '', 'type': <DataType.FLOAT_VECTOR: 101>, 'params': {'dim': 512}}, {'name': 'top_5_genres', 'description': '', 'type': <DataType.ARRAY: 22>, 'params': {'max_length': 100, 'max_capacity': 5}, 'element_type': <DataType.VARCHAR: 21>}], 'enable_dynamic_field': True}
Success!


In [22]:
def process_file(args):
    i, path = args
    try:
        with open(path, "rb") as f:
            data = pickle.load(f)
            filepath = data.get("filepath")
            title = data.get("title")
            album = data.get("album")
            artist = data.get("artist")
            predictions = data.get("embedding_512") #
            top_5_genres = data.get("top_5_genres")

        if isinstance(predictions, np.ndarray):
            return (i, filepath, title, album, artist, predictions, top_5_genres)
    except Exception as e:
        logging.error(f"Error processing file {path}: {str(e)}")
        return None

In [23]:
# Use a multiprocessing pool to process the files in parallel
with Pool() as p:
    results = p.map(process_file, enumerate(pkl_files))

# Filter out None results
results = [r for r in results if r is not None]

# Batch size
batch_size = 18

fails = []

# Insert the data into the collection
for i in range(0, len(results), batch_size):
    id_batch, path_batch, title_batch, album_batch, artist_batch, embeddings_batch, top_5_genres_batch = zip(*results[i : i + batch_size])
    documents = [
        {
            "id": id,
            "path": path,
            "title": title,
            "album": album,
            "artist": artist,
            "embedding": embeddings.tolist(),
            "top_5_genres": top_5_genres
        }
        for id, path, title, album, artist, embeddings, top_5_genres in zip(
            id_batch, path_batch, title_batch, album_batch, artist_batch, embeddings_batch, top_5_genres_batch,
        )
    ]
    try:
        collection.insert(documents)
        print(f"Inserted {i + batch_size} records into collection {collection.name}.")
    except Exception as e:
        fails.append((i + batch_size, documents))
        print(f"Error inserting batch {i + batch_size} into collection {collection.name}. Error: {str(e)}")

# 3m 21s

Inserted 18 records into collection embeddings_512.
Inserted 36 records into collection embeddings_512.
Inserted 54 records into collection embeddings_512.
Inserted 72 records into collection embeddings_512.
Inserted 90 records into collection embeddings_512.
Inserted 108 records into collection embeddings_512.
Inserted 126 records into collection embeddings_512.
Inserted 144 records into collection embeddings_512.
Inserted 162 records into collection embeddings_512.
Inserted 180 records into collection embeddings_512.
Inserted 198 records into collection embeddings_512.
Inserted 216 records into collection embeddings_512.
Inserted 234 records into collection embeddings_512.
Inserted 252 records into collection embeddings_512.
Inserted 270 records into collection embeddings_512.
Inserted 288 records into collection embeddings_512.
Inserted 306 records into collection embeddings_512.
Inserted 324 records into collection embeddings_512.
Inserted 342 records into collection embeddings_512

In [24]:
print(f"Failed batches: {fails}")

Failed batches: []


In [25]:
index = {
    "index_type": "IVF_FLAT",
    "metric_type": "L2",
    "params": {"nlist": 128},
}
collection.create_index(field_name="embedding", index_params=index)

Status(code=0, message=)

In [26]:
collection.load()

In [27]:
random_song = pick_random_mp3(DATASET).with_suffix(".pkl")
print(f"Random mp3: {random_song}")
print(get_top_5_genres(random_song, "utils/mtg_jamendo_genre.json"))

with open(random_song, "rb") as f:
    file_info = pickle.load(f)
    query_embed = file_info.get("embedding_512")
    
search_params = {"metric_type": "L2", "params": {"nprobe": 10}}
result = collection.search([query_embed], "embedding", search_params, limit=10, offset=1, output_fields=["path"])
print('\n')
for element in result[0]:
    print(element)

Random mp3: MegaSet/Muse/Muse - 2009 - The Resistance/04. United States of Eurasia (+Collateral Damage).pkl
['rock', 'classical', 'alternative', 'soundtrack', 'pop']


id: 6130, distance: 16.25131607055664, entity: {'path': 'MegaSet/Muse/Muse - 2009 - The Resistance/10. Exogenesis Symphony Pt. 2 (Cross-pollination).mp3'}
id: 6126, distance: 18.423208236694336, entity: {'path': "MegaSet/Muse/Muse - 2009 - The Resistance/08. I Belong to You (+Mon Coeur S'Ouvre a Ta Voix).mp3"}
id: 238, distance: 20.80367660522461, entity: {'path': 'MegaSet/Norah Jones/Norah Jones - 2012 - Little Broken Hearts/05. Take It Back.mp3'}
id: 6135, distance: 21.548248291015625, entity: {'path': 'MegaSet/Muse/Muse - 2009 - The Resistance/02. Resistance.mp3'}
id: 15877, distance: 23.440109252929688, entity: {'path': 'MegaSet/John Mayer/John Mayer - 2006 - Continuum/10. John Mayer - Dreaming with a Broken Heart.mp3'}
id: 9463, distance: 23.50017738342285, entity: {'path': 'MegaSet/Boston/Third Stage/05-My Destinat