In [20]:
import pickle
import json
import os
import logging
from dotenv import load_dotenv
load_dotenv()

import numpy as np
from pathlib import Path
from multiprocessing import Pool

from pymilvus import Collection, DataType, FieldSchema, CollectionSchema, MilvusClient, utility, connections

In [21]:
URI = os.getenv("MILVUS_URI")
TOKEN = os.getenv("MILVUS_TOKEN")

In [22]:
milvus_client = MilvusClient(
    uri=URI, 
    token=TOKEN)

In [None]:
# Create schema
schema = MilvusClient.create_schema(
    auto_id=False,
    enable_dynamic_field=True,
)

collection_name = "predictions_87"
dimension = 87

# Add fields to schema
schema.add_field(field_name="filename", datatype=DataType.INT64, is_primary=True)
schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=dimension)

# Prepare index parameters
index_params = milvus_client.prepare_index_params()

# Add indexes
index_params.add_index(
    field_name="filename",
    index_type="STL_SORT"
)

index_params.add_index(
    field_name="vector", 
    index_type="AUTOINDEX",
    metric_type="IP"
)

# Create a collection
milvus_client.create_collection(
    collection_name=collection_name,
    schema=schema,
    index_params=index_params
)

In [None]:
milvus_client.list_collections()

In [17]:
# Set up logging
logging.basicConfig(filename='app.log', filemode='w', format='%(name)s - %(levelname)s - %(message)s')

def process_file(args):
    i, path = args
    try:
        with open(path, "rb") as f:
            data = pickle.load(f)
            filename = data.get("filename")
            vector = data.get("embeddings_87")


        if isinstance(vector, np.ndarray):
            return (i, filename, vector)
    except Exception as e:
        logging.error(f"Error processing file {path}: {str(e)}")
        return None

In [None]:
DATASET = Path('MUSIC_DATASET')
pkl_files = list(DATASET.glob('**/*.pkl'))

In [None]:
# Use a multiprocessing pool to process the files in parallel
with Pool() as p:
    results = p.map(process_file, enumerate(pkl_files))

# Filter out None results
results = [r for r in results if r is not None]

# Batch size
batch_size = 18

fails = []

# Insert the data into the collection
for i in range(0, len(results), batch_size):
    id_batch, vector_batch = zip(*[(r[1], r[2]) for r in results[i : i + batch_size]])
    try:
        milvus_client.insert(
            collection_name=collection_name, 
            data=vector_batch, ids=id_batch)
        
        print(f"Inserted {i + batch_size} records into collection {collection_name}.")

    except Exception as e:
        fails.append((i + batch_size, id_batch, vector_batch))
        print(f"Error inserting batch {i + batch_size} into collection example_collection. Error: {str(e)}")

