In [None]:
import json
import numpy as np
from tqdm.notebook import tqdm
import chromadb
import statistics

In [None]:
# File paths
dir_path = "data/hnm/"
vectors_path = dir_path + "vectors.npy"
payloads_path = dir_path + "payloads.jsonl"
tests_path = dir_path + "tests.jsonl"

# Load vectors as numpy array
vectors = np.load(vectors_path)

# Load payloads.jsonl as python list
with open(payloads_path, 'r') as file:
    payloads = [json.loads(line) for line in file]

# Load tests.jsonl as python list
with open(tests_path, 'r') as file:
    tests = [json.loads(line) for line in file]

(vectors.shape, len(payloads), len(tests))


In [None]:
# Check whether all filter conditions have same format
def check_conditions_format(queries):
    # Define the format for comparison
    required_format = {'and': [{'some_str': {'match': {'value': 'some_str'}}}]}

    # Check each query's conditions format
    for query in queries:
        conditions = query.get('conditions', {})
        # Check the outermost keys
        if list(conditions.keys()) != ['and']:
            return False
        # Check the inner structure
        for condition in conditions['and']:
            if len(condition) > 1:
                return False
            if not isinstance(condition, dict) or len(condition) != 1:
                return False
            for key, value in condition.items():
                if not isinstance(value, dict) or list(value.keys()) != ['match']:
                    return False
                if list(value['match'].keys()) != ['value']:
                    return False
    return True

check_conditions_format(tests)

In [None]:
def preprocess_payloads(payloads):
    """
    Preprocess payloads replacing None values with the string 'None'.
    :param payloads: A list of payload entries
    :return: The preprocessed list of payloads
    """
    for payload in payloads:
        for key, value in payload.items():
            if value is None:
                payload[key] = 'None'

preprocess_payloads(payloads)

In [None]:
# Configuration from SISAP 2023 Indexing Challenge - LMI except n_categories
index_configuraiton = {
    "lmi:epochs": "[200]",
    "lmi:model_types": "['MLP-4']",
    "lmi:lrs": "[0.01]",
    "lmi:n_categories": "[20]",
    "lmi:kmeans": "{'verbose': False, 'seed': 2023, 'min_points_per_centroid': 1000}",
}

In [None]:
client = chromadb.Client()

collection_name = "synthetic_collection"
collection = client.create_collection(
    name=collection_name,
    metadata=index_configuraiton
)

In [None]:
batch_size = 40000
dataset_size = vectors.shape[0]
for i in tqdm(range(0, dataset_size, batch_size), desc="Adding documents"):
    collection.add(
        embeddings=vectors[i: i + batch_size].tolist(),
        metadatas=payloads[i: i + batch_size],
        ids=[
            str(i) for i in range(i, min(i + batch_size, dataset_size))
        ]
    )

In [None]:
%%time
bucket_assignment = collection.build_index()

In [None]:
def convert_condition_to_simple_dict(condition):
    key = list(condition['and'][0].keys())[0]  # Extract key from the first item
    value = condition['and'][0][key]['match']['value']  # Extract value from the nested structure
    return {key: value}

def calculate_precision(relevant_ids, retrieved_ids):
    retrieved_ids = set(map(int, retrieved_ids))  # Convert retrieved IDs to integers and remove duplicates
    relevant_ids = set(relevant_ids)  # Remove duplicates
    true_positives = len(relevant_ids & retrieved_ids)
    return true_positives / len(retrieved_ids) if retrieved_ids else 0

In [None]:
%%time
queries_evaluated = []

for query_id in [ 32]:
    print("query_id", query_id)
    test_query_object = tests[query_id]
    results = collection.query(
        query_embeddings=test_query_object["query"],
        include=["metadatas",  'distances'],
        where=convert_condition_to_simple_dict(test_query_object["conditions"]),
        n_results=25,
        n_buckets=2,
        constraint_weight=0.0,
    )
    queries_evaluated.append(calculate_precision(tests[query_id]["closest_ids"], results['ids'][0]))

In [None]:
print(queries_evaluated)

In [None]:
print(queries_evaluated)
indexes = [i for i, val in enumerate(queries_evaluated) if val == 0.0]
print("indexes with zero precision: ", indexes)

In [None]:
print("Average precision: ", sum(queries_evaluated) / len(queries_evaluated))
print("Median precision: ", statistics.median(queries_evaluated))