In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
import chromadb
import pandas as pd
from chromadb.config import Settings
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
csv_file_path = 'data/synthetic_clusters_colored.csv'
data = pd.read_csv(csv_file_path)

In [None]:
query_color = "purple"
query_point = data[data['cluster'] == query_color][['x', 'y']].iloc[0]

In [None]:
plt.figure(figsize=(8, 6))
for color in data['cluster'].unique().tolist():
    # Select data points belonging to each cluster/color
    cluster_data = data[data['cluster'] == color]
    plt.scatter(cluster_data['x'], cluster_data['y'], color=color, label=f'{color.capitalize()} Cluster')

# Highlight the query point
plt.scatter(query_point['x'], query_point['y'], color='darkorange', edgecolor='darkorange', label='Query', zorder=5)

plt.title('Visualization of Synthetic Clusters')
plt.xlabel('x')
plt.ylabel('y')
plt.legend()
plt.show()

In [None]:
client = chromadb.Client()

collections = client.list_collections()
if collections:
    client.delete_collection(collections[0].name)

In [None]:
index_configuraiton = {
    "lmi:epochs": "[2000]",
    "lmi:model_types": "['MLP']",
    "lmi:lrs": "[0.01]",
    "lmi:n_categories": "[2, 2]",
    "lmi:kmeans": "{'verbose': False, 'seed': 2023, 'nredo': 10}",
}

collection_name = "synthetic_collection"
collection = client.create_collection(
    name=collection_name,
    metadata=index_configuraiton
)

In [None]:
# Assuming 'embeddings' are the first three columns
# 'status' is the fourth column, 'document' is the fifth column, and 'id' is the sixth column
collection.add(
    embeddings=data[['x', 'y']].values.tolist(),
    metadatas=[{"cluster": cluster} for cluster in data['cluster']],
    ids=data['id'].values.tolist(),
)

bucket_assignment = collection.build_index()

### Visualize buckets

In [None]:
# Map the ids in data to buckets using bucket_labels_new_format
data['bucket'] = data['id'].map(lambda x: list(bucket_assignment.get(x, [])))

# Aggregate data: Count the number of points from each cluster in each mapped bucket
# First, we need to convert the list buckets to a string to be able to group them
data['bucket_str'] = data['bucket'].apply(lambda x: str(x))
bucket_cluster_counts = data.groupby(['bucket_str', 'cluster']).size().unstack(fill_value=0)

# Prepare the data for plotting
plot_data = bucket_cluster_counts.reset_index().melt(id_vars='bucket_str', var_name='cluster', value_name='count')

# Create a palette that uses cluster names as colors
palette = {cluster: cluster for cluster in data['cluster'].unique()}

# Create the bar plot
plt.figure(figsize=(10, 6))
ax = sns.barplot(data=plot_data, x='bucket_str', y='count', hue='cluster', palette=palette)
plt.title('Number of Items from Each Cluster in Each Bucket')
plt.xlabel('Bucket')
plt.ylabel('Count')

# Annotate each bar with the count of elements
for p in ax.patches:
    bar_height = int(p.get_height())
    if bar_height > 0:  # Only annotate bars with a height greater than zero
        ax.annotate(f'{bar_height}', (p.get_x() + p.get_width() / 2., bar_height),
                    ha='center', va='center', fontsize=10, color='black', xytext=(0, 5),
                    textcoords='offset points')

plt.show()

In [None]:
# Output from ChromaDB
expected_output = [['id70', 'id74', 'id10', 'id86', 'id29', 'id76', 'id92', 'id3', 'id23', 'id47']]

In [None]:
filter_color = "red"
results = collection.query(
    query_embeddings=list(query_point),
    include=["metadatas", 'embeddings', 'distances'],
    where={"cluster": filter_color},
    n_results=10,
    n_buckets=1,
    constraint_weight=0.0,
)

In [None]:
assert results['ids'] == expected_output
print(results['ids'])
print(results['distances'])
print(results['metadatas'])
print(results['bucket_order'])

In [None]:
bucket_order = results['bucket_order'][0]  # Adjust according to where bucket_order comes from
order_labels = [str(list(bucket)) for bucket in bucket_order]  # Convert to string with brackets

# Extract the counts for the specified color
cluster_specific_counts = bucket_cluster_counts[filter_color]

# Reindex the bucket counts according to the specified order
ordered_counts = cluster_specific_counts.reindex(order_labels, fill_value=0).reset_index()
ordered_counts.columns = ['bucket', 'count']

# Create the bar plot
plt.figure(figsize=(10, 6))
sns.barplot(data=ordered_counts, x='bucket', y='count', color=filter_color)
plt.title(f'Number of Items from {filter_color} Cluster in Ordered Buckets')
plt.xlabel('Bucket')
plt.ylabel('Count')
plt.show()

In [None]:
results = collection.query(
    query_embeddings=list(query_point),
    include=["metadatas", 'embeddings', 'distances'],
    where={"cluster": filter_color},
    n_results=10,
    n_buckets=1,
    use_threshold=False, # TODO: this causes problems, hotfix applied for now
    constraint_weight=0.51,
)

In [None]:
assert results['ids'] == expected_output
print(results['ids'])
print(results['distances'])
print(results['metadatas'])
print(results['bucket_order'])

In [None]:
bucket_order = results['bucket_order'][0]  # Adjust according to where bucket_order comes from
order_labels = [str(list(bucket)) for bucket in bucket_order]  # Convert to string with brackets

# Extract the counts for the specified color
cluster_specific_counts = bucket_cluster_counts[filter_color]

# Reindex the bucket counts according to the specified order
ordered_counts = cluster_specific_counts.reindex(order_labels, fill_value=0).reset_index()
ordered_counts.columns = ['bucket', 'count']

# Create the bar plot
plt.figure(figsize=(10, 6))
sns.barplot(data=ordered_counts, x='bucket', y='count', color=filter_color)
plt.title(f'Number of Items from {filter_color} Cluster in Ordered Buckets')
plt.xlabel('Bucket')
plt.ylabel('Count')
plt.show()