In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import seaborn as sns

# Function to identify Pareto efficient points
def is_pareto_efficient(costs):
    is_efficient = np.ones(costs.shape[0], dtype=bool)
    for i, c in enumerate(costs):
        if is_efficient[i]:
            is_efficient[i] = ~np.any((costs[:, 0] <= c[0]) & (costs[:, 1] <= c[1]) & 
                                      ((costs[:, 0] < c[0]) | (costs[:, 1] < c[1])))
    return is_efficient

# Function to plot Pareto frontier
def plot_pareto(csv_file, output_file, title):

    df = pd.read_csv(csv_file)
    costs = df[['inference_time', 'val_mse']].values
    costs[:, 0] *= 1000  # Convert inference time to ms
    pareto_mask = is_pareto_efficient(costs)

    # Set Seaborn style for professional look
    sns.set_style("whitegrid")
    plt.figure(figsize=(7, 4), dpi=300)

    # Plot non-optimal points
    plt.scatter(costs[~pareto_mask, 0], costs[~pareto_mask, 1], 
                c='gray', alpha=0.3, label='Non-optimal')

    # Plot Pareto frontier
    pareto_points = costs[pareto_mask]
    sorted_indices = np.argsort(pareto_points[:, 0])
    pareto_points = pareto_points[sorted_indices]
    plt.scatter(pareto_points[:, 0], pareto_points[:, 1], 
                c='#1f77b4', label='Pareto Frontier')
    plt.plot(pareto_points[:, 0], pareto_points[:, 1], 
             color='#1f77b4', linestyle='-', linewidth=1.5)


    # Customize axes
    plt.xlabel('Inference Time (ms)', fontsize=10)
    plt.ylabel('Test MSE', fontsize=10)
    plt.title(title, fontsize=12, pad=10)


    # # Adjust tick label size
    # plt.tick_params(axis='both', labelsize=8)

    # Place legend outside the plot
    plt.legend(fontsize=8, loc='upper left')

    # Save plot
    plt.tight_layout()
    plt.savefig(output_file, format='svg', bbox_inches='tight')
    plt.close()

# Generate plots
plot_pareto("data/results/long_seq_experiment.csv", "data/results/pareto_long_seq.svg", "Pareto Frontier: Long Sequences")
plot_pareto("data/results/short_seq_experiment.csv", "data/results/pareto_short_seq.svg", "Pareto Frontier: Short Sequences")