In [1]:
import pickle
import random
from pathlib import Path
import json
import music_tag
import numpy as np

from music_utils import *

from pymilvus import connections, utility
from pymilvus import Collection, DataType, FieldSchema, CollectionSchema

In [2]:
DATASET = Path("Music_Dataset")
# list the path of all the .pkl files
pkl_files = list(DATASET.rglob('*.pkl'))
# check if the .pkl files are valid
valid_files = [check_file_info(pkl_file) for pkl_file in pkl_files]
print(f"Number of valid files: {sum(valid_files)} | Number of invalid files: {len(valid_files) - sum(valid_files)}")

Number of valid files: 174 | Number of invalid files: 0


In [3]:
types87 = []
types512 = []

for pkl_path in pkl_files:
    with pkl_path.open('rb') as f:
        try:
            content = pickle.load(f)
            embed = content.get('embedding_512')
            if embed is not None:
                types87.append(type(embed))
                types512.append(type(embed))
        except Exception as e:
            print(f"Error reading file {pkl_path}: {e}")

unique_types87 = set(types87)
unique_types512 = set(types512)

print(f"Unique types 87: {unique_types87}")
print(f"Unique types 512: {unique_types512}")

Unique types 87: {<class 'numpy.ndarray'>}
Unique types 512: {<class 'numpy.ndarray'>}


In [4]:
print_info(pick_random_mp3(DATASET))

filename: 01_Rehab.mp3
filepath: Music_Dataset/Amy_Whinehouse/01_Rehab.mp3
folder: Music_Dataset/Amy_Whinehouse
filesize: 8.28
title: Rehab
artist: Amy Winehouse
album: Back To Black (Deluxe Edition)
year: None
tracknumber: 1
genre: Pop
predictions_87: [0.0137573  0.00787907 0.00550216 0.01150016 0.00887729 0.125214
 0.00893105 0.04464517 0.00872509 0.03716161 0.0035134  0.00543459
 0.01179131 0.00099303 0.00342366 0.02123244 0.00120388 0.00799035
 0.00206727 0.00677059 0.00117313 0.00716032 0.02799957 0.00169101
 0.00217736 0.00263762 0.01154235 0.01169348 0.01869302 0.01457836
 0.00332781 0.02558908 0.0033474  0.25421107 0.00414976 0.04483763
 0.00655554 0.00181256 0.04970998 0.02087118 0.06964471 0.01897717
 0.02035891 0.00426782 0.00467592 0.00557946 0.07791467 0.01519128
 0.00302347 0.00396947 0.1014649  0.0075494  0.02356417 0.01140175
 0.03421957 0.00418526 0.01481703 0.03088906 0.00073046 0.00478129
 0.00256177 0.00542741 0.00592084 0.00773648 0.34483925 0.0139124
 0.08931171 0

In [5]:
########## Milvus Client ##########

In [6]:
URI = "https://in03-efa63c0579a14a1.api.gcp-us-west1.zillizcloud.com"

TOKEN = "e58021f476f7b39e5d84eb5c804e27bfec1a7fb89b6e01f7560ac57877be699b9b1f109a2ba8fabefd2fa26f2efab109ebdd79f0"

In [7]:
# connect to milvus
connections.connect("default",
                    uri=URI,
                    token=TOKEN)
print(f"Connecting to DB: {URI}")

Connecting to DB: https://in03-efa63c0579a14a1.api.gcp-us-west1.zillizcloud.com


In [8]:
########## predictions_87 ##########

In [9]:
print(utility.list_collections())

['predictions_87', 'embeddings_512']


In [10]:
# Check if the collection exists
check_collection_name = 'predictions_87'

check_collection = utility.has_collection(check_collection_name )
if check_collection:
    drop_result = utility.drop_collection(check_collection_name )

print(utility.list_collections())

['embeddings_512']


In [11]:
# create a collection for prediction_87
collection_name = "predictions_87"
dimension = 87

# Define the schema
schema = CollectionSchema(
    fields=[
        FieldSchema(name="id",dtype=DataType.INT64,is_primary=True,auto_id=False,max_length=100),
        FieldSchema(name="path", dtype=DataType.VARCHAR, max_length=280),
        FieldSchema(name="title", dtype=DataType.VARCHAR, max_length=220),
        FieldSchema(name="artist", dtype=DataType.VARCHAR, max_length=120),
        FieldSchema(name="album", dtype=DataType.VARCHAR, max_length=240),
        FieldSchema(name="predictions", dtype=DataType.FLOAT_VECTOR, dim=dimension),
        FieldSchema(name="top_5_genres", dtype=DataType.ARRAY, element_type=DataType.VARCHAR, max_capacity=5, max_length=100),
        ],
    enable_dynamic_field=True,  # Optional, defaults to 'False'.
)

print(f"Creating example collection: {collection_name}")
collection = Collection(collection_name, schema)
print(f"Schema: {schema}")
print("Success!")

Creating example collection: predictions_87
Schema: {'auto_id': False, 'description': '', 'fields': [{'name': 'id', 'description': '', 'type': <DataType.INT64: 5>, 'is_primary': True, 'auto_id': False}, {'name': 'path', 'description': '', 'type': <DataType.VARCHAR: 21>, 'params': {'max_length': 280}}, {'name': 'title', 'description': '', 'type': <DataType.VARCHAR: 21>, 'params': {'max_length': 220}}, {'name': 'artist', 'description': '', 'type': <DataType.VARCHAR: 21>, 'params': {'max_length': 120}}, {'name': 'album', 'description': '', 'type': <DataType.VARCHAR: 21>, 'params': {'max_length': 240}}, {'name': 'predictions', 'description': '', 'type': <DataType.FLOAT_VECTOR: 101>, 'params': {'dim': 87}}, {'name': 'top_5_genres', 'description': '', 'type': <DataType.ARRAY: 22>, 'params': {'max_length': 100, 'max_capacity': 5}, 'element_type': <DataType.VARCHAR: 21>}], 'enable_dynamic_field': True}
Success!


In [12]:
import logging
from multiprocessing import Pool

# Set up logging
logging.basicConfig(filename='app.log', filemode='w', format='%(name)s - %(levelname)s - %(message)s')


def process_file(args):
    i, path = args
    try:
        with open(path, "rb") as f:
            data = pickle.load(f)
            filepath = data.get("filepath")
            title = data.get("title")
            album = data.get("album")
            artist = data.get("artist")
            predictions = data.get("predictions_87")
            top_5_genres = data.get("top_5_genres")

        if isinstance(predictions, np.ndarray):
            return (i, filepath, title, album, artist, predictions, top_5_genres)
    except Exception as e:
        logging.error(f"Error processing file {path}: {str(e)}")
        return None

In [13]:
# Use a multiprocessing pool to process the files in parallel
with Pool() as p:
    results = p.map(process_file, enumerate(pkl_files))

# Filter out None results
results = [r for r in results if r is not None]

# Batch size
batch_size = 18

fails = []

# Insert the data into the collection
for i in range(0, len(results), batch_size):
    id_batch, path_batch, title_batch, album_batch, artist_batch, predictions_batch, top_5_genres_batch = zip(*results[i : i + batch_size])
    documents = [
        {
            "id": id,
            "path": path,
            "title": title,
            "album": album,
            "artist": artist,
            "predictions": predictions.tolist(),
            "top_5_genres": top_5_genres
        }
        for id, path, title, album, artist, predictions, top_5_genres in zip(
            id_batch, path_batch, title_batch, album_batch, artist_batch, predictions_batch, top_5_genres_batch,
        )
    ]
    try:
        collection.insert(documents)
        print(f"Inserted {i + batch_size} records into collection {collection.name}.")
    except Exception as e:
        fails.append((i + batch_size, documents))
        print(f"Error inserting batch {i + batch_size} into collection {collection.name}. Error: {str(e)}")

# 2m20s


Inserted 18 records into collection predictions_87.
Inserted 36 records into collection predictions_87.
Inserted 54 records into collection predictions_87.
Inserted 72 records into collection predictions_87.
Inserted 90 records into collection predictions_87.
Inserted 108 records into collection predictions_87.
Inserted 126 records into collection predictions_87.
Inserted 144 records into collection predictions_87.
Inserted 162 records into collection predictions_87.
Inserted 180 records into collection predictions_87.


In [14]:
print(f"Failed batches: {fails}")

Failed batches: []


In [15]:
index = {
    "index_type": "IVF_FLAT",
    "metric_type": "L2",
    "params": {"nlist": 128},
}
collection.create_index(field_name="predictions", index_params=index)

Status(code=0, message=)

In [16]:
collection.load()

In [17]:
random_song = pick_random_mp3(DATASET).with_suffix(".pkl")
print(f"Random mp3: {random_song}")
print(get_top_5_genres(random_song, "mtg_jamendo_genre.json"))

with open(random_song, "rb") as f:
    file_info = pickle.load(f)
    query_embed = file_info.get("predictions_87")
    
search_params = {"metric_type": "L2", "params": {"nprobe": 10}}
result = collection.search([query_embed], "predictions", search_params, limit=10,offset=1, output_fields=["path"])

for element in result[0]:
    print(element)

Random mp3: Music_Dataset/Java/01-Sex_Accordeon_Et_Alcool.pkl
['hiphop', 'rap', 'electronic', 'experimental', 'reggae']
id: 72, distance: 0.004329318180680275, entity: {'path': 'Music_Dataset/Soul_Square/02_-_Living_the_Dream_(feat._Justis).mp3'}
id: 27, distance: 0.0067001529969275, entity: {'path': 'Music_Dataset/CYNE/02_The_River.mp3'}
id: 133, distance: 0.007741497829556465, entity: {'path': 'Music_Dataset/Jurassic_5/005_Quality_Control.mp3'}
id: 48, distance: 0.010727956891059875, entity: {'path': 'Music_Dataset/Dr_Dre/04. Still D.R.E.mp3'}
id: 62, distance: 0.011915805749595165, entity: {'path': 'Music_Dataset/The_Streets/The Streets - 11 - The Irony Of It All.mp3'}
id: 49, distance: 0.013396674767136574, entity: {'path': 'Music_Dataset/Dr_Dre/02. The Watcher.mp3'}
id: 28, distance: 0.014681385830044746, entity: {'path': 'Music_Dataset/CYNE/11_Stomping_Ground.mp3'}
id: 74, distance: 0.023833826184272766, entity: {'path': 'Music_Dataset/Soul_Square/03_-_Take_It_Back_(feat._Blezz).

In [18]:
########## embedding_512 ##########

In [19]:
print(utility.list_collections())

['embeddings_512', 'predictions_87']


In [20]:
# Check if the collection exists
check_collection_name = 'embeddings_512'

check_collection = utility.has_collection(check_collection_name )
if check_collection:
    drop_result = utility.drop_collection(check_collection_name )

print(utility.list_collections())

['predictions_87']


In [21]:
# create a collection for embeddings_512
collection_name = "embeddings_512"
dimension = 512

# Define the schema
schema = CollectionSchema(
    fields=[
        FieldSchema(name="id",dtype=DataType.INT64,is_primary=True,auto_id=False,max_length=100),
        FieldSchema(name="path", dtype=DataType.VARCHAR, max_length=280),
        FieldSchema(name="title", dtype=DataType.VARCHAR, max_length=220),
        FieldSchema(name="album", dtype=DataType.VARCHAR, max_length=240),
        FieldSchema(name="artist", dtype=DataType.VARCHAR, max_length=120),
        FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=dimension),
        FieldSchema(name="top_5_genres",dtype=DataType.ARRAY,element_type=DataType.VARCHAR,max_capacity=5,max_length=100),
    ],
    enable_dynamic_field=True,  # Optional, defaults to 'False'.
)

print(f"Creating example collection: {collection_name}")
collection = Collection(collection_name, schema)
print(f"Schema: {schema}")
print("Success!")

Creating example collection: embeddings_512
Schema: {'auto_id': False, 'description': '', 'fields': [{'name': 'id', 'description': '', 'type': <DataType.INT64: 5>, 'is_primary': True, 'auto_id': False}, {'name': 'path', 'description': '', 'type': <DataType.VARCHAR: 21>, 'params': {'max_length': 280}}, {'name': 'title', 'description': '', 'type': <DataType.VARCHAR: 21>, 'params': {'max_length': 220}}, {'name': 'album', 'description': '', 'type': <DataType.VARCHAR: 21>, 'params': {'max_length': 240}}, {'name': 'artist', 'description': '', 'type': <DataType.VARCHAR: 21>, 'params': {'max_length': 120}}, {'name': 'embedding', 'description': '', 'type': <DataType.FLOAT_VECTOR: 101>, 'params': {'dim': 512}}, {'name': 'top_5_genres', 'description': '', 'type': <DataType.ARRAY: 22>, 'params': {'max_length': 100, 'max_capacity': 5}, 'element_type': <DataType.VARCHAR: 21>}], 'enable_dynamic_field': True}
Success!


In [22]:
def process_file(args):
    i, path = args
    try:
        with open(path, "rb") as f:
            data = pickle.load(f)
            filepath = data.get("filepath")
            title = data.get("title")
            album = data.get("album")
            artist = data.get("artist")
            predictions = data.get("embedding_512") #
            top_5_genres = data.get("top_5_genres")

        if isinstance(predictions, np.ndarray):
            return (i, filepath, title, album, artist, predictions, top_5_genres)
    except Exception as e:
        logging.error(f"Error processing file {path}: {str(e)}")
        return None

In [23]:
# Use a multiprocessing pool to process the files in parallel
with Pool() as p:
    results = p.map(process_file, enumerate(pkl_files))

# Filter out None results
results = [r for r in results if r is not None]

# Batch size
batch_size = 18

fails = []

# Insert the data into the collection
for i in range(0, len(results), batch_size):
    id_batch, path_batch, title_batch, album_batch, artist_batch, embeddings_batch, top_5_genres_batch = zip(*results[i : i + batch_size])
    documents = [
        {
            "id": id,
            "path": path,
            "title": title,
            "album": album,
            "artist": artist,
            "embedding": embeddings.tolist(),
            "top_5_genres": top_5_genres
        }
        for id, path, title, album, artist, embeddings, top_5_genres in zip(
            id_batch, path_batch, title_batch, album_batch, artist_batch, embeddings_batch, top_5_genres_batch,
        )
    ]
    try:
        collection.insert(documents)
        print(f"Inserted {i + batch_size} records into collection {collection.name}.")
    except Exception as e:
        fails.append((i + batch_size, documents))
        print(f"Error inserting batch {i + batch_size} into collection {collection.name}. Error: {str(e)}")

# 2m30s


Inserted 18 records into collection embeddings_512.
Inserted 36 records into collection embeddings_512.
Inserted 54 records into collection embeddings_512.
Inserted 72 records into collection embeddings_512.
Inserted 90 records into collection embeddings_512.
Inserted 108 records into collection embeddings_512.
Inserted 126 records into collection embeddings_512.
Inserted 144 records into collection embeddings_512.
Inserted 162 records into collection embeddings_512.
Inserted 180 records into collection embeddings_512.


In [24]:
print(f"Failed batches: {fails}")

Failed batches: []


In [25]:
index = {
    "index_type": "IVF_FLAT",
    "metric_type": "L2",
    "params": {"nlist": 128},
}
collection.create_index(field_name="embedding", index_params=index)

Status(code=0, message=)

In [26]:
collection.load()

In [27]:
random_song = pick_random_mp3(DATASET).with_suffix(".pkl")
print(f"Random mp3: {random_song}")
print(get_top_5_genres(random_song, "mtg_jamendo_genre.json"))

with open(random_song, "rb") as f:
    file_info = pickle.load(f)
    query_embed = file_info.get("embedding_512")
    
search_params = {"metric_type": "L2", "params": {"nprobe": 10}}
result = collection.search([query_embed], "embedding", search_params, limit=10, offset=1, output_fields=["path"])
print('\n')
for element in result[0]:
    print(element)

Random mp3: Music_Dataset/Damian_Marley/11_-_Me_Name_Jr._Gong.pkl
['reggae', 'dub', 'hiphop', 'world', 'electronic']


id: 69, distance: 24.275529861450195, entity: {'path': 'Music_Dataset/Damian_Marley/06_-_Kingston_12.mp3'}
id: 71, distance: 24.611507415771484, entity: {'path': 'Music_Dataset/Damian_Marley/03_-_10,000_Chariots.mp3'}
id: 138, distance: 29.747173309326172, entity: {'path': 'Music_Dataset/Groundation/01_fight_all_you_can.mp3'}
id: 37, distance: 29.939144134521484, entity: {'path': 'Music_Dataset/Tanya_Stephens/Tanya Stephens - These_Streets.mp3'}
id: 93, distance: 30.485294342041016, entity: {'path': 'Music_Dataset/Dub_Incorporation/04 - Rudeboy.mp3'}
id: 22, distance: 30.70628547668457, entity: {'path': 'Music_Dataset/Nneka/nneka - africans.mp3'}
id: 113, distance: 30.7811279296875, entity: {'path': 'Music_Dataset/Winston_McAnuff/06_Sort_Me_Out.mp3'}
id: 34, distance: 31.458003997802734, entity: {'path': 'Music_Dataset/Ben_Harper/01_-_Ben_harper_-_With_my_two_hands.mp3