In [None]:
import pickle
import random
from pathlib import Path
import json
import music_tag
import numpy as np
import os
from utils.music_utils import *

from pymilvus import connections, utility
from pymilvus import Collection, DataType, FieldSchema, CollectionSchema

In [None]:
DATASET = Path('MegaSet')
# count the mp3 and pkl files in DATASET

count_mp3 = 0   
count_pkl = 0
for root, dirs, files in os.walk(DATASET):
    for file in files:
        if file.endswith('.mp3'):
            count_mp3 += 1
        if file.endswith('.pkl'):
            count_pkl += 1
print(f"founds {count_mp3} mp3 files and {count_pkl} pkl files")

In [None]:
pkl_files = list(DATASET.glob('**/*.pkl'))

types87 = []
types512 = []

for pkl_path in pkl_files:
    with pkl_path.open('rb') as f:
        try:
            content = pickle.load(f)
            embed = content.get('embedding_512')
            if embed is not None:
                types87.append(type(embed))
                types512.append(type(embed))
        except Exception as e:
            print(f"Error reading file {pkl_path}: {e}")

unique_types87 = set(types87)
unique_types512 = set(types512)

print(f"Unique types 87: {unique_types87}")
print(f"Unique types 512: {unique_types512}")

In [None]:
print_info(pick_random_mp3(DATASET))

In [None]:
########## Milvus Client ##########

In [None]:
from dotenv import load_dotenv
import os
load_dotenv()

URI = os.getenv("MILVUS_URI")
TOKEN = os.getenv("MILVUS_TOKEN")

In [None]:
# connect to milvus
connections.connect("default",
                    uri=URI,
                    token=TOKEN)
print(f"Connecting to DB: {URI}")

In [None]:
########## predictions_87 ##########

In [None]:
print(utility.list_collections())

In [None]:
# Check if the collection exists
check_collection_name = 'predictions_87'

check_collection = utility.has_collection(check_collection_name )
if check_collection:
    drop_result = utility.drop_collection(check_collection_name )

print(utility.list_collections())

In [None]:
# create a collection for prediction_87
collection_name = "predictions_87"
dimension = 87

# Define the schema
schema = CollectionSchema(
    fields=[
        FieldSchema(name="id",dtype=DataType.INT64,is_primary=True,auto_id=False,max_length=100),
        FieldSchema(name="path", dtype=DataType.VARCHAR, max_length=280),
        FieldSchema(name="title", dtype=DataType.VARCHAR, max_length=220),
        FieldSchema(name="artist", dtype=DataType.VARCHAR, max_length=120),
        FieldSchema(name="album", dtype=DataType.VARCHAR, max_length=240),
        FieldSchema(name="predictions", dtype=DataType.FLOAT_VECTOR, dim=dimension),
        FieldSchema(name="top_5_genres", dtype=DataType.ARRAY, element_type=DataType.VARCHAR, max_capacity=5, max_length=100),
        ],
    enable_dynamic_field=True,  # Optional, defaults to 'False'.
)

print(f"Creating example collection: {collection_name}")
collection = Collection(collection_name, schema)
print(f"Schema: {schema}")
print("Success!")

In [None]:
import logging
from multiprocessing import Pool

# Set up logging
logging.basicConfig(filename='app.log', filemode='w', format='%(name)s - %(levelname)s - %(message)s')


def process_file(args):
    i, path = args
    try:
        with open(path, "rb") as f:
            data = pickle.load(f)
            filepath = data.get("filepath")
            title = data.get("title")
            album = data.get("album")
            artist = data.get("artist")
            predictions = data.get("predictions_87")
            top_5_genres = data.get("top_5_genres")

        if isinstance(predictions, np.ndarray):
            return (i, filepath, title, album, artist, predictions, top_5_genres)
    except Exception as e:
        logging.error(f"Error processing file {path}: {str(e)}")
        return None

In [None]:
# Use a multiprocessing pool to process the files in parallel
with Pool() as p:
    results = p.map(process_file, enumerate(pkl_files))

# Filter out None results
results = [r for r in results if r is not None]

# Batch size
batch_size = 18

fails = []

# Insert the data into the collection
for i in range(0, len(results), batch_size):
    id_batch, path_batch, title_batch, album_batch, artist_batch, predictions_batch, top_5_genres_batch = zip(*results[i : i + batch_size])
    documents = [
        {
            "id": id,
            "path": path,
            "title": title,
            "album": album,
            "artist": artist,
            "predictions": predictions.tolist(),
            "top_5_genres": top_5_genres
        }
        for id, path, title, album, artist, predictions, top_5_genres in zip(
            id_batch, path_batch, title_batch, album_batch, artist_batch, predictions_batch, top_5_genres_batch,
        )
    ]
    try:
        collection.insert(documents)
        print(f"Inserted {i + batch_size} records into collection {collection.name}.")
    except Exception as e:
        fails.append((i + batch_size, documents))
        print(f"Error inserting batch {i + batch_size} into collection {collection.name}. Error: {str(e)}")

# 3m3s

In [None]:
print(f"Failed batches: {fails}")

In [None]:
index = {
    "index_type": "IVF_FLAT",
    "metric_type": "L2",
    "params": {"nlist": 128},
}
collection.create_index(field_name="predictions", index_params=index)

In [None]:
collection.load()

In [None]:
random_song = pick_random_mp3(DATASET).with_suffix(".pkl")
print(f"Random mp3: {random_song}")
print(get_top_5_genres(random_song, "utils/mtg_jamendo_genre.json"))

with open(random_song, "rb") as f:
    file_info = pickle.load(f)
    query_embed = file_info.get("predictions_87")
    
search_params = {"metric_type": "L2", "params": {"nprobe": 10}}
result = collection.search([query_embed], "predictions", search_params, limit=10,offset=1, output_fields=["path"])
print('\n')
for element in result[0]:
    print(element)

In [None]:
########## embedding_512 ##########

In [None]:
print(utility.list_collections())

In [None]:
# Check if the collection exists
check_collection_name = 'embeddings_512'

check_collection = utility.has_collection(check_collection_name )
if check_collection:
    drop_result = utility.drop_collection(check_collection_name )

print(utility.list_collections())

In [None]:
# create a collection for embeddings_512
collection_name = "embeddings_512"
dimension = 512

# Define the schema
schema = CollectionSchema(
    fields=[
        FieldSchema(name="id",dtype=DataType.INT64,is_primary=True,auto_id=False,max_length=100),
        FieldSchema(name="path", dtype=DataType.VARCHAR, max_length=280),
        FieldSchema(name="title", dtype=DataType.VARCHAR, max_length=220),
        FieldSchema(name="album", dtype=DataType.VARCHAR, max_length=240),
        FieldSchema(name="artist", dtype=DataType.VARCHAR, max_length=120),
        FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=dimension),
        FieldSchema(name="top_5_genres",dtype=DataType.ARRAY,element_type=DataType.VARCHAR,max_capacity=5,max_length=100),
    ],
    enable_dynamic_field=True,  # Optional, defaults to 'False'.
)

print(f"Creating example collection: {collection_name}")
collection = Collection(collection_name, schema)
print(f"Schema: {schema}")
print("Success!")

In [None]:
def process_file(args):
    i, path = args
    try:
        with open(path, "rb") as f:
            data = pickle.load(f)
            filepath = data.get("filepath")
            title = data.get("title")
            album = data.get("album")
            artist = data.get("artist")
            predictions = data.get("embedding_512") #
            top_5_genres = data.get("top_5_genres")

        if isinstance(predictions, np.ndarray):
            return (i, filepath, title, album, artist, predictions, top_5_genres)
    except Exception as e:
        logging.error(f"Error processing file {path}: {str(e)}")
        return None

In [None]:
# Use a multiprocessing pool to process the files in parallel
with Pool() as p:
    results = p.map(process_file, enumerate(pkl_files))

# Filter out None results
results = [r for r in results if r is not None]

# Batch size
batch_size = 18

fails = []

# Insert the data into the collection
for i in range(0, len(results), batch_size):
    id_batch, path_batch, title_batch, album_batch, artist_batch, embeddings_batch, top_5_genres_batch = zip(*results[i : i + batch_size])
    documents = [
        {
            "id": id,
            "path": path,
            "title": title,
            "album": album,
            "artist": artist,
            "embedding": embeddings.tolist(),
            "top_5_genres": top_5_genres
        }
        for id, path, title, album, artist, embeddings, top_5_genres in zip(
            id_batch, path_batch, title_batch, album_batch, artist_batch, embeddings_batch, top_5_genres_batch,
        )
    ]
    try:
        collection.insert(documents)
        print(f"Inserted {i + batch_size} records into collection {collection.name}.")
    except Exception as e:
        fails.append((i + batch_size, documents))
        print(f"Error inserting batch {i + batch_size} into collection {collection.name}. Error: {str(e)}")

# 3m 21s

In [None]:
print(f"Failed batches: {fails}")

In [None]:
index = {
    "index_type": "IVF_FLAT",
    "metric_type": "L2",
    "params": {"nlist": 128},
}
collection.create_index(field_name="embedding", index_params=index)

In [None]:
collection.load()

In [None]:
random_song = pick_random_mp3(DATASET).with_suffix(".pkl")
print(f"Random mp3: {random_song}")
print(get_top_5_genres(random_song, "utils/mtg_jamendo_genre.json"))

with open(random_song, "rb") as f:
    file_info = pickle.load(f)
    query_embed = file_info.get("embedding_512")
    
search_params = {"metric_type": "L2", "params": {"nprobe": 10}}
result = collection.search([query_embed], "embedding", search_params, limit=10, offset=1, output_fields=["path"])
print('\n')
for element in result[0]:
    print(element)