In [None]:
%load_ext autoreload
%autoreload 2
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from aeroblade.paper import configure_mpl, get_nice_name, set_figsize

# Configure matplotlib
configure_mpl()
set_figsize("single", ratio=1.0, factor=0.49)

# Define output directory
output_dir = Path("../output/02/jpeg-comp/figures")
output_dir.mkdir(exist_ok=True, parents=True)

# Load and preprocess the data
combined = pd.read_parquet("../output/02/jpeg-comp/combined_dist_compl.parquet").query(
    "repo_id == 'max'"
)
combined[["dir", "distance_metric"]] = combined[["dir", "distance_metric"]].map(
    get_nice_name
)

# Define reusable functions
def plot_histogram(data, x, y, title, xlabel, ylabel, filename, output_dir, binrange, vmax):
    """Plots and saves a histogram."""
    sns.histplot(
        x=data[x],
        y=data[y],
        stat="density",
        bins=100,
        legend=False,
        binrange=binrange,
        vmax=vmax,
    )
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.savefig(output_dir / filename)
    plt.close()


def plot_mean_curve(data, x, y, title, xlabel, ylabel, filename, output_dir, xlim, ylim):
    """Calculates and plots mean curve."""
    data["Complexity_bin"] = pd.cut(data[x], bins=100, labels=False, retbins=False)
    aggregated_data = data.groupby("Complexity_bin").agg({
        x: "mean",  # Mean of complexities in each bin
        y: "mean"    # Mean distance for each complexity bin
    }).dropna()

    plt.figure(figsize=(8, 6))
    sns.lineplot(data=aggregated_data, x=x, y=y)
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.xlim(xlim)
    plt.ylim(ylim)
    plt.grid(True)
    plt.savefig(output_dir / filename)
    plt.show()
    plt.close()


def plot_combined_mean_curves(groups, x, y, title, xlabel, ylabel, filename, output_dir, xlim, ylim):
    """Plots all mean curves on the same graph."""
    plt.figure(figsize=(8, 6))
    for nice_dir, group_df in groups:
        complexity_values = np.stack(group_df.complexity).flatten()
        distance_values = np.stack(group_df.distance).flatten()
        group_data = pd.DataFrame({x: complexity_values, y: distance_values})

        group_data["Complexity_bin"] = pd.cut(group_data[x], bins=100, labels=False, retbins=False)
        aggregated_data = group_data.groupby("Complexity_bin").agg({
            x: "mean",
            y: "mean"
        }).dropna()

        sns.lineplot(data=aggregated_data, x=x, y=y, label=nice_dir)

    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.xlim(xlim)
    plt.ylim(ylim)
    plt.legend()
    plt.grid(True)
    plt.savefig(output_dir / filename)
    plt.show()
    plt.close()

# Process and plot group histograms and mean curves
xlim = (0.05, 0.4)
ylim = (0, 0.07)

for nice_dir, group_df in combined.groupby("dir", observed=True):
    complexity_values = np.stack(group_df.complexity).flatten()
    distance_values = np.stack(group_df.distance).flatten()

    group_data = pd.DataFrame({"complexity": complexity_values, "distance": distance_values})

    # Plot histogram
    plot_histogram(
        group_data, "complexity", "distance",
        f"Distance vs Complexity ({nice_dir})",
        "Complexity", group_df.iloc[0].distance_metric,
        f"dist_vs_compl_{nice_dir}.pdf", output_dir,
        binrange=((0.05, 0.4), (0, 0.07)), vmax=1000
    )

    # Plot mean curve
    plot_mean_curve(
        group_data, "complexity", "distance",
        f"Mean Distance vs Complexity ({nice_dir})",
        "Complexity", group_df.iloc[0].distance_metric,
        f"mean_dist_vs_compl_{nice_dir}.pdf", output_dir,
        xlim, ylim
    )

# Process and plot all generated data
all_generated = combined.query("dir != 'Real'")
all_generated_data = pd.DataFrame({
    "complexity": np.concatenate(all_generated.complexity.values),
    "distance": np.concatenate(all_generated.distance.values),
})

# Plot histogram for all generated data
plot_histogram(
    all_generated_data, "complexity", "distance",
    "Distance vs Complexity (All Generated)",
    "Complexity", combined.iloc[0].distance_metric,
    "dist_vs_compl_all_generated.pdf", output_dir,
    binrange=((0.05, 0.4), (0, 0.07)), vmax=1000
)

# Plot mean curve for all generated data
plot_mean_curve(
    all_generated_data, "complexity", "distance",
    "Mean Distance vs Complexity (All Generated)",
    "Complexity", combined.iloc[0].distance_metric,
    "mean_dist_vs_compl_all_generated.pdf", output_dir,
    xlim, ylim
)

# Plot combined mean curves for all groups except 'All Generated'
groups = combined.groupby("dir", observed=True)
plot_combined_mean_curves(
    groups, "complexity", "distance",
    "Mean Distance vs Complexity (All Groups)",
    "Complexity", combined.iloc[0].distance_metric,
    "combined_mean_dist_vs_compl.pdf", output_dir,
    xlim, ylim
)
