In [None]:
%load_ext autoreload
%autoreload 2

import os
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from modules.plots import normalize_column_of_nparrays, get_nice_name, configure_default_plot_style, set_default_figsize
from pathlib import Path
from modules.datasetify import create_dataset

In [None]:
# Configure Plots
configure_default_plot_style()
set_default_figsize("single", ratio=1.0, factor=1)

In [None]:
# Arguments
experiment_id = "full-old"
distance_complexity_parquet = f"../output/02/{experiment_id}/combined_dist_compl.parquet"
ap_comp_bins_parquet = f"../output/compute_ap_comp_bins/{experiment_id}/combined_ap_comp_bins.parquet"
nice_name_mapping = {
    "raise1k": "RAISE-1k",
    "synthbuster/midjourney-v5": "SB-MJ5",
    "synthbuster/stable-diffusion-1-3": "SB-SD1.3",
    "synthbuster/stable-diffusion-1-4": "SB-SD1.4",
    "synthbuster/stable-diffusion-2": "SB-SD2",
    "synthbuster/stable-diffusion-xl": "SB-SDXL",
    "lpips_vgg_2": "LPIPS$_2$",
    "CompVis/stable-diffusion-v1-1": "SD1.1",
    "stabilityai/stable-diffusion-2-base": "SD2",
    "stabilityai/stable-diffusion-xl-base": "SDXL-BASE",
    "stabilityai/stable-diffusion-xl-refiner-1.0": "SDXL-REFINE",
    "kandinsky-2-1": "KD2.1"
}
save_plot = False
show_plot = True

# Output
output_dir = Path(f"../output/plots/{experiment_id}/")

In [None]:
def plot_comp_distr_for_ds_per_comp_metric(input_parquet, output_dir, nice_name_mapping, bins=30, y_scale_log=False, save_plot=True, show_plot=True):
    df = pd.read_parquet(input_parquet)
    
    df = df.query("repo_id == 'max'").copy()
    
    df["dir"] = df["dir"].map(
        lambda x: get_nice_name(x, nice_name_mapping)
    )
    
    for cm, cm_group in df.groupby("complexity_metric", observed=True):
        normalized_complexity_df = normalize_column_of_nparrays(cm_group, "complexity")
        
        for ds, ds_group in normalized_complexity_df.groupby("dir", observed=True):
            sns.histplot(
                x=np.stack(ds_group.complexity).flatten(),
                label=ds,
                bins=bins,
                log_scale=(False, y_scale_log),
                element="poly",
                fill=False
            )

        plt.title("Complexity Distribution")
        plt.xlabel(f"{cm.upper()} Complexity ({bins} Bins)")
        plt.ylabel(f"Count{' (log)' if y_scale_log else ''}")
        plt.xlim(0, 1)
        plt.legend(title="Datasets")

        if save_plot:
            plot_output_dir = output_dir / "plot_comp_distr_for_ds_per_comp_metric"
            plot_output_dir.mkdir(exist_ok=True, parents=True)
            plt.savefig(plot_output_dir / f"plot_bins_{bins}_y_{'log' if y_scale_log else 'lin'}_{cm}.pdf")

        if show_plot:
            plt.show()
            
        plt.close()

In [None]:
plot_comp_distr_for_ds_per_comp_metric(
    input_parquet=distance_complexity_parquet, 
    output_dir=output_dir, 
    nice_name_mapping=nice_name_mapping, 
    bins=30,
    y_scale_log=False,
    save_plot=save_plot,
    show_plot=show_plot
)

plot_comp_distr_for_ds_per_comp_metric(
    input_parquet=distance_complexity_parquet, 
    output_dir=output_dir, 
    nice_name_mapping=nice_name_mapping, 
    bins=30,
    y_scale_log=True,
    save_plot=save_plot,
    show_plot=show_plot
)

In [None]:
def plot_dist_vs_comp_per_ds_per_comp_metric(input_parquet, output_dir, nice_name_mapping, bins=100, save_plot=True, show_plot=True):
    df = pd.read_parquet(input_parquet)
    
    df = df.query("repo_id == 'max'").copy()
    
    df[["dir", "distance_metric"]] = df[["dir", "distance_metric"]].map(
        lambda x: get_nice_name(x, nice_name_mapping)
    )

    max_distance = np.stack(df.distance).flatten().max()
    
    for cm, cm_group in df.groupby("complexity_metric", observed=True):
        normalized_complexity_df = normalize_column_of_nparrays(cm_group, "complexity")
        
        for ds, ds_group in normalized_complexity_df.groupby("dir", observed=True):
            sns.histplot(
                x=np.stack(ds_group.complexity).flatten(),
                y=np.stack(ds_group.distance).flatten(),
                bins=bins,
                stat="density",
                vmax="300"
            )

            plt.title(f"Distance vs. Complexity for {ds}")
            plt.xlabel(f"{cm.upper()} Complexity ({bins} Bins)")
            plt.ylabel(f"{ds_group.iloc[0].distance_metric.upper()} Distance ({bins} Bins)")
            plt.xlim(0, 1)
            plt.ylim(0, max_distance)

            if save_plot:
                plot_output_dir = output_dir / "plot_dist_vs_comp_per_ds_per_comp_metric"
                plot_output_dir.mkdir(exist_ok=True, parents=True)
                plt.savefig(plot_output_dir / f"plot_bins_{bins}_{ds}_{cm}.pdf")

            if show_plot:
                plt.show()
            
            plt.close()

In [None]:
plot_dist_vs_comp_per_ds_per_comp_metric(
    input_parquet=distance_complexity_parquet, 
    output_dir=output_dir, 
    nice_name_mapping=nice_name_mapping, 
    bins=100,
    save_plot=save_plot,
    show_plot=show_plot
)

In [None]:
def plot_mean_dist_vs_comp_for_ds_per_comp_metric(input_parquet, output_dir, nice_name_mapping, bins=30, save_plot=True, show_plot=True):
    df = pd.read_parquet(input_parquet)
    
    df = df.query("repo_id == 'max'").copy()
    
    df[["dir", "distance_metric"]] = df[["dir", "distance_metric"]].map(
        lambda x: get_nice_name(x, nice_name_mapping)
    )

    max_distance = np.stack(df.distance).flatten().max()
    
    for cm, cm_group in df.groupby("complexity_metric", observed=True):
        normalized_complexity_df = normalize_column_of_nparrays(cm_group, "complexity")
        
        for ds, ds_group in normalized_complexity_df.groupby("dir", observed=True):
            # Flatten complexity and distance arrays for binning
            complexities = np.stack(ds_group.complexity).flatten()
            distances = np.stack(ds_group.distance).flatten()
            
            # Bin the complexity values
            bins_edges = np.linspace(complexities.min(), complexities.max(), bins + 1)
            bin_indices = np.digitize(complexities, bins_edges) - 1

            # Calculate mean distance for each bin
            mean_distances = [
                distances[bin_indices == i].mean() if len(distances[bin_indices == i]) > 0 else np.nan
                for i in range(bins)
            ]

            bin_centers = (bins_edges[:-1] + bins_edges[1:]) / 2

            # Filter out NaN values for plotting
            valid = ~np.isnan(mean_distances)
            bin_centers = bin_centers[valid]
            mean_distances = np.array(mean_distances)[valid]
            
            sns.lineplot(
                x=bin_centers,
                y=mean_distances,
                label=ds
            )

        plt.title(f"Mean Distance vs. Complexity")
        plt.xlabel(f"{cm.upper()} Complexity ({bins} Bins)")
        plt.ylabel(f"{cm_group.iloc[0].distance_metric.upper()} Distance")
        plt.xlim(0, 1)
        plt.ylim(0, max_distance)
        plt.legend(title="Datasets")

        if save_plot:
            plot_output_dir = output_dir / "plot_mean_dist_vs_comp_for_ds_per_comp_metric"
            plot_output_dir.mkdir(exist_ok=True, parents=True)
            plt.savefig(plot_output_dir / f"plot_bins_{bins}_{cm}.pdf")

        if show_plot:
            plt.show()
        
        plt.close()

In [None]:
plot_mean_dist_vs_comp_for_ds_per_comp_metric(
    input_parquet=distance_complexity_parquet, 
    output_dir=output_dir, 
    nice_name_mapping=nice_name_mapping, 
    bins=30,
    save_plot=save_plot,
    show_plot=show_plot
)

In [None]:
def plot_mean_dist_vs_comp_for_comp_metric_per_ds(input_parquet, output_dir, nice_name_mapping, bins=30, save_plot=True, show_plot=True):
    df = pd.read_parquet(input_parquet)

    df = df.query("repo_id == 'max'").copy()

    df[["dir", "distance_metric"]] = df[["dir", "distance_metric"]].map(
        lambda x: get_nice_name(x, nice_name_mapping)
    )

    max_distance = np.stack(df.distance).flatten().max()

    # Normalize complexity values
    normalized_groups = []

    for cm, cm_group in df.groupby("complexity_metric", sort=False, observed=True):
        cm_group = normalize_column_of_nparrays(cm_group, "complexity")
        normalized_groups.append(cm_group)

    df = pd.concat(normalized_groups, ignore_index=True)
    
    for ds, ds_group in df.groupby("dir", observed=True):
        for cm, cm_group in ds_group.groupby("complexity_metric", observed=True):
            # Flatten complexity and distance arrays for binning
            complexities = np.stack(cm_group.complexity).flatten()
            distances = np.stack(cm_group.distance).flatten()

            # Bin the complexity values
            bins_edges = np.linspace(complexities.min(), complexities.max(), bins + 1)
            bin_indices = np.digitize(complexities, bins_edges) - 1

            # Calculate mean distance for each bin
            mean_distances = [
                distances[bin_indices == i].mean() if len(distances[bin_indices == i]) > 0 else np.nan
                for i in range(bins)
            ]

            bin_centers = (bins_edges[:-1] + bins_edges[1:]) / 2

            # Filter out NaN values for plotting
            valid = ~np.isnan(mean_distances)
            bin_centers = bin_centers[valid]
            mean_distances = np.array(mean_distances)[valid]

            sns.lineplot(
                x=bin_centers,
                y=mean_distances,
                label=cm.upper()
            )

        plt.title(f"Mean Distance vs. Complexity for {ds}")
        plt.xlabel(f"Complexity ({bins} Bins)")
        plt.ylabel(f"{ds_group.iloc[0].distance_metric.upper()} Distance")
        plt.xlim(0, 1)
        plt.ylim(0, max_distance)
        plt.legend(title="Complexity Metrics")

        if save_plot:
            plot_output_dir = output_dir / "plot_mean_dist_vs_comp_for_comp_metric_per_ds"
            plot_output_dir.mkdir(exist_ok=True, parents=True)
            plt.savefig(plot_output_dir / f"plot_bins_{bins}_{ds}.pdf")

        if show_plot:
            plt.show()

        plt.close()

In [None]:
plot_mean_dist_vs_comp_for_comp_metric_per_ds(
    input_parquet=distance_complexity_parquet, 
    output_dir=output_dir, 
    nice_name_mapping=nice_name_mapping, 
    bins=30,
    save_plot=save_plot,
    show_plot=show_plot
)

In [None]:
def plot_mean_dist_vs_comp_for_reconst_ae_per_ds_per_comp_metric(input_parquet, output_dir, nice_name_mapping, bins=30, save_plot=True, show_plot=True):
    df = pd.read_parquet(input_parquet)
    
    df = df.query("repo_id != 'max'").copy()
    
    df[["dir", "distance_metric", "repo_id"]] = df[["dir", "distance_metric", "repo_id"]].map(
        lambda x: get_nice_name(x, nice_name_mapping)
    )

    max_distance = np.stack(df.distance).flatten().max()
    
    for cm, cm_group in df.groupby("complexity_metric", observed=True):
        normalized_complexity_df = normalize_column_of_nparrays(cm_group, "complexity")
        
        for ds, ds_group in normalized_complexity_df.groupby("dir", observed=True):
            for model, model_group in ds_group.groupby("repo_id", observed=True):
                # Flatten complexity and distance arrays for binning
                complexities = np.stack(model_group.complexity).flatten()
                distances = np.stack(model_group.distance).flatten()
                
                # Bin the complexity values
                bins_edges = np.linspace(complexities.min(), complexities.max(), bins + 1)
                bin_indices = np.digitize(complexities, bins_edges) - 1
    
                # Calculate mean distance for each bin
                mean_distances = [
                    distances[bin_indices == i].mean() if len(distances[bin_indices == i]) > 0 else np.nan
                    for i in range(bins)
                ]
    
                bin_centers = (bins_edges[:-1] + bins_edges[1:]) / 2
    
                # Filter out NaN values for plotting
                valid = ~np.isnan(mean_distances)
                bin_centers = bin_centers[valid]
                mean_distances = np.array(mean_distances)[valid]
                
                sns.lineplot(
                    x=bin_centers,
                    y=mean_distances,
                    label=model
                )

            plt.title(f"Mean Distance vs. Complexity for {ds}")
            plt.xlabel(f"{cm.upper()} Complexity ({bins} Bins)")
            plt.ylabel(f"{ds_group.iloc[0].distance_metric.upper()} Distance")
            plt.xlim(0, 1)
            plt.ylim(0, max_distance)
            plt.legend(title="Reconst. AEs")
    
            if save_plot:
                plot_output_dir = output_dir / "plot_mean_dist_vs_comp_for_reconst_ae_per_ds_per_comp_metric"
                plot_output_dir.mkdir(exist_ok=True, parents=True)
                plt.savefig(plot_output_dir / f"plot_bins_{bins}_{ds}_{cm}.pdf")
    
            if show_plot:
                plt.show()
            
            plt.close()

In [None]:
plot_mean_dist_vs_comp_for_reconst_ae_per_ds_per_comp_metric(
    input_parquet=distance_complexity_parquet, 
    output_dir=output_dir, 
    nice_name_mapping=nice_name_mapping, 
    bins=30,
    save_plot=save_plot,
    show_plot=show_plot
)

In [None]:
def plot_ap_vs_comp_for_comp_metric(input_parquet, output_dir, nice_name_mapping, save_plot=True, show_plot=True):
    df = pd.read_parquet(input_parquet)
    
    df = df.query("repo_id == 'max'").copy()
    
    df[["fake_dir"]] = df[["fake_dir"]].map(
        lambda x: get_nice_name(x, nice_name_mapping)
    )

    max_ap = np.stack(df.ap).flatten().max()
    
    for cm, cm_group in df.groupby("complexity_metric", observed=True):
        for ds, ds_group in cm_group.groupby("fake_dir", observed=True): 
            complexity_bin_center = np.stack(ds_group.complexity_bin_center).flatten()
            ap = np.stack(ds_group.ap).flatten()
            
            # Recover the original number of bins
            bin_centers = np.unique(complexity_bin_center)
            bin_edges = np.linspace(bin_centers.min(), bin_centers.max(), len(bin_centers) + 1)
            num_bins = len(bin_edges) - 1
            
            sns.lineplot(
                x=complexity_bin_center,
                y=ap,
                label=ds
            )

        plt.title(f"AP vs. Complexity")
        plt.xlabel(f"{cm.upper()} Complexity ({num_bins} Bins)")
        plt.ylabel(f"Average Precision (AP)")
        plt.xlim(0, 1)
        plt.ylim(0, max_ap)
        plt.legend(title="Fake Datasets")
    
        if save_plot:
            plot_output_dir = output_dir / "plot_ap_vs_comp_for_comp_metric"
            plot_output_dir.mkdir(exist_ok=True, parents=True)
            plt.savefig(plot_output_dir / f"plot_bins_{num_bins}.pdf")
    
        if show_plot:
            plt.show()
        
        plt.close()

In [None]:
plot_ap_vs_comp_for_comp_metric(
    input_parquet=ap_comp_bins_parquet, 
    output_dir=output_dir, 
    nice_name_mapping=nice_name_mapping,
    save_plot=save_plot,
    show_plot=show_plot
)