In [None]:
from PIL import Image
from glob import glob
from os.path import exists, join, basename
from tqdm import tqdm
from json import load, dump
from multiprocessing import Pool
from umap import UMAP
from matplotlib import pyplot as plt

import time
import shutil
import gc
import random
import math
import cuml
import matplotlib
import pickle

import numpy as np
import pandas as pd
import altair as alt
alt.data_transformers.disable_max_rows()

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

matplotlib.rcParams["figure.dpi"] = 300

SEED = 20221111
WORK_DIR = "/nvmescratch/diffusiondb"
OUTPUT_DIR = join(WORK_DIR, 'outputs')
PROMPT_EMB_DIR = "/nvmescratch/diffusiondb/prompts/"

In [None]:
# Get UMAP plot of prompt embedding
n_parts = 19
prompts = []
prompts_emb = []

for i in tqdm(range(n_parts)):
    prompt_emb = np.load(join(PROMPT_EMB_DIR, f'prompt-emb-part-{i + 1}-of-19.npz'))
    prompts.append(prompt_emb['prompts'])
    prompts_emb.append(prompt_emb['emb'])


In [None]:
prompts = np.concatenate(prompts, axis=0)
prompts_emb = np.concatenate(prompts_emb, axis=0)

In [None]:
prompts_emb.shape

In [None]:
# nn = 60
# min_dist = 0.9

# reducer = UMAP(
#     n_neighbors=nn,
#     min_dist=min_dist,
#     spread=5.0,
#     metric='cosine',
#     n_components=2,
#     verbose=True,
#     random_state=SEED
# )

# projected_emb = reducer.fit_transform(prompts_emb[random_indexes, :])

In [None]:
# nn_candidates = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100]
# mdist_candidates = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
# spreads = [0.5, 1.0, 5.0, 10.0]

# params = []
# for nn in nn_candidates:
#     for mdist in mdist_candidates:
#         for s in spreads:
#             params.append((nn, mdist, s))
            

In [None]:
def plot_umap(cur_prompts, cur_prompts_emb, nn, min_dist, spread, param_id=1):
    reducer_cuml = cuml.UMAP(
        n_neighbors=nn,
        min_dist=min_dist,
        metric='cosine',
        spread=spread,
        n_components=2,
        verbose=False,
        random_state=SEED
    )

    # Fit UMAP
    projected_emb_cuml = reducer_cuml.fit_transform(cur_prompts_emb)
    
    # Plot UMAP
    projected_df = pd.DataFrame(
        {
            "x": projected_emb_cuml[:, 0],
            "y": projected_emb_cuml[:, 1],
            "prompt": cur_prompts,
        }
    )

    umap_hw_ratio = (np.max(projected_df["y"]) - np.min(projected_df["y"])) / (
        np.max(projected_df["x"]) - np.min(projected_df["x"])
    )

    # Ignore far-away outliers
    y_mean = np.mean(projected_df["y"])
    y_std = np.std(projected_df["y"])

    x_mean = np.mean(projected_df["x"])
    x_std = np.std(projected_df["x"])


    plt.scatter(
        projected_df['x'],
        projected_df['y'],
        s=0.1,
        alpha=0.03,
        c='steelblue',
        edgecolors='none'
    )
    sigma_scale = 4
    plt.xlim((x_mean - x_std * sigma_scale, x_mean + x_std * sigma_scale))
    plt.ylim((y_mean - y_std * sigma_scale, y_mean + y_std * sigma_scale))
    plt.title(f'UMAP {len(cur_prompts)} Prompts (nn={nn}, mdist={min_dist}, spread={spread})')

    plot_dir = join(WORK_DIR, 'plots')
    plt.savefig(
        join(plot_dir, f"umap{param_id:03}-nn={nn}-mdist={min_dist}-spread={spread}.jpg"),
        dpi=300,
        bbox_inches='tight'
    )
    
    return reducer_cuml, projected_df

In [None]:
nn = 60
min_dist = 0.1
spread = 1.0

# Randomly select a subset of prompt embedding to train the UMAP
# prompt_num = 300000
# rng = np.random.RandomState(SEED)
# random_indexes = rng.choice(range(prompts_emb.shape[0]), prompt_num, replace=False)
# random_indexes

# cur_prompts = prompts[random_indexes]
# cur_prompts_emb = prompts_emb[random_indexes, :]

cur_prompts = prompts
cur_prompts_emb = prompts_emb

reducer_cuml, projected_df = plot_umap(cur_prompts, cur_prompts_emb, nn, min_dist, spread, param_id=1)

In [None]:
projected_df.to_csv(join(OUTPUT_DIR, 'umap-18m.csv'))

In [None]:
with open(join(OUTPUT_DIR, 'umap-18m.pickle'), 'wb') as fp:
    pickle.dump(reducer_cuml, fp)

In [None]:
# cur_projected = projected_emb_cuml

# projected_df = pd.DataFrame(
#     {
#         "x": cur_projected[:, 0],
#         "y": cur_projected[:, 1],
#         "prompt": prompts[random_indexes],
#     }
# )

# umap_hw_ratio = (np.max(projected_df["y"]) - np.min(projected_df["y"])) / (
#     np.max(projected_df["x"]) - np.min(projected_df["x"])
# )

# y_mean = np.mean(projected_df["y"])
# y_std = np.std(projected_df["y"])

# x_mean = np.mean(projected_df["x"])
# x_std = np.std(projected_df["x"])

In [None]:
# param_id = 1

# plt.scatter(projected_df['x'], projected_df['y'], s=0.8, alpha=0.06, c='steelblue', edgecolors='none')
# sigma_scale = 4
# plt.xlim((x_mean - x_std * sigma_scale, x_mean + x_std * sigma_scale))
# plt.ylim((y_mean - y_std * sigma_scale, y_mean + y_std * sigma_scale))
# plt.title(f'UMAP {len(random_indexes)} Prompts (nn={nn}, mdist={min_dist}, spread={spread})')

# plot_dir = join(WORK_DIR, 'plots')
# plt.savefig(
#     join(plot_dir, f"umap{param_id:03}-nn={nn}-mdist={min_dist}-spread={spread}.jpg"),
#     dpi=300,
#     bbox_inches='tight'
# )
# plt.show()

In [None]:
# alt.Chart(projected_df).mark_circle(
#     color='steelblue',
#     opacity=0.3
# ).encode(
#     x='x:Q',
#     y='y:Q',
#     tooltip='prompt:N'
# ).properties(
#     title=f'UMAP Plot of {projected_emb.shape[0]} Points',
#     width=900,
#     height=int(umap_hw_ratio * 900)
# ).interactive()

## Topic Modeling

In [None]:
from bertopic import BERTopic
from hdbscan import HDBSCAN

In [None]:
topic_reducer = UMAP(
    n_neighbors=15,
    min_dist=0.1,
    spread=1.0,
    metric='cosine',
    n_components=5,
    verbose=True
)

emb_dim5 = topic_reducer.fit_transform(prompts_emb[:prompt_num, :])

In [None]:
hdbscan_model = HDBSCAN(
    min_cluster_size=15,
    metric="euclidean",
    cluster_selection_method="eom",
    prediction_data=True,
)


In [None]:
hdbscan_model.fit(emb_dim5)

In [None]:
hdbscan_model.labels_

In [None]:
prompt_num = 5000
topic_model = BERTopic(verbose=True, min_topic_size=5)
topic_model.fit(prompts[:prompt_num], prompts_emb[:prompt_num, :])

In [None]:
topic_model.get_topic_info()