In [None]:
import json
import os
import matplotlib.pyplot as plt
import seaborn as sns

def load_json_files(experiment_names, dataset, experiment_type):
    directory = "../results/results_hnm_2k"
    data = []
    for filename in os.listdir(directory):
        for experiment_name in experiment_names:
            if experiment_name in filename and dataset in filename and experiment_type in filename:
                with open(os.path.join(directory, filename), 'r') as file:
                    content = json.load(file)
                    system_name = filename.split('-')[0]  # Extracting system name from filename
                    if experiment_type == "upload":
                        upload_time = content['results']['upload_time']
                        total_time = content['results']['total_time']
                        data.append({'system_name': system_name, 'experiment_name': experiment_name, 'upload_time': upload_time, 'total_time': total_time})
                    elif experiment_type == "search":
                        mean_precision = content['results']['mean_precisions']
                        mean_time = content['results']['mean_time']
                        data.append({'system_name': system_name, 'experiment_name': experiment_name, 'mean_precision': mean_precision, 'mean_time': mean_time})
    return data

def create_bar_chart(plot_data, upload_type):
    sns.set_theme(style="whitegrid")
    colors = [sns.color_palette("deep")[0], sns.color_palette("deep")[2]]  # Specified colors

    labels = [item['system_name'] for item in plot_data]
    upload_times = [item['upload_time'] for item in plot_data]
    total_times = [item['total_time'] for item in plot_data]

    x = range(len(labels))  # the label locations
    width = 0.35  # the width of the bars

    fig, ax = plt.subplots()
    fig.set_size_inches(6.4, 4.8)
    ax.bar([pos - width/2 for pos in x], upload_times, width, label='Upload Time', color=colors[0])
    ax.bar([pos + width/2 for pos in x], total_times, width, label='Total Time', color=colors[1])

    ax.set_ylabel('Time (s)', fontsize=16)
    ax.set_title(f'Upload {upload_type}', fontsize=16)
    ax.set_xticks(x)
    ax.set_xticklabels(labels)
    ax.legend()
    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)
    plt.tight_layout()
    plt.savefig(f'images/{dataset.strip("-")}_upload.png', format='png')
    plt.show()

def create_scatter_plot(plot_data, systems_colors, dataset, search_type):
    sns.set_theme(style="whitegrid")
    plt.figure(figsize=(6.4, 4.8))
    for i, item in enumerate(plot_data):
        # Adjusting the alpha parameter for transparency
        plt.scatter(item['mean_precision'], item['mean_time'], label=item['system_name'], color=systems_colors[item['experiment_name']], s=200, alpha=1)
    # Adding fontsize parameter to increase font size
    plt.xlabel('Mean Precision', fontsize=16)
    plt.ylabel('Mean Time (s)', fontsize=16)
    plt.title(search_type, fontsize=16)
    plt.xlim(0.7, 1.05)
    plt.xticks([0.7, 0.8, 0.9, 1])
    plt.legend(loc='upper left')
    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)
    plt.tight_layout()
    plt.savefig(f'images/{dataset.strip("-")}_search.png', format='png')
    plt.show()

systems_colors = {
    "qdrant-m-16-ef-100": sns.color_palette("deep")[3],
    "chroma-m-16-ef-100": sns.color_palette("deep")[1],
    "lvd-model-mlp4-ncat-20-epoch-200-lrs-001-nbuck-2-bthr-02-constw--1": sns.color_palette("deep")[4],
    "weaviate-m-16-ef-128": sns.color_palette("deep")[2],
    "milvus-m-16-ef-128": sns.color_palette("deep")[-1]
}


experiment_names = systems_colors.keys()

dataset_name = "H&M"
dataset = "-hnm_2k-"

search_type = f"Constrained Search on {dataset_name}"
upload_type = f"{dataset_name} With Metadata"

if "no" in dataset:
    search_type = "Regular kANN Search"
    upload_type = f"{dataset_name} Without Metadata"

experiment_type = "upload"
upload_data = load_json_files(experiment_names, dataset, experiment_type)
create_bar_chart(upload_data, upload_type)

experiment_type = "search"
search_data = load_json_files(experiment_names, dataset, experiment_type)
create_scatter_plot(search_data, systems_colors, dataset, search_type)