In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append("/n/holylabs/LABS/kempner_fellows/Users/binxuwang/Github/edm")
# torch_utils is needed from this path. 
sys.path.append("/n/home12/binxuwang/Github/mini_edm")
sys.path.append("/n/home12/binxuwang/Github/DiffusionMemorization")

In [3]:
import json
from tqdm import tqdm
import re 
import glob
import os
from os.path import join
import torch
import pickle as pkl
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from generate import edm_sampler
# from core.edm_utils import edm_sampler
# from train_edm import create_model, edm_sampler
from torchvision.utils import make_grid, save_image
from torchvision.transforms import ToPILImage

device = 'cuda'

In [4]:
train_root = "/n/holylabs/LABS/kempner_fellows/Users/binxuwang/Github/edm/training-runs"

In [5]:
!cd /n/holylabs/LABS/kempner_fellows/Users/binxuwang/Github/edm/training-runs ; du -sh *

240M	00004-afhqv2-64x64-uncond-ncsnpp-edm-gpus1-batch64-fp32
1.7G	00013-afhqv2-64x64-uncond-ncsnpp-edm-gpus4-batch768-fp16
237M	00017-afhqv2-64x64-uncond-ddpmpp-edm-gpus4-batch256-fp32
237M	00018-afhqv2-64x64-uncond-ddpmpp-edm-gpus4-batch392-fp32
4.2G	00019-afhqv2-64x64-uncond-ddpmpp-edm-gpus4-batch256-fp32
4.2G	00020-ffhq-64x64-uncond-ddpmpp-edm-gpus4-batch256-fp32
6.5G	00021-afhqv2-64x64-whitened-nonorm-uncond-ddpmpp-edm-gpus4-batch256-fp32
6.5G	00022-afhqv2-64x64-whitened-uncond-ddpmpp-edm-gpus4-batch256-fp32
6.5G	00023-ffhq-64x64-whitened-uncond-ddpmpp-edm-gpus4-batch256-fp32
6.5G	00024-ffhq-64x64-whitened-nonorm-uncond-ddpmpp-edm-gpus4-batch256-fp32
6.5G	00028-afhqv2-64x64-spectral-whiten-uncond-ddpmpp-edm-gpus4-batch256-fp32
6.5G	00029-ffhq-64x64-spectral-whiten-uncond-ddpmpp-edm-gpus4-batch256-fp32
3.3G	00030-afhqv2-64x64-uncond-ddpmpp-edm-gpus4-batch256-fp32
3.3G	00031-ffhq-64x64-uncond-ddpmpp-edm-gpus4-batch256-fp32
12G	00032-cifar10-32x32-uncond-ddpmpp-edm-gpus4-batch256-fp32

In [6]:
import sys
sys.path.append("/n/holylabs/LABS/kempner_fellows/Users/binxuwang/Github/edm")
def load_edm_model(ckptdir, ckpt_idx=-1, train_root=train_root, return_epoch=False):
    ckpt_list = glob.glob(join(train_root, ckptdir, "*.pkl"))
    ckpt_list = sorted(ckpt_list)
    ckpt_path = ckpt_list[ckpt_idx]
    epoch = int(re.findall(r'-(\d+).pkl', ckpt_path)[-1])
    print(f"Loading {ckpt_idx}th ckpt", ckpt_path)
    print("Epoch ", epoch)
    with open(ckpt_path, 'rb') as f:
        net = pkl.load(f)['ema'].to(device)
    if return_epoch:
        return net, epoch
    else:
        return net
    


def load_stats(ckptdir, train_root=train_root):
    train_stats = []
    with open(join(train_root, ckptdir, "stats.jsonl")) as f:
        for line in tqdm(f):
            train_stats.append(json.loads(line))
    return pd.DataFrame(train_stats)

### Load Datasets

In [7]:
from training.dataset import ImageFolderDataset

In [8]:
!ls /n/holylabs/LABS/kempner_fellows/Users/binxuwang/Github/edm/datasets

afhqv2-64x64-eigen.pt		  ffhq-64x64-eigen.pt
afhqv2-64x64-spectral-whiten.pt   ffhq-64x64-spectral-whiten.pt
afhqv2-64x64-whitened-nonorm.zip  ffhq-64x64-whitened-nonorm.zip
afhqv2-64x64-whitened.zip	  ffhq-64x64-whitened.zip
afhqv2-64x64.zip		  ffhq-64x64.zip
cifar10-32x32.zip


In [9]:
dataroot = "/n/holylabs/LABS/kempner_fellows/Users/binxuwang/Github/edm/datasets"
dataset_afhq = ImageFolderDataset(join(dataroot, "ffhq-64x64.zip"))

In [10]:
Xtsr = np.stack([sample for sample, _ in dataset_afhq], axis=0) # (N, C, H, W)
Xtsr = torch.from_numpy(Xtsr)
Xtsr_norm = Xtsr / 127.5 - 1 # convention of edm model 
edm_Xmat = Xtsr_norm.view(Xtsr_norm.shape[0], -1)
edm_Xmat = edm_Xmat.to(device)
edm_Xmean = edm_Xmat.mean(dim=0)
edm_Xcov = (edm_Xmat - edm_Xmean).T @ (edm_Xmat - edm_Xmean) / edm_Xmat.shape[0]
eigvals, eigvecs = torch.linalg.eigh(edm_Xcov)
eigvals = eigvals.flip(0)
eigvecs = eigvecs.flip(1)
edm_imgshape = Xtsr.shape[1:]
edm_std_mean = (torch.trace(edm_Xcov) / edm_Xcov.shape[0]).sqrt()

### Clustering structure of AFHQ

In [11]:
from core.analytical_score_lib import mean_isotropic_score, Gaussian_score, delta_GMM_score
from core.analytical_score_lib import explained_var_vec
from core.analytical_score_lib import sample_Xt_batch, sample_Xt_batch
from core.gaussian_mixture_lib import gaussian_mixture_score_batch_sigma_torch, \
    gaussian_mixture_lowrank_score_batch_sigma_torch, compute_cluster

In [12]:
kmeans_batch = 2048
kmeans_random_seed = 42
kmeans_verbose = 0
lambda_EPS = 1E-5
Us_col = {}
mus_col = {}
Lambdas_col = {}
weights_col = {}
for n_clusters in reversed([1, 2, 5, 10, 20,]): #  50, 100, 
    kmeans, eigval_mat, eigvec_mat, freq_vec, center_mat = compute_cluster(edm_Xmat.cpu(), 
                            n_clusters=n_clusters,
                            kmeans_batch=kmeans_batch, 
                            kmeans_random_seed=kmeans_random_seed,
                            kmeans_verbose=kmeans_verbose,
                            lambda_EPS=lambda_EPS)
    Us_col[n_clusters] = eigvec_mat #.to(device)
    mus_col[n_clusters] = center_mat #.to(device)
    Lambdas_col[n_clusters] = eigval_mat #.to(device)
    weights = freq_vec / freq_vec.sum()
    weights_col[n_clusters] = weights #.to(device)
    print(f"n_clusters={n_clusters}, computed.")

  super()._check_params_vs_input(X, default_n_init=3)


Kmeans fitting completing, loss  31556048.80640143


100%|███████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [02:46<00:00,  8.31s/it]


cov PCA completed for each cluster.
n_clusters=20, computed.


  super()._check_params_vs_input(X, default_n_init=3)


Kmeans fitting completing, loss  33712946.7767118


100%|███████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [01:37<00:00,  9.78s/it]


cov PCA completed for each cluster.
n_clusters=10, computed.


  super()._check_params_vs_input(X, default_n_init=3)


Kmeans fitting completing, loss  36387242.270318665


100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [01:04<00:00, 12.93s/it]


cov PCA completed for each cluster.
n_clusters=5, computed.


  super()._check_params_vs_input(X, default_n_init=3)


Kmeans fitting completing, loss  40646453.04290734


100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:43<00:00, 21.96s/it]


cov PCA completed for each cluster.
n_clusters=2, computed.


  super()._check_params_vs_input(X, default_n_init=3)


Kmeans fitting completing, loss  45597426.595639005


100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:37<00:00, 37.37s/it]


cov PCA completed for each cluster.
n_clusters=1, computed.


### Mass compute

In [None]:
from collections import defaultdict
from tqdm import trange
from functools import partial


In [None]:
score_func_col = {
    "mean isotropic": lambda Xt, sigma: mean_isotropic_score(Xt, edm_Xmean, sigma).cpu(), 
    "mean + std isotropic": lambda Xt, sigma: mean_isotropic_score(Xt, edm_Xmean, sigma, sigma0=edm_std_mean).cpu(), 
    "gaussian": lambda Xt, sigma: Gaussian_score(Xt, edm_Xmean, edm_Xcov, sigma).cpu(), 
    "gaussian regularize": lambda Xt, sigma: Gaussian_score(Xt, edm_Xmean, edm_Xcov + torch.eye(edm_Xcov.shape[0]).to(device) * 1E-4, sigma).cpu(), 
    "gmm delta": lambda Xt, sigma: delta_GMM_score(Xt, edm_Xmat, sigma).cpu(), 
}
# for n_clusters in reversed([1, 2, 5, 10, 20,]): #  50, 100, 
#     score_func_col[f"gmm {n_clusters} mode"] = lambda Xt, sigma: gaussian_mixture_score_batch_sigma_torch(Xt, 
#                 mus_col[n_clusters].cuda(), Us_col[n_clusters].cuda(), Lambdas_col[n_clusters].cuda() + sigma**2, 
#                 weights=weights_col[n_clusters].cuda()).cpu()

In [None]:
ckptname = "00035-ffhq-64x64-uncond-ddpmpp-edm-gpus4-batch256-fp32"
device = "cuda"
batch_size = 256
Nreps = 4
ckpt_num = len(glob.glob(join(train_root, ckptname, "*.pkl")))
print("Explaining EDM score with GMM and other analytical scores")
df_col = []
for ckpt_idx in trange(ckpt_num):
    edm, epoch = load_edm_model(ckptname, ckpt_idx=ckpt_idx, return_epoch=True)
    edm.to(device).eval();
    print(f"ckpt_idx={ckpt_idx}, epoch={epoch}")
    for sigma in [0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 0.75, 1.0, 1.5, 2.0, 5.0, 10.0, 20.0, 30.0, 40.0, 80.0]:
        Xt_col = []
        score_vec_col = defaultdict(list)
        for rep in trange(Nreps, desc=f"sigma {sigma} rep"):
            Xt = sample_Xt_batch(edm_Xmat, batch_size, sigma=sigma).to(device)
            with torch.no_grad():
                edm_Dt = edm(Xt.view(-1, *edm_imgshape), torch.tensor(sigma).cuda(), None, ).detach().cpu()
            edm_Dt = edm_Dt.view(Xt.shape)
            score_edm = (edm_Dt - Xt.cpu()) / (sigma**2)
            score_vec_col["EDM"].append(score_edm)
            Xt_col.append(Xt.cpu())
            for score_name, analy_score_func in score_func_col.items():
                score_vec_col[score_name].append(analy_score_func(Xt, sigma))
            for n_clusters in reversed([1, 2, 5, 10, 20,]): #  50, 100, 
                gmm_scores = gaussian_mixture_score_batch_sigma_torch(Xt, 
                    mus_col[n_clusters].cuda(), Us_col[n_clusters].cuda(), Lambdas_col[n_clusters].cuda() + sigma**2, 
                    weights=weights_col[n_clusters].cuda()).cpu()
                score_vec_col[f"gmm_{n_clusters}_mode"].append(gmm_scores)
                torch.cuda.empty_cache()
        
        Xt_all = torch.cat(Xt_col, dim=0).cuda()
        for score_name, score_vec_list in score_vec_col.items():
            score_vec_col[score_name] = torch.cat(score_vec_list, dim=0)
        
        torch.cuda.empty_cache()
        
        score_edm = score_vec_col["EDM"].cuda()
        edm_Dt = score_edm * (sigma**2) + Xt_all
        for score_name, score in score_vec_col.items():
            score = score.to(device)
            Dnoiser = score * (sigma**2) + Xt_all
            exp_var_vec = explained_var_vec(score_edm, score)
            exp_var_rev_vec = explained_var_vec(score, score_edm)
            exp_var_vec_Dt = explained_var_vec(edm_Dt, Dnoiser)
            exp_var_rev_vec_Dt = explained_var_vec(Dnoiser, edm_Dt)
            St_var_vec = score.pow(2).sum(dim=1)
            Dt_var_vec = Dnoiser.pow(2).sum(dim=1)
            df_col.append({"epoch": epoch, "sigma": sigma, "name": score_name, 
                        "St_EV": exp_var_vec.mean().item(), 
                        "St_EV_std": exp_var_vec.std().item(),
                        "St_EV_rev": exp_var_rev_vec.mean().item(), 
                        "St_EV_rev_std": exp_var_rev_vec.std().item(),
                        "Dt_EV": exp_var_vec_Dt.mean().item(), 
                        "Dt_EV_std": exp_var_vec_Dt.std().item(),
                        "Dt_EV_rev": exp_var_rev_vec_Dt.mean().item(),
                        "Dt_EV_rev_std": exp_var_rev_vec_Dt.std().item(),
                        "St_Var": St_var_vec.mean().item(),
                        "St_Var_std": St_var_vec.std().item(),
                        "Dt_Var": Dt_var_vec.mean().item(), 
                        "Dt_Var_std": Dt_var_vec.std().item(),})
        torch.cuda.empty_cache()
        
    df_syn = pd.DataFrame(df_col)
    df_syn.to_csv("FFHQ_edm_5k_epoch_gmm_exp_var_part.csv")

df_syn = pd.DataFrame(df_col)
df_syn["St_residual"] = 1 - df_syn["St_EV"]
df_syn["St_rev_residual"] = 1 - df_syn["St_EV_rev"]
df_syn["Dt_residual"] = 1 - df_syn["Dt_EV"]
df_syn["Dt_rev_residual"] = 1 - df_syn["Dt_EV_rev"]
df_syn.to_csv("FFHQ_edm_5k_epoch_gmm_exp_var.csv")
df_syn.to_csv(join(train_root, ckptname, "FFHQ_edm_5k_epoch_gmm_exp_var.csv"))