In [None]:
%load_ext autoreload
%autoreload 2

from collections import defaultdict
import os
import numpy as np
from pathlib import Path
import pandas as pd
import tqdm
import matplotlib.pyplot as plt
from matplotlib import ticker
from dataclasses import field, dataclass
import logging
from multiprocessing import Pool
import seaborn as sns
import time
from torchvision.transforms import Compose
import functools
import torch
import torch.nn as nn
from diffusers import UNet2DModel
from diffusers import DDPMScheduler
from typing import Optional

import matplotlib
from matplotlib.lines import Line2D
import matplotlib.patches as mpatches
import matplotlib.text as mtext
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
import matplotlib.colors as mcolors
from matplotlib.colors import rgb_to_hsv, hsv_to_rgb, to_rgb

# LaTeX + matplotlib Preamble

In [None]:
COLORS = [
    (204/255, 57/255, 42/255),  # palette1
    (79/255, 155/255, 143/255),  # palette2
    (44/255, 97/255, 194/255),   # palette3
    (217/255, 116/255, 89/255),  # palette4
    (228/255, 197/255, 119/255), # palette5
    (155/255, 106/255, 145/255), # palette6 converted from "#9B6A91"
    (51/255, 110/255, 49/255),   # palette7 converted from "#336E31"
    (198/255, 5/255, 79/255),    # palette8 converted from "#C6054F"
]
sns.palplot(COLORS)

golden_ratio = (1 + 5 ** 0.5) / 2

markerlist = ['o', 4, 5, 6, 7, 'p', 'd', 'P']

# ICLR 2024 width: 397.48499pt
text_width_pt = 397.48499  # in pt
text_width = text_width_pt / 72.27  # in inches

fs_m1 = 7  # for figure ticks
fs = 9  # for regular figure text
fs_p1 = 10  # figure titles

axes_lw = 0.7

In [None]:
import matplotlib

%matplotlib inline
matplotlib.rc('font', size=fs)          # controls default text sizes
matplotlib.rc('axes', titlesize=fs)     # fontsize of the axes title
matplotlib.rc('axes', labelsize=fs)    # fontsize of the x and y labels
matplotlib.rc('axes', linewidth=axes_lw)    # fontsize of the x and y labels
matplotlib.rc('xtick', labelsize=fs_m1)    # fontsize of the tick labels
matplotlib.rc('ytick', labelsize=fs_m1)    # fontsize of the tick labels
matplotlib.rc('legend', fontsize=fs_m1)    # legend fontsize
matplotlib.rc('figure', titlesize=fs_p1)  # fontsize of the figure title

plt.rcParams["savefig.facecolor"] = "white"

# matplotlib.rc('font', **{'family': 'serif', 'serif': ['Times New Roman']})
matplotlib.rc('font', **{'family': 'serif'})

# Turn on if you've got TeX installed
matplotlib.rc('text', usetex=True)

plt.rcParams.update({'text.latex.preamble': r'\usepackage{amsfonts}'})


In [None]:
def darken_color(color, factor=0.7):
    """
    Darken the given color by multiplying (1-luminosity) by the given factor.
    Input can be matplotlib color string, hex string, or RGB tuple.
    
    Examples:
    >> darken_color('g', 0.3)
    >> darken_color('#F034A3', 0.6)
    >> darken_color((0.3, 0.55, 0.1), 0.5)
    """
    try:
        c = mcolors.to_rgb(color)
    except ValueError:
        return color  # If conversion fails, return the original color
    return tuple(factor * c[i] for i in range(3))

def lighten_color(color, factor=0.7):
    """
    Lighten the given color by multiplying (1-luminosity) by the given factor.
    Input can be matplotlib color string, hex string, or RGB tuple.
    
    Examples:
    >> lighten_color('g', 0.3)
    >> lighten_color('#F034A3', 0.6)
    >> lighten_color((0.3, 0.55, 0.1), 0.5)
    """
    try:
        c = mcolors.to_rgb(color)
    except ValueError:
        return color  # If conversion fails, return the original color
    return tuple((1 - factor) * c[i] + factor for i in range(3))




def interpolate_hsv(color1, color2, t):
    """
    Interpolate between two colors in HSV color space, carefully interpolating hue circularly.
    
    Parameters:
        color1, color2: initial colors, can be any matplotlib recognizable color (e.g., rgb tuple, hex, named color)
        t: interpolation parameter in [0, 1]
    
    Returns:
        interpolated RGB color as a tuple
    """
    # Convert input colors to RGB tuples
    rgb1 = np.array(to_rgb(color1))
    rgb2 = np.array(to_rgb(color2))
    
    # Convert RGB to HSV
    hsv1 = rgb_to_hsv(rgb1)
    hsv2 = rgb_to_hsv(rgb2)
    
    # Interpolate Hue circularly
    h1, s1, v1 = hsv1
    h2, s2, v2 = hsv2
    
    dh = h2 - h1
    if dh > 0.5:
        dh -= 1.0
    elif dh < -0.5:
        dh += 1.0
    
    h = (h1 + t * dh) % 1.0
    s = (1 - t) * s1 + t * s2
    v = (1 - t) * v1 + t * v2
    
    # Convert interpolated HSV back to RGB
    rgb_interp = hsv_to_rgb(np.array([h, s, v]).reshape(1, 1, 3)).flatten()
    
    return tuple(rgb_interp)


In [None]:

@dataclass
class DatasetMeasurementPlotConfig:
    # save_filename: str
    # savedir: Path = Path(__file__) / "figures"
    axtitle: str
    dataset: str
    measurement_dir: str # = "/srv/shared/outputs/measurements/DatasetType.cifar2/ddpm_samples_loss_measurement"
    # --- Per method settings ---
    influence_paths: list[str] # = field(
    influence_paths_untuned: list[Optional[str]]
    # The untuned damping factors used for computation here
    method_names: list[str] 

    train_model_idxs_path: Optional[str] = "/srv/shared/outputs/idxs/DatasetType.cifar2/idx_train.csv"
    retrained_model_idxs_dir: str = "/srv/shared/outputs/idxs/DatasetType.cifar2/retrain"
    num_retrained_models: int = 100



In [None]:
import re
import subprocess

repository_base = Path("..")

def extract_correlation_values(output: str) -> tuple[Optional[float], Optional[float], Optional[float], Optional[float]]:
    # Define a regular expression to capture correlation values from the text
    rank_correlation_pattern = r"Rank correlation mean: ([-+]?(?:\d*\.\d+|\d+)(?:[eE][-+]?\d+)?)"
    stde_pattern = r"Rank correlation stde: ([-+]?(?:\d*\.\d+|\d+)(?:[eE][-+]?\d+)?)"
    accross_seeds_pattern = r"Accross seeds \(to other averaged\) rank correlation mean: ([-+]?(?:\d*\.\d+|\d+)(?:[eE][-+]?\d+)?)"
    accross_seeds_stde_pattern = r"Accross seeds \(to other averaged\) rank correlation stde: ([-+]?(?:\d*\.\d+|\d+)(?:[eE][-+]?\d+)?)"

    # Search for the patterns and capture the group with the number
    rank_match = re.search(rank_correlation_pattern, output)
    stde_match = re.search(stde_pattern, output)
    across_seeds_match = re.search(accross_seeds_pattern, output)
    across_seeds_stde_match = re.search(accross_seeds_stde_pattern, output)

    # Use default values or the captured number if matches were found
    rank_correlation = float(rank_match.group(1)) if rank_match else None
    stde = float(stde_match.group(1)) if stde_match else None
    accross_seeds_correlation = float(across_seeds_match.group(1)) if across_seeds_match else None
    across_seeds_stde = float(across_seeds_stde_match.group(1)) if across_seeds_stde_match else None

    return rank_correlation, stde, accross_seeds_correlation, across_seeds_stde


def extract_lds_scores_for_method(
    influence_path: Path,
    dataset: str,
    retrained_model_idxs_dir: str,
    measurement_dir: str,
    train_model_idxs_path: Optional[str],
    num_retrained_models: int,
):
    if not influence_path.parent.exists():
        # Directory not found
        raise FileNotFoundError(f"Directory not found: {influence_path.parent}")
    if not influence_path.exists():
        raise FileNotFoundError(f"Missing {influence_path}")
    elif np.isnan(np.load(influence_path).mean()):
        raise ValueError(f"Some values are NaN in {influence_path}.")

    command = (
        f"python {str(repository_base.resolve())}/scripts/lds_score.py --config-name compute_lds_score_default "
        f"dataset={dataset} retrained_model_idxs_dir={retrained_model_idxs_dir} "
        f"measurement_dir={measurement_dir} "
        f"train_model_idxs_path={train_model_idxs_path if train_model_idxs_path else 'null'} "
        f"num_retrained_models={num_retrained_models} "
        f"influence_path={str(influence_path)}"
    ) 
    print(f"Running command: {command}")
    # Execute the command and capture stdout
    output = subprocess.run(command, shell=True, capture_output=True)
    stdoutput = output.stdout.decode()
    stderror = output.stderr.decode()
    # Extract and collect the correlation values
    rank_correlation, stde, accross_seeds_correlation, across_seeds_stde = extract_correlation_values(stdoutput)
    if rank_correlation is None or stde is None or accross_seeds_correlation is None or across_seeds_stde is None:
        raise ValueError(
            f"Failed to extract values for {influence_path} \n"
            f"rank_correlation={rank_correlation}, stde={stde}, accross_seeds_correlation={accross_seeds_correlation}"
            f"\nCommand:\n{command}"
            f"\nOutput:\n{stdoutput}"
            f"\nError:\n{stderror}"
        )
    return rank_correlation, stde, accross_seeds_correlation, across_seeds_stde

def get_lds_scores_for_config(config: DatasetMeasurementPlotConfig):

    rank_correlations = []
    stde_values = []
    accross_seeds_correlations = []
    across_seeds_stde_values = []

    for method_idx in range(len(config.influence_paths)):
        influence_path = Path(config.influence_paths[method_idx])
        rank_correlation, stde, accross_seeds_correlation, across_seeds_stde = extract_lds_scores_for_method(
            influence_path=influence_path,
            dataset=config.dataset,
            retrained_model_idxs_dir=config.retrained_model_idxs_dir,
            measurement_dir=config.measurement_dir,
            train_model_idxs_path=config.train_model_idxs_path,
            num_retrained_models=config.num_retrained_models,
        )


        rank_correlations.append(np.abs(rank_correlation))
        stde_values.append(stde)
        accross_seeds_correlations.append(accross_seeds_correlation)
        across_seeds_stde_values.append(across_seeds_stde)
    
        
    # Assert across seed correlations almost the same:
    for across_seed in accross_seeds_correlations[1:]:
        assert np.isclose(accross_seeds_correlations[0], across_seed, rtol=1e-8), f"Across seed correlation values are not the same: {accross_seeds_correlations}"
    for across_seed_stde in across_seeds_stde_values[1:]:
        assert np.isclose(across_seeds_stde_values[0], across_seed_stde, rtol=1e-8), f"Across seed stde values are not the same: {across_seeds_stde_values}"
    # Also extract the untuned values
    rank_correlations_untuned, stde_values_untuned = [], []
    for method_idx, influence_path in enumerate(config.influence_paths_untuned):
        if influence_path is None:
            rank_correlations_untuned.append(None)
            stde_values_untuned.append(None)
            continue
        rank_correlation, stde, accross_seeds_correlation, across_seeds_stde = extract_lds_scores_for_method(
            influence_path=Path(influence_path),
            dataset=config.dataset,
            retrained_model_idxs_dir=config.retrained_model_idxs_dir,
            measurement_dir=config.measurement_dir,
            train_model_idxs_path=config.train_model_idxs_path,
            num_retrained_models=config.num_retrained_models,
        )
        rank_correlations_untuned.append(np.abs(rank_correlation))
        stde_values_untuned.append(stde)
    
    return rank_correlations, stde_values, accross_seeds_correlations[0], across_seeds_stde_values[0], rank_correlations_untuned, stde_values_untuned

In [None]:
# Test config:
config_loss_cifar2 = DatasetMeasurementPlotConfig(
    axtitle="\\texttt{CIFAR}-2",
    dataset="cifar2",
    measurement_dir="/srv/shared/outputs/measurements/DatasetType.cifar2/ddpm_samples_loss_measurement",
    influence_paths=[
        "/srv/shared/outputs/DatasetType.cifar2/idx_train-0/clip-similarity-scores/scores.npy",
        "/srv/shared/outputs/DatasetType.cifar2/idx_train-0/trak_for_ddpm_samples__measure_LOSS_trainloss_LOSS_250measure_250loss__projdim32768_float32_damping1e-1/scores.npy",
        "/srv/shared/outputs/DatasetType.cifar2/idx_train-0/trak_for_ddpm_samples__measure_SQUARE_NORM_trainloss_SQUARE_NORM_250measure_250loss__projdim32768_float32_damping1/scores.npy",
        "/srv/shared/outputs/DatasetType.cifar2/idx_train-0/influence_for_ddpm_samples_LOSS_1000kfac_250measurement_250loss__quantize_8bits__damping1e-8/influence_scores.npy"
    ],
    influence_paths_untuned=[
        None,
        "/srv/shared/outputs/DatasetType.cifar2/idx_train-0/trak_for_ddpm_samples__measure_LOSS_trainloss_LOSS_250measure_250loss__projdim32768_float32_damping1e-1/scores.npy",
        "/srv/shared/outputs/DatasetType.cifar2/idx_train-0/trak_for_ddpm_samples__measure_SQUARE_NORM_trainloss_SQUARE_NORM_250measure_250loss__projdim32768_float32_damping1/scores.npy",
        "/srv/shared/outputs/DatasetType.cifar2/idx_train-0/influence_for_ddpm_samples_LOSS_1000kfac_250measurement_250loss__quantize_8bits__damping1e-8/influence_scores.npy"
    ],
    method_names=[
        "Clip Embedding Cosine Distance",
        "TRAK",
        "D-TRAK",
        "K-FAC Influence",
    ],
    train_model_idxs_path="/srv/shared/outputs/idxs/DatasetType.cifar2/idx_train.csv",
    retrained_model_idxs_dir="/srv/shared/outputs/idxs/DatasetType.cifar2/retrain"
)

In [None]:
def make_figure(
    configs,
    results,
    suptitle,
    colors=COLORS,
    height=0.28,
    width=1.0,
    extra_separators_at: list[int] | tuple[int] = tuple(),
    extra_separator_space=0.1,
    rank_correlation_fontsize=fs,
):
    fig, axes = plt.subplots(
        ncols=len(configs), nrows=1, figsize=(text_width * width, text_width * height)
    )
    plt.tight_layout()
    for i, (result, config) in enumerate(zip(results, configs)):
        assert len(config.method_names) == len(config.influence_paths)
        ax = axes[i] if len(configs) > 1 else axes
        (
            rank_correlations,
            stde_values,
            accross_seeds_correlation,
            across_seeds_stde,
            rank_correlations_untuned,
            stde_values_untuned,
        ) = result
        all_method_names = config.method_names + ["Exact Retraining"]
        all_rank_correlations = rank_correlations + [accross_seeds_correlation]
        all_stdes = stde_values + [across_seeds_stde]

        ax.yaxis.tick_right()
        if i != len(configs) - 1:
            ax.yaxis.set_visible(False)
        # Add ticks to the right side of the plot
        ax.yaxis.set_ticks_position("both")

        method_y_locs_from_up = [
            j + sum(extra_separator_space for sep in extra_separators_at if j >= sep)
            for j in range(len(all_method_names))
        ]
        method_y_locs = -np.array(method_y_locs_from_up)

        # Set the y-axis labels:
        ax.set_yticks(method_y_locs[::-1])
        ax.set_yticklabels(all_method_names[::-1])
        max_xlim = max(all_rank_correlations) * 100 * 1.7

        for j, (method_name, rank_correlation, stde) in enumerate(
            zip(all_method_names, all_rank_correlations, all_stdes)
        ):
            method_y_loc = method_y_locs[j]
            # Plot error bar
            color = colors[j]
            ax.errorbar(
                x=[rank_correlation * 100],
                y=[method_y_loc],
                xerr=[stde],
                fmt="o",
                color=color,
                markersize=2,
                capsize=3,
                elinewidth=0.5,
                clip_on=False,
            )
            # ax.text(rank_correlation + stde + 0.01, len(all_method_names) - i - 1, method_name, ha='left', va='center', fontsize=fs)
            value_string = (
                # f"${100*rank_correlation:.1f}\%$ \\footnotesize{{$\pm {100*stde:.1f}$}}"
                f"${100*rank_correlation:.1f}\%${{\\fontsize{{{int(rank_correlation_fontsize * 0.7)}pt}}{{11pt}}\\selectfont$\pm {100*stde:.1f}$}}"
            )
            ax.text(
                100 * rank_correlation + 100 * stde + 0.02 * max_xlim,
                method_y_loc,
                value_string,
                ha="left",
                va="center",
                fontsize=rank_correlation_fontsize,
                # Add background color to the text:
                bbox=dict(facecolor="white", edgecolor="none"),
                zorder=-1,
            )
                
        for j, (method_name, rank_correlation, stde, rank_correlation_tuned) in enumerate(
            zip(all_method_names, rank_correlations_untuned, stde_values_untuned, rank_correlations)
        ):
            if rank_correlation is None or stde is None:
                continue
            method_y_loc = method_y_locs[j]
            # Plot error bar
            color = colors[j]
            ax.errorbar(
                x=[rank_correlation * 100],
                y=[method_y_loc],
                xerr=[stde * 100],
                fmt="o",
                mfc='none',
                color=lighten_color(color, 0.4),
                markersize=4,
                capsize=0,
                elinewidth=0.5,
                clip_on=False,
            )
            value_string = (
                f"(${100*rank_correlation:.1f}\%)$"
                # f"{{\\tiny (${100*rank_correlation:.1f}\%)$}}"
            )
            text_to_left = (100 * rank_correlation_tuned) > 0.75 * max_xlim
            # ax.text(
            #     0.001 * max_xlim if text_to_left else 0.999 * max_xlim,
            ax.text(
                0.001 * max_xlim if text_to_left else 0.999 * max_xlim,
                # (0.999 * max_xlim),
                # 98,
                method_y_loc + 0.2,
                value_string,
                ha="right",
                va="bottom",
                fontsize=int(rank_correlation_fontsize * 0.7),
                # Add background color to the text:
                # bbox=dict(facecolor=(1.0, 1.0, 1.0, 0.5), edgecolor="none"),
                color="gray",
                zorder=-1,
            )
        # Add gray lines to guide the eye to label:
        for j in range(len(all_method_names)):
            ax.axhline(
                method_y_locs[j],
                color=[0.9] * 3,
                linewidth=0.5,
                linestyle="-",
                zorder=-2,
            )
        ax.set_title(config.axtitle, pad=8)
        # ax.set_xlabel("Rank Correlation $\%$ (LDS)")
        sns.despine(ax=ax, offset=3, left=True, right=False)
        # sns.despine(ax=ax, offset=3)
        for j, tick in enumerate(ax.get_yticklabels()):
            tick.set_color(
                darken_color(colors[len(all_method_names) - 1 - j], factor=0.7)
            )

        for extra_separator_at in extra_separators_at:
            # Right inbetween the two methods
            yloc = (
                method_y_locs[extra_separator_at - 1]
                + method_y_locs[extra_separator_at]
            ) / 2
            # ax.axhline(yloc, xmin=0.9, xmax=1.0, color=[0.2] * 3, linewidth=1.5, linestyle="-", zorder=3)
            # ax.axhline(yloc, xmin=0.0, xmax=0.1, color=[0.2] * 3, linewidth=1.5, linestyle="-", zorder=3)
            # ax.axhline(yloc, xmin=0.3, xmax=0.7, color=[0.2] * 3, linewidth=0.5, linestyle="-", zorder=3)

        ax.set_xlim(0, max_xlim)

    # add the separators:
    # Move suptitle a little bit higher:
    # fig.suptitle(suptitle)
    # Xlabel with offset
    fig.supxlabel("Rank Correlation $\%$ (LDS)", x=0.5, y=-0.1)


    # fig.subplots_adjust(top=0.77)
    right_margin = 0.76  # Adjust this value as needed (0.85 means 15% of width is reserved for right margin)
    fig.subplots_adjust(right=right_margin, wspace=0.12)
    return fig, axes

## Loss

In [None]:
configs = [
    DatasetMeasurementPlotConfig(
        axtitle="\\texttt{CIFAR-2}",
        dataset="cifar2",
        measurement_dir="/srv/shared/outputs/measurements/DatasetType.cifar2/ddpm_samples_loss_measurement",
        influence_paths=[
            "/srv/shared/outputs/DatasetType.cifar2/idx_train-0/clip-similarity-scores/scores.npy",
            "/srv/shared/outputs/DatasetType.cifar2/idx_train-0/trak_for_ddpm_samples__measure_LOSS_trainloss_LOSS_250measure_250loss__projdim32768_float32_damping1e-1/scores.npy",
            "/srv/shared/outputs/DatasetType.cifar2/idx_train-0/influence_for_ddpm_samples_LOSS_125ekfac_250measurement_250loss__quantize_8bits_damping1e-10/influence_scores.npy",
            "/srv/shared/outputs/DatasetType.cifar2/idx_train-0/trak_for_ddpm_samples__measure_SQUARE_NORM_trainloss_SQUARE_NORM_250measure_250loss__projdim32768_float32_damping1/scores.npy",
        ],
        influence_paths_untuned=[
            None,
            "/srv/shared/outputs/DatasetType.cifar2/idx_train-0/trak_for_ddpm_samples__measure_LOSS_trainloss_LOSS_250measure_250loss__projdim32768_float32_damping1e-6/scores.npy",
            "/srv/shared/outputs/DatasetType.cifar2/idx_train-0/influence_for_ddpm_samples_LOSS_125ekfac_250measurement_250loss__quantize_8bits_damping1e-8/influence_scores.npy",
            "/srv/shared/outputs/DatasetType.cifar2/idx_train-0/trak_for_ddpm_samples__measure_SQUARE_NORM_trainloss_SQUARE_NORM_250measure_250loss__projdim32768_float32_damping1e-6/scores.npy",
        ],
        method_names=(method_names := [
            "CLIP Similarity",
            "TRAK",
            "\\textbf{K-FAC Influence}",
            "D-TRAK",
        ]),
        train_model_idxs_path="/srv/shared/outputs/idxs/DatasetType.cifar2/idx_train.csv",
        retrained_model_idxs_dir="/srv/shared/outputs/idxs/DatasetType.cifar2/retrain"
    ),
    DatasetMeasurementPlotConfig(
        axtitle="\\texttt{CIFAR-10}",
        dataset="cifar10",
        measurement_dir="/srv/shared/outputs/measurements/DatasetType.cifar10/ddpm_samples_loss_measurement",
        influence_paths=[
            "/srv/shared/outputs/DatasetType.cifar10/idx_train-0/clip-similarity-scores/scores.npy",
            "/srv/shared/outputs/DatasetType.cifar10/idx_train-0/trak_for_ddpm_samples__measure_LOSS_trainloss_LOSS_250measure_250loss__projdim32768_float32_damping1e-1/scores.npy",
            "/srv/shared/outputs/DatasetType.cifar10/idx_train-0/influence_for_ddpm_samples_LOSS_125ekfac_250measurement_250loss__quantize_8bits_damping1e-11/influence_scores.npy",
            "/srv/shared/outputs/DatasetType.cifar10/idx_train-0/trak_for_ddpm_samples__measure_SQUARE_NORM_trainloss_SQUARE_NORM_250measure_250loss__projdim32768_float32_damping1/scores.npy",
        ],
        influence_paths_untuned=[
            None,
            "/srv/shared/outputs/DatasetType.cifar10/idx_train-0/trak_for_ddpm_samples__measure_LOSS_trainloss_LOSS_250measure_250loss__projdim32768_float32_damping1e-6/scores.npy",
            "/srv/shared/outputs/DatasetType.cifar10/idx_train-0/influence_for_ddpm_samples_LOSS_125ekfac_250measurement_250loss__quantize_8bits_damping1e-8/influence_scores.npy",
            "/srv/shared/outputs/DatasetType.cifar10/idx_train-0/trak_for_ddpm_samples__measure_SQUARE_NORM_trainloss_SQUARE_NORM_250measure_250loss__projdim32768_float32_damping1e-6/scores.npy",
        ],
        method_names=method_names,
        train_model_idxs_path=None,
        retrained_model_idxs_dir="/srv/shared/outputs/idxs/DatasetType.cifar10/retrain"
    ),
    DatasetMeasurementPlotConfig(
        axtitle="\\texttt{ArtBench}",
        dataset="artbench",
        measurement_dir="/srv/shared/outputs/measurements/DatasetType.artbench/idx_train-0ddpm_samples_LOSS_measurement",
        influence_paths=[
            "/srv/shared/outputs/DatasetType.artbench/idx_train-0/clip-similarity-scores/scores.npy",
            "/srv/shared/outputs/DatasetType.artbench/idx_train-0/trak_for_ddpm_samples__measure_LOSS_trainloss_LOSS_125measure_125loss__projdim32768_float32_damping1e-2/scores.npy",
            "/srv/shared/outputs/DatasetType.artbench/idx_train-0/influence_for_ddpm_samples_LOSS_63ekfac_125measurement_125loss__quantize_8bits__damping1e-10/influence_scores.npy",
            "/srv/shared/outputs/DatasetType.artbench/idx_train-0/trak_for_ddpm_samples__measure_SQUARE_NORM_trainloss_SQUARE_NORM_125measure_125loss__projdim32768_float32_damping1/scores.npy",
        ],
        influence_paths_untuned=[
            None,
            "/srv/shared/outputs/DatasetType.artbench/idx_train-0/trak_for_ddpm_samples__measure_LOSS_trainloss_LOSS_125measure_125loss__projdim32768_float32_damping1e-6/scores.npy",
            "/srv/shared/outputs/DatasetType.artbench/idx_train-0/influence_for_ddpm_samples_LOSS_63ekfac_125measurement_125loss__quantize_8bits__damping1e-8/influence_scores.npy",
            "/srv/shared/outputs/DatasetType.artbench/idx_train-0/trak_for_ddpm_samples__measure_SQUARE_NORM_trainloss_SQUARE_NORM_125measure_125loss__projdim32768_float32_damping1e-6/scores.npy",
        ],
        method_names=method_names,
        train_model_idxs_path=None,
        retrained_model_idxs_dir="/srv/shared/outputs/idxs/DatasetType.artbench/retrain",
        num_retrained_models=50,
    )
]

In [None]:
results = [
    get_lds_scores_for_config(config)
    for config in configs
]

In [None]:
fig, axes = make_figure(configs, results, suptitle="Loss measurement", height=0.22, extra_separators_at=[3, 4], extra_separator_space=0.3,
                        colors=[COLORS[2]] * 3 + [COLORS[3]] + [COLORS[4]],
                        rank_correlation_fontsize=fs - 2)
                        # colors=[COLORS[2], COLORS[2], darken_color(COLORS[2], factor=0.7)]  + [COLORS[3]] + [COLORS[4]])
# for ax in axes:
#     ax.axhline(0.5, color="gray", linestyle=":", linewidth=1.0, zorder=-1)
save_filename = Path("figures/lds_scores_loss.pdf").resolve().absolute()
print(save_filename)
fig.savefig(save_filename, bbox_inches="tight")
# fig.savefig(save_filename)

## ELBO

In [None]:
# ELBO* measurement
configs = [
    DatasetMeasurementPlotConfig(
        axtitle="\\texttt{CIFAR-2}",
        dataset="cifar2",
        measurement_dir="/srv/shared/outputs/measurements/DatasetType.cifar2/ddpm_samples_elbo_measurement",
        influence_paths=[
            "/srv/shared/outputs/DatasetType.cifar2/idx_train-0/clip-similarity-scores/scores.npy",
            "/srv/shared/outputs/DatasetType.cifar2/idx_train-0/trak_for_ddpm_samples__measure_SIMPLIFIED_ELBO_trainloss_LOSS_250measure_250loss__projdim32768_float32_damping1e-1/scores.npy",
            "/srv/shared/outputs/DatasetType.cifar2/idx_train-0/influence_for_ddpm_samples_SIMPLIFIED_ELBO_125ekfac_250measurement_250loss__quantize_8bits_damping1e-10/influence_scores.npy",
            "/srv/shared/outputs/DatasetType.cifar2/idx_train-0/trak_for_ddpm_samples__measure_SQUARE_NORM_trainloss_SQUARE_NORM_250measure_250loss__projdim32768_float32_damping1/scores.npy",
            "/srv/shared/outputs/DatasetType.cifar2/idx_train-0/influence_for_ddpm_samples_LOSS_125ekfac_250measurement_250loss__quantize_8bits_damping1e-10/influence_scores.npy",
        ],
        influence_paths_untuned=[
            None,
            "/srv/shared/outputs/DatasetType.cifar2/idx_train-0/trak_for_ddpm_samples__measure_SIMPLIFIED_ELBO_trainloss_LOSS_250measure_250loss__projdim32768_float32_damping1e-6/scores.npy",
            "/srv/shared/outputs/DatasetType.cifar2/idx_train-0/influence_for_ddpm_samples_SIMPLIFIED_ELBO_125ekfac_250measurement_250loss__quantize_8bits_damping1e-8/influence_scores.npy",
            "/srv/shared/outputs/DatasetType.cifar2/idx_train-0/trak_for_ddpm_samples__measure_SQUARE_NORM_trainloss_SQUARE_NORM_250measure_250loss__projdim32768_float32_damping1e-6/scores.npy",
            "/srv/shared/outputs/DatasetType.cifar2/idx_train-0/influence_for_ddpm_samples_LOSS_125ekfac_250measurement_250loss__quantize_8bits_damping1e-8/influence_scores.npy",
        ],
        method_names=(method_names := [
            "CLIP Cosine Similarity",
            "TRAK$^\\ast$",
            "K-FAC Influence",
            "D-TRAK$^\\ast$",
            "K-FAC Influence{\\tiny{(m. loss)}",
            # "K-FAC Influence {\\tiny{(m. traj.)}",
        ]),
        train_model_idxs_path="/srv/shared/outputs/idxs/DatasetType.cifar2/idx_train.csv",
        retrained_model_idxs_dir="/srv/shared/outputs/idxs/DatasetType.cifar2/retrain"
    ),
    DatasetMeasurementPlotConfig(
        axtitle="\\texttt{CIFAR-10}",
        dataset="cifar10",
        measurement_dir="/srv/shared/outputs/measurements/DatasetType.cifar10/ddpm_samples_simpelbo_measurement",
        influence_paths=[
            "/srv/shared/outputs/DatasetType.cifar10/idx_train-0/clip-similarity-scores/scores.npy",
            "/srv/shared/outputs/DatasetType.cifar10/idx_train-0/trak_for_ddpm_samples__measure_SIMPLIFIED_ELBO_trainloss_LOSS_250measure_250loss__projdim32768_float32_damping1e-1/scores.npy",
            "/srv/shared/outputs/DatasetType.cifar10/idx_train-0/influence_for_ddpm_samples_SIMPLIFIED_ELBO_125ekfac_250measurement_250loss__quantize_8bits_damping1e-10/influence_scores.npy",
            "/srv/shared/outputs/DatasetType.cifar10/idx_train-0/trak_for_ddpm_samples__measure_SQUARE_NORM_trainloss_SQUARE_NORM_250measure_250loss__projdim32768_float32_damping1/scores.npy",
            "/srv/shared/outputs/DatasetType.cifar10/idx_train-0/influence_for_ddpm_samples_LOSS_125ekfac_250measurement_250loss__quantize_8bits_damping1e-11/influence_scores.npy",
        ],
        influence_paths_untuned=[
            None,
            "/srv/shared/outputs/DatasetType.cifar10/idx_train-0/trak_for_ddpm_samples__measure_SIMPLIFIED_ELBO_trainloss_LOSS_250measure_250loss__projdim32768_float32_damping1e-6/scores.npy",
            "/srv/shared/outputs/DatasetType.cifar10/idx_train-0/influence_for_ddpm_samples_SIMPLIFIED_ELBO_125ekfac_250measurement_250loss__quantize_8bits_damping1e-8/influence_scores.npy",
            "/srv/shared/outputs/DatasetType.cifar10/idx_train-0/trak_for_ddpm_samples__measure_SQUARE_NORM_trainloss_SQUARE_NORM_250measure_250loss__projdim32768_float32_damping1e-6/scores.npy",
            "/srv/shared/outputs/DatasetType.cifar10/idx_train-0/influence_for_ddpm_samples_LOSS_125ekfac_250measurement_250loss__quantize_8bits_damping1e-8/influence_scores.npy",
        ],
        method_names=[
            "CLIP Cosine Similarity",
            "TRAK",
            "\\textbf{K-FAC Influence}",
            "D-TRAK",
            "\\textbf{K-FAC Influence} {\\tiny{(m. loss)}",
        ],
        train_model_idxs_path=None,
        retrained_model_idxs_dir="/srv/shared/outputs/idxs/DatasetType.cifar10/retrain"
    )
]

In [None]:

results = [
    get_lds_scores_for_config(config)
    for config in configs
]

In [None]:
fig, axes = make_figure(configs, results, suptitle="ELBO measurement",  height=0.23, extra_separators_at=[3, 5], extra_separator_space=0.3,
    # colors=COLORS[:4] + [COLORS[5]] + [COLORS[4]],
                        # colors=[colors[2], colors[2], darken_color(colors[2], factor=0.7)]  + [colors[3], darken_color(colors[3], factor=0.8)] + [colors[4]])
                        colors=[COLORS[2]] * 3  + [COLORS[3]] *2 + [COLORS[4]],
                        rank_correlation_fontsize=fs - 2,
                        )
save_filename = Path("figures/lds_scores_elbosimpelbo.pdf").resolve().absolute()
print(save_filename)

fig.savefig(save_filename, bbox_inches="tight")

# Probability of sampling trajectory:

In [None]:
# ELBO* measurement
configs = [
    DatasetMeasurementPlotConfig(
        axtitle="\\texttt{CIFAR-2}",
        dataset="cifar2",
        measurement_dir="/srv/shared/outputs/measurements/DatasetType.cifar2/idx_train-0ddpm_samples_SAMPLING_TRAJECTORY_LOGPROB_measurement",
        influence_paths=[
            "/srv/shared/outputs/DatasetType.cifar2/idx_train-0/clip-similarity-scores/scores.npy",
            "/srv/shared/outputs/DatasetType.cifar2/idx_train-0/trak_for_ddpm_samples__measure_SQUARE_NORM_trainloss_SQUARE_NORM_250measure_250loss__projdim32768_float32_damping1/scores.npy",
            "/srv/shared/outputs/DatasetType.cifar2/idx_train-0/influence_for_ddpm_samples_SIMPLIFIED_ELBO_125ekfac_250measurement_250loss__quantize_8bits_damping1e-10/influence_scores.npy",
            "/srv/shared/outputs/DatasetType.cifar2/idx_train-0/influence_for_ddpm_samples_LOSS_125ekfac_250measurement_250loss__quantize_8bits_damping1e-10/influence_scores.npy",
        ],
        influence_paths_untuned=[
            None,
            "/srv/shared/outputs/DatasetType.cifar2/idx_train-0/trak_for_ddpm_samples__measure_SQUARE_NORM_trainloss_SQUARE_NORM_250measure_250loss__projdim32768_float32_damping1e-6/scores.npy",
            "/srv/shared/outputs/DatasetType.cifar2/idx_train-0/influence_for_ddpm_samples_SIMPLIFIED_ELBO_125ekfac_250measurement_250loss__quantize_8bits_damping1e-8/influence_scores.npy",
            "/srv/shared/outputs/DatasetType.cifar2/idx_train-0/influence_for_ddpm_samples_LOSS_125ekfac_250measurement_250loss__quantize_8bits_damping1e-8/influence_scores.npy",
        ],
        method_names=(method_names := [
            "CLIP Cosine Similarity",
            "D-TRAK",
            "\\textbf{K-FAC Influence}{\\tiny{(m. ELBO)}",
            "\\textbf{K-FAC Influence}{\\tiny{(m. loss)}",
        ]),
        train_model_idxs_path="/srv/shared/outputs/idxs/DatasetType.cifar2/idx_train.csv",
        retrained_model_idxs_dir="/srv/shared/outputs/idxs/DatasetType.cifar2/retrain"
    ),
    DatasetMeasurementPlotConfig(
        axtitle="\\texttt{CIFAR-10}",
        dataset="cifar10",
        measurement_dir="/srv/shared/outputs/measurements/DatasetType.cifar10/idx_train-0ddpm_samples_SAMPLING_TRAJECTORY_LOGPROB_measurement",
        influence_paths=[
            "/srv/shared/outputs/DatasetType.cifar10/idx_train-0/clip-similarity-scores/scores.npy",
            "/srv/shared/outputs/DatasetType.cifar10/idx_train-0/trak_for_ddpm_samples__measure_SQUARE_NORM_trainloss_SQUARE_NORM_250measure_250loss__projdim32768_float32_damping1/scores.npy",
            "/srv/shared/outputs/DatasetType.cifar10/idx_train-0/influence_for_ddpm_samples_SIMPLIFIED_ELBO_125ekfac_250measurement_250loss__quantize_8bits_damping1e-8/influence_scores.npy",
            "/srv/shared/outputs/DatasetType.cifar10/idx_train-0/influence_for_ddpm_samples_LOSS_125ekfac_250measurement_250loss__quantize_8bits_damping1e-11/influence_scores.npy",
        ],
        influence_paths_untuned=[
            None,
            "/srv/shared/outputs/DatasetType.cifar10/idx_train-0/trak_for_ddpm_samples__measure_SQUARE_NORM_trainloss_SQUARE_NORM_250measure_250loss__projdim32768_float32_damping1e-6/scores.npy",
            "/srv/shared/outputs/DatasetType.cifar10/idx_train-0/influence_for_ddpm_samples_SIMPLIFIED_ELBO_125ekfac_250measurement_250loss__quantize_8bits_damping1e-8/influence_scores.npy",
            "/srv/shared/outputs/DatasetType.cifar10/idx_train-0/influence_for_ddpm_samples_LOSS_125ekfac_250measurement_250loss__quantize_8bits_damping1e-8/influence_scores.npy",
        ],
        method_names=method_names,
        train_model_idxs_path=None,
        retrained_model_idxs_dir="/srv/shared/outputs/idxs/DatasetType.cifar10/retrain"
    )
]

In [None]:
results = [
    get_lds_scores_for_config(config)
    for config in configs
]

In [None]:
fig , axes = make_figure(configs, results, suptitle="Sampling Trajectory Prob. measurement", 
    extra_separators_at=[1, 4], extra_separator_space=0.3,
    colors=[COLORS[2], darken_color(COLORS[2], factor=0.7)]  + [COLORS[3]] + 2* [darken_color(COLORS[3], factor=0.8)] + [COLORS[4]]
)
save_filename = Path("figures/lds_scores_sampling_trajectory_prob.pdf").resolve().absolute()
print(save_filename)

fig.savefig(save_filename, bbox_inches="tight")
# fig.savefig(save_filename)

### Log-likelihood (deq. datasets)

In [None]:
configs = [
    DatasetMeasurementPlotConfig(
        axtitle="\\texttt{CIFAR-2 (uniformly dequantised)}",
        dataset="cifar2deq",
        measurement_dir="/srv/shared/outputs/measurements/DatasetType.cifar2deq/ddpm_samples_log_likelihood_measurement",
        influence_paths=[
            # CLIP Similarity
            "/srv/shared/outputs/DatasetType.cifar2deq/idx_train-0/clip-similarity-scores/scores.npy",
            # TRAK (m. ELBO)
            "/srv/shared/outputs/DatasetType.cifar2deq/idx_train-0/trak_for_ddpm_samples__measure_SIMPLIFIED_ELBO_trainloss_LOSS_250measure_250loss__projdim32768_float32_damping1e-2/scores.npy",
            # TRAK (m. loss)
            "/srv/shared/outputs/DatasetType.cifar2deq/idx_train-0/trak_for_ddpm_samples__measure_LOSS_trainloss_LOSS_250measure_250loss__projdim32768_float32_damping1e-2/scores.npy",
            # D-TRAK
            "/srv/shared/outputs/DatasetType.cifar2deq/idx_train-0/trak_for_ddpm_samples__measure_SQUARE_NORM_trainloss_SQUARE_NORM_250measure_250loss__projdim32768_float32_damping1/scores.npy",
            # K-FAC Influence (m. ELBO)
            "/srv/shared/outputs/DatasetType.cifar2deq/idx_train-0/influence_for_ddpm_samples_SIMPLIFIED_ELBO_125ekfac_250measurement_250loss__quantize_8bits_damping1e-8/influence_scores.npy",
            # K-FAC Influence (m. loss)  
            "/srv/shared/outputs/DatasetType.cifar2deq/idx_train-0/influence_for_ddpm_samples_LOSS_125ekfac_250measurement_250loss__quantize_8bits_damping1e-10/influence_scores.npy",
        ],
        influence_paths_untuned=[
            None,
            "/srv/shared/outputs/DatasetType.cifar2deq/idx_train-0/trak_for_ddpm_samples__measure_LOSS_trainloss_LOSS_250measure_250loss__projdim32768_float32_damping1e-6/scores.npy",
            "/srv/shared/outputs/DatasetType.cifar2deq/idx_train-0/trak_for_ddpm_samples__measure_SIMPLIFIED_ELBO_trainloss_LOSS_250measure_250loss__projdim32768_float32_damping1e-6/scores.npy",
            "/srv/shared/outputs/DatasetType.cifar2deq/idx_train-0/trak_for_ddpm_samples__measure_SQUARE_NORM_trainloss_SQUARE_NORM_250measure_250loss__projdim32768_float32_damping1e-6/scores.npy",
            # 
            "/srv/shared/outputs/DatasetType.cifar2deq/idx_train-0/influence_for_ddpm_samples_SIMPLIFIED_ELBO_125ekfac_250measurement_250loss__quantize_8bits_damping1e-8/influence_scores.npy",
            # K-FAC Influence (m. loss)  
            "/srv/shared/outputs/DatasetType.cifar2deq/idx_train-0/influence_for_ddpm_samples_LOSS_125ekfac_250measurement_250loss__quantize_8bits_damping1e-8/influence_scores.npy",
        ],
        method_names=(method_names := [
            "CLIP Cosine Similarity",
            "TRAK {\\tiny{(m. ELBO)}",
            "TRAK {\\tiny{(m. loss)}",
            "D-TRAK",
            "\\textbf{K-FAC Influence} {\\tiny{(m. ELBO)}",
            "\\textbf{K-FAC Influence} {\\tiny{(m. loss)}",
        ]),
        train_model_idxs_path="/srv/shared/outputs/idxs/DatasetType.cifar2/idx_train.csv",
        retrained_model_idxs_dir="/srv/shared/outputs/idxs/DatasetType.cifar2/retrain"
    ),
]


In [None]:
results = [
    get_lds_scores_for_config(config)
    for config in configs
]

In [None]:
fig, axes = make_figure(configs, results, suptitle="Loss measurement", height=0.27, extra_separators_at=[], extra_separator_space=0.3,
                        colors=[COLORS[0]] + [COLORS[1]] * 3 + [COLORS[2]] * 2 + 1*[COLORS[4]],
                        rank_correlation_fontsize=fs - 2)
                        # colors=[COLORS[2], COLORS[2], darken_color(COLORS[2], factor=0.7)]  + [COLORS[3]] + [COLORS[4]])
# for ax in axes:
#     ax.axhline(0.5, color="gray", linestyle=":", linewidth=1.0, zorder=-1)
save_filename = Path("figures/lds_scores_log_likelihood.pdf").resolve().absolute()
print(save_filename)
fig.savefig(save_filename, bbox_inches="tight")
# fig.savefig(save_filename)

# KFAC ablation

In [None]:

configs = [
    DatasetMeasurementPlotConfig(
        axtitle="",
        dataset="cifar2",
        measurement_dir="/srv/shared/outputs/measurements/DatasetType.cifar2/ddpm_samples_loss_measurement",
        influence_paths=[
            "/srv/shared/outputs/DatasetType.cifar2/idx_train-0/influence_for_ddpm_samples_250kfac_empirical_reduce_250measurement_250loss__quantize_8bits_damping1e-5/influence_scores.npy",
            "/srv/shared/outputs/DatasetType.cifar2/idx_train-0/influence_for_ddpm_samples_250kfac_empirical_expand_250measurement_250loss__quantize_8bits_damping1e-9/influence_scores.npy",
            "/srv/shared/outputs/DatasetType.cifar2/idx_train-0/influence_for_ddpm_samples_250kfac_mc_reduce_250measurement_250loss__quantize_8bits_damping1e-7/influence_scores.npy",
            "/srv/shared/outputs/DatasetType.cifar2/idx_train-0/influence_for_ddpm_samples_250kfac_mc_expand_250measurement_250loss__quantize_8bits_damping1e-9/influence_scores.npy",
            "/srv/shared/outputs/DatasetType.cifar2/idx_train-0/influence_for_ddpm_samples_125ekfac_empirical_expand_250measurement_250loss__quantize_8bits_damping1e-10/influence_scores.npy",
            "/srv/shared/outputs/DatasetType.cifar2/idx_train-0/influence_for_ddpm_samples_125ekfac_mc_reduce_250measurement_250loss__quantize_8bits_damping1e-11/influence_scores.npy",
            "/srv/shared/outputs/DatasetType.cifar2/idx_train-0/influence_for_ddpm_samples_125ekfac_mc_expand_250measurement_250loss__quantize_8bits_damping1e-10/influence_scores.npy",
        ],
        influence_paths_untuned=[
            "/srv/shared/outputs/DatasetType.cifar2/idx_train-0/influence_for_ddpm_samples_250kfac_empirical_reduce_250measurement_250loss__quantize_8bits_damping1e-8/influence_scores.npy",
            "/srv/shared/outputs/DatasetType.cifar2/idx_train-0/influence_for_ddpm_samples_250kfac_empirical_expand_250measurement_250loss__quantize_8bits_damping1e-8/influence_scores.npy",
            "/srv/shared/outputs/DatasetType.cifar2/idx_train-0/influence_for_ddpm_samples_250kfac_mc_reduce_250measurement_250loss__quantize_8bits_damping1e-8/influence_scores.npy",
            "/srv/shared/outputs/DatasetType.cifar2/idx_train-0/influence_for_ddpm_samples_250kfac_mc_expand_250measurement_250loss__quantize_8bits_damping1e-8/influence_scores.npy",
            "/srv/shared/outputs/DatasetType.cifar2/idx_train-0/influence_for_ddpm_samples_125ekfac_empirical_expand_250measurement_250loss__quantize_8bits_damping1e-8/influence_scores.npy",
            "/srv/shared/outputs/DatasetType.cifar2/idx_train-0/influence_for_ddpm_samples_125ekfac_mc_reduce_250measurement_250loss__quantize_8bits_damping1e-8/influence_scores.npy",
            "/srv/shared/outputs/DatasetType.cifar2/idx_train-0/influence_for_ddpm_samples_125ekfac_mc_expand_250measurement_250loss__quantize_8bits_damping1e-8/influence_scores.npy",
        ],
        method_names=[
            "$\\mathrm{GGN}^\\texttt{loss\ }$ K-FAC {reduce}",
            "$\\mathrm{GGN}^\\texttt{loss\ }$ K-FAC {expand}",
            "$\\mathrm{GGN}^\\texttt{model}$ K-FAC {reduce}",
            "$\\mathrm{GGN}^\\texttt{model}$ K-FAC {expand",
            "$\\mathrm{GGN}^\\texttt{loss\ }$ \\textbf{E}K FAC-{expand}",
            "$\\mathrm{GGN}^\\texttt{model}$ \\textbf{E}K FAC-{reduce}",
            "$\\mathrm{GGN}^\\texttt{model}$ \\textbf{E}K FAC-{expand}",
        ],
        train_model_idxs_path="/srv/shared/outputs/idxs/DatasetType.cifar2/idx_train.csv",
        retrained_model_idxs_dir="/srv/shared/outputs/idxs/DatasetType.cifar2/retrain"
    ),
]

In [None]:
results = [
    get_lds_scores_for_config(config)
    for config in configs
]

In [None]:
fig, ax = make_figure(configs, results, suptitle="KFAC Ablation",
    extra_separators_at=[4,7], extra_separator_space=0.4,
    colors=[
        interpolate_hsv(COLORS[0], COLORS[3], 0.3), COLORS[0], interpolate_hsv(COLORS[1], COLORS[3], 0.3), COLORS[1]
    ] + [
        COLORS[0], interpolate_hsv(COLORS[1], COLORS[3], 0.3), COLORS[1]
    ]  + [COLORS[4]]* 1,
    height=0.33,
    width=0.8
)
save_filename = Path("figures/lds_scores_kfac_ablation.pdf").resolve().absolute()
print(save_filename)

fig.savefig(save_filename, bbox_inches="tight")