In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pyprojroot import here
import json
import torch

def plot_metrics_advanced(datasets_data, dataset_labels=None, xlabel="Number of Steps", 
                         mmd_metrics=None, title_prefix="", single_dataset=False):
    """
    Advanced plotting function that can handle both single and multiple datasets.
    
    Args:
        datasets_data: List of DataFrames (for multiple datasets) or single DataFrame (for single dataset)
        dataset_labels: List of labels for datasets (ignored for single dataset)
        xlabel: Label for x-axis
        mmd_metrics: List of MMD metric names to plot
        title_prefix: Prefix for plot titles
        single_dataset: Whether plotting a single dataset (True) or multiple datasets (False)
    
    Returns:
        The matplotlib figure object
    """
    
    if mmd_metrics is None:
        mmd_metrics = ["orbit_mmd", "degree_mmd", "spectral_mmd", "clustering_mmd", "gin_mmd"]
    
    # Set up color scheme
    sns.set_palette("colorblind")
    colors = sns.color_palette("colorblind")
    descriptor_colors = {
        "orbit": colors[0],
        "degree": colors[1],
        "spectral": colors[2],
        "clustering": colors[3],
        "gin": colors[4],
    }
    
    # Handle single dataset case
    if single_dataset:
        if not isinstance(datasets_data, list):
            datasets_data = [datasets_data]
        if dataset_labels is None:
            dataset_labels = [""]
        else:
            dataset_labels = [dataset_labels] if not isinstance(dataset_labels, list) else dataset_labels[:1]
    
    n_datasets = len(datasets_data)
    
    # Create subplot layout
    if single_dataset:
        # For single dataset: horizontal layout with 3 columns
        fig, axes = plt.subplots(1, 3, figsize=(10, 2))
        axes = axes.flatten()
        bottom_margin = 0.22
        legend_y = -0.04
    else:
        # For multiple datasets: 3x3 grid
        fig, axes = plt.subplots(n_datasets, 3, figsize=(10, len(datasets_data) * 1.8))
        bottom_margin = 0.15
        legend_y = -0.02

        if n_datasets == 1:
            axes = axes.reshape(1, -1)
    
    # Create legend elements
    legend_elements = []
    legend_elements.append(plt.Line2D([0], [0], color="#7e9ef7", lw=2, label="Validity"))
    legend_elements.append(plt.Line2D([0], [0], color="black", lw=2, label="PolyGraphScore"))
    for metric in mmd_metrics:
        color = next((color for key, color in descriptor_colors.items() if key in metric), "black")
        legend_elements.append(plt.Line2D([0], [0], color=color, lw=2, 
                                        label=metric.replace("_mmd", "").title() + " RBF"))

    for i, (data, label) in enumerate(zip(datasets_data, dataset_labels)):
        # Add dataset label for multiple datasets
        if not single_dataset and label:
            fig.text(-0.01, 0.83 - (i * 0.315), label, 
                     rotation=90, verticalalignment='center', fontsize=12, fontweight='bold')
        
        # Get axes for this dataset
        if single_dataset:
            ax_validity, ax_polyscore, ax_mmd = axes[0], axes[1], axes[2]
        else:
            ax_validity, ax_polyscore, ax_mmd = axes[i, 0], axes[i, 1], axes[i, 2]
        
        # Column 0: Validity
        ax_validity.plot(data["num_steps"], data["validity"], color="#7e9ef7")
        ax_validity.set_ylabel("Validity")
        ax_validity.set_ylim([0, 1])
        ax_validity.yaxis.set_major_locator(plt.MaxNLocator(6))
        if i == 0:
            ax_validity.set_title(f"{title_prefix}Validity" if title_prefix else "Validity")
        if i == n_datasets - 1:
            ax_validity.set_xlabel(xlabel)
        
        # Column 1: PolyScore
        ax_polyscore.plot(data["num_steps"], data["polyscore"], color="black")
        ax_polyscore.set_ylabel("PGS")
        ax_polyscore.set_ylim([0, 1])
        ax_polyscore.yaxis.set_major_locator(plt.MaxNLocator(6))
        if i == 0:
            ax_polyscore.set_title(f"{title_prefix}PolyGraphScore" if title_prefix else "PolyGraphScore")
        if i == n_datasets - 1:
            ax_polyscore.set_xlabel(xlabel)
        
        # Column 2: Combined MMD plots with multiple y-axes
        axes_list = []
        
        # Create y-axes for each MMD metric
        for j in range(len(mmd_metrics)):
            if j == 0:
                ax_twin = ax_mmd.twinx()
                ax_twin.spines['right'].set_position(('outward', 0))
            else:
                ax_twin = ax_mmd.twinx()
                ax_twin.spines['right'].set_position(('outward', 35 * j))
            axes_list.append(ax_twin)
        
        # Hide the main axis ticks and labels since we're not using it
        ax_mmd.set_yticks([])
        ax_mmd.set_ylabel("MMD²")
        ax_mmd.spines['left'].set_visible(False)
        
        # Plot each MMD metric on its own y-axis with scientific notation
        for j, metric in enumerate(mmd_metrics):
            color = next((color for key, color in descriptor_colors.items() if key in metric), "black")
            
            # Get the data for this metric
            metric_data = data[metric]
            
            # Calculate the scale factor (power of 10)
            max_val = metric_data.max()
            if max_val > 0:
                power = int(np.floor(np.log10(max_val)))
                scale_factor = 10 ** power
                
                # Scale the data
                scaled_data = metric_data / scale_factor
                
                # Plot the scaled data
                line, = axes_list[j].plot(data["num_steps"], scaled_data, color=color)
                axes_list[j].tick_params(axis='y', labelcolor=color)
                axes_list[j].yaxis.set_major_locator(plt.MaxNLocator(6))
                
                x_offset = 1.02 + (j * 0.28)
                axes_list[j].text(x_offset, 1.2, f'$\\times 10^{{{power}}}$',
                                transform=axes_list[j].transAxes,
                                color=color, fontsize=8,
                                verticalalignment='top',
                                horizontalalignment='left')
                #x_offset = 1.02 + (j * 0.28)
                #axes_list[j].text(x_offset, 1.2, f'$\\times 10^{{{power}}}$', 
                #                transform=axes_list[j].transAxes,
                #                color=color, fontsize=8, 
                #                verticalalignment='top',
                #                horizontalalignment='left')
            else:
                # Fallback for zero data
                line, = axes_list[j].plot(data["num_steps"], metric_data, color=color)
                axes_list[j].tick_params(axis='y', labelcolor=color)
                axes_list[j].yaxis.set_major_locator(plt.MaxNLocator(6))
        
        # Set x-axis properties for the main axis
        if i == n_datasets - 1:
            ax_mmd.set_xlabel(xlabel)
        
        if i == 0:
            ax_mmd.set_title(f"{title_prefix}MMD Metrics" if title_prefix else "MMD Metrics")

    # Add legend
    fig.legend(handles=legend_elements, loc='lower center', bbox_to_anchor=(0.5, legend_y), 
               ncol=len(legend_elements), frameon=False)
    plt.subplots_adjust(bottom=bottom_margin)
    
    plt.tight_layout()
    return fig


def plot_single_dataset_advanced(results_df, mmd_metrics=None, xlabel="Number of Steps"):
    """
    Convenience function for plotting a single dataset using the advanced style.
    
    Args:
        results_df: DataFrame containing the results
        mmd_metrics: List of MMD metric names to plot
        xlabel: Label for x-axis
    
    Returns:
        The matplotlib figure object
    """
    return plot_metrics_advanced(
        datasets_data=results_df,
        xlabel=xlabel,
        mmd_metrics=mmd_metrics,
        single_dataset=True
    )


def plot_multiple_datasets_advanced(datasets_data, dataset_labels, xlabel="Number of Steps", mmd_metrics=None):
    """
    Convenience function for plotting multiple datasets using the advanced style.
    
    Args:
        datasets_data: List of DataFrames containing the results for each dataset
        dataset_labels: List of labels for each dataset
        xlabel: Label for x-axis
        mmd_metrics: List of MMD metric names to plot
    
    Returns:
        The matplotlib figure object
    """
    return plot_metrics_advanced(
        datasets_data=datasets_data,
        dataset_labels=dataset_labels,
        xlabel=xlabel,
        mmd_metrics=mmd_metrics,
        single_dataset=False
    )

with open("/fs/pool/pool-mlsb/polygraph/rcparams.json", "r") as f:
    style = json.load(f)

plt.rcParams.update(style)

np.random.seed(42)
torch.manual_seed(42)

In [None]:
import ipywidgets as widgets
from IPython.display import display

# Create a dropdown menu with options
dropdown = widgets.Dropdown(
    options=[('Jensen-Shannon-Distance', 'jsd'), ('Informedness', 'informedness-adaptive')],
    value='jsd',
    description='Metric:',
)
# Observe changes in dropdown value
def on_value_change(change):
    global selected_metric
    selected_metric = change.new
    print(f"Selected metric changed to: {selected_metric}")

dropdown.observe(on_value_change, names='value')

# Display the dropdown
display(dropdown)

# Access the selected value using dropdown.value
selected_metric = dropdown.value

## Number of denoising iterations

In [None]:
print(selected_metric)
experiment = "model-quality/denoising-iterations"

results = pd.read_csv(
    f"/fs/pool/pool-mlsb/polygraph/{experiment}/{selected_metric}_planar.csv"
)
display(results)

# Define MMD metrics to plot
mmd_metrics = ["orbit_mmd", "degree_mmd", "spectral_mmd", "clustering_mmd", "gin_mmd"]

# Create and save the plot using the new advanced plotting function
fig = plot_single_dataset_advanced(results, xlabel="Number of Denoising Steps")
fig.savefig(
    here() / ".local/plots" / f"{experiment.replace('/', '_')}_validity_polyscore_all_mmd_{selected_metric}.pdf",
    bbox_inches='tight'
)
plt.show()
plt.close(fig)

In [None]:
# Print a LaTex table with all polygraph scores
pgs_columns = ["num_steps", "validity", "polyscore", "orbit_pgs", "degree_pgs", "spectral_pgs", "clustering_pgs", "gin_pgs"]
pgs_df = results[pgs_columns]
header = ["\# Steps", "Validity", "PolyGraphScore", "Orbit PGS", "Degree PGS", "Spectral PGS", "Clustering PGS", "GIN PGS"]
formatters = [lambda x: x, lambda x: "\\formatpercent{"+str(x)+"}"] + [lambda x: "\\num[round-mode=places, round-precision=2]{"+str(x)+"}"] * 6
print(pgs_df.to_latex(columns=pgs_columns, index=False, header=header, formatters=formatters, column_format="l|cc|ccccc"))

In [None]:
# Print a LaTex table with all MMDs
pgs_columns = ["num_steps", "validity", "orbit_mmd", "degree_mmd", "spectral_mmd", "clustering_mmd", "gin_mmd"]
pgs_df = results[pgs_columns]
header = ["\# Steps", "Validity", "Orbit RBF", "Degree RBF", "Spectral RBF", "Clustering RBF", "GIN RBF"]
formatters = [lambda x: x, lambda x: "\\formatpercent{"+str(x)+"}"] + [lambda x: "\\num[round-mode=places, round-precision=4]{"+str(x)+"}"] * 6
print(pgs_df.to_latex(columns=pgs_columns, index=False, header=header, formatters=formatters, column_format="l|c|ccccc"))

In [20]:
# Print Pearson correlation between validity and other metrics
iter_val_pearson_correlations = results.corr(method="pearson")["validity"][["polyscore", "orbit_mmd", "degree_mmd", "spectral_mmd", "clustering_mmd", "gin_mmd"]]

## Number of Training Epochs

In [None]:
print(f"Using metric: {selected_metric}")
datasets = ["planar-procedural", "sbm-procedural", "lobster-procedural"]
datasets_labels = ["Planar", "SBM", "Lobster"]
datasets_data = []

for i, dataset in enumerate(datasets):
    experiment = f"/fs/pool/pool-mlsb/polygraph/digress-samples/{dataset}"
    results = pd.read_csv(f"{experiment}/{selected_metric}_{datasets_labels[i].lower()}.csv")
    datasets_data.append(results)

fig = plot_multiple_datasets_advanced(datasets_data, datasets_labels)

plt.show()
fig.savefig(
    here() / ".local/plots" / f"all_training_epochs_{selected_metric}.pdf",
    bbox_inches='tight'
)
plt.close(fig)

In [None]:
import scipy.stats

# Function to compute correlations for each dataset
def compute_dataset_correlations(method="spearman", reference_col="num_steps", flip_metrics=("validity",)):
    dataset_correlations = {}
    metric_names = ['validity', 'polyscore'] + mmd_metrics
    
    for i, dataset in enumerate(datasets):
        experiment = f"/fs/pool/pool-mlsb/polygraph/digress-samples/{dataset}"
        results = pd.read_csv(f"{experiment}/{selected_metric}_{datasets_labels[i].lower()}.csv")
        
        correlations = {}
        for metric in metric_names:
            if method == "spearman":
                corr, _ = scipy.stats.spearmanr(results[reference_col], results[metric])
            elif method == "pearson":
                corr, _ = scipy.stats.pearsonr(results[reference_col], results[metric])
            if metric in flip_metrics:
                correlations[metric] = -corr
            else:
                correlations[metric] = corr
        
        dataset_correlations[datasets_labels[i]] = correlations
    
    return dataset_correlations

# Compute correlations
dataset_correlations = compute_dataset_correlations()
dataset_correlations = pd.DataFrame(dataset_correlations)
dataset_correlations = dataset_correlations.transpose()

columns = ['validity', 'polyscore', 'orbit_mmd', 'degree_mmd', 'spectral_mmd', 'clustering_mmd', 'gin_mmd']
header = ["Validity", "PGS", "Orbit RBF", "Deg. RBF", "Spec. RBF", "Clust. RBF", "GIN RBF"]
formatters = [lambda x: f"{100 * x:.2f}"] * 7
dataset_correlations = -dataset_correlations[columns]

# Generate LaTeX table as before
print(dataset_correlations.to_latex(columns=columns, index=True, header=header, formatters=formatters, column_format="l|cc|ccccc"))

In [None]:
train_val_pearson_corrleations = compute_dataset_correlations(method="pearson", reference_col="validity")
pearson_df = pd.DataFrame(train_val_pearson_corrleations).transpose()[["polyscore", "orbit_mmd", "degree_mmd", "spectral_mmd", "clustering_mmd", "gin_mmd"]]
pearson_df = pd.concat([iter_val_pearson_correlations.to_frame(name="Planar").transpose(), pearson_df], keys=["Denoising", "Training"], axis=0)

columns = ['polyscore', 'orbit_mmd', 'degree_mmd', 'spectral_mmd', 'clustering_mmd', 'gin_mmd']
header = ["PGS", "Orbit RBF", "Deg. RBF", "Spec. RBF", "Clust. RBF", "GIN RBF"]
formatters = [lambda x: f"{100 * x:.2f}"] * 6 
pearson_df = -pearson_df[columns]

# Generate LaTeX table as before
print(pearson_df.to_latex(columns=columns, index=True, header=header, formatters=formatters, column_format="ll|cccccc"))