In [None]:
from sentence_transformers import SentenceTransformer
from sklearn.cluster import AgglomerativeClustering
import numpy as np
import pandas as pd
import umap
import matplotlib.pyplot as plt

In [None]:
embedder = SentenceTransformer('all-MiniLM-L6-v2')

In [None]:
corpus = pd.read_csv("queries-prod.csv")
cleaned_corpus = corpus.groupby(['query'], as_index = False).agg({'weight': 'sum', 'count': 'mean'})
print(cleaned_corpus)
corpus_embeddings = embedder.encode(cleaned_corpus["query"])

In [None]:
corpus_embeddings = corpus_embeddings /  np.linalg.norm(corpus_embeddings, axis=1, keepdims=True)

In [None]:
clustering_model = AgglomerativeClustering(n_clusters=None, distance_threshold=1.5) #, affinity='cosine', linkage='average', distance_threshold=0.4)
clustering_model.fit(corpus_embeddings)
cluster_assignment = clustering_model.labels_

In [None]:
clustered_sentences = {}
for sentence_id, cluster_id in enumerate(cluster_assignment):
    if cluster_id not in clustered_sentences:
        clustered_sentences[cluster_id] = []
    
    clustered_sentences[cluster_id].append(cleaned_corpus.iloc[sentence_id].to_list())

print("Number of clusters ", len(clustered_sentences))  
print("==============================================")
for i, cluster in clustered_sentences.items():
    total_sum = 0
    for item in cluster:
        total_sum += item[2] + item[1]
    print(f"""Cluster {i+1} size {total_sum}""")
    print(cluster)
    print("")
    print("==============================================")

In [None]:
umap_data = umap.UMAP(n_neighbors=15, n_components=2, min_dist=0.0, metric='cosine').fit_transform(corpus_embeddings)
result = pd.DataFrame(umap_data, columns=['x', 'y'])
result['labels'] = clustering_model.labels_

# Visualize clusters
fig, ax = plt.subplots(figsize=(20, 10))
outliers = result.loc[result.labels == -1, :]
clustered = result.loc[result.labels != -1, :]
plt.scatter(outliers.x, outliers.y, color='#BDBDBD', s=1)
plt.scatter(clustered.x, clustered.y, c=clustered.labels, s=1, cmap='hsv_r')
plt.colorbar()