Code and data is simplified to fit within upload requirements (50 MB). Full code and data can be found on our GitHub.

Overview of Code:
- Setup: Imports and required installations
- Metric Functions: Functions following metric definitions and feature extraction
- Human Correlation: Perform experiment to test metrics' correlation with human judgement. Note that for small sample sizes (n=10), Density performs properly; however, for the full dataset (n=100000), we see the opposite effect.
- Mode Shrinkage Test: Perform experiment to test the impact of increasing classifer-free guidance parameter (CFG), i.e. decreasing generative output diversity while increasing fidelity
- Mode Drop Test: Gradually drop modes to test recall measures sensitivity
- Additional Experiments: Experiments shown in the appendix: Density metric behavior with increasing samples + memorized image PCE scores

# Setup

In [None]:
pip install frechetdist xformers diffusers timm

In [None]:
import shutil
import os
import random
import math

import torch
import torchvision.models as models
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader, Subset
from torchvision.datasets import ImageFolder
from torch.nn.functional import interpolate

from diffusion import create_diffusion
from diffusers.models import AutoencoderKL
from download import find_model
from models import DiT_models
import argparse

import numpy as np
from numpy import pi
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

from scipy import linalg
from scipy.special import digamma, gamma, loggamma
from scipy.linalg import sqrtm

from sklearn.neighbors import NearestNeighbors, KDTree
from sklearn.linear_model import LinearRegression
import sklearn.metrics
from sklearn.metrics import pairwise_distances

from PIL import Image
from tqdm import tqdm
import xformers
from collections import defaultdict, Counter
import pickle

In [None]:
# Default matplotlib parameters used
plt.rcParams.update({
    'font.size': 24,
    'axes.labelsize': 26,
    'xtick.labelsize': 15,
    'ytick.labelsize': 17,
    'legend.fontsize': 24,
    'axes.titlesize': 24,
    'lines.linewidth': 5,
    'lines.markersize': 11,
    'legend.handlelength': 2,
    'errorbar.capsize': 7,
    'lines.markeredgewidth': 3,
    'font.family': 'sans-serif',
    'font.sans-serif': 'DejaVu Sans'
})

In [None]:
# Image pre-processing
imagenet_mean = [0.485, 0.456, 0.406]
imagenet_std = [0.229, 0.224, 0.225]

transform = transforms.Compose([
    transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize(imagenet_mean, imagenet_std),
])

# Feature Extraction

In [None]:
# Feature extraction using DinoV2 encoder
def extract_and_save_features(image_dir, file_name, transform):
    gen_dataset = ImageFolder(image_dir, transform=transform)
    gen_loader = DataLoader(gen_dataset, batch_size=64, shuffle=False)

    model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14')
    model.eval()
    model.cuda()
    model = torch.nn.DataParallel(model)
    print("DOING ", file_name)
    def extract_features(loader, model):
        features = []
        with torch.no_grad():
            for inputs, _ in tqdm(loader, desc='Extracting Features', leave=True):
                inputs = inputs.cuda()
                outputs = model(inputs)
                features.append(outputs.cpu().numpy())
        return np.concatenate(features, axis=0)

    gen_features = extract_features(gen_loader, model)

    np.save(file_name, gen_features)

In [None]:
# For a subset of models trained on ImageNet, extract features
source_directories = [
    'train',
    #'rq',
    #'styleganxl',
    #'ADM',
    #'ADMG',
    'ADMG-ADMU',
    #'biggan',
    #'dit',
    #'gigagan',
    #d'ldm'
]

for base_name in source_directories:
    source_dir = f'images/imagenet-{base_name}'
    file_name = os.path.join(os.getcwd(), f'output/imagenet-{base_name}_features.npy')
    extract_and_save_features(source_dir, file_name, transform)

In [None]:
# For a subset of models trained on CIFAR10, extract features
source_directories = [
    'train',
    #'ACGAN-Mod',
    #'BigGAN-Deep',
    #'iDDPM-DDIM',
    'LOGAN',
    #'LSGM-ODE',
    #'MHGAN',
    #'NVAE',
    'PFGMPP',
    #'ReACGAN',
    #'RESFLOW',
    'StyleGAN2-ada',
    #'WGAN-GP',
    #'StyleGAN-XL'
]

for base_name in source_directories:
    source_dir = f'images/cifar10-{base_name}'
    file_name = os.path.join(os.getcwd(), f'output/cifar10-{base_name}_features.npy')
    extract_and_save_features(source_dir, file_name, transform)

# Metric Functions

In [None]:
# Computes the Frechet Distance between two datasets rep. by multivariate normal distributions
def compute_fd(reps1, reps2, eps=1e-6):
    mu1, sigma1 = np.mean(reps1, axis=0), np.cov(reps1, rowvar=False)
    mu2, sigma2 = np.mean(reps2, axis=0), np.cov(reps2, rowvar=False)

    diff = mu1 - mu2
    try:
        covmean = sqrtm(sigma1.dot(sigma2))
        if np.iscomplexobj(covmean):
            covmean = covmean.real
    except ValueError:
        covmean = sqrtm(sigma1 + eps * np.eye(sigma1.shape[0])).dot(sqrtm(sigma2 + eps * np.eye(sigma2.shape[0])))
        covmean = covmean.real

    tr_covmean = np.trace(covmean)
    frechet_distance = np.dot(diff, diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
    return frechet_distance

In [None]:
# Functions to calculate our metric defined in Section 5, precision cross-entropy, recall cross-entropy, and recall-entropy

# helper function
def volume_of_unit_ball_log(d):
    # Use Stirling's approximation directly in log form
    d_over_2 = d / 2
    return d_over_2 * np.log(pi) - (loggamma(d_over_2 + 1))

# Calculate cross entropy
def cross_entropy(N, M, k, nu_k, d):
    psi_k = digamma(k)

    c_bar = volume_of_unit_ball_log(d)

    inner_term = np.log(M) - psi_k + c_bar + d*np.log(nu_k)
    entropy_estimate = (1 / N) * np.sum(inner_term)

    return entropy_estimate

# Calculate entropy
def entropy(N, k, rho_k, d):
    psi_k = digamma(k)

    c_bar = volume_of_unit_ball_log(d)

    inner_term = np.log(N-1) - psi_k + c_bar + d*np.log(rho_k)
    entropy_estimate = (1 / N) * np.sum(inner_term)

    return entropy_estimate

In [None]:
# Calculate all knn metrics analyzed

def calc_precision(dist_R, dist_RG_pairs, M):
    radii_R = dist_R[:, -1]  # dist to the real k-th neighbor of real points

    # for every real point, check if at least 1 gen point within radius
    G_in_radius = (dist_RG_pairs <= radii_R[:, np.newaxis])
    precision_count = np.sum(np.any(G_in_radius, axis=0))

    return precision_count / M

def calc_density(dist_R, dist_RG_pairs, k, M):
    radii_R = dist_R[:, -1]

    # for every real point, count # of gen points within radius
    G_in_radius = (dist_RG_pairs <= radii_R[:, np.newaxis])
    density_count = np.sum(G_in_radius) / k

    return density_count / M

def calc_coverage(dist_R, dist_RG_pairs):
    radii_R = dist_R[:, -1]

    G_in_radius = (dist_RG_pairs <= radii_R[:, np.newaxis])
    R_contains_G = np.any(G_in_radius, axis=1)

    return np.mean(R_contains_G)

def calc_PC(G, nbrs_G, nbrs_R, M, k, C):
    k_prime = C * k
    dist_G, _ = nbrs_G.kneighbors(G, k_prime+1)
    radii_G = dist_G[:, -1]

    precision_count = 0
    dist_RG, _ = nbrs_R.kneighbors(G, k)
    R_in_radius = (dist_RG[:, -1] <= radii_G)
    precision_count += np.sum(R_in_radius)

    return precision_count / M

# k = 2 for representative dataset, used k=5 for full dataset
def calculate_realism_scores(R, G, k = 2, C = 3):
    prc_k = C * k

    # set up nearest neighbor graphs
    nbrs_R = NearestNeighbors(n_neighbors=prc_k+1, algorithm='auto', n_jobs=-1).fit(R) # ignore first neighbor (itself)
    dist_R, _ = nbrs_R.kneighbors(R, k+1)

    nbrs_G = NearestNeighbors(n_neighbors=prc_k+1, algorithm='auto', n_jobs=-1).fit(G)
    dist_G, _ = nbrs_G.kneighbors(G, k+1)

    dist_RG_pairs = pairwise_distances(R, G, n_jobs=-1)
    dist_GR_pairs = pairwise_distances(G, R, n_jobs=-1)

    dist_RG, _ = nbrs_G.kneighbors(R, k+1)
    dist_GR, _ = nbrs_R.kneighbors(G, k+1)

    # density + coverage (section 4)
    density = calc_density(dist_R, dist_RG_pairs, k, len(G))
    coverage = calc_coverage(dist_R, dist_RG_pairs)

    # prc (section 4)
    pc = calc_PC(G, nbrs_G, nbrs_R, len(G), k, 3)
    rc = calc_PC(R, nbrs_R, nbrs_G, len(R), k, 3) # rc is symmetric to pc

    # fd
    fd = compute_fd(R, G)

    # information theoretic score (section 5)
    ce_gr = cross_entropy(len(G), len(R), k, dist_GR[:, k-1], len(R[0]))
    ce_rg = cross_entropy(len(R), len(G), k, dist_RG[:, k-1], len(R[0]))
    e_r = entropy(len(R), k, dist_R[:, k], len(R[0]))
    e_g = entropy(len(G), k, dist_G[:, k], len(G[0]))

    return fd, density, coverage, pc, rc, ce_gr-e_r, ce_rg-e_r, e_g-e_r


# Metric Analysis

In [None]:
# Calculate the correlation coefficient between metrics and human error rate and print correlation matrix

def calc_human_corr(real_features, model_features, human_ranking):
    scores_FD = {}
    scores_PCE = {}
    scores_RCE = {}
    scores_RE = {}
    scores_density = {}
    scores_coverage = {}
    scores_PC = {}
    scores_RC = {}

    for model_name, features in model_features.items():
        print(model_name)
        fd, density, coverage, pc, rc, pce, rce, re = calculate_realism_scores(real_features, features)
        scores_FD[model_name] = fd
        scores_density[model_name] = density
        scores_coverage[model_name] = coverage
        scores_PC[model_name] = pc
        scores_RC[model_name] = rc
        scores_PCE[model_name] = pce
        scores_RCE[model_name] = rce
        scores_RE[model_name] = re
        print("done")

    # rank
    rankings_PCE = sorted(scores_PCE.items(), key=lambda x: x[1])
    rankings_RCE = sorted(scores_RCE.items(), key=lambda x: x[1])
    rankings_RE = sorted(scores_RE.items(), key=lambda x: -x[1])
    rankings_FD = sorted(scores_FD.items(), key=lambda x: x[1])
    rankings_D = sorted(scores_density.items(), key=lambda x: -x[1])
    rankings_C = sorted(scores_coverage.items(), key=lambda x: -x[1])
    rankings_PC = sorted(scores_PC.items(), key=lambda x: -x[1])
    rankings_RC = sorted(scores_RC.items(), key=lambda x: -x[1])

    print("Human Ranking (Best to Worst):")
    sorted_human_ranking = sorted(human_ranking.items(), key=lambda item: -item[1])
    for rank, (model, score) in enumerate(sorted_human_ranking, start=1):
        print(f"{rank}. {model}: {score}")

    print("\nModel-Based Rankings:")

    print("\nRankings based on PC Score (Best to Worst):")
    for rank, (model, score) in enumerate(rankings_PC, start=1):
        print(f"{rank}. {model}: {score}")

    print("\nRankings based on RC Score (Best to Worst):")
    for rank, (model, score) in enumerate(rankings_RC, start=1):
        print(f"{rank}. {model}: {score}")

    print("\nRankings based on PCE Score (Best to Worst):")
    for rank, (model, score) in enumerate(rankings_PCE, start=1):
        print(f"{rank}. {model}: {score}")

    print("\nRankings based on RCE Score (Best to Worst):")
    for rank, (model, score) in enumerate(rankings_RCE, start=1):
        print(f"{rank}. {model}: {score}")

    print("\nRankings based on RE Score (Best to Worst):")
    for rank, (model, score) in enumerate(rankings_RE, start=1):
        print(f"{rank}. {model}: {score}")

    print("\nRankings based on FD Score (Best to Worst):")
    for rank, (model, score) in enumerate(rankings_FD, start=1):
        print(f"{rank}. {model}: {score}")

    print("\nRankings based on Density Score (Best to Worst):")
    for rank, (model, score) in enumerate(rankings_D, start=1):
        print(f"{rank}. {model}: {score}")

    print("\nRankings based on Coverage Score (Best to Worst):")
    for rank, (model, score) in enumerate(rankings_C, start=1):
        print(f"{rank}. {model}: {score}")

    human_ranking_neg = {key: -value for key, value in human_ranking.items()}
    scores_density_neg = {key: -value for key, value in scores_density.items()}
    scores_coverage_neg = {key: -value for key, value in scores_coverage.items()}
    scores_PC_neg = {key: -value for key, value in scores_PC.items()}
    scores_RC_neg = {key: -value for key, value in scores_RC.items()}
    scores_RE_neg = {key: -value for key, value in scores_RE.items()}

    model_names = human_ranking.keys()

    def align_scores(scores_dict, model_names):
        return [scores_dict[name] for name in model_names]

    scores_list_PCE = align_scores(scores_PCE, model_names)
    scores_list_RCE = align_scores(scores_RCE, model_names)
    scores_list_RE = align_scores(scores_RE_neg, model_names)
    scores_list_FD = align_scores(scores_FD, model_names)
    scores_list_D = align_scores(scores_density_neg, model_names)
    scores_list_C = align_scores(scores_coverage_neg, model_names)
    scores_list_PC = align_scores(scores_PC_neg, model_names)
    scores_list_RC = align_scores(scores_RC_neg, model_names)
    scores_list_human = align_scores(human_ranking_neg, model_names)

    # print correlation matrix
    scores_matrix = np.array([scores_list_human, scores_list_PC, scores_list_RC, scores_list_PCE, scores_list_RCE, scores_list_RE, scores_list_FD, scores_list_D, scores_list_C])
    correlation_matrix = np.corrcoef(scores_matrix)
    print(correlation_matrix)
    plt.figure(figsize=(10, 8))
    sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', fmt=".2f", xticklabels=['Human', 'PC', 'RC', 'PCE', 'RCE', 'RE', 'FD', 'Density', 'Coverage'], yticklabels=['Human', 'PC', 'RC', 'PCE', 'RCE', 'RE', 'FD', 'Density', 'Coverage'])
    plt.title("Correlation Heatmap between Rankings")
    plt.show()


In [None]:
# Human error rate for sampled generated datasets from Stein et al. (2024)

#human_ranking = {'RESFLOW':0.088, 'NVAE':0.1308, 'ACGAN-Mod':0.148, 'WGAN-GP':0.169, 'LOGAN':0.2056, 'ReACGAN':0.335, 'MHGAN':0.336, 'BigGAN-Deep': 0.386, 'StyleGAN2':0.393, 'StyleGAN-XL':0.3988, 'iDDPM-DDIM':0.399, 'PFGMPP':0.4358, 'LSGM-ODE':0.436}
human_ranking = {'LOGAN':0.2056, 'StyleGAN2':0.393, 'PFGMPP':0.4358}

In [None]:
real_features = np.load('output/cifar10-train_features.npy')
model_features = {
    #'ACGAN-Mod': np.load('output/cifar10-ACGAN-Mod_features.npy'),
    #'BigGAN-Deep': np.load('output/cifar10-BigGAN-Deep_features.npy'),
    #'iDDPM-DDIM': np.load('output/cifar10-iDDPM-DDIM_features.npy'),
    'LOGAN': np.load('output/cifar10-LOGAN_features.npy'),
    #'LSGM-ODE': np.load('output/cifar10-LSGM-ODE_features.npy'),
    #'MHGAN': np.load('output/cifar10-MHGAN_features.npy'),
    #'NVAE': np.load('output/cifar10-NVAE_features.npy'),
    'PFGMPP': np.load('output/cifar10-PFGMPP_features.npy'),
    #'ReACGAN': np.load('output/cifar10-ReACGAN_features.npy'),
    #'RESFLOW': np.load('output/cifar10-RESFLOW_features.npy'),
    'StyleGAN2': np.load('output/cifar10-StyleGAN2-ada_features.npy'),
    #'WGAN-GP': np.load('output/cifar10-WGAN-GP_features.npy'),
    #'StyleGAN-XL': np.load('output/cifar10-StyleGAN-XL_features.npy')
}

In [None]:
calc_human_corr(real_features, model_features, human_ranking)

# Mode Shrinkage Test

In [None]:
# Originally, 15 class labels for 2 sets are randomly generated without replacement from 0-999
# Simplified code presets classes to fit representative dataset
class_labels = [319, 121, 299, 32, 269]

In [None]:
# Code adapted from official DiT model documentation to run model to generate images at varying CFG parameters

def run_model(class_labels):
    # configuration
    device = "cuda"
    model_choice = "DiT-XL/2"
    vae_choice = "mse"
    image_size = 256
    num_classes = 1000
    num_sampling_steps = 250
    batch_number = 1
    ckpt_path = None

    # setup
    if ckpt_path is None:
        assert model_choice == "DiT-XL/2", "Only DiT-XL/2 models are available for auto-download."
        assert image_size in [256, 512]
        assert num_classes == 1000

    latent_size = image_size // 8
    model = DiT_models[model_choice](input_size=latent_size, num_classes=num_classes).to(device)
    state_dict = find_model(f"DiT-XL-2-{image_size}x{image_size}.pt")
    model.load_state_dict(state_dict)
    model.eval()
    diffusion = create_diffusion(str(num_sampling_steps))
    vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{vae_choice}").to(device)

    # generation loop
    cfg_scales = [1.5, 3.0, 4.5, 6.0]
    samples_per_class = 10

    for cfg_scale in cfg_scales:
        for class_label in class_labels:
            print(f"Generating {samples_per_class} samples for class {class_label} with CFG scale: {cfg_scale}")
            output_dir = f"./output/cfg_{cfg_scale}/{batch_number}/{class_label}"
            os.makedirs(output_dir, exist_ok=True)

            batch_size = 10
            num_batches = (samples_per_class + batch_size - 1) // batch_size

            for batch_index in range(num_batches):
                current_batch_size = min(batch_size, samples_per_class - batch_index * batch_size)
                z = torch.randn(current_batch_size, 4, latent_size, latent_size, device=device)
                y = torch.full((current_batch_size,), class_label, device=device)
                z = torch.cat([z, z], 0)
                y_null = torch.full((current_batch_size,), 1000, device=device)
                y = torch.cat([y, y_null], 0)

                model_kwargs = {'y': y, 'cfg_scale': cfg_scale}
                samples = diffusion.p_sample_loop(model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device)
                samples, _ = samples.chunk(2, dim=0)
                decoded_samples = vae.decode(samples / 0.18215).sample

                for i in range(current_batch_size):
                    sample_filename = os.path.join(output_dir, f"sample_{batch_index * batch_size + i + 1}.png")
                    save_image(decoded_samples[i], sample_filename, normalize=True, value_range=(-1, 1))

run_model(class_labels)

In [None]:
# Extract features for previously generated images
source_directories = [
    'cfg_1.5',
    'cfg_3.0',
    'cfg_4.5',
    'cfg_6.0',
]

for source_dir in source_directories:
    file_name = os.path.join(os.getcwd(), f'./output/{source_dir}_features.npy')
    full_source_dir = os.path.expanduser(f'./output/{source_dir}/1')
    extract_and_save_features(full_source_dir, file_name, transform)

In [None]:
# Calculate metrics for all levels of CFG
def run_metrics_all():
    metrics = ['FD', 'Density', 'Coverage', 'PC', 'RC', 'PCE', 'RCE', 'RE']
    data = {
        'Run': [],
        'Class Set': [],
        'CFG': [],
        'Metric': [],
        'Value': []
    }

    for batch_number in range(1):  # full code uses 6 batches
        class_set = 1
        real_features_path = './output/imagenet-train_features.npy' if class_set == 1 else './output/imagenet-train2_features.npy'
        real_features = np.load(real_features_path)

        model_features = {
            '1.5': np.load(f'./output/cfg_1.5_features.npy'),
            '3.0': np.load(f'./output/cfg_3.0_features.npy'),
            '4.5': np.load(f'./output/cfg_4.5_features.npy'),
            '6.0': np.load(f'./output/cfg_6.0_features.npy')
        }

        for cfg, features in model_features.items():
            scores = calculate_realism_scores(real_features, features)
            for metric, value in zip(metrics, scores):
                print(metric, value)
                data['Run'].append(batch_number)
                data['Class Set'].append(class_set)
                data['CFG'].append(cfg)
                data['Metric'].append(metric)
                data['Value'].append(value)

    return pd.DataFrame(data)

df_results = run_metrics_all()

In [None]:
# Plotting functions

def plot_d_c(data, ax):
    metrics = ['Density', 'Coverage']
    cfgs = sorted(data['CFG'].unique(), key=float)
    class_sets = data['Class Set'].unique()

    # define marker styles for metrics
    markers = {
        'Density': ('x', 10),  # 'x' marker with size 10
        'Coverage': ('D', 5)   # Diamond marker with size 5
    }

    # define style combinations for each metric and class set pair
    color_styles = {
        ('Density', 1): {'color': 'mediumblue', 'line': '-', 'dashes': [10, 10]},
        #('Density', 2): {'color': 'mediumblue', 'line': '--', 'dashes': [10, 10]},
        ('Coverage', 1): {'color': 'orange', 'line': '-', 'dashes': [10, 10]},
        #('Coverage', 2): {'color': 'orange', 'line': '--', 'dashes': [10, 10]}
    }

    for metric in metrics:
        for class_set in class_sets:
            subset = data[(data['Metric'] == metric) & (data['Class Set'] == class_set)]
            means = subset.groupby('CFG')['Value'].mean()
            stds = subset.groupby('CFG')['Value'].std()

            # apply color and line style based on metric and class set combination
            style = color_styles[(metric, class_set)]
            marker, size = markers[metric]
            ax.plot(cfgs, means, f"{style['line']}{marker}", color=style['color'],
                    markersize=size, label=f'{metric} ({class_set})')
            #ax.errorbar(cfgs, means, yerr=stds, fmt=f"{style['line']}{marker}", color=style['color'],elinewidth=2,
            #            markersize=size, label=f'{metric[0]} ({class_set})', capthick=2, ecolor='gray')

    ax.set_title('(a) Density and Coverage')
    ax.set_xlabel('CFG Scale')
    ax.set_ylabel('Score')
    ax.set_ylim(bottom=0)
    ax.legend(loc='lower left', fontsize=17)
    ax.grid(True)

def plot_pc_rc(data, ax):
    metrics = ['PC', 'RC']
    cfgs = sorted(data['CFG'].unique(), key=float)
    class_sets = data['Class Set'].unique()

    markers = {'PC': ('x', 10), 'RC': ('D', 5)}
    color_styles = {
        ('PC', 1): {'color': 'mediumblue', 'line': '-', 'dashes': [10, 10]},
        #('PC', 2): {'color': 'mediumblue', 'line': '--', 'dashes': [10, 10]},
        ('RC', 1): {'color': 'orange', 'line': '-', 'dashes': [10, 10]},
        #('RC', 2): {'color': 'orange', 'line': '--', 'dashes': [10, 10]}
    }

    for metric in metrics:
        for class_set in class_sets:
            subset = data[(data['Metric'] == metric) & (data['Class Set'] == class_set)]
            means = subset.groupby('CFG')['Value'].mean()
            stds = subset.groupby('CFG')['Value'].std()

            style = color_styles[(metric, class_set)]
            marker, size = markers[metric]
            ax.plot(cfgs, means, f"{style['line']}{marker}", color=style['color'],
                    markersize=size, label=f'{metric} ({class_set})')
            #ax.errorbar(cfgs, means, yerr=stds, fmt=f"{style['line']}{marker}", color=style['color'], elinewidth=2,
            #            markersize=size, label=f'{metric} ({class_set})', capthick=2, ecolor='gray')

    ax.set_title('(b) Precision and Recall Coverage')
    ax.set_xlabel('CFG Scale')
    ax.set_ylabel('Score')
    ax.set_ylim(0)
    ax.legend()
    ax.grid(True)

def plot_re(data, ax):
    metrics = ['RE']
    cfgs = sorted(data['CFG'].unique(), key=float)
    class_sets = data['Class Set'].unique()

    markers = {'RE': ('^', 7)}
    color_styles = {
        ('RE', 1): {'color': 'green', 'line': '-', 'dashes': [10, 10]},
        #('RE', 2): {'color': 'green', 'line': '--', 'dashes': [10, 10]}
    }

    for metric in metrics:
        for class_set in class_sets:
            subset = data[(data['Metric'] == metric) & (data['Class Set'] == class_set)]
            means = subset.groupby('CFG')['Value'].mean()
            stds = subset.groupby('CFG')['Value'].std()

            style = color_styles[(metric, class_set)]
            marker, size = markers[metric]
            ax.plot(cfgs, means, f"{style['line']}{marker}", color=style['color'],
                    markersize=size, label=f'{metric} ({class_set})')
            #ax.errorbar(cfgs, means, yerr=stds, fmt=f"{style['line']}{marker}", color=style['color'], elinewidth=2,
            #            markersize=size, label=f'{metric} ({class_set})', capthick=2, ecolor='gray')

    ax.set_title('(d) Recall Entropy')
    ax.set_xlabel('CFG Scale')
    ax.set_ylabel('Score')
    ax.legend()
    ax.grid(True)

def plot_prce(data, ax):
    metrics = ['PCE', 'RCE']
    cfgs = sorted(data['CFG'].unique(), key=float)
    class_sets = data['Class Set'].unique()

    markers = {'PCE': ('x', 10), 'RCE': ('D', 5)}
    color_styles = {
        ('PCE', 1): {'color': 'mediumblue', 'line': '-', 'dashes': [10, 10]},
        #('PCE', 2): {'color': 'mediumblue', 'line': '--', 'dashes': [10, 10]},
        ('RCE', 1): {'color': 'orange', 'line': '-', 'dashes': [10, 10]},
        #('RCE', 2): {'color': 'orange', 'line': '--', 'dashes': [10, 10]}
    }

    for metric in metrics:
        for class_set in class_sets:
            subset = data[(data['Metric'] == metric) & (data['Class Set'] == class_set)]
            means = subset.groupby('CFG')['Value'].mean()
            stds = subset.groupby('CFG')['Value'].std()

            style = color_styles[(metric, class_set)]
            marker, size = markers[metric]
            ax.plot(cfgs, means, f"{style['line']}{marker}", color=style['color'],
                    markersize=size, label=f'{metric} ({class_set})')
            #ax.errorbar(cfgs, means, yerr=stds, fmt=f"{style['line']}{marker}", color=style['color'], elinewidth=2,
            #            markersize=size, label=f'{metric} ({class_set})', capthick=2, ecolor='gray')

    ax.set_title('(c) Precision and Recall Cross-Entropy')
    ax.set_xlabel('CFG Scale')
    ax.set_ylabel('Score')
    ax.set_ylim(bottom=-500)
    ax.legend(fontsize=22)
    ax.grid(True)

fig, axs = plt.subplots(2, 2, figsize=(18, 12))

plot_d_c(df_results, axs[0, 0])
plot_pc_rc(df_results, axs[0, 1])
plot_prce(df_results, axs[1, 0])
plot_re(df_results, axs[1, 1])

plt.tight_layout()
plt.show()

# Drop Test

In [None]:
# Set up generated dataset for mode dropping
gen_features = np.load('output/imagenet-ADMG-ADMU_features.npy')
print(gen_features.shape)
np.save('output/imagenet-ADMG-ADMU_features_5.npy', gen_features)

In [None]:
# Create a dictionary for the # of samples per class to aid in dropping
def count_samples_per_class(source_dir, transform):
    full_dataset = ImageFolder(source_dir, transform=transform)

    class_to_idx = full_dataset.class_to_idx
    all_classes = list(class_to_idx.keys())

    class_counts = {}

    for class_name, idx in class_to_idx.items():
        class_dir = os.path.join(source_dir, class_name)
        num_files = len([f for f in os.listdir(class_dir) if f.endswith(('.png'))])
        class_counts[class_name] = num_files

    return class_counts

source_dir= 'images/imagenet-ADMG-ADMU'
class_counts = count_samples_per_class(source_dir, transform)

In [None]:
# Mode dropping
def drop_random_classes(features, class_counts, classes_to_drop):
    expanded_classes = []
    for class_label, count in class_counts.items():
        expanded_classes.extend([class_label] * count)

    expanded_classes = np.array(expanded_classes)

    # determine indices to keep
    keep_indices = ~np.isin(expanded_classes, classes_to_drop)

    # filter the features array
    reduced_features = features[keep_indices]

    return reduced_features

# determine which classes remain and can be dropped
def drop(features, class_counts, keep, batch, dropped_classes, num_classes_to_drop=100):
    if(len(dropped_classes) == 0):
        remaining_classes = [cls for cls in class_counts]
    else:
        remaining_classes = [cls for cls in class_counts if cls not in dropped_classes]
    classes_to_drop = np.random.choice(remaining_classes, num_classes_to_drop, replace=False)
    dropped_classes.extend(classes_to_drop)
    print(len(dropped_classes))

    reduced_features = drop_random_classes(features, class_counts, dropped_classes)
    print(reduced_features.shape)
    np.save(f'output/imagenet-ADMG-ADMU_features_{keep}_{batch}.npy', reduced_features)
    return dropped_classes


#keep_num = np.arange(100, 1000, 100).tolist()
#keep_num.reverse()
keep_num = [4, 3, 2, 1]

for batch in range(1): # originally 10 batches
    dropped_classes = []
    for keep in keep_num:
        new_dropped_classes = drop(gen_features, class_counts, keep, batch+1, dropped_classes, 1)
        dropped_classes = new_dropped_classes


In [None]:
# Calculate metrics for mode dropping
def calc_for_drop(model_name, batch):
    real_features = np.load('output/imagenet-train_features.npy')
    model_features = {}
    num_to_keep = [1, 2, 3, 4]
    for keep_num in num_to_keep:
        model_features[f'{model_name}_{keep_num}'] = np.load(f'output/imagenet-{model_name}_features_{keep_num}_{batch}.npy')
    model_features[f'{model_name}_5'] = np.load(f'output/imagenet-{model_name}_features_5.npy')

    scores_FD = {}
    scores_PCE = {}
    scores_RCE = {}
    scores_RE = {}
    scores_density = {}
    scores_coverage = {}
    scores_PC = {}
    scores_RC = {}

    for model_name, features in model_features.items():
        print(model_name)
        fd, density, coverage, pc, rc, pce, rce, re = calculate_realism_scores(real_features, features)
        scores_FD[model_name] = fd
        scores_density[model_name] = density
        scores_coverage[model_name] = coverage
        scores_PC[model_name] = pc
        scores_RC[model_name] = rc
        scores_PCE[model_name] = pce
        scores_RCE[model_name] = rce
        scores_RE[model_name] = re
        print("done")

    # Dump
    with open(f'output/imagenetscores_recall_{batch}.pkl', 'wb') as f:
        pickle.dump((scores_FD, scores_density, scores_coverage, scores_PC, scores_RC, scores_PCE, scores_RCE, scores_RE), f)


In [None]:
calc_for_drop('ADMG-ADMU', 1)

In [None]:
# Plotting setup
density_scores, coverage_scores, rc_scores, pc_scores = [], [], [], []
pce_scores, rce_scores, re_scores, fd_scores = [], [], [], []

for i in range(1):
    with open(f'output/imagenetscores_recall_{i+1}.pkl', 'rb') as f:
        scores = pickle.load(f)
        coverage_scores.append(list(np.flip(list(scores[2].values()))))
        rc_scores.append(list(np.flip(list(scores[4].values()))))
        pc_scores.append(list(np.flip(list(scores[3].values()))))
        density_scores.append(list(np.flip(list(scores[1].values()))))
        rce_scores.append(list(np.flip(list(scores[6].values()))))
        pce_scores.append(list(np.flip(list(scores[5].values()))))
        re_scores.append(list(np.flip(list(scores[7].values()))))
        fd_scores.append(list(np.flip(list(scores[0].values()))))

coverage_scores = np.array(coverage_scores)
rc_scores = np.array(rc_scores)
pc_scores = np.array(pc_scores)
density_scores = np.array(density_scores)
rce_scores = np.array(rce_scores)
pce_scores = np.array(pce_scores)
re_scores = np.array(re_scores)
fd_scores = np.array(fd_scores)

# calculate means and standard deviations for error bars, not visualized in the simplified code
means = {
    "Coverage": np.mean(coverage_scores, axis=0),
    "RC": np.mean(rc_scores, axis=0),
    "PC": np.mean(pc_scores, axis=0),
    "Density": np.mean(density_scores, axis=0),
    "PCE": np.mean(pce_scores, axis=0),
    "RCE": np.mean(rce_scores, axis=0),
    "RE": np.mean(re_scores, axis=0),
    "FD": np.mean(fd_scores, axis=0)
}
std_devs = {
    "Coverage": np.std(coverage_scores, axis=0),
    "RC": np.std(rc_scores, axis=0),
    "PC": np.std(pc_scores, axis=0),
    "Density": np.std(density_scores, axis=0),
    "PCE": np.std(pce_scores, axis=0),
    "RCE": np.std(rce_scores, axis=0),
    "RE": np.std(re_scores, axis=0),
    "FD": np.std(fd_scores, axis=0)
}

In [None]:
# Plot for mode dropping experiment

#models = [0, 100, 200, 300, 400, 500, 600, 700, 800, 900]
models = [0, 1, 2, 3, 4]
fig, axes = plt.subplots(2, 2, figsize=(11, 9))

# Subplot 1
ax1 = axes[0, 0]
ax1.errorbar(models, means['Density'], yerr=std_devs['Density'], label='D', marker='x', markersize=10, color = 'mediumblue',  elinewidth=2, capthick=2, ecolor='gray')
ax1.errorbar(models, means['Coverage'], yerr=std_devs['Coverage'], label='C', marker='D', markersize=7, color = 'orange',  elinewidth=2, capthick=2, ecolor='gray')
ax1.set_xlabel('# of Modes Dropped')
ax1.set_ylabel('Score')
ax1.set_ylim(0, 1)
ax1.set_xticks(models)
ax1.tick_params(axis='x', rotation=45)
ax1.set_title('(a) Density and Coverage')
ax1.legend()

# Subplot 2
ax2 = axes[0, 1]
ax2.errorbar(models, means['PC'], yerr=std_devs['PC'], label='PC', marker='x', markersize=10, color = 'mediumblue',  elinewidth=2, capthick=2, ecolor='gray')
ax2.errorbar(models, means['RC'], yerr=std_devs['RC'], label='RC', marker='D', markersize=7, color = 'orange',  elinewidth=2, capthick=2, ecolor='gray')
ax2.set_xlabel('# of Modes Dropped')
ax2.set_ylabel('Score')
ax2.set_ylim(0, 1)
ax2.set_xticks(models)
ax2.tick_params(axis='x', rotation=45)
ax2.set_title('(b) PC and RC')
ax2.legend()

# Subplot 3
ax3 = axes[1, 0]
ax3.errorbar(models, means['PCE'], yerr=std_devs['PCE'], label='PCE', marker='x', markersize=10, color = 'mediumblue',  elinewidth=2, capthick=2, ecolor='gray')
ax3.errorbar(models, means['RCE'], yerr=std_devs['RCE'], label='RCE', marker='D', markersize=7, color = 'orange',  elinewidth=2, capthick=2, ecolor='gray')
ax3.errorbar(models, means['RE'], yerr=std_devs['RE'], label='RE', marker='^', markersize=7, color = 'green',  elinewidth=2, capthick=2, ecolor='gray')
ax3.set_xlabel('# of Modes Dropped')
ax3.set_ylabel('Score')
ax3.set_yticks(range(0, 601, 100))
ax3.set_xticks(models)
ax3.tick_params(axis='x', rotation=45)
ax3.set_title('(c) PCE, RCE, and RE')
ax3.legend(fontsize=19)

# Subplot 4
ax4 = axes[1, 1]
ax4.errorbar(models, means['FD'], yerr=std_devs['FD'], label='FD', marker='x', markersize=10, color = 'mediumblue',  elinewidth=2, capthick=2, ecolor='gray')
ax4.set_xlabel('# of Modes Dropped')
ax4.set_ylabel('Score')
ax4.set_yticks(range(0, 701, 100))
ax4.set_xticks(models)
ax4.tick_params(axis='x', rotation=45)
ax4.set_title('(d) Fréchet Distance')
ax4.legend(fontsize=26)


fig.tight_layout()
plt.show()

# Additional Experiments

## Sample Convergence

In [None]:
# Metric + sampling functions
def generate_multid_gaussian_samples(mean, cov, num_samples):
    return np.random.multivariate_normal(mean, cov, num_samples)

# Optimized version of select metrics using KD trees

def calc_density_kd(R, G, tree_G, k):
    tree_R = KDTree(R)
    dist_R, _ = tree_R.query(R, k=k+1)
    radii_R = dist_R[:, -1]

    density_counts = np.array([tree_G.query_radius([p], r, count_only=True) for p, r in zip(R, radii_R)])
    overall_density = np.sum(density_counts) / (k * len(G))
    return overall_density

def calc_PC_kd(G, tree_G, tree_R, M, k, C):
    k_prime = C * k
    dist_G, ind_G = tree_G.query(G, k=k_prime+1)
    radii_G = dist_G[:, -1]

    points_in_radius = tree_R.query_radius(G, r=radii_G, count_only=True)

    valid_balls = np.sum(points_in_radius >= k)

    return valid_balls / M

def calculate_realism_scores_kd(R, G, k=5, C=3):
    # init KD-Trees
    tree_R = KDTree(R)
    tree_G = KDTree(G)

    # compute metrics
    density = calc_density_kd(R, G, tree_G, k)
    pc = calc_PC_kd(G, tree_G, tree_R, len(G), k, C)

    print(f"Density: {density}")
    print(f"Precision Coverage (PC): {pc}")

    return density, pc


In [None]:
# Analyze the behavior of density and precision coverage for two different distributions

mean_R = [0, 0]
mean_G = [2, 2]
cov_R = np.array([[1, 0.1], [0.1, 1]])
cov_G = np.array([[1, 0.1], [0.1, 1]])
k = 25
sample_sizes = np.linspace(10000, 1500000, num=30).astype(int)

density = []
pc = []

for num_samples in sample_sizes:
    R = generate_multid_gaussian_samples(mean_R, cov_R, num_samples)
    G = generate_multid_gaussian_samples(mean_G, cov_G, num_samples)
    print(G.shape)
    density_i, pc_i = calculate_realism_scores_kd(R, G, k)
    density.append(density_i)
    pc.append(pc_i)

In [None]:
# Plot sample experiment

plt.figure(figsize=(10, 6))
plt.plot(sample_sizes, density, 'r-o', label='Density')
plt.plot(sample_sizes, pc, 'b-o', label='PC')
plt.axhline(y=1, color='gray', linestyle='--')  # dashed line at y=1
plt.xlabel('Sample Size')
plt.ylabel('Metric Values')
plt.title('Density and PC Over Increasing Sample Sizes')
plt.legend()
plt.ylim(0, 1.2)
plt.show()

In [None]:
# 2D visualization of distributions

N_R = 500
N_G = 500

R = np.random.multivariate_normal(mean_R, cov_R, N_R)
G = np.random.multivariate_normal(mean_G, cov_G, N_G)

plt.figure(figsize=(8, 6))
plt.scatter(R[:, 0], R[:, 1], color='green', alpha=0.5, label='R')
plt.scatter(G[:, 0], G[:, 1], color='hotpink', alpha=0.5, label='G')


plt.title('2D Visualization of R and G')
plt.legend()
plt.grid(True)
plt.axis('equal')
plt.show()

## Memorization

Here we perform our sample-level memorization experiment; note because of the limited sample set, the points printed are not what, from the full dataset, we would consider memorized.

In [None]:
# Calculate sample-level contribution to PCE score
def cross_entropy_sample(N, M, k, nu_k, d):
    psi_k = digamma(k)

    c_bar = volume_of_unit_ball_log(d)

    inner_term = np.log(M) - psi_k + c_bar + d*np.log(nu_k)
    entropy_estimate = (1 / N) * np.sum(inner_term)

    return inner_term

def entropy_sample(N, k, rho_k, d):
    psi_k = digamma(k)

    c_bar = volume_of_unit_ball_log(d)

    inner_term = np.log(N-1) - psi_k + c_bar + d*np.log(rho_k)
    entropy_estimate = (1 / N) * np.sum(inner_term)

    return inner_term

def calculate_pce_scores(R, G, ce_k = 1):

    nbrs_R = NearestNeighbors(n_neighbors=ce_k+1, algorithm='auto', n_jobs=-1).fit(R) # ignore first neighbor (itself)
    dist_R, _ = nbrs_R.kneighbors(R, ce_k+1)

    nbrs_G = NearestNeighbors(n_neighbors=ce_k+1, algorithm='auto', n_jobs=-1).fit(G)
    dist_G, _ = nbrs_G.kneighbors(G, ce_k+1)

    dist_RG_pairs = pairwise_distances(R, G, n_jobs=-1)
    dist_GR_pairs = pairwise_distances(G, R, n_jobs=-1)

    dist_RG, _ = nbrs_G.kneighbors(R, ce_k+1)
    dist_GR, ind = nbrs_R.kneighbors(G, ce_k+1)

    ce_gr = cross_entropy_sample(len(G), len(R), ce_k, dist_GR[:, ce_k-1], len(R[0]))
    e_r = entropy_sample(len(R), ce_k, dist_R[:, ce_k], len(R[0]))

    return ce_gr-e_r, ind

In [None]:
# Calculate PCE for samples of LOGAN
scores_PCE = []
PCE_ind = []

real_features = np.load('output/cifar10-train_features.npy')
model_features = {'LOGAN': np.load('output/cifar10-LOGAN_features.npy')}

for model_name, features in model_features.items():
    print(model_name)
    pce, pce_ind = calculate_pce_scores(real_features, features, 1)
    scores_PCE = pce
    PCE_ind = pce_ind[:, 0]
    print("done")

# combine the indices of images with scores
indexed_scores = list(zip(scores_PCE, range(scores_PCE.shape[0])))
# sort by scores
indexed_scores.sort(key=lambda x: x[0], reverse=False)

In [None]:
# Print the generated images alongside their nearest real neighbor
def display_images_in_batches(indexed_scores, gen_dataset, real_dataset, qp_ind, batch_size=5):
    num_batches = len(indexed_scores) // batch_size

    for batch in range(num_batches):
        plt.figure(figsize=(25, 10))  # 5 images per row, 2 rows a set
        gen_paths = []
        real_paths = []
        titles = []
        scores = []

        for score, ind in indexed_scores[batch * batch_size : (batch + 1) * batch_size]:
            gen_path, gen_label = gen_dataset.samples[ind]
            real_path, real_label = real_dataset.samples[qp_ind[ind]]

            gen_paths.append(gen_path)
            real_paths.append(real_path)
            titles.append(gen_dataset.classes[gen_label])
            scores.append(score)

        # display generated images with scores
        for i in range(batch_size):
            plt.subplot(2, batch_size, i + 1)
            image = Image.open(gen_paths[i])
            plt.imshow(image)
            plt.title(f"PCE: {scores[i]:.2f}", fontsize=40)
            plt.axis('off')

        # Ddsplay real images
        for i in range(batch_size):
            plt.subplot(2, batch_size, batch_size + i + 1)
            image = Image.open(real_paths[i])
            plt.imshow(image)
            plt.axis('off')

        plt.show()

# Usage:
gen_dataset = ImageFolder('images/cifar10-LOGAN', transform=transform)
real_dataset = ImageFolder('images/cifar10-train', transform=transform)
display_images_in_batches(indexed_scores, gen_dataset, real_dataset, PCE_ind, batch_size=5)
