In [None]:
from scipy.stats import gaussian_kde
from sklearn.neighbors import KernelDensity
from KDEpy import FFTKDE
from scipy.stats import norm

from PIL import Image
from glob import glob
from os.path import exists, join, basename
from tqdm import tqdm
from json import load, dump
from multiprocessing import Pool
from umap import UMAP
from matplotlib import pyplot as plt
import pyarrow.feather as feather

import time
import shutil
import gc
import random
import math
import cuml
import matplotlib

import numpy as np
import pandas as pd
import altair as alt
alt.data_transformers.disable_max_rows()

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

matplotlib.rcParams["figure.dpi"] = 300

SEED = 20221111
WORK_DIR = "/nvmescratch/diffusiondb"
OUTPUT_DIR = join(WORK_DIR, 'outputs')
PROMPT_EMB_DIR = "/nvmescratch/diffusiondb/prompts/"

In [None]:
# # Get UMAP plot of prompt embedding
# n_parts = 5
# prompts = []
# prompts_emb = []

# for i in tqdm(range(n_parts)):
#     prompt_emb = np.load(join(PROMPT_EMB_DIR, f'prompt-emb-part-{i + 1}-of-19.npz'))
#     prompts.append(prompt_emb['prompts'])
#     prompts_emb.append(prompt_emb['emb'])

# prompts = np.concatenate(prompts, axis=0)
# prompts_emb = np.concatenate(prompts_emb, axis=0)
# prompts_emb.shape

# prompt_num = 50000
# rng = np.random.RandomState(SEED)
# random_indexes = rng.choice(range(prompts_emb.shape[0]), prompt_num, replace=False)

# cur_prompts = prompts[random_indexes]
# cur_prompts_emb = prompts_emb[random_indexes, :]
# cur_prompts_emb.shape

# reducer = UMAP(
#     metric='cosine',
#     n_neighbors=60,
#     min_dist=0.1,
#     spread=1.0,
#     n_components=2,
#     verbose=False,
#     random_state=SEED
# )

# projected_emb= reducer.fit_transform(cur_prompts_emb)

In [None]:
umap_1m_df = pd.read_csv(join(OUTPUT_DIR, 'umap-1m.csv'))
print(umap_1m_df.shape)

cur_df = umap_1m_df.sample(60000, replace=False, random_state=SEED)
cur_df.shape

In [None]:
xs = cur_df['xs'].to_numpy()
ys = cur_df['ys'].to_numpy()
prompts = cur_df['prompts'].to_numpy()

projected_emb = np.stack((xs, ys), axis=1)

In [None]:
plt.scatter(xs, ys, s=1.0, alpha=0.2, c='steelblue', edgecolors='none')
plt.show()

In [None]:
# df = pd.DataFrame({'x': projected_emb[:, 0], 'y': projected_emb[:, 1]})
# df.to_csv('test-data-2d.csv', index=False)

In [None]:
# # Compute the bandwidth using silverman's rule
n, d = projected_emb.shape
bw = (n * (d + 2) / 4.)**(-1. / (d + 4))

# # Scott's rule
# bw = n**(-1./(d+4))

# from sklearn.model_selection import GridSearchCV

# kde_cv = GridSearchCV(
#     KernelDensity(),
#     {'bandwidth': np.linspace(0.1, 1.0, 30)},
#     cv=5,
#     verbose=2
# )

# kde_cv.fit(projected_emb)

kde = KernelDensity(kernel='gaussian', bandwidth=bw)
kde.fit(projected_emb[:, :])

In [None]:
xs = projected_emb[:, 0]
ys = projected_emb[:, 1]

x_min, x_max = np.min(xs), np.max(xs)
y_min, y_max = np.min(ys), np.max(ys)

x_gap = x_max - x_min
y_gap = y_max - y_min

if x_gap > y_gap:
    # Expand the larger range to leave some padding in the plots
    x_min -= x_gap / 50
    x_max += x_gap / 50
    x_gap = x_max - x_min
    
    # Regulate the 2D grid to be a square
    y_min -= (x_gap - y_gap) / 2
    y_max += (x_gap - y_gap) / 2
else:
    # Expand the larger range to leave some padding in the plots
    y_min -= y_gap / 50
    y_max += y_gap / 50
    y_gap = y_max - y_min
    
    # Regulate the 2D grid to be a square
    x_min -= (y_gap - x_gap) / 2
    x_max += (y_gap - x_gap) / 2

# Estimate on a 2D grid
grid_size = 200
grid_xs = np.linspace(x_min, x_max, grid_size)
grid_ys = np.linspace(y_min, y_max, grid_size)
xx, yy = np.meshgrid(grid_xs, grid_ys)

grid = np.vstack([xx.ravel(), yy.ravel()]).transpose()
grid.shape

In [None]:
print(x_min, x_max, y_min, y_max)

In [None]:
# # Scipy
# kde_model = gaussian_kde(projected_emb.T, bw_method='silverman')
# log_density = kde_model.evaluate(grid.T)
# grid_density = np.reshape(log_density, xx.shape)
# grid_density.shape

In [None]:
# Sklearn
log_density = kde.score_samples(grid)
log_density = np.exp(log_density)
grid_density = np.reshape(log_density, xx.shape)
grid_density.shape

In [None]:
fig = plt.figure(figsize=(10, 3))
plt.hist(grid_density.reshape(-1), bins=20)
fig.show()

In [None]:
grid_density.round(4)

In [None]:
fig = plt.figure()
ax = fig.gca()

ax.set_xlim(x_min, x_max)
ax.set_ylim(y_min, y_max)

# Contourf plot
ax.set_title(f'KDE on {grid_density.shape[0]} Grid of {projected_emb.shape[0]} Prompts (bw={bw:.2f})')
cfset = ax.contourf(xx, yy, grid_density.round(4),
                    levels=np.linspace(0, np.max(grid_density), 20),
                    cmap='Blues',
                    alpha=1)

# ax.scatter(projected_emb[:, 0], projected_emb[:, 1], s=0.5, alpha=0.2, c='black', edgecolors='none')


In [None]:
grid_density.shape

In [None]:
def locate_cell(x, y, x_min, x_max, y_min, y_max, grid_size):
    """
    Locate a cell where the given point (x, y) falls into.
    """
    
    # Handle out-of-bound point
    if x <= x_min:
        x_i = 0
    elif x >= x_max:
        x_i = grid_size - 1
    else:
        x_step = (x_max - x_min) / grid_size
        x_i = int((x - x_min) / x_step)
        
    if y <= y_min:
        y_i = 0
    elif y >= y_max:    
        y_i = grid_size - 1
    else:
        y_step = (y_max - y_min) / grid_size
        y_i = int((y - y_min) // y_step)
    
    return x_i, y_i

In [None]:
result = locate_cell(4, -2, x_min, x_max, y_min, y_max, 200)
result

In [None]:
xs = projected_emb[:, 0]
ys = projected_emb[:, 1]

x_min, x_max = np.min(xs), np.max(xs)
y_min, y_max = np.min(ys), np.max(ys)

grid_xs = np.linspace(x_min, x_max, grid_size)
grid_ys = np.linspace(y_min, y_max, grid_size)
xx, yy = np.meshgrid(grid_xs, grid_ys)

grid = np.vstack([xx.ravel(), yy.ravel()]).transpose()
grid.shape

In [None]:
grid_ys[54]

## Export the Data

In [None]:
umap_60k = {
    'xs': projected_emb[:, 0].astype(float).round(4).tolist(),
    'ys': projected_emb[:, 1].astype(float).round(4).tolist(),
    'prompts': prompts.tolist()
}

dump(umap_60k, open(join(OUTPUT_DIR, 'umap-60k.json'), 'w'))

In [None]:
# umap_50k = {
#     'xs': projected_emb[:, 0],
#     'ys': projected_emb[:, 1],
#     'prompts': cur_prompts.tolist()
# }

# umap_50k_df = pd.DataFrame(umap_50k)
# feather.write_feather(umap_50k_df, join(OUTPUT_DIR, 'umap-50k.feather'), compression="uncompressed")

In [None]:
x_min, x_max, y_min, y_max = float(x_min), float(x_max), float(y_min), float(y_max)

grid_density_json = {
    'grid': grid_density.astype(float).round(4).tolist(),
    'xRange': [x_min, x_max],
    'yRange': [y_min, y_max],
}
dump(grid_density_json, open(join(OUTPUT_DIR, 'umap-60k-grid.json'), 'w'))