In [None]:
import os
import random

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import pyarrow.parquet as pq
from sentence_transformers import SentenceTransformer

In [None]:
# Parameters
MODEL_NAME: str = "all-MiniLM-L6-v2"
POOLING_STRATEGY: str = "max"
OWNER: str = "sentence-transformers"

In [None]:
CLEAN_DATA_PATH = os.path.join("..", "data", "healthhub_small_clean")
OUTPUT_CM_PATH = os.path.join(
    "..",
    "artifacts",
    "outputs",
    f"{MODEL_NAME}_{POOLING_STRATEGY}_confusion_matrix.png",
)

CLEANED_EMBEDDING_LIST_PATH = os.path.join(
    CLEAN_DATA_PATH,
    f"healthhub_{MODEL_NAME}_{POOLING_STRATEGY}_embeddings_small_clean.parquet",
)

OUTPUT_SIM_PATH = os.path.join(
    "..",
    "artifacts",
    "outputs",
    f"{MODEL_NAME}_{POOLING_STRATEGY}_similarity_score.csv",
)

## Load Embeddings Dataframe

In [None]:
embeddings_table = pq.read_table(CLEANED_EMBEDDING_LIST_PATH)
embeddings_df = embeddings_table.to_pandas()
embeddings_df.head()

## Load Ground Truth Dataframe

In [None]:
ground_df = pd.read_excel(
    os.path.join(
        "..", "data", "Synapxe Content Prioritisation - Live Healthy_020724.xlsx"
    ),
    sheet_name="All Live Healthy",
    index_col=False,
)

ground_df.head()

In [None]:
merge_df = pd.merge(
    embeddings_df, ground_df, how="left", left_on="doc_source", right_on="URL"
)
merge_df = merge_df[[*embeddings_df.columns, "Combine Group ID", "Page Title"]]
merge_df = merge_df[merge_df["Combine Group ID"].notna()]
merge_df = merge_df.sort_values(by="Combine Group ID").reset_index(drop=True)
merge_df["Combine Group ID"] = merge_df["Combine Group ID"].astype(int)

In [None]:
model = SentenceTransformer(f'{OWNER}/{MODEL_NAME}')

In [None]:
embedding_col = f"{MODEL_NAME}_{POOLING_STRATEGY}_embeddings"

# Get the embeddings to compute the similarities
embeddings_series = merge_df[embedding_col]
embeddings = np.vstack(embeddings_series)

print(embeddings.shape)  # (num_documents, embedding_dim)

In [None]:
# Calculate the embedding similarities
similarities = model.similarity(embeddings, embeddings)

print(similarities.shape)  # (num_documents, num_documents)

In [None]:
# Function to darken a hex color
def darken_hex_color(hex_color, factor=0.7):
    # Ensure factor is between 0 and 1
    factor = max(0, min(1, factor))

    # Convert hex color to RGB
    r = int(hex_color[1:3], 16)
    g = int(hex_color[3:5], 16)
    b = int(hex_color[5:7], 16)

    # Darken the color
    r = int(r * factor)
    g = int(g * factor)
    b = int(b * factor)

    # Convert RGB back to hex
    darkened_color = f"#{r:02x}{g:02x}{b:02x}".upper()

    return darkened_color

In [None]:
article_titles = merge_df.loc[:, "doc_title"].tolist()

start = 0
end = 20

cutoff_similarities = similarities[start:end, start:end]
cutoff_article_titles = article_titles[start:end]

# Generate random colours
hexadecimal_alphabets = "0123456789ABCDEF"
ground_truth_cluster_ids = merge_df.iloc[start:end]["Combine Group ID"].unique()
colours = {
    id: darken_hex_color(
        "#" + "".join([random.choice(hexadecimal_alphabets) for _ in range(6)])
    )
    for id in ground_truth_cluster_ids
}


plt.subplots(figsize=(20, 18))
ax = sns.heatmap(
    cutoff_similarities,
    xticklabels=cutoff_article_titles,
    yticklabels=cutoff_article_titles,
    annot=True,
    fmt=".2g",
)

for x_tick_label, y_tick_label in zip(
    ax.axes.get_xticklabels(), ax.axes.get_yticklabels()
):

    ground_truth_cluster_id = (
        merge_df[merge_df["doc_title"] == y_tick_label.get_text()]["Combine Group ID"]
        .values[0]
        .astype(int)
    )
    colour = colours[ground_truth_cluster_id]
    y_tick_label.set_color(colour)
    x_tick_label.set_color(colour)

plt.tight_layout()
plt.show()

In [None]:
ax.figure.savefig(OUTPUT_CM_PATH, dpi=400)

In [None]:
sim_df = pd.DataFrame(similarities.numpy())

In [None]:
sim_df.index = merge_df['Page Title']
sim_df.columns = merge_df['Page Title']

In [None]:
sim_df.to_csv(OUTPUT_SIM_PATH, encoding='utf-8-sig')