In [1]:
import pickle
import random
from pathlib import Path
import json
import music_tag
import numpy as np
import os
from utils.music_utils import *

from pymilvus import connections, utility
from pymilvus import Collection, DataType, FieldSchema, CollectionSchema

In [2]:
DATASET = Path('MegaSet')
# count the mp3 and pkl files in DATASET

count_mp3 = 0   
count_pkl = 0
for root, dirs, files in os.walk(DATASET):
    for file in files:
        if file.endswith('.mp3'):
            count_mp3 += 1
        if file.endswith('.pkl'):
            count_pkl += 1
print(f"founds {count_mp3} mp3 files and {count_pkl} pkl files")

founds 11637 mp3 files and 11637 pkl files


In [3]:
pkl_files = list(DATASET.glob('**/*.pkl'))

types87 = []
types512 = []

for pkl_path in pkl_files:
    with pkl_path.open('rb') as f:
        try:
            content = pickle.load(f)
            embed = content.get('embedding_512')
            if embed is not None:
                types87.append(type(embed))
                types512.append(type(embed))
        except Exception as e:
            print(f"Error reading file {pkl_path}: {e}")

unique_types87 = set(types87)
unique_types512 = set(types512)

print(f"Unique types 87: {unique_types87}")
print(f"Unique types 512: {unique_types512}")

Unique types 87: {<class 'numpy.ndarray'>}
Unique types 512: {<class 'numpy.ndarray'>}


In [4]:
print_info(pick_random_mp3(DATASET))

filename: 12. Katie Melua - The House.mp3
filepath: MegaSet/Katie Melua/Katie Melua - 2010 - The House/12. Katie Melua - The House.mp3
folder: MegaSet/Katie Melua/Katie Melua - 2010 - The House
filesize: 5.88
title: The House
artist: Katie Melua
album: The House
year: 2010
tracknumber: 12
genre: Pop
predictions_87: [0.00402669 0.00223018 0.00242836 0.00310962 0.00212826 0.10296998
 0.00331248 0.15418217 0.03190089 0.02666511 0.00111806 0.00174515
 0.00142048 0.01870859 0.01331209 0.04728153 0.00393937 0.06470482
 0.00551657 0.00110773 0.00949835 0.01578013 0.00371254 0.00745916
 0.00482683 0.00067051 0.00139667 0.0121285  0.00109338 0.001048
 0.00138212 0.1109677  0.00097812 0.15695696 0.00206799 0.00722451
 0.01399639 0.00052159 0.0486715  0.14298022 0.00241574 0.00482373
 0.00082823 0.00294661 0.00091043 0.00219242 0.00794023 0.00247123
 0.00134016 0.00403215 0.0821307  0.00581988 0.03051211 0.01617099
 0.01726036 0.00155757 0.006372   0.04027506 0.00412078 0.01179634
 0.00461762 0.0

In [5]:
########## Milvus Client ##########

In [6]:
from dotenv import load_dotenv
import os
load_dotenv()

URI = os.getenv("MILVUS_URI")
TOKEN = os.getenv("MILVUS_TOKEN")

In [7]:
# connect to milvus
connections.connect("default",
                    uri=URI,
                    token=TOKEN)
print(f"Connecting to DB: {URI}")

Connecting to DB: https://in03-efa63c0579a14a1.api.gcp-us-west1.zillizcloud.com


In [8]:
########## predictions_87 ##########

In [9]:
print(utility.list_collections())

['embeddings_512', 'predictions_87']


In [10]:
# Check if the collection exists
check_collection_name = 'predictions_87'

check_collection = utility.has_collection(check_collection_name )
if check_collection:
    drop_result = utility.drop_collection(check_collection_name )

print(utility.list_collections())

['embeddings_512']


In [11]:
# create a collection for prediction_87
collection_name = "predictions_87"
dimension = 87

# Define the schema
schema = CollectionSchema(
    fields=[
        FieldSchema(name="id",dtype=DataType.INT64,is_primary=True,auto_id=False,max_length=100),
        FieldSchema(name="path", dtype=DataType.VARCHAR, max_length=280),
        FieldSchema(name="title", dtype=DataType.VARCHAR, max_length=220),
        FieldSchema(name="artist", dtype=DataType.VARCHAR, max_length=120),
        FieldSchema(name="album", dtype=DataType.VARCHAR, max_length=240),
        FieldSchema(name="predictions", dtype=DataType.FLOAT_VECTOR, dim=dimension),
        FieldSchema(name="top_5_genres", dtype=DataType.ARRAY, element_type=DataType.VARCHAR, max_capacity=5, max_length=100),
        ],
    enable_dynamic_field=True,  # Optional, defaults to 'False'.
)

print(f"Creating example collection: {collection_name}")
collection = Collection(collection_name, schema)
print(f"Schema: {schema}")
print("Success!")

Creating example collection: predictions_87
Schema: {'auto_id': False, 'description': '', 'fields': [{'name': 'id', 'description': '', 'type': <DataType.INT64: 5>, 'is_primary': True, 'auto_id': False}, {'name': 'path', 'description': '', 'type': <DataType.VARCHAR: 21>, 'params': {'max_length': 280}}, {'name': 'title', 'description': '', 'type': <DataType.VARCHAR: 21>, 'params': {'max_length': 220}}, {'name': 'artist', 'description': '', 'type': <DataType.VARCHAR: 21>, 'params': {'max_length': 120}}, {'name': 'album', 'description': '', 'type': <DataType.VARCHAR: 21>, 'params': {'max_length': 240}}, {'name': 'predictions', 'description': '', 'type': <DataType.FLOAT_VECTOR: 101>, 'params': {'dim': 87}}, {'name': 'top_5_genres', 'description': '', 'type': <DataType.ARRAY: 22>, 'params': {'max_length': 100, 'max_capacity': 5}, 'element_type': <DataType.VARCHAR: 21>}], 'enable_dynamic_field': True}
Success!


In [12]:
import logging
from multiprocessing import Pool

# Set up logging
logging.basicConfig(filename='app.log', filemode='w', format='%(name)s - %(levelname)s - %(message)s')


def process_file(args):
    i, path = args
    try:
        with open(path, "rb") as f:
            data = pickle.load(f)
            filepath = data.get("filepath")
            title = data.get("title")
            album = data.get("album")
            artist = data.get("artist")
            predictions = data.get("predictions_87")
            top_5_genres = data.get("top_5_genres")

        if isinstance(predictions, np.ndarray):
            return (i, filepath, title, album, artist, predictions, top_5_genres)
    except Exception as e:
        logging.error(f"Error processing file {path}: {str(e)}")
        return None

In [13]:
# Use a multiprocessing pool to process the files in parallel
with Pool() as p:
    results = p.map(process_file, enumerate(pkl_files))

# Filter out None results
results = [r for r in results if r is not None]

# Batch size
batch_size = 18

fails = []

# Insert the data into the collection
for i in range(0, len(results), batch_size):
    id_batch, path_batch, title_batch, album_batch, artist_batch, predictions_batch, top_5_genres_batch = zip(*results[i : i + batch_size])
    documents = [
        {
            "id": id,
            "path": path,
            "title": title,
            "album": album,
            "artist": artist,
            "predictions": predictions.tolist(),
            "top_5_genres": top_5_genres
        }
        for id, path, title, album, artist, predictions, top_5_genres in zip(
            id_batch, path_batch, title_batch, album_batch, artist_batch, predictions_batch, top_5_genres_batch,
        )
    ]
    try:
        collection.insert(documents)
        print(f"Inserted {i + batch_size} records into collection {collection.name}.")
    except Exception as e:
        fails.append((i + batch_size, documents))
        print(f"Error inserting batch {i + batch_size} into collection {collection.name}. Error: {str(e)}")

# 2m 11.0s

Inserted 18 records into collection predictions_87.
Inserted 36 records into collection predictions_87.
Inserted 54 records into collection predictions_87.
Inserted 72 records into collection predictions_87.
Inserted 90 records into collection predictions_87.
Inserted 108 records into collection predictions_87.
Inserted 126 records into collection predictions_87.
Inserted 144 records into collection predictions_87.
Inserted 162 records into collection predictions_87.
Inserted 180 records into collection predictions_87.
Inserted 198 records into collection predictions_87.
Inserted 216 records into collection predictions_87.
Inserted 234 records into collection predictions_87.
Inserted 252 records into collection predictions_87.
Inserted 270 records into collection predictions_87.
Inserted 288 records into collection predictions_87.
Inserted 306 records into collection predictions_87.
Inserted 324 records into collection predictions_87.
Inserted 342 records into collection predictions_87

In [14]:
print(f"Failed batches: {fails}")

Failed batches: []


In [15]:
index = {
    "index_type": "IVF_FLAT",
    "metric_type": "L2",
    "params": {"nlist": 128},
}
collection.create_index(field_name="predictions", index_params=index)

Status(code=0, message=)

In [16]:
collection.load()

In [17]:
random_song = pick_random_mp3(DATASET).with_suffix(".pkl")
print(f"Random mp3: {random_song}")
print(get_top_5_genres(random_song, "utils/mtg_jamendo_genre.json"))

with open(random_song, "rb") as f:
    file_info = pickle.load(f)
    query_embed = file_info.get("predictions_87")
    
search_params = {"metric_type": "L2", "params": {"nprobe": 10}}
result = collection.search([query_embed], "predictions", search_params, limit=10,offset=1, output_fields=["path"])
print('\n')
for element in result[0]:
    print(element)

Random mp3: MegaSet/Casualty/Casualty - Version 5_2 (2008)/008_casualty_-_bad_dreams.pkl
['reggae', 'electronic', 'dub', 'experimental', 'alternative']


id: 2288, distance: 0.012982561253011227, entity: {'path': 'MegaSet/High Tone/High Tone-Bass Temperature/Piste 11 [2].mp3'}
id: 2252, distance: 0.014431831426918507, entity: {'path': 'MegaSet/High Tone/Hight Tone & Improvisators Dub - Highvisators - 100% Dub/06. U-Man Bass.mp3'}
id: 2283, distance: 0.014883620664477348, entity: {'path': 'MegaSet/High Tone/High Tone-Bass Temperature/Piste 05 [2].mp3'}
id: 6164, distance: 0.016342703253030777, entity: {'path': 'MegaSet/Jah Shaka meets Mad Professor - At Ariwa Sounds (1996)/09. Living Dub.mp3'}
id: 10296, distance: 0.020877933129668236, entity: {'path': 'MegaSet/The Black Seeds/The Black Seeds - [2008] Solid Ground/13-the_black_seeds--make_a_move_dub-oma.mp3'}
id: 6134, distance: 0.021576734259724617, entity: {'path': 'MegaSet/Casualty/Casualty - Version 5_2 (2008)/009_casualty_-_diablo_

In [18]:
########## embedding_512 ##########

In [19]:
print(utility.list_collections())

['predictions_87', 'embeddings_512']


In [20]:
# Check if the collection exists
check_collection_name = 'embeddings_512'

check_collection = utility.has_collection(check_collection_name )
if check_collection:
    drop_result = utility.drop_collection(check_collection_name )

print(utility.list_collections())

['predictions_87']


In [21]:
# create a collection for embeddings_512
collection_name = "embeddings_512"
dimension = 512

# Define the schema
schema = CollectionSchema(
    fields=[
        FieldSchema(name="id",dtype=DataType.INT64,is_primary=True,auto_id=False,max_length=100),
        FieldSchema(name="path", dtype=DataType.VARCHAR, max_length=280),
        FieldSchema(name="title", dtype=DataType.VARCHAR, max_length=220),
        FieldSchema(name="album", dtype=DataType.VARCHAR, max_length=240),
        FieldSchema(name="artist", dtype=DataType.VARCHAR, max_length=120),
        FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=dimension),
        FieldSchema(name="top_5_genres",dtype=DataType.ARRAY,element_type=DataType.VARCHAR,max_capacity=5,max_length=100),
    ],
    enable_dynamic_field=True,  # Optional, defaults to 'False'.
)

print(f"Creating example collection: {collection_name}")
collection = Collection(collection_name, schema)
print(f"Schema: {schema}")
print("Success!")

Creating example collection: embeddings_512
Schema: {'auto_id': False, 'description': '', 'fields': [{'name': 'id', 'description': '', 'type': <DataType.INT64: 5>, 'is_primary': True, 'auto_id': False}, {'name': 'path', 'description': '', 'type': <DataType.VARCHAR: 21>, 'params': {'max_length': 280}}, {'name': 'title', 'description': '', 'type': <DataType.VARCHAR: 21>, 'params': {'max_length': 220}}, {'name': 'album', 'description': '', 'type': <DataType.VARCHAR: 21>, 'params': {'max_length': 240}}, {'name': 'artist', 'description': '', 'type': <DataType.VARCHAR: 21>, 'params': {'max_length': 120}}, {'name': 'embedding', 'description': '', 'type': <DataType.FLOAT_VECTOR: 101>, 'params': {'dim': 512}}, {'name': 'top_5_genres', 'description': '', 'type': <DataType.ARRAY: 22>, 'params': {'max_length': 100, 'max_capacity': 5}, 'element_type': <DataType.VARCHAR: 21>}], 'enable_dynamic_field': True}
Success!


In [22]:
def process_file(args):
    i, path = args
    try:
        with open(path, "rb") as f:
            data = pickle.load(f)
            filepath = data.get("filepath")
            title = data.get("title")
            album = data.get("album")
            artist = data.get("artist")
            predictions = data.get("embedding_512") #
            top_5_genres = data.get("top_5_genres")

        if isinstance(predictions, np.ndarray):
            return (i, filepath, title, album, artist, predictions, top_5_genres)
    except Exception as e:
        logging.error(f"Error processing file {path}: {str(e)}")
        return None

In [23]:
# Use a multiprocessing pool to process the files in parallel
with Pool() as p:
    results = p.map(process_file, enumerate(pkl_files))

# Filter out None results
results = [r for r in results if r is not None]

# Batch size
batch_size = 18

fails = []

# Insert the data into the collection
for i in range(0, len(results), batch_size):
    id_batch, path_batch, title_batch, album_batch, artist_batch, embeddings_batch, top_5_genres_batch = zip(*results[i : i + batch_size])
    documents = [
        {
            "id": id,
            "path": path,
            "title": title,
            "album": album,
            "artist": artist,
            "embedding": embeddings.tolist(),
            "top_5_genres": top_5_genres
        }
        for id, path, title, album, artist, embeddings, top_5_genres in zip(
            id_batch, path_batch, title_batch, album_batch, artist_batch, embeddings_batch, top_5_genres_batch,
        )
    ]
    try:
        collection.insert(documents)
        print(f"Inserted {i + batch_size} records into collection {collection.name}.")
    except Exception as e:
        fails.append((i + batch_size, documents))
        print(f"Error inserting batch {i + batch_size} into collection {collection.name}. Error: {str(e)}")

# 2m 28.2s

Inserted 18 records into collection embeddings_512.
Inserted 36 records into collection embeddings_512.
Inserted 54 records into collection embeddings_512.
Inserted 72 records into collection embeddings_512.
Inserted 90 records into collection embeddings_512.
Inserted 108 records into collection embeddings_512.
Inserted 126 records into collection embeddings_512.
Inserted 144 records into collection embeddings_512.
Inserted 162 records into collection embeddings_512.
Inserted 180 records into collection embeddings_512.
Inserted 198 records into collection embeddings_512.
Inserted 216 records into collection embeddings_512.
Inserted 234 records into collection embeddings_512.
Inserted 252 records into collection embeddings_512.
Inserted 270 records into collection embeddings_512.
Inserted 288 records into collection embeddings_512.
Inserted 306 records into collection embeddings_512.
Inserted 324 records into collection embeddings_512.
Inserted 342 records into collection embeddings_512

In [24]:
print(f"Failed batches: {fails}")

Failed batches: []


In [25]:
index = {
    "index_type": "IVF_FLAT",
    "metric_type": "L2",
    "params": {"nlist": 128},
}
collection.create_index(field_name="embedding", index_params=index)

Status(code=0, message=)

In [26]:
collection.load()

In [27]:
random_song = pick_random_mp3(DATASET).with_suffix(".pkl")
print(f"Random mp3: {random_song}")
print(get_top_5_genres(random_song, "utils/mtg_jamendo_genre.json"))

with open(random_song, "rb") as f:
    file_info = pickle.load(f)
    query_embed = file_info.get("embedding_512")
    
search_params = {"metric_type": "L2", "params": {"nprobe": 10}}
result = collection.search([query_embed], "embedding", search_params, limit=10, offset=1, output_fields=["path"])
print('\n')
for element in result[0]:
    print(element)

Random mp3: MegaSet/Dub Incorporation/2005 - Dans Le Décor/09 - Face à soi.pkl
['reggae', 'dub', 'hiphop', 'rock', 'electronic']


id: 2794, distance: 8.968767166137695, entity: {'path': 'MegaSet/Dub Incorporation/2005 - Dans Le Décor/02 - One Shot.mp3'}
id: 2802, distance: 9.882335662841797, entity: {'path': 'MegaSet/Dub Incorporation/2005 - Dans Le Décor/06 - Décor.mp3'}
id: 2804, distance: 10.682575225830078, entity: {'path': 'MegaSet/Dub Incorporation/2005 - Dans Le Décor/01 - Survie.mp3'}
id: 2805, distance: 12.890682220458984, entity: {'path': 'MegaSet/Dub Incorporation/2005 - Dans Le Décor/04 - Chaines.mp3'}
id: 2797, distance: 13.144603729248047, entity: {'path': 'MegaSet/Dub Incorporation/2005 - Dans Le Décor/10 - Speed ( Feat. David Hinds).mp3'}
id: 2795, distance: 13.314383506774902, entity: {'path': 'MegaSet/Dub Incorporation/2005 - Dans Le Décor/12 - Never Stop.mp3'}
id: 2798, distance: 14.644083023071289, entity: {'path': 'MegaSet/Dub Incorporation/2005 - Dans Le Décor/11