In [None]:
#write a script that makes barplots of the results
#there will be 8 bars for each metric and they will be separated two by two
#every pair of bars will represent the results of the centralized and decentralized models
#for the decentralized models, we will take the final value of the metric
#the bars will be grouped by metric
#the bars will be colored by model
#the bars will be labeled by model
#the x-axis will be the metrics
#the y-axis will be the values of the metrics
#the title will be the metric
#the legend will be the models
#the plot will be saved as a png file
#the script will be called plot_bars.py

# Import libraries
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

# Define metrics and models
metrics = ["Loss", "Accuracy", "Precision", "Recall"]
models = ["log_reg", "xgb", "mlp", "dqn"]
colors = [('red', 'blue'), ('red', 'blue'), ('red', 'blue'), ('red', 'blue')]
labels = ['Logistic Regression', 'XGBoost', 'MLP', "DQN"]

# Load central results
central_results = pd.read_csv("central_res.csv")

# Set width of bar
barWidth = 0.3

# Increase the space between model groups
space_between_groups = 1.0

# Loop through models and metrics
for metric in metrics:
    plt.figure(figsize=(10, 5))

    # Setting the positions of the bars on x-axis
    r1 = np.arange(len(models))
    r2 = [x + 2*barWidth for x in r1]

    # Adjust positions to add space between model groups
    r1 = [x + space_between_groups*i for i, x in enumerate(r1)]
    r2 = [x + space_between_groups*i for i, x in enumerate(r2)]
    
    added_legend = (False, False)
    
    for i, (model, (color1, color2), label) in enumerate(zip(models, colors, labels)):
        # Load the per-round data for the model
        data = pd.read_csv(f"{model}_plot.csv")

        # Get the final value for the metric
        final_value = data[metric].iloc[-1]

        # Get the central value for the metric from central_res.csv
        central_value = central_results.loc[central_results["Model Name"] == model, metric].values[0]

        # Plot bars
        plt.bar(r1[i], central_value, color=color1, label="Central Model" if not added_legend[0] else "")
        plt.bar(r2[i], final_value, color=color2, label=f"Federated Model" if not added_legend[1] else "")
        added_legend = (True, True)

        # Customize the plot
        plt.xlabel('Models')
        plt.ylabel(metric)
        plt.title(f"{metric} vs Models")
        plt.xticks([(x + barWidth/2) for x in r1], labels)

        # Adjust the legend position
        plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.1), ncol=2, frameon=False)
        plt.tight_layout()

        
        # Save the figure
        plt.savefig(f"bar_{metric}.png")
    plt.clf()  # Clear the figure for the next plot