In [None]:
%load_ext autoreload
%autoreload 2

import os
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from modules.plots import normalize_column_of_nparrays, get_nice_name, configure_default_plot_style, set_default_figsize
from pathlib import Path
from modules.datasetify import create_dataset

In [None]:
# Configure Plots
configure_default_plot_style()
set_default_figsize("single", ratio=1.0, factor=1)

In [None]:
# Arguments
experiment_id = "full"
distance_complexity_parquet = f"../output/02/{experiment_id}/combined_dist_compl.parquet"
nice_name_mapping = {
    "raise1k": "RAISE-1k",
    "midjourney-v5": "SB-MJ5",
    "stable-diffusion-1-3": "SB-SD1.3",
    "stable-diffusion-1-4": "SB-SD1.4",
    "stable-diffusion-2": "SB-SD2",
    "stable-diffusion-xl": "SB-SDXL",
    "lpips_vgg_2": "LPIPS$_2$"
}

# Output
output_dir = Path(f"../output/plots/{experiment_id}/")

In [None]:
def plot_ds_comp_distr_per_comp_metric(input_parquet, output_dir, nice_name_mapping, bins=30, y_scale_log=False):
    plot_output_dir = output_dir / "plot_ds_comp_distr_per_comp_metric"
    plot_output_dir.mkdir(exist_ok=True, parents=True)
    
    df = pd.read_parquet(input_parquet)
    
    df = df.query("repo_id == 'max'").copy()
    
    df["dir"] = df["dir"].map(
        lambda x: get_nice_name(x, nice_name_mapping)
    )
    
    for cm, cm_group in df.groupby("complexity_metric", observed=True):
        normalized_complexity_df = normalize_column_of_nparrays(cm_group, "complexity")
        
        for ds, ds_group in normalized_complexity_df.groupby("dir", observed=True):
            sns.histplot(
                x=np.stack(ds_group.complexity).flatten(),
                label=ds,
                bins=bins,
                log_scale=(False, y_scale_log),
                element="poly",
                fill=False
            )

        plt.title("Dataset Complexity Distribution")
        plt.xlabel(f"{cm.upper()} Complexity ({bins} Bins)")
        plt.ylabel(f"Count{' (log)' if y_scale_log else ''}")
        plt.legend(title="Datasets")
        plt.grid(alpha=0.5)
        plt.savefig(plot_output_dir / f"plot_ds_comp_distr_bins_{bins}_y_{'log' if y_scale_log else 'lin'}_{cm}.pdf")
        plt.show()
        plt.close()

In [None]:
plot_ds_comp_distr_per_comp_metric(
    input_parquet=distance_complexity_parquet, 
    output_dir=output_dir, 
    nice_name_mapping=nice_name_mapping, 
    bins=30,
    y_scale_log=False
)

plot_ds_comp_distr_per_comp_metric(
    input_parquet=distance_complexity_parquet, 
    output_dir=output_dir, 
    nice_name_mapping=nice_name_mapping, 
    bins=30,
    y_scale_log=True
)

In [None]:
def plot_ds_comp_vs_dist_per_comp_metric(input_parquet, output_dir, nice_name_mapping, bins=100):
    df = pd.read_parquet(input_parquet)
    
    df = df.query("repo_id == 'max'").copy()
    
    df[["dir", "distance_metric"]] = df[["dir", "distance_metric"]].map(
        lambda x: get_nice_name(x, nice_name_mapping)
    )

    max_distance = np.stack(df.distance).flatten().max()
    
    for cm, cm_group in df.groupby("complexity_metric", observed=True):
        normalized_complexity_df = normalize_column_of_nparrays(cm_group, "complexity")
        
        for ds, ds_group in normalized_complexity_df.groupby("dir", observed=True):
            sns.histplot(
                x=np.stack(ds_group.complexity).flatten(),
                y=np.stack(ds_group.distance).flatten(),
                bins=bins,
                binrange=((0, 1), (0, max_distance)),
                stat="density",
                vmax="1000"
            )

            plt.title(f"Complexity vs. Distance of {ds}")
            plt.xlabel(f"{cm.upper()} Complexity ({bins} Bins)")
            plt.ylabel(f"{ds_group.iloc[0].distance_metric.upper()} Distance")
            plt.grid(alpha=0.5)
            plt.show()
            plt.close()

In [None]:
plot_ds_comp_vs_dist_per_comp_metric(distance_complexity_parquet, output_dir, nice_name_mapping, bins=100)