In [1]:
import pickle
import random
from pathlib import Path
import json
import music_tag
import numpy as np

from music_utils import *

from pymilvus import connections, utility
from pymilvus import Collection, DataType, FieldSchema, CollectionSchema

In [2]:
DATASET = Path("Music_Dataset")

# check if the .pkl files are valid
pkl_files = list(DATASET.rglob('*.pkl'))
valid_files = [check_file_info(pkl_file) for pkl_file in pkl_files]
print(f"Number of valid files: {sum(valid_files)} | Number of invalid files: {len(valid_files) - sum(valid_files)}")

Number of valid files: 174 | Number of invalid files: 0


In [3]:
types87 = []
types512 = []

for pkl_path in pkl_files:
    with pkl_path.open('rb') as f:
        try:
            content = pickle.load(f)
            embed = content.get('embedding_512')
            if embed is not None:
                types87.append(type(embed))
                types512.append(type(embed))
        except Exception as e:
            print(f"Error reading file {pkl_path}: {e}")

unique_types87 = set(types87)
unique_types512 = set(types512)

print(f"Unique types 87: {unique_types87}")
print(f"Unique types 512: {unique_types512}")

Unique types 87: {<class 'numpy.ndarray'>}
Unique types 512: {<class 'numpy.ndarray'>}


In [4]:
print_info(pick_random_mp3(DATASET))

filename: 08 Dub Pistols - Point Blank - Point Blank.mp3
filepath: Music_Dataset/Dub_Pistols/08 Dub Pistols - Point Blank - Point Blank.mp3
folder: Music_Dataset/Dub_Pistols
filesize: 10.28
title: Point Blank
artist: Dub Pistols
album: Point Blank
year: 1998
tracknumber: 8
genre: Electronic
predictions_87: [9.8139979e-04 2.0125259e-03 4.8705768e-03 2.5960850e-02 6.5091685e-03
 5.3468816e-02 1.0176755e-03 5.2671704e-02 7.7658012e-03 1.0173839e-03
 2.7207640e-04 7.1531860e-04 2.4420131e-02 2.5757801e-04 2.5672841e-04
 1.9329896e-02 4.4482294e-04 3.4126728e-03 3.4647752e-04 2.7871925e-02
 4.2696315e-04 1.1454568e-03 1.5394966e-01 1.8224511e-03 1.6941880e-03
 6.3558104e-03 1.6666764e-02 1.2132645e-02 1.0800989e-02 4.7003655e-03
 4.9221735e-03 4.0789612e-02 6.1270767e-03 7.0964789e-01 8.2458956e-03
 3.1663079e-02 4.1977563e-03 9.9393986e-03 7.8052819e-02 1.3477497e-03
 3.4186963e-02 7.5531667e-03 1.0674270e-02 4.0053626e-04 5.8707679e-03
 1.5159677e-03 7.7537581e-02 6.8314813e-02 3.9541754e

In [5]:
########## Milvus Client ##########

In [6]:
from dotenv import load_dotenv
import os
load_dotenv()

URI = os.getenv("MILVUS_URI")
TOKEN = os.getenv("MILVUS_TOKEN")

In [None]:
# connect to milvus
connections.connect("default",
                    uri=URI,
                    token=TOKEN)
print(f"Connecting to DB: {URI}")

In [8]:
########## predictions_87 ##########

In [9]:
print(utility.list_collections())

['predictions_87', 'embeddings_512']


In [10]:
# Check if the collection exists
check_collection_name = 'predictions_87'

check_collection = utility.has_collection(check_collection_name )
if check_collection:
    drop_result = utility.drop_collection(check_collection_name )

print(utility.list_collections())

['embeddings_512']


In [11]:
# create a collection for prediction_87
collection_name = "predictions_87"
dimension = 87

# Define the schema
schema = CollectionSchema(
    fields=[
        FieldSchema(name="id",dtype=DataType.INT64,is_primary=True,auto_id=False,max_length=100),
        FieldSchema(name="path", dtype=DataType.VARCHAR, max_length=280),
        FieldSchema(name="title", dtype=DataType.VARCHAR, max_length=220),
        FieldSchema(name="artist", dtype=DataType.VARCHAR, max_length=120),
        FieldSchema(name="album", dtype=DataType.VARCHAR, max_length=240),
        FieldSchema(name="predictions", dtype=DataType.FLOAT_VECTOR, dim=dimension),
        FieldSchema(name="top_5_genres", dtype=DataType.ARRAY, element_type=DataType.VARCHAR, max_capacity=5, max_length=100),
        ],
    enable_dynamic_field=True,  # Optional, defaults to 'False'.
)

print(f"Creating example collection: {collection_name}")
collection = Collection(collection_name, schema)
print(f"Schema: {schema}")
print("Success!")

Creating example collection: predictions_87
Schema: {'auto_id': False, 'description': '', 'fields': [{'name': 'id', 'description': '', 'type': <DataType.INT64: 5>, 'is_primary': True, 'auto_id': False}, {'name': 'path', 'description': '', 'type': <DataType.VARCHAR: 21>, 'params': {'max_length': 280}}, {'name': 'title', 'description': '', 'type': <DataType.VARCHAR: 21>, 'params': {'max_length': 220}}, {'name': 'artist', 'description': '', 'type': <DataType.VARCHAR: 21>, 'params': {'max_length': 120}}, {'name': 'album', 'description': '', 'type': <DataType.VARCHAR: 21>, 'params': {'max_length': 240}}, {'name': 'predictions', 'description': '', 'type': <DataType.FLOAT_VECTOR: 101>, 'params': {'dim': 87}}, {'name': 'top_5_genres', 'description': '', 'type': <DataType.ARRAY: 22>, 'params': {'max_length': 100, 'max_capacity': 5}, 'element_type': <DataType.VARCHAR: 21>}], 'enable_dynamic_field': True}
Success!


In [12]:
import logging
from multiprocessing import Pool

# Set up logging
logging.basicConfig(filename='app.log', filemode='w', format='%(name)s - %(levelname)s - %(message)s')


def process_file(args):
    i, path = args
    try:
        with open(path, "rb") as f:
            data = pickle.load(f)
            filepath = data.get("filepath")
            title = data.get("title")
            album = data.get("album")
            artist = data.get("artist")
            predictions = data.get("predictions_87")
            top_5_genres = data.get("top_5_genres")

        if isinstance(predictions, np.ndarray):
            return (i, filepath, title, album, artist, predictions, top_5_genres)
    except Exception as e:
        logging.error(f"Error processing file {path}: {str(e)}")
        return None

In [13]:
# Use a multiprocessing pool to process the files in parallel
with Pool() as p:
    results = p.map(process_file, enumerate(pkl_files))

# Filter out None results
results = [r for r in results if r is not None]

# Batch size
batch_size = 18

fails = []

# Insert the data into the collection
for i in range(0, len(results), batch_size):
    id_batch, path_batch, title_batch, album_batch, artist_batch, predictions_batch, top_5_genres_batch = zip(*results[i : i + batch_size])
    documents = [
        {
            "id": id,
            "path": path,
            "title": title,
            "album": album,
            "artist": artist,
            "predictions": predictions.tolist(),
            "top_5_genres": top_5_genres
        }
        for id, path, title, album, artist, predictions, top_5_genres in zip(
            id_batch, path_batch, title_batch, album_batch, artist_batch, predictions_batch, top_5_genres_batch,
        )
    ]
    try:
        collection.insert(documents)
        print(f"Inserted {i + batch_size} records into collection {collection.name}.")
    except Exception as e:
        fails.append((i + batch_size, documents))
        print(f"Error inserting batch {i + batch_size} into collection {collection.name}. Error: {str(e)}")



Inserted 18 records into collection predictions_87.
Inserted 36 records into collection predictions_87.
Inserted 54 records into collection predictions_87.
Inserted 72 records into collection predictions_87.
Inserted 90 records into collection predictions_87.
Inserted 108 records into collection predictions_87.
Inserted 126 records into collection predictions_87.
Inserted 144 records into collection predictions_87.
Inserted 162 records into collection predictions_87.
Inserted 180 records into collection predictions_87.


In [14]:
print(f"Failed batches: {fails}")

Failed batches: []


In [15]:
index = {
    "index_type": "IVF_FLAT",
    "metric_type": "L2",
    "params": {"nlist": 128},
}
collection.create_index(field_name="predictions", index_params=index)

Status(code=0, message=)

In [16]:
collection.load()

In [17]:
random_song = pick_random_mp3(DATASET).with_suffix(".pkl")
print(f"Random mp3: {random_song}")
print(get_top_5_genres(random_song, "mtg_jamendo_genre.json"))

with open(random_song, "rb") as f:
    file_info = pickle.load(f)
    query_embed = file_info.get("predictions_87")
    
search_params = {"metric_type": "L2", "params": {"nprobe": 10}}
result = collection.search([query_embed], "predictions", search_params, limit=10,offset=1, output_fields=["path"])

for element in result[0]:
    print(element)

Random mp3: Music_Dataset/Groundation/01_fight_all_you_can.pkl
['reggae', 'dub', 'rock', 'world', 'pop']
id: 93, distance: 0.007208452094346285, entity: {'path': 'Music_Dataset/Dub_Incorporation/04 - Rudeboy.mp3'}
id: 69, distance: 0.008759484626352787, entity: {'path': 'Music_Dataset/Damian_Marley/06_-_Kingston_12.mp3'}
id: 140, distance: 0.008851049467921257, entity: {'path': 'Music_Dataset/Groundation/03_Young_Tree.mp3'}
id: 139, distance: 0.009532451629638672, entity: {'path': 'Music_Dataset/Groundation/02_Congress_Man.mp3'}
id: 79, distance: 0.010738189332187176, entity: {'path': 'Music_Dataset/Doniki/01 My Bredren.mp3'}
id: 78, distance: 0.013725015334784985, entity: {'path': 'Music_Dataset/Doniki/02 Eretrians.mp3'}
id: 70, distance: 0.01601240038871765, entity: {'path': 'Music_Dataset/Damian_Marley/11_-_Me_Name_Jr._Gong.mp3'}
id: 99, distance: 0.01986054889857769, entity: {'path': 'Music_Dataset/Rockamovya/01 - Take The Night.mp3'}
id: 95, distance: 0.0212133526802063, entity: {

In [18]:
########## embedding_512 ##########

In [19]:
print(utility.list_collections())

['predictions_87', 'embeddings_512']


In [20]:
# Check if the collection exists
check_collection_name = 'embeddings_512'

check_collection = utility.has_collection(check_collection_name )
if check_collection:
    drop_result = utility.drop_collection(check_collection_name )

print(utility.list_collections())

['predictions_87']


In [21]:
# create a collection for embeddings_512
collection_name = "embeddings_512"
dimension = 512

# Define the schema
schema = CollectionSchema(
    fields=[
        FieldSchema(name="id",dtype=DataType.INT64,is_primary=True,auto_id=False,max_length=100),
        FieldSchema(name="path", dtype=DataType.VARCHAR, max_length=280),
        FieldSchema(name="title", dtype=DataType.VARCHAR, max_length=220),
        FieldSchema(name="album", dtype=DataType.VARCHAR, max_length=240),
        FieldSchema(name="artist", dtype=DataType.VARCHAR, max_length=120),
        FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=dimension),
        FieldSchema(name="top_5_genres",dtype=DataType.ARRAY,element_type=DataType.VARCHAR,max_capacity=5,max_length=100),
    ],
    enable_dynamic_field=True,  # Optional, defaults to 'False'.
)

print(f"Creating example collection: {collection_name}")
collection = Collection(collection_name, schema)
print(f"Schema: {schema}")
print("Success!")

Creating example collection: embeddings_512
Schema: {'auto_id': False, 'description': '', 'fields': [{'name': 'id', 'description': '', 'type': <DataType.INT64: 5>, 'is_primary': True, 'auto_id': False}, {'name': 'path', 'description': '', 'type': <DataType.VARCHAR: 21>, 'params': {'max_length': 280}}, {'name': 'title', 'description': '', 'type': <DataType.VARCHAR: 21>, 'params': {'max_length': 220}}, {'name': 'album', 'description': '', 'type': <DataType.VARCHAR: 21>, 'params': {'max_length': 240}}, {'name': 'artist', 'description': '', 'type': <DataType.VARCHAR: 21>, 'params': {'max_length': 120}}, {'name': 'embedding', 'description': '', 'type': <DataType.FLOAT_VECTOR: 101>, 'params': {'dim': 512}}, {'name': 'top_5_genres', 'description': '', 'type': <DataType.ARRAY: 22>, 'params': {'max_length': 100, 'max_capacity': 5}, 'element_type': <DataType.VARCHAR: 21>}], 'enable_dynamic_field': True}
Success!


In [22]:
def process_file(args):
    i, path = args
    try:
        with open(path, "rb") as f:
            data = pickle.load(f)
            filepath = data.get("filepath")
            title = data.get("title")
            album = data.get("album")
            artist = data.get("artist")
            predictions = data.get("embedding_512") #
            top_5_genres = data.get("top_5_genres")

        if isinstance(predictions, np.ndarray):
            return (i, filepath, title, album, artist, predictions, top_5_genres)
    except Exception as e:
        logging.error(f"Error processing file {path}: {str(e)}")
        return None

In [23]:
# Use a multiprocessing pool to process the files in parallel
with Pool() as p:
    results = p.map(process_file, enumerate(pkl_files))

# Filter out None results
results = [r for r in results if r is not None]

# Batch size
batch_size = 18

fails = []

# Insert the data into the collection
for i in range(0, len(results), batch_size):
    id_batch, path_batch, title_batch, album_batch, artist_batch, embeddings_batch, top_5_genres_batch = zip(*results[i : i + batch_size])
    documents = [
        {
            "id": id,
            "path": path,
            "title": title,
            "album": album,
            "artist": artist,
            "embedding": embeddings.tolist(),
            "top_5_genres": top_5_genres
        }
        for id, path, title, album, artist, embeddings, top_5_genres in zip(
            id_batch, path_batch, title_batch, album_batch, artist_batch, embeddings_batch, top_5_genres_batch,
        )
    ]
    try:
        collection.insert(documents)
        print(f"Inserted {i + batch_size} records into collection {collection.name}.")
    except Exception as e:
        fails.append((i + batch_size, documents))
        print(f"Error inserting batch {i + batch_size} into collection {collection.name}. Error: {str(e)}")



Inserted 18 records into collection embeddings_512.
Inserted 36 records into collection embeddings_512.
Inserted 54 records into collection embeddings_512.
Inserted 72 records into collection embeddings_512.
Inserted 90 records into collection embeddings_512.
Inserted 108 records into collection embeddings_512.
Inserted 126 records into collection embeddings_512.
Inserted 144 records into collection embeddings_512.
Inserted 162 records into collection embeddings_512.
Inserted 180 records into collection embeddings_512.


In [24]:
print(f"Failed batches: {fails}")

Failed batches: []


In [25]:
index = {
    "index_type": "IVF_FLAT",
    "metric_type": "L2",
    "params": {"nlist": 128},
}
collection.create_index(field_name="embedding", index_params=index)

Status(code=0, message=)

In [26]:
collection.load()

In [27]:
random_song = pick_random_mp3(DATASET).with_suffix(".pkl")
print(f"Random mp3: {random_song}")
print(get_top_5_genres(random_song, "mtg_jamendo_genre.json"))

with open(random_song, "rb") as f:
    file_info = pickle.load(f)
    query_embed = file_info.get("embedding_512")
    
search_params = {"metric_type": "L2", "params": {"nprobe": 10}}
result = collection.search([query_embed], "embedding", search_params, limit=10, offset=1, output_fields=["path"])
print('\n')
for element in result[0]:
    print(element)

Random mp3: Music_Dataset/Sofi_tukker/Sofi Tukker - awoo.pkl
['electronic', 'house', 'dance', 'pop', 'ambient']


id: 12, distance: 81.15666961669922, entity: {'path': 'Music_Dataset/Les_Blaireaux/13 Berlin.mp3'}
id: 13, distance: 87.51583862304688, entity: {'path': 'Music_Dataset/Les_Blaireaux/07 Natalia Poutine.mp3'}
id: 52, distance: 94.92430877685547, entity: {'path': 'Music_Dataset/Massive_Attack/Massive_Attack_-_Blue_Lines_-01-_Safe_from_Harm.mp3'}
id: 18, distance: 96.50566101074219, entity: {'path': 'Music_Dataset/Major_Lazer/03 Major Lazer - Get Free (feat. Amber of Dirty Projectors).mp3'}
id: 159, distance: 97.5042495727539, entity: {'path': 'Music_Dataset/A_Tribe_Called_Quest/01_Push_It_Along.mp3'}
id: 116, distance: 99.24071502685547, entity: {'path': 'Music_Dataset/Birdy_Nam_Nam/15 Abbesses.mp3'}
id: 67, distance: 100.14732360839844, entity: {'path': 'Music_Dataset/Adele/01 Adele - Rolling In The Deep.mp3'}
id: 51, distance: 102.6116943359375, entity: {'path': 'Music_Datas