In [None]:
import json
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np

#folder:
best_hyp_path = '../best_hyperparameters_saved/finished_fl/best_hyperparameters_flaubert.json'
log_hyp_path = '../best_hyperparameters_saved/finished_fl/hyperparameter_log_flaubert.json'

# Load the grid search results from the file
# Function to load JSON lines
def load_json_lines(file_path):
    with open(file_path, 'r') as f:
        return [json.loads(line) for line in f]
    

grid_search_results = load_json_lines(log_hyp_path)

# Load the best hyperparameters and metrics from the file
with open(best_hyp_path, 'r') as f:
    best_hyperparameters_data = json.load(f)
    best_hyperparameters = best_hyperparameters_data['best_parameters']
    best_metrics = best_hyperparameters_data['metrics']

def plot_grid_search_results(grid_search_results):
    df = pd.DataFrame(grid_search_results)
    df['learning_rate'] = df['hyperparameters'].apply(lambda x: x['learning_rate'])
    df['batch_size'] = df['hyperparameters'].apply(lambda x: x['batch_size'])
    df['epochs'] = df['hyperparameters'].apply(lambda x: x['epochs'])

    plt.figure(figsize=(12, 8))
    sns.lineplot(data=df, x='batch_size', y='accuracy', hue='learning_rate', marker='o')
    plt.title('Grid Search Results')
    plt.xlabel('Batch Size')
    plt.ylabel('Accuracy')
    plt.legend(title='Learning Rate')
    plt.grid(True)
    plt.show()


# Plotting the confusion matrix
def plot_confusion_matrix(confusion_matrix, labels):
    plt.figure(figsize=(10, 7))
    sns.heatmap(confusion_matrix, annot=True, fmt="d", cmap="Blues", xticklabels=labels, yticklabels=labels)
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    plt.title('Confusion Matrix')
    plt.show()

# Plotting the metrics
def plot_metrics(metrics):
    labels = list(metrics.keys())
    values = list(metrics.values())

    plt.figure(figsize=(10, 5))
    plt.barh(labels[:-1], values[:-1], color='skyblue')
    plt.xlabel('Score')
    plt.title('Evaluation Metrics')
    plt.xlim(0, 1)
    for index, value in enumerate(values[:-1]):
        plt.text(value, index, f'{value:.3f}')
    plt.show()


# Labels for the confusion matrix
labels = ['A1', 'A2', 'B1', 'B2', 'C1', 'C2']

# Plotting the results
plot_grid_search_results(grid_search_results)
plot_confusion_matrix(best_metrics['confusion_matrix'], labels)
plot_metrics(best_metrics)
