In [None]:
from qdrant_client import QdrantClient

from src.collection.query_collection import (
    filter_search,
    get_semantically_similar_results,
)
from src.utils.utils import load_qdrant_client, load_config
from src.utils.utils import load_model
from src.collection.evaluate_collection import (
    calculate_precision,
    calculate_recall,
    calculate_f1_score,
    calculate_f2_score,
    get_unique_labels,
    get_data_for_evaluation,
)

from dotenv import load_dotenv
import os
import pickle
import numpy as np
from collections import defaultdict

load_dotenv()

QDRANT_HOST = os.getenv("QDRANT_HOST")
QDRANT_PORT = os.getenv("QDRANT_PORT")
COLLECTION_NAME = os.getenv("COLLECTION_NAME")
HF_MODEL_NAME = os.getenv("HF_MODEL_NAME")
PUBLISHING_PROJECT_ID = os.getenv("PUBLISHING_PROJECT_ID")
EVALUATION_TABLE = os.getenv("EVALUATION_TABLE")
EVALUATION_TABLE = f"`{EVALUATION_TABLE}`"

In [None]:
config = load_config("../.config/config.json")
similarity_threshold = float(config.get("similarity_threshold_1"))

with open("../data/regex_ids.pkl", "rb") as f:
    regex_ids = pickle.load(f)

In [None]:
qdrant = load_qdrant_client(QDRANT_HOST, port=QDRANT_PORT)
model = load_model(HF_MODEL_NAME)

In [None]:
query_embedding = model.encode("applications")

In [None]:
results = get_semantically_similar_results(
    client=qdrant,
    collection_name=COLLECTION_NAME,
    query_embedding=query_embedding,
    score_threshold=0.1,
)

In [None]:
results

In [None]:
result_ids = [str(result.id) for result in results]
result_ids

In [None]:
app_ids = regex_ids["application"]
apps_ids = regex_ids["applications"]

print(f"intersection: {sorted(list(set(app_ids) & set(apps_ids)))}")
print(f"apps_ids: {sorted(apps_ids)}")
print(f"count of app_ids: {len(app_ids)}")

In [None]:
precision = calculate_precision(result_ids, apps_ids)
recall = calculate_recall(result_ids, apps_ids)
f1_score = calculate_f1_score(precision, recall)
f2_score = calculate_f2_score(precision, recall)

In [None]:
# We want high recall and we don't particularly mind if precision is low
# because it just means that we are recommending more records than necessary
# but we're including all the relevant records in our recommendations

print(
    f"precision: {precision}"
)  # low precision = high fals positives (to be expected with low ANN similarity)
print(
    f"recall: {recall}"
)  # high recall = low false negatives (to be expected with low ANN similarity)
print(f"f1_score: {f1_score}")  # low f1 score = low precision
print(f"f2_score: {f2_score}")  # low f2 score = low recall

In [None]:
data = get_data_for_evaluation(
    project_id=PUBLISHING_PROJECT_ID,
    evaluation_table=EVALUATION_TABLE,
)

# Get unique labels
unique_labels = get_unique_labels(data)

In [None]:
def calculate_metrics(unique_label, regex_ids, model, client, similarity_threshold):
    # Get the count of records from the regex counts
    relevant_records = regex_ids[unique_label]

    # Embed the label
    query_embedding = model.encode(unique_label)

    # Retrieve the top K results for the label
    try:
        results = get_semantically_similar_results(
            client=client,
            collection_name=COLLECTION_NAME,
            query_embedding=query_embedding,
            score_threshold=similarity_threshold,
        )
    except Exception as e:
        print(f"get_semantically_similar_results error: {e}")
        pass

    result_ids = [str(result.id) for result in results]

    # Calculate precision and recall
    precision = calculate_precision(result_ids, relevant_records)
    recall = calculate_recall(result_ids, relevant_records)

    return precision, recall

In [None]:
# Loop over unique labels and similarity thresholds
precision_values = []
recall_values = []
for unique_label in unique_labels:
    for threshold in np.arange(0, 1.1, 0.1):
        precision, recall = calculate_metrics(
            unique_label=unique_label,
            regex_ids=regex_ids,
            model=model,
            client=qdrant,
            similarity_threshold=threshold,
        )
        precision_values.append({unique_label: {threshold: precision}})
        recall_values.append({unique_label: {threshold: recall}})

# 33 mins to run all records

In [None]:
precision_values

In [None]:
# pickle precision and recall values
# with open("../data/precision_values.pkl", "wb") as f:
#     pickle.dump(precision_values, f)

# with open("../data/recall_values.pkl", "wb") as f:
#     pickle.dump(recall_values, f)

In [None]:
def calculate_mean_values(data_list):
    # Dictionary to hold cumulative sums and counts for each test
    sums_counts = defaultdict(lambda: {"sum": 0, "count": 0})

    for item in data_list:
        for _, values in item.items():
            for threshold, value in values.items():
                sums_counts[threshold]["sum"] += value
                sums_counts[threshold]["count"] += 1

    # Calculate mean for each test
    mean_values = {
        test: info["sum"] / info["count"] for test, info in sums_counts.items()
    }
    return mean_values


# Calculate and print the mean values
mean_precision_values = calculate_mean_values(precision_values)
mean_recall_values = calculate_mean_values(recall_values)

In [None]:
import matplotlib.pyplot as plt

data = mean_precision_values
recall_data = mean_recall_values

x = list(data.keys())
y = list(data.values())
recall_x = list(recall_data.keys())
recall_y = list(recall_data.values())

plt.plot(x, y)
plt.plot(recall_x, recall_y)
plt.xlabel("Threshold")
plt.ylabel("Precision/Recall")
plt.title("Precision and Recall vs Threshold")
plt.xlim(0, 1)
plt.ylim(0, 1)
plt.legend(["Precision", "Recall"])

# Remove top and right-hand side borders
plt.gca().spines["top"].set_visible(False)
plt.gca().spines["right"].set_visible(False)

# Add thin grey gridlines with an opacity of 0.3
plt.grid(color="grey", linestyle=":", linewidth=0.5, alpha=0.3)

# Change x-axis labels to increment by 0.1
plt.xticks([i / 10 for i in range(11)])

plt.show()

In [None]:
# Create boxplots for precision and recall
with open("../data/precision_values.pkl", "rb") as f:
    precision_values = pickle.load(f)
precision_values = [list(item.values())[0] for item in precision_values]
rounded_precision_values = [
    {round(key, 2): value for key, value in item.items()} for item in precision_values
]
rounded_precision_values

In [None]:
def get_threshold_values(data, input_threshold=0.0):
    threshold_list = []
    for item in data:
        for threshold, value in item.items():
            if threshold == input_threshold:
                threshold_list.append(value)
    return threshold_list


threshold_list = get_threshold_values(rounded_precision_values, input_threshold=0.0)

plotting_values = {}

for i in np.arange(0, 1.1, 0.1):
    threshold_list = get_threshold_values(rounded_precision_values, input_threshold=i)
    plotting_values[i] = threshold_list

meta_list = []

for threshold_value in np.arange(0, 0.3, 0.1):
    threshold_list = get_threshold_values(
        rounded_precision_values, input_threshold=threshold_value
    )
    meta_list.append(
        {
            "threshold": threshold_value,
            "mean": np.mean(threshold_list),
            "median": np.median(threshold_list),
            "std": np.std(threshold_list),
            "min": np.min(threshold_list),
            "max": np.max(threshold_list),
        }
    )

plotting_values
# fig, ax = plt.subplots()

# ax.boxplot(plotting_values.values())
# ax.set_xticklabels(data.keys())

# plt.show()

In [None]:
!pip install plotly

In [None]:
import plotly.graph_objs as go
from plotly.offline import plot

fig = go.Figure()
# Your data dictionary
for item in plotting_values.items():
    fig.add_trace(go.Box(y=item[1], name=f"Threshold: {item[0]}", orientation="h"))

plot(fig)


# Create a list of Box objects for Plotly, one for each key in the dictionary
# box_plots = [
#     go.Box(
#         y=values,  # Assign values for the horizontal box plot
#         name=str(key),  # Use the dictionary key as the name of the boxplot
#         orientation="h",  # 'h' for horizontal boxplot
#     )
#     for key, values in data_dict.items()
# ]

# Define layout options
layout = go.Layout(
    title="Horizontal Boxplots", xaxis=dict(title="Values"), yaxis=dict(title="Dataset")
)

# Create the figure with data and layout
# fig = go.Figure(data=box_plots, layout=layout)

# Plot the figure
# plot(fig)