In [None]:
from glob import glob
from os.path import exists, join, basename
from copy import deepcopy
from tqdm import tqdm
from json import load, dump
from sys import argv
from matplotlib import pyplot as plt
from scipy.spatial.distance import cosine
from scipy.stats import norm
from scipy import stats
from PIL import Image, ImageDraw
from collections import Counter
from sklearn.neighbors import KernelDensity
from functools import reduce

import re
import os
import shutil
import time
import random
import tarfile
import zipfile
import pickle
import ndjson

import pandas as pd
import numpy as np

SEED = 20220101

# plt.style.use('ggplot')
# plt.rcParams['figure.dpi'] = 300

In [None]:
PROMPT_EMB_PATH = "/nvmescratch/diffusiondb/prompt-emb/prompt-emb.npz"
IMAGE_EMB_DIR = "/nvmescratch/diffusiondb/img-emb"
PARQUET_PATH = "/nvmescratch/diffusiondb/metadata.parquet"

ZIP_DIR1 = "/project/diffusiondb-hugging/diffusiondb-large-part-1/"
ZIP_DIR2 = "/project/diffusiondb-hugging/diffusiondb-large-part-2/"

WORK_DIR = "/nvmescratch/diffusiondb/"
WORKING_IMAGE_DIR = "/nvmescratch/diffusiondb/images/"
REMOTE_IMAGE_DIR = '/project/diffusiondb/images/'

OUTPUT_DIR = join(WORK_DIR, 'outputs')

### Transform Image Embedding to UMAP

In [None]:
metadata_df = pd.read_parquet(
    PARQUET_PATH,
    columns=[
        "image_name",
        "part_id",
        "prompt",
        "cfg",
        "step",
        "sampler",
        "width",
        "height",
        "seed",
        "image_nsfw",
    ],
)
print(metadata_df.shape)
metadata_df.head()


In [None]:
prompts = set([p.lower() for p in metadata_df['prompt']])

In [None]:
# count = 0
# limit = 300
# for p in prompts:
#     if 'dying' in p and '' in p and count < limit:
#         print(p)
#         print()
#         count += 1

In [None]:
# Find unique hyperparameter + prompt pairs
image_tuples = []

for row in tqdm(metadata_df.itertuples(), total=len(metadata_df)):
    cur_tuple = (row[3], row[4], row[5], row[6], row[7], row[8], row[9])
    image_tuples.append(cur_tuple)

In [None]:
image_tuples = set(image_tuples)

In [None]:
target_names = set()
name_to_index = {}

for row in tqdm(metadata_df.itertuples(), total=len(metadata_df)):
    cur_tuple = (row[3], row[4], row[5], row[6], row[7], row[8], row[9])
    name_to_index[row[1]] = row[0]
    
    if cur_tuple in image_tuples:
        target_names.add(row[1])

In [None]:
selected_names = []
selected_embs = []

counter = 0
limit = 600

for f in tqdm(glob(join(IMAGE_EMB_DIR, '*.npz')), total=limit):
    try:
        cur_emb = np.load(f)
        images_name = cur_emb['images_name']
        images_emb = cur_emb['images_emb']
        
        for i, name in enumerate(images_name):
            if name in target_names:
                selected_names.append(name)
                selected_embs.append(images_emb[i, :])

    except Exception as e:
        print(e, f)
        
    counter += 1
    if counter > limit:
        break

In [None]:
# Temp testing

selected_names = []
selected_embs = []

counter = 0
limit = len(glob(join(IMAGE_EMB_DIR, '*.npz'))) + 1
dir_count = {}

for f in tqdm(glob(join(IMAGE_EMB_DIR, '*.npz')), total=limit):
    try:
        cur_count = 0
        cur_emb = np.load(f)
        images_name = cur_emb['images_name']

        
        for i, name in enumerate(images_name):
            if name in target_names:
                cur_count += 1
                selected_names.append(name)
                
        dir_count[basename(f)] = cur_count

    except Exception as e:
        print(e, f)
        
    counter += 1
    if counter > limit:
        break

In [None]:
dir_count_pair = list(zip(dir_count.keys(), dir_count.values()))
dir_count_pair.sort(key=lambda x: x[1])
np.sum(list(dir_count.values()))

In [None]:
len(target_names)

In [None]:
# np.savez_compressed(join(OUTPUT_DIR, 'unique_img_embeddings.npz'), names=selected_names, embs=selected_embs)

In [None]:
# data = np.load(join(OUTPUT_DIR, 'unique_img_embeddings.npz'))
# # selected_names = []
# # selected_embs = []

In [None]:
rng = np.random.RandomState(SEED)
target_size = 150000
random_index = rng.choice(len(selected_embs), target_size, replace=False)

subset_embs = []
subset_names = []
for i in random_index:
    subset_embs.append(selected_embs[i])
    subset_names.append(selected_names[i])
    
subset_embs = np.vstack(subset_embs)
subset_names = np.array(subset_names)

In [None]:
umap_data = pickle.load(open('./outputs/umap-18m.pickle', 'rb'))

In [None]:
start = 0
step = 5000
end = start + step

umap_names = []
umap_results = []
# upper_limit = selected_embs_mat.shape[0]
upper_limit = 150000

with tqdm(total=upper_limit) as pbar:
    while end <= upper_limit:
        umap_result = umap_data.transform(subset_embs[start: end, :])
        umap_results.append(umap_result)
        umap_names.append(subset_names[start: end])
        
        start += step
        end = min(start + step, len(subset_embs))
        pbar.update(step)

In [None]:
projected_name = []

for item in umap_names:
    for name in item:
        projected_name.append(name)
        
projected_name = np.array(projected_name)
projected_emb = np.vstack(umap_results)

np.savez_compressed('./outputs/image-umap-150k.npz', umap=projected_emb)
projected_emb.shape

In [None]:
# Find the prompts of these images
projected_prompts = []
projected_parameters = []

for name in tqdm(projected_name):
    cur_i = name_to_index[name]
    row = metadata_df.iloc[cur_i]
    projected_prompts.append(metadata_df['prompt'][cur_i])
    
    cur_parameters = f"cfg: {row['cfg']}, step: {row['step']}, sampler: {row['sampler']}, width: {row['width']}, height: {row['height']}, seed: {row['seed']}"
    projected_parameters.append(cur_parameters)

In [None]:
umap_data_short = [
    [
        round(float(projected_emb[i, 0]), 3),
        round(float(projected_emb[i, 1]), 3),
        projected_name[i],
        projected_prompts[i],
        projected_parameters[i],
    ]
    for i in range(len(projected_name))
]

with open("./outputs/image-umap-150k.ndjson", "w") as fp:
    ndjson.dump(umap_data_short, fp)


In [None]:
def plot_umap(cur_prompts_emb):
    projected_df = pd.DataFrame(
        {
            "x": cur_prompts_emb[:, 0],
            "y": cur_prompts_emb[:, 1],
        }
    )

    umap_hw_ratio = (np.max(projected_df["y"]) - np.min(projected_df["y"])) / (
        np.max(projected_df["x"]) - np.min(projected_df["x"])
    )

    # Ignore far-away outliers
    y_mean = np.mean(projected_df["y"])
    y_std = np.std(projected_df["y"])

    x_mean = np.mean(projected_df["x"])
    x_std = np.std(projected_df["x"])

    plt.scatter(
        projected_df['x'],
        projected_df['y'],
        s=0.3,
        alpha=0.3,
        c='steelblue',
        edgecolors='none'
    )
    sigma_scale = 4
    plt.xlim((x_mean - x_std * sigma_scale, x_mean + x_std * sigma_scale))
    plt.ylim((y_mean - y_std * sigma_scale, y_mean + y_std * sigma_scale))
    plt.title(f'UMAP {len(cur_prompts_emb)} Image Embeddings')

    plot_dir = join(WORK_DIR, 'plots')
    plt.savefig(
        join(plot_dir, f"image-umap-{len(projected_df)}.jpg"),
        dpi=300,
        bbox_inches='tight'
    )
    plt.show()

In [None]:
plot_umap(projected_emb)

In [None]:
projected_emb_all = np.load("./outputs/image-14m-umap.npz")['umap']

In [None]:
rng = np.random.RandomState(SEED)
target_size = 2000000
random_index = rng.choice(len(projected_emb_all), target_size, replace=False)
projected_emb = projected_emb_all[random_index]

In [None]:
# # Compute the bandwidth using silverman's rule
n, d = projected_emb.shape
bw = (n * (d + 2) / 4.)**(-1. / (d + 4))

# # Scott's rule
# bw = n**(-1./(d+4))

# from sklearn.model_selection import GridSearchCV

# kde_cv = GridSearchCV(
#     KernelDensity(),
#     {'bandwidth': np.linspace(0.1, 1.0, 30)},
#     cv=5,
#     verbose=2
# )

# kde_cv.fit(projected_emb)

kde = KernelDensity(kernel='gaussian', bandwidth=bw)
kde.fit(projected_emb[:, :])

In [None]:
# xs = projected_emb[:, 0]
# ys = projected_emb[:, 1]

# x_min, x_max = np.min(xs), np.max(xs)
# y_min, y_max = np.min(ys), np.max(ys)

# x_gap = x_max - x_min
# y_gap = y_max - y_min

# if x_gap > y_gap:
#     # Expand the larger range to leave some padding in the plots
#     x_min -= x_gap / 50
#     x_max += x_gap / 50
#     x_gap = x_max - x_min
    
#     # Regulate the 2D grid to be a square
#     y_min -= (x_gap - y_gap) / 2
#     y_max += (x_gap - y_gap) / 2
# else:
#     # Expand the larger range to leave some padding in the plots
#     y_min -= y_gap / 50
#     y_max += y_gap / 50
#     y_gap = y_max - y_min
    
#     # Regulate the 2D grid to be a square
#     x_min -= (y_gap - x_gap) / 2
#     x_max += (y_gap - x_gap) / 2

x_min, y_min, x_max, y_max = -17.16386748, -17.17534323, 16.75664148, 16.74516573

# Estimate on a 2D grid
grid_size = 200
grid_xs = np.linspace(x_min, x_max, grid_size)
grid_ys = np.linspace(y_min, y_max, grid_size)
xx, yy = np.meshgrid(grid_xs, grid_ys)

grid = np.vstack([xx.ravel(), yy.ravel()]).transpose()
grid.shape

In [None]:
print(x_min, y_min, x_max, y_max)

In [None]:
log_density = kde.score_samples(grid)
log_density = np.exp(log_density)
grid_density = np.reshape(log_density, xx.shape)
grid_density.shape

In [None]:
fig = plt.figure()
ax = fig.gca()

ax.set_xlim(x_min, x_max)
ax.set_ylim(y_min, y_max)

# Contourf plot
ax.set_title(f'KDE on {grid_density.shape[0]} Grid of {projected_emb.shape[0]} Prompts (bw={bw:.2f})')
cfset = ax.contourf(xx, yy, grid_density.round(4),
                    levels=np.linspace(0, np.max(grid_density), 20),
                    cmap='Blues',
                    alpha=1)

In [None]:
x_min, x_max, y_min, y_max = float(x_min), float(x_max), float(y_min), float(y_max)

grid_density_json = {
    'grid': grid_density.astype(float).round(4).tolist(),
    'xRange': [x_min, x_max],
    'yRange': [y_min, y_max],
}
dump(grid_density_json, open(join(OUTPUT_DIR, 'umap-image-1m-grid.json'), 'w'))

## Visualize the Change Between Prompt and Image Embeddings

In [None]:
prompt_umap_data = load(open('./outputs/umap-1m.json', 'r'))

In [None]:
prompt_to_index = {}
for i in tqdm(range(len(prompt_umap_data['xs']))):
    prompt_to_index[prompt_umap_data['prompts'][i]] = i

In [None]:
image_prompt_umap = {
    'image_xs': [],
    'image_ys': [],
    'prompt_xs': [],
    'prompt_ys': []
}


error = 0

for i in tqdm(range(projected_emb.shape[0])):
    image_x = projected_emb[i, 0]
    image_y = projected_emb[i, 1]
    
    try:
        prompt = projected_prompts[i].lower()
        prompt_i = prompt_to_index[prompt]
        prompt_x = prompt_umap_data['xs'][i]
        prompt_y = prompt_umap_data['ys'][i]
        
        image_prompt_umap['image_xs'].append(image_x)
        image_prompt_umap['image_ys'].append(image_y)
        image_prompt_umap['prompt_xs'].append(prompt_x)
        image_prompt_umap['prompt_ys'].append(prompt_y)
        
    except KeyError:
        error += 1        

In [None]:
image_umap_array = np.c_[image_prompt_umap['image_xs'], image_prompt_umap['image_ys']]
prompt_umap_array = np.c_[image_prompt_umap['prompt_xs'], image_prompt_umap['prompt_ys']]

In [None]:
lines = np.c_[image_umap_array, prompt_umap_array]
lines = lines.reshape(-1, 2, 2).swapaxes(1, 2).reshape(-1, 2)

In [None]:
plt.title(f'Lines from Image Embedding to Prompt Embedding ({10000})')

for i in range(len(image_umap_array[:10000])):
    plt.plot(image_umap_array[i], prompt_umap_array[i], color='steelblue', linewidth=0.5, alpha=0.06)

plt.show()