In [13]:
import pandas as pd
from os.path import join
from training import dataset
import os
import click
import tqdm.auto as tqdm    
import pickle
import numpy as np
import scipy.linalg
import torch
import dnnlib
from torch_utils import distributed as dist
from torch.utils.data import DataLoader, TensorDataset
#----------------------------------------------------------------------------

def calculate_inception_stats(
    image_path, num_expected=None, seed=0, max_batch_size=64,
    num_workers=3, prefetch_factor=2, device=torch.device('cuda'),
):
    # Rank 0 goes first.
    # if dist.get_rank() != 0:
    #     torch.distributed.barrier()

    # Load Inception-v3 model.
    # This is a direct PyTorch translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
    print('Loading Inception-v3 model...')
    detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl'
    detector_kwargs = dict(return_features=True)
    feature_dim = 2048
    with dnnlib.util.open_url(detector_url, verbose=(dist.get_rank() == 0)) as f:
        detector_net = pickle.load(f).to(device)

    # List images.
    print(f'Loading images from "{image_path}"...')
    dataset_obj = dataset.ImageFolderDataset(path=image_path, max_size=num_expected, random_seed=seed)
    # if num_expected is not None and len(dataset_obj) < num_expected:
    #     raise click.ClickException(f'Found {len(dataset_obj)} images, but expected at least {num_expected}')
    # if len(dataset_obj) < 2:
    #     raise click.ClickException(f'Found {len(dataset_obj)} images, but need at least 2 to compute statistics')

    # Other ranks follow.
    # if dist.get_rank() == 0:
    #     torch.distributed.barrier()

    # Divide images into batches.
    num_batches = ((len(dataset_obj) - 1) // (max_batch_size * dist.get_world_size()) + 1) * dist.get_world_size()
    all_batches = torch.arange(len(dataset_obj)).tensor_split(num_batches)
    rank_batches = all_batches[dist.get_rank() :: dist.get_world_size()]
    data_loader = torch.utils.data.DataLoader(dataset_obj, batch_sampler=rank_batches, num_workers=num_workers, prefetch_factor=prefetch_factor)

    # Accumulate statistics.
    print(f'Calculating statistics for {len(dataset_obj)} images...')
    mu = torch.zeros([feature_dim], dtype=torch.float64, device=device)
    sigma = torch.zeros([feature_dim, feature_dim], dtype=torch.float64, device=device)
    for images, _labels in tqdm.tqdm(data_loader, unit='batch', disable=(dist.get_rank() != 0)):
        # torch.distributed.barrier()
        if images.shape[0] == 0:
            continue
        if images.shape[1] == 1:
            images = images.repeat([1, 3, 1, 1])
        if images.shape[1] == 4:
            images = images[:, :3, :, :]
        features = detector_net(images.to(device), **detector_kwargs).to(torch.float64)
        mu += features.sum(0)
        sigma += features.T @ features

    # Calculate grand totals.
    # torch.distributed.all_reduce(mu)
    # torch.distributed.all_reduce(sigma)
    mu /= len(dataset_obj)
    sigma -= mu.ger(mu) * len(dataset_obj)
    sigma /= len(dataset_obj) - 1
    return mu.cpu().numpy(), sigma.cpu().numpy()


def calculate_inception_stats_from_numpy(
    images_np, num_expected=None, seed=0, max_batch_size=64,
    num_workers=4, prefetch_factor=2, device=torch.device('cuda'),
):
    # Load Inception-v3 model.
    # This is a direct PyTorch translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
    print('Loading Inception-v3 model...')
    detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl'
    detector_kwargs = dict(return_features=True)
    feature_dim = 2048
    with dnnlib.util.open_url(detector_url, verbose=False) as f:
        detector_net = pickle.load(f).to(device)

    # Prepare the dataset
    if num_expected is not None:
        images_np = images_np[:num_expected]
    # NOTE: no normalization, raw uint8 format, [0,1] will yield incorrect results!!!!
    # dataset_tensor = torch.from_numpy(images_np).float() / 255.0  # Normalize to [0, 1]
    dataset_tensor = torch.from_numpy(images_np)  # no normalization, raw uint8 format
    dataset_tensor = dataset_tensor.permute(0, 3, 1, 2)  # Ensure shape (N, 3, H, W)
    dataset_tensor = dataset_tensor.to(device)
    if dataset_tensor.shape[1] == 1:
        dataset_tensor = dataset_tensor.repeat([1, 3, 1, 1])
    elif dataset_tensor.shape[1] == 4:
        dataset_tensor = dataset_tensor[:, :3, :, :]
    assert dataset_tensor.shape[1] == 3
    assert dataset_tensor.dtype == torch.uint8
    tensor_dataset = TensorDataset(dataset_tensor)
    data_loader = DataLoader(
        tensor_dataset,
        batch_size=max_batch_size,
        shuffle=False,
        num_workers=num_workers,
        # pin_memory=True,
        prefetch_factor=prefetch_factor,
    )
    # Accumulate statistics.
    print(f'Calculating statistics for {len(tensor_dataset)} images...')
    mu = torch.zeros([feature_dim], dtype=torch.float64, device=device)
    sigma = torch.zeros([feature_dim, feature_dim], dtype=torch.float64, device=device)
    for images, in tqdm.tqdm(data_loader, unit='batch', ):
        with torch.no_grad():
            features = detector_net(images.to(device), **detector_kwargs).to(torch.float64)
        mu += features.sum(0)
        sigma += features.T @ features
    # Calculate grand totals.
    mu /= len(tensor_dataset)
    sigma -= mu.ger(mu) * len(tensor_dataset)
    sigma /= len(tensor_dataset) - 1
    return mu.cpu().numpy(), sigma.cpu().numpy()


def calculate_fid_from_inception_stats(mu, sigma, mu_ref, sigma_ref):
    m = np.square(mu - mu_ref).sum()
    s, _ = scipy.linalg.sqrtm(np.dot(sigma, sigma_ref), disp=False)
    fid = m + np.trace(sigma + sigma_ref - s * 2)
    return float(np.real(fid))


### Main computation loop

In [14]:
fid_root = "/n/holylfs06/LABS/kempner_fellow_binxuwang/Users/binxuwang/Datasets/EDM_datasets/fid-refs"
refdata_root = "/n/holylfs06/LABS/kempner_fellow_binxuwang/Users/binxuwang/Datasets/EDM_datasets/datasets"
sample_root = "/n/holylfs06/LABS/kempner_fellow_binxuwang/Users/binxuwang/DL_Projects/edm_analy_sampler_benchmark/samples/"
eval_root = "/n/holylfs06/LABS/kempner_fellow_binxuwang/Users/binxuwang/DL_Projects/edm_analy_sampler_benchmark/eval/"
model_ckpt_dict = {"afhqv264": "edm-afhqv2-64x64-uncond-vp",
                   "ffhq64": "edm-ffhq-64x64-uncond-vp",
                   "cifar10": "edm-cifar10-32x32-uncond-vp"}
refdata_dict = {"afhqv264": "afhqv2-64x64.zip",
                "ffhq64": "ffhq-64x64.zip",
                "cifar10": "cifar10-32x32.zip"}
refstats_dict = {"afhqv264": "afhqv2-64x64.npz",
                "ffhq64": "ffhq-64x64.npz",
                "cifar10": "cifar10-32x32.npz"}
refstats_url_dict = {"ffhq64": "https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/ffhq-256.npz",
            "afhqv264": "https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/afhqv2-64x64.npz",
            "cifar10": "https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/cifar10-32x32.npz"}


In [17]:
dataset_name = "cifar10" # "ffhq64" # "cifar10"  #

model_ckpt = model_ckpt_dict[dataset_name]
model_dir = join(sample_root, model_ckpt)
eval_dir = join(eval_root, model_ckpt)
os.makedirs(eval_root, exist_ok=True)
os.makedirs(eval_dir, exist_ok=True)

ref = np.load(join(fid_root, refstats_dict[dataset_name]))
ref_mu, ref_sigma = ref["mu"], ref["sigma"]
# with dnnlib.util.open_url(refstats_url_dict[dataset_name]) as f:
#     ref = dict(np.load(f))
#     ref_mu, ref_sigma = ref["mu"], ref["sigma"]

"""for all subfolders in model_dir, collect all files ending with .npz, load them as numpy arrays, and concatenate them along the first axis"""
fid_col = []
# for subdir in ["uni_pc_bh1_25_skip20.0"]:#["dpm_solver_v3_25_skip80.0"]:#tqdm.tqdm(os.listdir(model_dir)):
for subdir in sorted(os.listdir(model_dir)):
    sampler_str = subdir 
    print(f"Evaluating {sampler_str}...")
    npz_files = [join(model_dir, subdir, f) for f in os.listdir(join(model_dir, subdir)) if f.endswith(".npz")]
    if len(npz_files) == 0:
        print(f"No npz files found in {model_dir}/{subdir}")
        continue
    sample_all = np.concatenate([np.load(f)["samples"] for f in npz_files], axis=0) # shape (N, 3, H, W) uint8 format
    print(sample_all.shape)
    Mu, Sigma = calculate_inception_stats_from_numpy(sample_all, num_expected=50000,
                                   seed=0, max_batch_size=512, num_workers=0, prefetch_factor=None,
                                   device=torch.device('cuda'))
    # print(Mu.shape, Sigma.shape)
    fid = calculate_fid_from_inception_stats(Mu, Sigma, ref_mu, ref_sigma)
    print(f"{sampler_str} FID: {fid:.2f}")
    # save Mu, Sigma, fid to eval_dir
    np.savez(join(eval_dir, f"{sampler_str}_stats.npz"), mu=Mu, sigma=Sigma, fid=fid)
    fid_col.append({"sampler": sampler_str, "fid": fid, "dataset": dataset_name, }) # "ckpt": model_ckpt, 
    # save fid_col to eval_dir
    pd.DataFrame(fid_col).to_csv(join(eval_dir, "fid_by_sampler.csv"))
    # raise ValueError("stop here")

Evaluating dpm_solver++_10_skip1.0...
(50176, 32, 32, 3)
Loading Inception-v3 model...
Calculating statistics for 50000 images...


  0%|          | 0/98 [00:00<?, ?batch/s]

dpm_solver++_10_skip1.0 FID: 29.34
Evaluating dpm_solver++_10_skip10.0...
(50176, 32, 32, 3)
Loading Inception-v3 model...
Calculating statistics for 50000 images...


  0%|          | 0/98 [00:00<?, ?batch/s]

dpm_solver++_10_skip10.0 FID: 2.94
Evaluating dpm_solver++_10_skip2.5...
(50176, 32, 32, 3)
Loading Inception-v3 model...
Calculating statistics for 50000 images...


  0%|          | 0/98 [00:00<?, ?batch/s]

dpm_solver++_10_skip2.5 FID: 7.23
Evaluating dpm_solver++_10_skip20.0...
(50176, 32, 32, 3)
Loading Inception-v3 model...
Calculating statistics for 50000 images...


  0%|          | 0/98 [00:00<?, ?batch/s]

dpm_solver++_10_skip20.0 FID: 2.84
Evaluating dpm_solver++_10_skip40.0...
(50176, 32, 32, 3)
Loading Inception-v3 model...
Calculating statistics for 50000 images...


  0%|          | 0/98 [00:00<?, ?batch/s]

dpm_solver++_10_skip40.0 FID: 2.83
Evaluating dpm_solver++_10_skip5.0...
(50176, 32, 32, 3)
Loading Inception-v3 model...
Calculating statistics for 50000 images...


  0%|          | 0/98 [00:00<?, ?batch/s]

dpm_solver++_10_skip5.0 FID: 3.62
Evaluating dpm_solver++_10_skip80.0...
(50176, 32, 32, 3)
Loading Inception-v3 model...
Calculating statistics for 50000 images...


  0%|          | 0/98 [00:00<?, ?batch/s]

dpm_solver++_10_skip80.0 FID: 3.06
Evaluating dpm_solver++_12_skip1.0...
(50176, 32, 32, 3)
Loading Inception-v3 model...
Calculating statistics for 50000 images...


  0%|          | 0/98 [00:00<?, ?batch/s]

dpm_solver++_12_skip1.0 FID: 28.28
Evaluating dpm_solver++_12_skip10.0...
(50176, 32, 32, 3)
Loading Inception-v3 model...
Calculating statistics for 50000 images...


  0%|          | 0/98 [00:00<?, ?batch/s]

dpm_solver++_12_skip10.0 FID: 2.44
Evaluating dpm_solver++_12_skip2.5...
(50176, 32, 32, 3)
Loading Inception-v3 model...
Calculating statistics for 50000 images...


  0%|          | 0/98 [00:00<?, ?batch/s]

dpm_solver++_12_skip2.5 FID: 6.28
Evaluating dpm_solver++_12_skip20.0...
(50176, 32, 32, 3)
Loading Inception-v3 model...
Calculating statistics for 50000 images...


  0%|          | 0/98 [00:00<?, ?batch/s]

dpm_solver++_12_skip20.0 FID: 2.36
Evaluating dpm_solver++_12_skip40.0...
(50176, 32, 32, 3)
Loading Inception-v3 model...
Calculating statistics for 50000 images...


  0%|          | 0/98 [00:00<?, ?batch/s]

dpm_solver++_12_skip40.0 FID: 2.44
Evaluating dpm_solver++_12_skip5.0...
(50176, 32, 32, 3)
Loading Inception-v3 model...
Calculating statistics for 50000 images...


  0%|          | 0/98 [00:00<?, ?batch/s]

dpm_solver++_12_skip5.0 FID: 2.96
Evaluating dpm_solver++_12_skip80.0...
(50176, 32, 32, 3)
Loading Inception-v3 model...
Calculating statistics for 50000 images...


  0%|          | 0/98 [00:00<?, ?batch/s]

dpm_solver++_12_skip80.0 FID: 2.58
Evaluating dpm_solver++_15_skip1.0...
(50176, 32, 32, 3)
Loading Inception-v3 model...
Calculating statistics for 50000 images...


  0%|          | 0/98 [00:00<?, ?batch/s]

dpm_solver++_15_skip1.0 FID: 27.39
Evaluating dpm_solver++_15_skip10.0...
(50176, 32, 32, 3)
Loading Inception-v3 model...
Calculating statistics for 50000 images...


  0%|          | 0/98 [00:00<?, ?batch/s]

dpm_solver++_15_skip10.0 FID: 2.15
Evaluating dpm_solver++_15_skip2.5...
(50176, 32, 32, 3)
Loading Inception-v3 model...
Calculating statistics for 50000 images...


  0%|          | 0/98 [00:00<?, ?batch/s]

dpm_solver++_15_skip2.5 FID: 5.59
Evaluating dpm_solver++_15_skip20.0...
(50176, 32, 32, 3)
Loading Inception-v3 model...
Calculating statistics for 50000 images...


  0%|          | 0/98 [00:00<?, ?batch/s]

dpm_solver++_15_skip20.0 FID: 2.13
Evaluating dpm_solver++_15_skip40.0...
(50176, 32, 32, 3)
Loading Inception-v3 model...
Calculating statistics for 50000 images...


  0%|          | 0/98 [00:00<?, ?batch/s]

dpm_solver++_15_skip40.0 FID: 2.19
Evaluating dpm_solver++_15_skip5.0...
(50176, 32, 32, 3)
Loading Inception-v3 model...
Calculating statistics for 50000 images...


  0%|          | 0/98 [00:00<?, ?batch/s]

dpm_solver++_15_skip5.0 FID: 2.56
Evaluating dpm_solver++_15_skip80.0...
(50176, 32, 32, 3)
Loading Inception-v3 model...
Calculating statistics for 50000 images...


  0%|          | 0/98 [00:00<?, ?batch/s]

dpm_solver++_15_skip80.0 FID: 2.26
Evaluating dpm_solver++_20_skip1.0...
(50176, 32, 32, 3)
Loading Inception-v3 model...
Calculating statistics for 50000 images...


  0%|          | 0/98 [00:00<?, ?batch/s]

dpm_solver++_20_skip1.0 FID: 26.62
Evaluating dpm_solver++_20_skip10.0...
(50176, 32, 32, 3)
Loading Inception-v3 model...
Calculating statistics for 50000 images...


  0%|          | 0/98 [00:00<?, ?batch/s]

dpm_solver++_20_skip10.0 FID: 2.05
Evaluating dpm_solver++_20_skip2.5...
(50176, 32, 32, 3)
Loading Inception-v3 model...
Calculating statistics for 50000 images...


  0%|          | 0/98 [00:00<?, ?batch/s]

dpm_solver++_20_skip2.5 FID: 5.09
Evaluating dpm_solver++_20_skip20.0...
(50176, 32, 32, 3)
Loading Inception-v3 model...
Calculating statistics for 50000 images...


  0%|          | 0/98 [00:00<?, ?batch/s]

dpm_solver++_20_skip20.0 FID: 2.05
Evaluating dpm_solver++_20_skip40.0...
(50176, 32, 32, 3)
Loading Inception-v3 model...
Calculating statistics for 50000 images...


  0%|          | 0/98 [00:00<?, ?batch/s]

dpm_solver++_20_skip40.0 FID: 2.09
Evaluating dpm_solver++_20_skip5.0...
(50176, 32, 32, 3)
Loading Inception-v3 model...
Calculating statistics for 50000 images...


  0%|          | 0/98 [00:00<?, ?batch/s]

dpm_solver++_20_skip5.0 FID: 2.34
Evaluating dpm_solver++_20_skip80.0...
(50176, 32, 32, 3)
Loading Inception-v3 model...
Calculating statistics for 50000 images...


  0%|          | 0/98 [00:00<?, ?batch/s]

dpm_solver++_20_skip80.0 FID: 2.12
Evaluating dpm_solver++_25_skip1.0...
(50176, 32, 32, 3)
Loading Inception-v3 model...
Calculating statistics for 50000 images...


  0%|          | 0/98 [00:00<?, ?batch/s]

dpm_solver++_25_skip1.0 FID: 26.22
Evaluating dpm_solver++_25_skip10.0...
(50176, 32, 32, 3)
Loading Inception-v3 model...
Calculating statistics for 50000 images...


  0%|          | 0/98 [00:00<?, ?batch/s]

dpm_solver++_25_skip10.0 FID: 2.03
Evaluating dpm_solver++_25_skip2.5...
(50176, 32, 32, 3)
Loading Inception-v3 model...
Calculating statistics for 50000 images...


  0%|          | 0/98 [00:00<?, ?batch/s]

dpm_solver++_25_skip2.5 FID: 4.86
Evaluating dpm_solver++_25_skip20.0...
(50176, 32, 32, 3)
Loading Inception-v3 model...
Calculating statistics for 50000 images...


  0%|          | 0/98 [00:00<?, ?batch/s]

dpm_solver++_25_skip20.0 FID: 2.04
Evaluating dpm_solver++_25_skip40.0...
(50176, 32, 32, 3)
Loading Inception-v3 model...
Calculating statistics for 50000 images...


  0%|          | 0/98 [00:00<?, ?batch/s]

dpm_solver++_25_skip40.0 FID: 2.08
Evaluating dpm_solver++_25_skip5.0...
(50176, 32, 32, 3)
Loading Inception-v3 model...
Calculating statistics for 50000 images...


  0%|          | 0/98 [00:00<?, ?batch/s]

dpm_solver++_25_skip5.0 FID: 2.27
Evaluating dpm_solver++_25_skip80.0...
(50176, 32, 32, 3)
Loading Inception-v3 model...
Calculating statistics for 50000 images...


  0%|          | 0/98 [00:00<?, ?batch/s]

dpm_solver++_25_skip80.0 FID: 2.09
Evaluating dpm_solver++_5_skip1.0...
(50176, 32, 32, 3)
Loading Inception-v3 model...
Calculating statistics for 50000 images...


  0%|          | 0/98 [00:00<?, ?batch/s]

dpm_solver++_5_skip1.0 FID: 34.46
Evaluating dpm_solver++_5_skip10.0...
(50176, 32, 32, 3)
Loading Inception-v3 model...
Calculating statistics for 50000 images...


  0%|          | 0/98 [00:00<?, ?batch/s]

dpm_solver++_5_skip10.0 FID: 12.27
Evaluating dpm_solver++_5_skip2.5...
(50176, 32, 32, 3)
Loading Inception-v3 model...
Calculating statistics for 50000 images...


  0%|          | 0/98 [00:00<?, ?batch/s]

dpm_solver++_5_skip2.5 FID: 14.25
Evaluating dpm_solver++_5_skip20.0...
(50176, 32, 32, 3)
Loading Inception-v3 model...
Calculating statistics for 50000 images...


  0%|          | 0/98 [00:00<?, ?batch/s]

dpm_solver++_5_skip20.0 FID: 14.37
Evaluating dpm_solver++_5_skip40.0...
(50176, 32, 32, 3)
Loading Inception-v3 model...
Calculating statistics for 50000 images...


  0%|          | 0/98 [00:00<?, ?batch/s]

dpm_solver++_5_skip40.0 FID: 18.29
Evaluating dpm_solver++_5_skip5.0...
(50176, 32, 32, 3)
Loading Inception-v3 model...
Calculating statistics for 50000 images...


  0%|          | 0/98 [00:00<?, ?batch/s]

dpm_solver++_5_skip5.0 FID: 11.37
Evaluating dpm_solver++_5_skip80.0...
(50176, 32, 32, 3)
Loading Inception-v3 model...
Calculating statistics for 50000 images...


  0%|          | 0/98 [00:00<?, ?batch/s]

dpm_solver++_5_skip80.0 FID: 25.11
Evaluating dpm_solver++_6_skip1.0...
(50176, 32, 32, 3)
Loading Inception-v3 model...
Calculating statistics for 50000 images...


  0%|          | 0/98 [00:00<?, ?batch/s]

dpm_solver++_6_skip1.0 FID: 34.26
Evaluating dpm_solver++_6_skip10.0...
(50176, 32, 32, 3)
Loading Inception-v3 model...
Calculating statistics for 50000 images...


  0%|          | 0/98 [00:00<?, ?batch/s]

dpm_solver++_6_skip10.0 FID: 6.78
Evaluating dpm_solver++_6_skip2.5...
(50176, 32, 32, 3)
Loading Inception-v3 model...
Calculating statistics for 50000 images...


  0%|          | 0/98 [00:00<?, ?batch/s]

dpm_solver++_6_skip2.5 FID: 11.79
Evaluating dpm_solver++_6_skip20.0...
(50176, 32, 32, 3)
Loading Inception-v3 model...
Calculating statistics for 50000 images...


  0%|          | 0/98 [00:00<?, ?batch/s]

dpm_solver++_6_skip20.0 FID: 7.22
Evaluating dpm_solver++_6_skip40.0...
(50176, 32, 32, 3)
Loading Inception-v3 model...
Calculating statistics for 50000 images...


  0%|          | 0/98 [00:00<?, ?batch/s]

dpm_solver++_6_skip40.0 FID: 9.00
Evaluating dpm_solver++_6_skip5.0...
(50176, 32, 32, 3)
Loading Inception-v3 model...
Calculating statistics for 50000 images...


  0%|          | 0/98 [00:00<?, ?batch/s]

dpm_solver++_6_skip5.0 FID: 7.40
Evaluating dpm_solver++_6_skip80.0...
(50176, 32, 32, 3)
Loading Inception-v3 model...
Calculating statistics for 50000 images...


  0%|          | 0/98 [00:00<?, ?batch/s]

dpm_solver++_6_skip80.0 FID: 12.23
Evaluating dpm_solver++_8_skip1.0...
(50176, 32, 32, 3)
Loading Inception-v3 model...
Calculating statistics for 50000 images...


  0%|          | 0/98 [00:00<?, ?batch/s]

dpm_solver++_8_skip1.0 FID: 31.50
Evaluating dpm_solver++_8_skip10.0...
(50176, 32, 32, 3)
Loading Inception-v3 model...
Calculating statistics for 50000 images...


  0%|          | 0/98 [00:00<?, ?batch/s]

dpm_solver++_8_skip10.0 FID: 3.96
Evaluating dpm_solver++_8_skip2.5...
(50176, 32, 32, 3)
Loading Inception-v3 model...
Calculating statistics for 50000 images...


  0%|          | 0/98 [00:00<?, ?batch/s]

dpm_solver++_8_skip2.5 FID: 9.13
Evaluating dpm_solver++_8_skip20.0...
(50176, 32, 32, 3)
Loading Inception-v3 model...
Calculating statistics for 50000 images...


  0%|          | 0/98 [00:00<?, ?batch/s]

dpm_solver++_8_skip20.0 FID: 3.71
Evaluating dpm_solver++_8_skip40.0...
(50176, 32, 32, 3)
Loading Inception-v3 model...
Calculating statistics for 50000 images...


  0%|          | 0/98 [00:00<?, ?batch/s]

dpm_solver++_8_skip40.0 FID: 3.82
Evaluating dpm_solver++_8_skip5.0...
(50176, 32, 32, 3)
Loading Inception-v3 model...
Calculating statistics for 50000 images...


  0%|          | 0/98 [00:00<?, ?batch/s]

dpm_solver++_8_skip5.0 FID: 4.93
Evaluating dpm_solver++_8_skip80.0...
(50176, 32, 32, 3)
Loading Inception-v3 model...
Calculating statistics for 50000 images...


  0%|          | 0/98 [00:00<?, ?batch/s]

dpm_solver++_8_skip80.0 FID: 4.56
Evaluating dpm_solver_v3_10_skip1.0...
(50176, 32, 32, 3)
Loading Inception-v3 model...
Calculating statistics for 50000 images...


  0%|          | 0/98 [00:00<?, ?batch/s]

dpm_solver_v3_10_skip1.0 FID: 31.93
Evaluating dpm_solver_v3_10_skip10.0...
(50176, 32, 32, 3)
Loading Inception-v3 model...
Calculating statistics for 50000 images...


  0%|          | 0/98 [00:00<?, ?batch/s]

KeyboardInterrupt: 

In [12]:
print(f"{sampler_str} FID: {fid:.4f}")

uni_pc_bh1_25_skip20.0 FID: 2.0426


In [None]:
sampler_1 = "uni_pc_bh1_20_skip80.0"
sampler_2 = "dpm_solver_v3_20_skip80.0"
sampler_2 = "uni_pc_bh2_20_skip10.0"

data1 = np.load(join(eval_dir, f"{sampler_1}_stats.npz"))
data2 = np.load(join(eval_dir, f"{sampler_2}_stats.npz"))
calculate_fid_from_inception_stats(data1["mu"], data1["sigma"], data2["mu"], data2["sigma"])


### Debug original dataset

Note original code, input format is 255 uint8. No normalization

In [5]:
refdata_path = join(refdata_root, refdata_dict[dataset_name])

In [6]:
dataset_obj = dataset.ImageFolderDataset(path=refdata_path, max_size=50000, random_seed=0)

In [7]:
dataset_obj[0]

(array([[[ 59,  43,  50, ..., 158, 152, 148],
         [ 16,   0,  18, ..., 123, 119, 122],
         [ 25,  16,  49, ..., 118, 120, 109],
         ...,
         [208, 201, 198, ..., 160,  56,  53],
         [180, 173, 186, ..., 184,  97,  83],
         [177, 168, 179, ..., 216, 151, 123]],
 
        [[ 62,  46,  48, ..., 132, 125, 124],
         [ 20,   0,   8, ...,  88,  83,  87],
         [ 24,   7,  27, ...,  84,  84,  73],
         ...,
         [170, 153, 161, ..., 133,  31,  34],
         [139, 123, 144, ..., 148,  62,  53],
         [144, 129, 142, ..., 184, 118,  92]],
 
        [[ 63,  45,  43, ..., 108, 102, 103],
         [ 20,   0,   0, ...,  55,  50,  57],
         [ 21,   0,   8, ...,  50,  50,  42],
         ...,
         [ 96,  34,  26, ...,  70,   7,  20],
         [ 96,  42,  30, ...,  94,  34,  34],
         [116,  94,  87, ..., 140,  84,  72]]], dtype=uint8),
 array([], dtype=float32))

### Debug FID feature calculate code

In [15]:
# load ref data
refdata_path = join(refdata_root, refdata_dict[dataset_name])
ref_mu_new, ref_sigma_new = calculate_inception_stats(refdata_path, num_expected=50000,
                                   seed=0, max_batch_size=512, num_workers=0, prefetch_factor=None,
                                   device=torch.device('cuda'))
#%%
# fiddata = np.load(join(fid_root, refstats_dict[dataset_name]))
with dnnlib.util.open_url(refstats_url_dict[dataset_name]) as f:
    ref = dict(np.load(f))
    ref_mu, ref_sigma = ref["mu"], ref["sigma"]
#%%
calculate_fid_from_inception_stats(ref_mu_new, ref_sigma_new, ref_mu, ref_sigma) # 2.77 E-5

Loading Inception-v3 model...
Loading images from "/n/holylfs06/LABS/kempner_fellow_binxuwang/Users/binxuwang/Datasets/EDM_datasets/datasets/cifar10-32x32.zip"...
Calculating statistics for 50000 images...


  0%|          | 0/98 [00:00<?, ?batch/s]

2.7743751329426898e-05

### Debug the feature network

In [None]:
detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl'
detector_kwargs = dict(return_features=True)
feature_dim = 2048
with dnnlib.util.open_url(detector_url, verbose=False) as f:
    detector_net = pickle.load(f).to('cuda')

### Debug the dataset formation

In [None]:
images_np = sample_all
dataset_tensor = torch.from_numpy(images_np).float() / 255.0  # Normalize to [0, 1]
dataset_tensor = dataset_tensor.permute(0, 3, 1, 2)  # Ensure shape (N, 3, H, W)
dataset_tensor = dataset_tensor.to("cuda")
if dataset_tensor.shape[1] == 1:
    dataset_tensor = dataset_tensor.repeat([1, 3, 1, 1])
elif dataset_tensor.shape[1] == 4:
    dataset_tensor = dataset_tensor[:, :3, :, :]
assert dataset_tensor.shape[1] == 3
tensor_dataset = TensorDataset(dataset_tensor)

### Dev zone

In [None]:
import os
from os.path import join
import re
import numpy as np
import matplotlib.pyplot as plt
import tqdm

def find_unique_suffixes(folder_name):
    # List all the files in the specified folder
    file_names_list = [f for f in os.listdir(folder_name) if os.path.isfile(os.path.join(folder_name, f))]

    # Regular expression pattern to match the "skipXX_noiseXX" part
    pattern = re.compile(r'skip\d+_noise\d+')

    # Extract the specific suffix pattern from each file name stem
    suffixes = [pattern.search(file_name.split('.')[0]).group() for file_name in file_names_list if
                pattern.search(file_name.split('.')[0])]

    # Get unique suffixes
    unique_suffixes = set(suffixes)

    # Define a function to extract the number after "skip"
    def get_skip_number(suffix):
        return int(suffix.split('_')[0].replace('skip', ''))

    # Sort the unique suffixes by the number after "skip"
    sorted_suffixes = sorted(unique_suffixes, key=get_skip_number)

    return sorted_suffixes


def crop_all_from_montage(img, totalnum, imgsize=32, pad=2):
    """Return all crops from a montage image"""
    nrow, ncol = (img.shape[0] - pad) // (imgsize + pad), (img.shape[1] - pad) // (imgsize + pad)
    imgcol = []
    for imgid in range(totalnum):
        ri, ci = np.unravel_index(imgid, (nrow, ncol))
        img_crop = img[pad + (pad + imgsize) * ri:pad + imgsize + (pad + imgsize) * ri, \
               pad + (pad + imgsize) * ci:pad + imgsize + (pad + imgsize) * ci, :]
        imgcol.append(img_crop)
    return imgcol


def find_files_with_suffix(folder_path, target_suffix):
    # List all files in the specified folder
    all_files = [f for f in os.listdir(folder_path) if os.path.isfile(os.path.join(folder_path, f))]

    # Filter files that contain the target suffix
    matching_files = sorted([file_name for file_name in all_files if target_suffix in file_name])

    return matching_files


In [None]:
# tabdir = r"E:\OneDrive - Harvard University\NeurIPS2023_Diffusion\Tables"
tabdir = r"D:\DL_Projects\Vision\edm_analy_sample\summary"
tabdir = r"/n/scratch3/users/b/biw905/edm_analy_sample/summary"
tabdir = r"/home/binxu/DL_Projects/edm_analy_sample/summary"
imgsize_dict = {"ffhq64": 64, "afhqv264": 64, "cifar10": 32, }
max_batch_size_dict = {"ffhq64": 64, "afhqv264": 64, "cifar10": 256, }
# refstats_dict = {"ffhq64": "https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/ffhq-256.npz",
#             "afhqv264": "https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/afhqv2-64x64.npz",
#             "cifar10": "https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/cifar10-32x32.npz"}

# refstats_url = r"ffhq-64x64.npz"
refstats_dict = {#"ffhq64": "ffhq-256.npz",
                "ffhq64": "ffhq-64x64.npz",
                "afhqv264": "afhqv2-64x64.npz",
                "cifar10": "cifar10-32x32.npz"}


dataset_name = "ffhq64" # "afhqv264"  # "ffhq64" # "cifar10"  #
figdir = rf"D:\DL_Projects\Vision\edm_analy_sample\{dataset_name}_uncond_vp_edm_theory"
figdir = rf"/n/scratch3/users/b/biw905/edm_analy_sample/{dataset_name}_uncond_vp_edm_theory"
figdir = rf"/home/binxu/DL_Projects/edm_analy_sample/{dataset_name}_uncond_vp_edm_theory"
croproot = figdir + "_crops"
imgsize = imgsize_dict[dataset_name]
max_batch_size = max_batch_size_dict[dataset_name]
refstats_url = refstats_dict[dataset_name]

suffixes = find_unique_suffixes(figdir)
for suffix in suffixes:
    os.makedirs(join(croproot, suffix), exist_ok=True)
os.makedirs(tabdir, exist_ok=True)
#%%
# load all mtg figures crop and save into folders
for suffix in suffixes:
    mtglist = find_files_with_suffix(figdir, suffix)
    for mtgname in tqdm.tqdm(mtglist):
        numbers_before_after_dash = re.findall(r'rnd(\d+)-(\d+)_', mtgname)
        rnd_start, rnd_end = numbers_before_after_dash[0]
        rnd_batch = list(range(int(rnd_start), int(rnd_end) + 1))
        mtg_arr = plt.imread(join(figdir, mtgname))
        imgcrops = crop_all_from_montage(mtg_arr, max_batch_size, imgsize=imgsize, pad=2)
        for imgcrop, rnd_id in zip(imgcrops, rnd_batch):
            plt.imsave(join(croproot, suffix, f"rnd{rnd_id:06d}.png"), imgcrop)

#%
fid_batch_size = 256
fid_col = []
for suffix in suffixes:
    Mu, Sigma = calculate_inception_stats(join(croproot, suffix), num_expected=50000,
                                   seed=0, max_batch_size=fid_batch_size, num_workers=0, prefetch_factor=None,
                                   device=torch.device('cuda'))
    # https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/cifar10-32x32.npz
    with dnnlib.util.open_url(refstats_url) as f:
        ref = dict(np.load(f))

    fid = calculate_fid_from_inception_stats(Mu, Sigma, ref['mu'], ref['sigma'])
    print(f"{suffix} FID: {fid:.2f}")
    fid_col.append(fid)
#%%
# sorted(os.listdir(croproot))
# with the folder name column
df = pd.DataFrame(fid_col, columns=["FID"], index=suffixes)
# df.to_csv(join(croproot, "fid_by_skipping.csv"))
df.to_csv(join(tabdir, f"{dataset_name}_fid_by_skipping_new.csv"))




In [None]:
#%% DEV ZONE

# TODO: find a way to compute t_steps from skipsteps
# suffixes = find_unique_suffixes(r"D:\DL_Projects\Vision\edm_analy_sample\ffhq64_uncond_vp_edm_theory")
# # for skipstep in [0, 1, 2, 3, 4, 5, 6, 7, 8, ]:  # range(1, num_steps):
# #     sigma_max_skip = t_steps[skipstep]
# #     print(f"skip{skipstep}_noise{sigma_max_skip:.0f}")
# filenames = find_files_with_suffix(r"D:\DL_Projects\Vision\edm_analy_sample\ffhq64_uncond_vp_edm_theory",
#                                       suffixes[0])
# #%%
# figdir = "/home/binxu/DL_Projects/edm_analy_sample/cifar10_uncond_vp_edm"
# croproot = "/home/binxu/DL_Projects/edm_analy_sample/cifar10_uncond_vp_edm_crops"

#%%
# load all mtg figures crop and save into folders
seeds = list(range(50000))
max_batch_size = 256
# iterate batches
rank_batches = np.array_split(seeds, len(seeds) // max_batch_size)
for batch_seeds in rank_batches:
    for skip, sigma_max_skip in zip([0, 1, 2, 3, 4, 5, 6, 7, 8, ],
                                    t_steps):
        assert os.path.exists(join(figdir, f"rnd{batch_seeds[0]:06d}-{batch_seeds[-1]:06d}"
                                           f"_skip{skip}_noise{sigma_max_skip:.0f}.png"))
        mtg_arr = plt.imread(join(figdir, f"rnd{batch_seeds[0]:06d}-{batch_seeds[-1]:06d}"
                                             f"_skip{skip}_noise{sigma_max_skip:.0f}.png"))
        imgcrops = crop_all_from_montage(mtg_arr, len(batch_seeds), imgsize=32, pad=2)
        for imgid, imgcrop in enumerate(imgcrops):
            plt.imsave(join(croproot, f"skip{skip}_noise{sigma_max_skip:.0f}", f"rnd{batch_seeds[imgid]:06d}.png"), imgcrop)

#%%
croproot = "/home/binxu/DL_Projects/edm_analy_sample/cifar10_uncond_vp_edm_crops"
foldername = f"skip1_noise58"
#%%
# from fid import calculate_inception_stats
# skip, sigma_max_skip = 0, t_steps[0]
fid_col = []
for foldername in sorted(os.listdir(croproot)):
    Mu, Sigma = calculate_inception_stats(join(croproot, foldername), num_expected=50000,
                                   seed=0, max_batch_size=256, num_workers=0, prefetch_factor=None,
                                device=torch.device('cuda'))

    with dnnlib.util.open_url("https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/cifar10-32x32.npz") as f:
        ref = dict(np.load(f))

    fid = calculate_fid_from_inception_stats(Mu, Sigma, ref['mu'], ref['sigma'])
    print(f"{foldername} FID: {fid:.2f}")
    fid_col.append(fid)
#%%
# tabdir = r"E:\OneDrive - Harvard University\NeurIPS2023_Diffusion\Tables"
tabdir = r"/home/binxu/DL_Projects/edm_analy_sample/summary"

# sorted(os.listdir(croproot))
# with the folder name column
df = pd.DataFrame(fid_col, columns=["FID"], index=sorted(os.listdir(croproot)))
# df.to_csv(join(croproot, "fid_by_skipping.csv"))
df.to_csv(join(tabdir, "fid_by_skipping.csv"))


#%%
#%%