In [None]:
import json
import os
import matplotlib.pyplot as plt
import seaborn as sns

def load_json_files(system_names, dataset, experiment_type):
    directory = "../results"
    data = []
    for filename in os.listdir(directory):
        if any(system_name in filename for system_name in system_names) and dataset in filename and experiment_type in filename:
            with open(os.path.join(directory, filename), 'r') as file:
                content = json.load(file)
                system_name = filename.split('-')[0]  # Extracting system name from filename
                if experiment_type == "upload":
                    upload_time = content['results']['upload_time']
                    total_time = content['results']['total_time']
                    data.append({'system_name': system_name, 'upload_time': upload_time, 'total_time': total_time})
                elif experiment_type == "search":
                    mean_precision = content['results']['mean_precisions']
                    mean_time = content['results']['mean_time']
                    data.append({'system_name': system_name, 'mean_precision': mean_precision, 'mean_time': mean_time})
    return data

def create_bar_chart(plot_data):
    sns.set_theme(style="whitegrid")
    colors = [sns.color_palette("deep")[0], sns.color_palette("deep")[2]]  # Specified colors

    labels = [item['system_name'] for item in plot_data]
    upload_times = [item['upload_time'] for item in plot_data]
    total_times = [item['total_time'] for item in plot_data]

    x = range(len(labels))  # the label locations
    width = 0.35  # the width of the bars

    fig, ax = plt.subplots()
    ax.bar([pos - width/2 for pos in x], upload_times, width, label='Upload Time', color=colors[0])
    ax.bar([pos + width/2 for pos in x], total_times, width, label='Total Time', color=colors[1])

    ax.set_ylabel('Times')
    ax.set_title('Upload and Total Times by System')
    ax.set_xticks(x)
    ax.set_xticklabels(labels)
    ax.legend()

    plt.show()

def create_scatter_plot(plot_data):
    sns.set_theme(style="whitegrid")
    palette = [sns.color_palette("deep")[1], sns.color_palette("deep")[3]]
    plt.figure(figsize=(10, 6))
    for i, item in enumerate(plot_data):
        plt.scatter(item['mean_precision'], item['mean_time'], label=item['system_name'], color=palette[i], s=100)
    plt.xlabel('Mean Precision')
    plt.ylabel('Mean Time')
    plt.title('Comparison of Systems on Mean Precision and Mean Time')
    plt.xlim(0, 1.25)
    plt.xticks([0, 0.25, 0.5, 0.75, 1])
    plt.legend(loc='upper left')
    plt.show()


system_names = ["qdrant-default", "chroma-default"]
dataset = "random_keywords_10k"

experiment_type = "upload"
upload_data = load_json_files(system_names, dataset, experiment_type)
create_bar_chart(upload_data)

experiment_type = "search"
search_data = load_json_files(system_names, dataset, experiment_type)
create_scatter_plot(search_data)
