In [None]:
from glob import glob
from os.path import exists, join, basename
from copy import deepcopy
from tqdm import tqdm
from json import load, dump
from sys import argv
from matplotlib import pyplot as plt
from scipy.spatial.distance import cosine
from scipy.stats import norm
from scipy import stats
from PIL import Image, ImageDraw
from collections import Counter

import re
import os
import shutil
import time
import random
import tarfile
import zipfile
import pickle

import pandas as pd
import numpy as np

# plt.style.use('ggplot')
# plt.rcParams['figure.dpi'] = 300

In [None]:
PROMPT_EMB_PATH = "/nvmescratch/diffusiondb/prompt-emb/prompt-emb.npz"
IMAGE_EMB_DIR = "/nvmescratch/diffusiondb/img-emb"
PARQUET_PATH = "/nvmescratch/diffusiondb/metadata.parquet"

ZIP_DIR1 = "/project/diffusiondb-hugging/diffusiondb-large-part-1/"
ZIP_DIR2 = "/project/diffusiondb-hugging/diffusiondb-large-part-2/"

WORK_DIR = "/nvmescratch/diffusiondb/temp/"
WORKING_IMAGE_DIR = "/nvmescratch/diffusiondb/images/"
REMOTE_IMAGE_DIR = '/project/diffusiondb/images/'

## Compute Distances

In [None]:
metadata_df = pd.read_parquet(PARQUET_PATH, columns=['image_name', 'prompt'])
print(metadata_df.shape)
metadata_df.head()

In [None]:
# Create a image name -> index dictionary for faster query
image_name_dict = {}
for i in range(0, metadata_df.shape[0]):
    image_name_dict[metadata_df['image_name'][i]] = i

In [None]:
# Load prompt embedding
prompt_embs_data = np.load(PROMPT_EMB_PATH)
prompts = prompt_embs_data['prompts']
prompts_emb = prompt_embs_data['emb']

In [None]:
# Create a prompt -> index dictionary for faster query
prompt_dict = {}
for i in range(0, len(prompts)):
    prompt_dict[prompts[i]] = i

In [None]:
def compute_distance(part_id):
    errors = []
    distances = []
    
    # Load image embedding
    try:
        image_emb_data = np.load(join(IMAGE_EMB_DIR, f'part-{part_id:06}-image-emb.npz'))
        image_names = image_emb_data['images_name']
        image_emb = image_emb_data['images_emb']
    
    except:
        print('Major error')
        errors.append([part_id, '', -1])
        return distances, errors
    

    for (i, name) in enumerate(image_names):
        cur_image_emb = image_emb[i, :]
        
        # Identify the prompt's embedding
        try:
            cur_prompt = metadata_df['prompt'][image_name_dict[name]]
            cur_prompt_index = prompt_dict[cur_prompt]
            cur_prompt_emb = prompts_emb[cur_prompt_index, :]
            
            distances.append([
                name,
                cur_prompt_index,
                cosine(cur_image_emb, cur_prompt_emb),
            ])
            
        except KeyError:
            errors.append([part_id, name, -1])
    
    return distances, errors

In [None]:
distances = []
errors = []

for part_id in tqdm(range(1, 14001)):
    local_distances, local_errors = compute_distance(part_id)
    distances.extend(local_distances)
    errors.extend(local_errors)

    # Save the progress
    if part_id % 5000 == 0:
        pickle.dump(
            {"distances": distances, "errors": errors},
            open(f"./image_prompt_distance_{part_id:06}.pkl", 'wb'),
        )


## Fixing Distance Errors

In [None]:
metadata_df = pd.read_parquet(PARQUET_PATH, columns=['image_name', 'part_id', 'prompt', 'cfg', 'step', 'sampler'])
print(metadata_df.shape)
metadata_df.head()

In [None]:
# Load prompt embedding
prompt_embs_data = np.load(PROMPT_EMB_PATH)
prompts = prompt_embs_data['prompts']
prompts_emb = prompt_embs_data['emb']
prompts_set = set(prompts)

In [None]:
name_to_part = {}
name_to_index = {}

for i in tqdm(range(0, len(metadata_df))):
    name_to_part[metadata_df['image_name'][i]] = metadata_df['part_id'][i]
    name_to_index[metadata_df['image_name'][i]] = i

image_prompt_distance_data = pickle.load(
    open(f"./image_prompt_distance.pkl", 'rb')
)

distances = image_prompt_distance_data['distances']
errors = image_prompt_distance_data['errors']

In [None]:
len(errors)

In [None]:
# # Create a prompt -> index dictionary for faster query
# prompt_dict = {}
# for i in range(0, len(prompts)):
#     prompt_dict[prompts[i]] = i

In [None]:
# prev_prompts_set = set(prompts)
# no_emb_prompts = set()

# for error_prompt in tqdm(error_prompts):
#     if error_prompt.lower() not in prev_prompts_set:
#         no_emb_prompts.add(error_prompt)

In [None]:
# len(no_emb_prompts)

In [None]:
# # Compute CLIP embedding for these no emb prompts
# import torch
# from sentence_transformers import SentenceTransformer, util
# from transformers import CLIPTokenizer

# print(torch.cuda.device_count(), "GPUs")
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# print(device)

# # Load CLIP model
# model = SentenceTransformer("clip-ViT-L-14")

# print("Initial # of prompts:", len(no_emb_prompts))

# # Make all prompts lower case and only use unique prompts
# working_prompts = [p.lower() for p in no_emb_prompts]
# working_prompts = set(working_prompts)
# working_prompts = list(working_prompts)
# print("Unique # of prompts:", len(working_prompts))

# tokenizer = model._first_module().processor.tokenizer


# def truncate_sentence(sentence, tokenizer):
#     """
#     Truncate a sentence to fit the CLIP max token limit (77 tokens including the
#     starting and ending tokens).
#     Args:
#         sentence(string): The sentence to truncate.
#         tokenizer(CLIPTokenizer): Rretrained CLIP tokenizer.
#     """

#     cur_sentence = sentence
#     tokens = tokenizer.encode(cur_sentence)

#     if len(tokens) > 77:
#         # Skip the starting token, only include 75 tokens
#         truncated_tokens = tokens[1:76]
#         cur_sentence = tokenizer.decode(truncated_tokens)

#         # Recursive call here, because the encode(decode()) can have different
#         # result
#         return truncate_sentence(cur_sentence, tokenizer)

#     else:
#         return cur_sentence


# truncated_prompts = []
# for p in tqdm(working_prompts):
#     truncated_prompts.append(truncate_sentence(p, tokenizer))

# # Encode prompts
# working_prompt_emb = model.encode(
#     truncated_prompts, device=device, show_progress_bar=True, batch_size=512
# )

# # Save the embedding
# print(working_prompt_emb.shape)

# # np.savez_compressed(
# #     join(PROMPT_EMB_DIR, "prompt-emb.npz"),
# #     prompts=prompts,
# #     emb=prompt_emb,
# # )

## Less-alignment Analysis

In [None]:
metadata_df = pd.read_parquet(
    PARQUET_PATH,
    columns=[
        "image_name",
        "part_id",
        "prompt",
        "cfg",
        "step",
        "sampler",
        "width",
        "height",
        "image_nsfw",
    ],
)
print(metadata_df.shape)
metadata_df.head()


In [None]:
name_to_part = {}
name_to_index = {}
for i in tqdm(range(0, len(metadata_df))):
    name_to_part[metadata_df['image_name'][i]] = metadata_df['part_id'][i]
    name_to_index[metadata_df['image_name'][i]] = i

In [None]:
image_prompt_distance_data = pickle.load(
    open(f"./image_prompt_distance_all.pkl", 'rb')
)

distances = image_prompt_distance_data['distances']
errors = image_prompt_distance_data['errors']

In [None]:
# Distance tuple: (name, prompt index, cosine, l1, l2, l-infinity)
distance_tuples = {
    'Cosine': 2,
    'L1': 3,
    'L2': 4,
    'L-infinity': 5
}

In [None]:
# for k in distance_tuples:

k = 'Cosine'
distance_scores = [p[distance_tuples[k]] for p in distances]
plt.figure(figsize=(10, 5))
plt.grid(alpha=0.2)
plt.hist(distance_scores, bins=100, edgecolor='white', linewidth=0.5, alpha=0.9)
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
plt.xlabel("Number of tokens in prompt", fontsize=16)
plt.title(f'Distribution of the prompt/image CLIP {k} distance', fontsize=18)
plt.savefig("plots/cosine-dist.pdf", bbox_inches='tight')

plt.show()

# plt.clf()
# plt.cla()
# plt.close()

In [None]:
def get_image_path(name_to_part, name, existing_part_ids):
    """
    Get the path of an image by its name.
    """
    
    # Find the part id of this image
    part_id = name_to_part[name]
    
    if part_id in existing_part_ids:
        image_path = join(REMOTE_IMAGE_DIR, f'part-{part_id:06}', name)
        return image_path
    
    # Need to download the image's zip file first
    else:
        cur_zip = join(WORK_DIR, f"part-{part_id:06}.zip")
        cur_img_dir = join(WORKING_IMAGE_DIR, f"part-{part_id:06}")

        if not exists(cur_img_dir):
            # Download and extract the zip file
            if part_id > 100000:
                shutil.copyfile(
                    join(ZIP_DIR2, f"part-{part_id:06}.zip"),
                    cur_zip,
                )
            else:
                shutil.copyfile(
                    join(ZIP_DIR1, f"part-{part_id:06}.zip"),
                    cur_zip,
                )

            shutil.unpack_archive(cur_zip, cur_img_dir)
            
        image_path = join(WORKING_IMAGE_DIR, f'part-{part_id:06}', name)
        
    
    return image_path

In [None]:
# distance_method = 'Cosine'
# distance_i = distance_tuples[distance_method]
# sorted_pairs = sorted(distances, key=lambda x: x[distance_i], reverse=True)

In [None]:
# # Fit a gausian distribution
# mean, std = norm.fit([p[distance_i] for p in distances])
# print(mean, std)

# low_pairs = []
# HIGH_BAR = mean + 6 * std

# for p in sorted_pairs:
#     if p[distance_i] > HIGH_BAR:
#         low_pairs.append(p)
        
#     if p[distance_i] < HIGH_BAR:
#         break
    
# print(HIGH_BAR, len(low_pairs))

In [None]:
# # Generate images from these pairs into bad-images
# folders = glob("/project/diffusiondb/images/*")
# existing_part_ids = set(
#     [int(re.sub(r".*part-(\d+)", r"\1", f)) for f in folders if "json" not in f]
# )

# shutil.rmtree(join(WORKING_IMAGE_DIR, 'bad-images'))
# os.makedirs(join(WORKING_IMAGE_DIR, 'bad-images'))

# # Copy low distance images into one folder
# for p in tqdm(low_pairs[:100]):
#     cur_path = get_image_path(name_to_part, p[0], existing_part_ids)
#     local_path = join(WORKING_IMAGE_DIR, 'bad-images', basename(cur_path))\

#     if not exists(local_path):
#         shutil.copyfile(cur_path, local_path)
        
#     img = Image.open(local_path)
#     img.thumbnail((200, 200), Image.Resampling.LANCZOS)
    
#     df_index = name_to_index[p[0]]
#     prompt = metadata_df['prompt'][df_index]
    
#     display(img)
#     print(prompt)

#     # try:
#     #     canvas = ImageDraw.Draw(img)
#     #     canvas.text((10, 5), prompt[:40], fill=(255, 0, 0))
        
#     #     if len(prompt) > 40:
#     #         canvas.text((10, 15), prompt[40:80], fill=(255, 0, 0))
            
#     # except:
#     #     pass
    
#     # img.save(local_path)

We need to focus on misalignment examples where the `cfg_scale` is at least positive.

In [None]:
distance_method = 'Cosine'
distance_i = distance_tuples[distance_method]
sorted_pairs = sorted(distances, key=lambda x: x[distance_i], reverse=True)

In [None]:
# Fit a gausian distribution
mean, std = norm.fit([p[distance_i] for p in distances])
print(mean, std)

low_pairs = []
HIGH_BAR = mean + 4 * std

for p in sorted_pairs:
    if p[distance_i] > HIGH_BAR:
        # Check it's cfg score
        cur_i = name_to_index[p[0]]
        cfg = metadata_df['cfg'][cur_i]
        sampler = metadata_df['sampler'][cur_i]
        step = metadata_df['step'][cur_i]
        nsfw = metadata_df['image_nsfw'][cur_i]
        width = metadata_df['width'][cur_i]
        height = metadata_df['height'][cur_i]
        nsfw = metadata_df['image_nsfw'][cur_i]
        
        # if cfg > 1:
        if nsfw < 2:
            new_p = {
                'name': p[0],
                'index': p[1],
                'Cosine': p[2],
                'L1': p[3],
                'L2': p[4],
                'L-infinity': p[5],
                'cfg': cfg,
                'sampler': sampler,
                'step': step,
                'width': width,
                'height': height,
                'nsfw': nsfw
            }
            low_pairs.append(new_p)
        
    if p[distance_i] < HIGH_BAR:
        break
    
print(HIGH_BAR, len(low_pairs))

In [None]:
# # Chi-square test on the sampler distribution

# samplers = [p['sampler'] for p in low_pairs]
# counter_cur_samplers = Counter(samplers)

# f_obs = [0 for _ in range(9)]
# for k in counter_cur_samplers:
#     f_obs[k - 1] = counter_cur_samplers[k]
    
# counter_all_samplers = Counter(metadata_df['sampler'])
# f_exp = [0 for _ in range(9)]
# for k in counter_all_samplers:
#     f_exp[k - 1] = len(samplers) * (counter_all_samplers[k] / len(metadata_df))
    
# results = stats.chisquare(f_obs=f_obs, f_exp=f_exp)
# results


#### Check `step`

In [None]:
# steps = [p['step'] for p in low_pairs if p['sampler'] == 8 and p['nsfw'] < 2]

# plt.title('Step distribution of 4 sigma away images')
# plt.hist(steps, bins=50)
# plt.show()

### Check size

In [None]:
# lengths = [min(p['height'], p['width']) for p in low_pairs if p['sampler'] == 8 and p['nsfw'] < 2]

# plt.title('Min(height, width) distribution of 4 sigma away images')
# plt.hist(lengths, bins=50)
# plt.show()

### Regression to test these variables

Treat bad image as a binary variable, and test the correlation with all parameters.

In [None]:
import statsmodels.api as sm
import statsmodels.formula.api as smf
from sklearn import linear_model
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from scipy import stats

In [None]:
# # Create the y variable
# ys = [0 for _ in range(len(metadata_df))]

# for p in low_pairs:
#     ys[p['index']] = p['Cosine']
    
# # Create the x matrix

# # cfg, step, sampler, min(width, height)
# ols_df_dict = {
#     'y': ys,
#     'cfg': metadata_df['cfg'],
#     'step': metadata_df['step'],
#     'sampler': metadata_df['sampler'],
#     'length':  np.min(np.array([metadata_df['width'].to_numpy(), metadata_df['height'].to_numpy()]), axis=0)
# }
# ols_df = pd.DataFrame.from_dict(ols_df_dict).dropna()
# ols_df = ols_df[ols_df['cfg'] < 1e10]
# print(ols_df.shape)
# ols_df.head()

# model = smf.ols(formula="y ~ cfg + step + C(sampler) + length", data=ols_df).fit(maxiter=100)
# model.summary()

In [None]:
# # Create a logistic regression 
# ys = [0 for _ in range(len(metadata_df))]

# for p in low_pairs:
#     ys[p['index']] = 1
    
# # Create the x matrix

# # cfg, step, sampler, min(width, height)
# ols_df_dict = {
#     'y': ys,
#     'cfg': metadata_df['cfg'],
#     'step': metadata_df['step'],
#     'sampler': metadata_df['sampler'],
#     # 'length':  np.min(np.array([metadata_df['width'].to_numpy(), metadata_df['height'].to_numpy()]), axis=0)
#     'width':  metadata_df['width'],
#     'height':  metadata_df['height'],
# }
# ols_df = pd.DataFrame.from_dict(ols_df_dict).dropna()
# ols_df = ols_df[ols_df['cfg'] < 1000]
# ols_df = ols_df[ols_df['cfg'] > -1000]
# print(ols_df.shape)
# ols_df.head()

# pos_df = ols_df[ols_df['y'] == 1]
# neg_df = ols_df[ols_df['y'] == 0]

# sampled_df = pd.concat([pos_df, neg_df.sample(pos_df.shape[0] * 10, replace=False)])
# sampled_df = pd.concat([pos_df, neg_df])

# scaler = StandardScaler().fit(sampled_df)
# result = scaler.transform(sampled_df)

# normed_df = pd.DataFrame({
#     'y': sampled_df['y'],
#     'cfg': result[:, 1],
#     'step': result[:, 2],
#     'sampler': sampled_df['sampler'],
#     'length': result[:, 4],
# })

# # model = smf.logit(formula="y ~ cfg + step + C(sampler) + length", data=ols_df).fit(method='nm', maxiter=600)
# model = smf.logit(formula="y ~ cfg + step + C(sampler) + width + height", data=ols_df).fit(method='nm', maxiter=600)
# model.summary()

### Analyze Bad Images after Controlling Parameters

In [None]:
distance_method = 'Cosine'
distance_i = distance_tuples[distance_method]
sorted_pairs = sorted(distances, key=lambda x: x[distance_i], reverse=True)

In [None]:
# Fit a gausian distribution
mean, std = norm.fit([p[distance_i] for p in distances])
print(mean, std)

low_pairs = []
HIGH_BAR = mean + 4 * std

for p in sorted_pairs:
    if p[distance_i] > HIGH_BAR:
        cur_i = name_to_index[p[0]]
        cfg = metadata_df['cfg'][cur_i]
        sampler = metadata_df['sampler'][cur_i]
        step = metadata_df['step'][cur_i]
        nsfw = metadata_df['image_nsfw'][cur_i]
        width = metadata_df['width'][cur_i]
        height = metadata_df['height'][cur_i]
        nsfw = metadata_df['image_nsfw'][cur_i]
        prompt = metadata_df['prompt'][cur_i]
        
        if nsfw < 2 and abs(cfg - 7) < 5 and sampler == 8 and min(width, height) >= 512 and step > 10:
            new_p = {
                'name': p[0],
                'index': p[1],
                'Cosine': p[2],
                'L1': p[3],
                'L2': p[4],
                'L-infinity': p[5],
                'cfg': cfg,
                'sampler': sampler,
                'step': step,
                'width': width,
                'height': height,
                'nsfw': nsfw,
                'prompt': prompt
            }
            low_pairs.append(new_p)
        
    if p[distance_i] < HIGH_BAR:
        break
    
print(HIGH_BAR, len(low_pairs))

## Generate Images for Error Figure

In [None]:
# Fit a gausian distribution
mean, std = norm.fit([p[distance_i] for p in distances])
print(mean, std)

HIGH_BAR = mean + 4 * std

In [None]:
low_pairs = []

for p in sorted_pairs:
    if p[distance_i] > HIGH_BAR:
        cur_i = name_to_index[p[0]]
        cfg = metadata_df['cfg'][cur_i]
        sampler = metadata_df['sampler'][cur_i]
        step = metadata_df['step'][cur_i]
        nsfw = metadata_df['image_nsfw'][cur_i]
        width = metadata_df['width'][cur_i]
        height = metadata_df['height'][cur_i]
        nsfw = metadata_df['image_nsfw'][cur_i]
        prompt = metadata_df['prompt'][cur_i]
        
        if nsfw < 2 and abs(cfg - 7) < 5 and step == 50 and width == 512 and height == 512 and len(prompt) < 5 and sampler == 8 and '😂' in prompt:
            new_p = {
                'name': p[0],
                'index': p[1],
                'Cosine': p[2],
                'L1': p[3],
                'L2': p[4],
                'L-infinity': p[5],
                'cfg': cfg,
                'sampler': sampler,
                'step': step,
                'width': width,
                'height': height,
                'nsfw': nsfw,
                'prompt': prompt
            }
            low_pairs.append(new_p)
        
    if p[distance_i] < HIGH_BAR:
        break
    
print(HIGH_BAR, len(low_pairs))

In [None]:
# Dsiplay these bad images
folders = glob("/project/diffusiondb/images/*")
existing_part_ids = set(
    [int(re.sub(r".*part-(\d+)", r"\1", f)) for f in folders if "json" not in f]
)

shutil.rmtree(join(WORKING_IMAGE_DIR, 'bad-images'))
os.makedirs(join(WORKING_IMAGE_DIR, 'bad-images'))

# Copy low distance images into one folder
random_indexes = np.random.choice(len(low_pairs), len(low_pairs), replace=False)
count_limit = 50
visited_prompts = set()

i = 0
with tqdm(total=count_limit) as pbar:
    while i < count_limit:
        p = low_pairs[random_indexes[i]]

        cur_path = get_image_path(name_to_part, p['name'], existing_part_ids)
        local_path = join(WORKING_IMAGE_DIR, 'bad-images', basename(cur_path))\

        if not exists(local_path):
            shutil.copyfile(cur_path, local_path)
            
        img = Image.open(local_path)
        img.thumbnail((150, 150), Image.Resampling.LANCZOS)
        
        prompt = p['prompt']
        
        # if prompt in visited_prompts:
        #     i += 1
        #     continue
        # visited_prompts.add(prompt)
        
        display(img)
        print(p)
        i += 1

        # try:
        #     canvas = ImageDraw.Draw(img)
        #     canvas.text((10, 5), prompt[:40], fill=(255, 0, 0))
            
        #     if len(prompt) > 40:
        #         canvas.text((10, 15), prompt[40:80], fill=(255, 0, 0))
                
        #     i += 1
        #     pbar.update(1)
                
        # except:
        #     pass
        
        # img.save(local_path)

### Display Bad Images

In [None]:
# prompts = set([p['prompt'] for p in low_pairs])
# prompts = list(prompts)

# md_string = ''
# md_string += '|||\n'
# md_string += '|:---|:---|\n'

# i = 0
# while i < len(prompts) - 1:
#     md_string += f'|{prompts[i]}|{prompts[i + 1]}|\n'
#     i += 2

# print(md_string[:10000])

### Trying to see the high-tfidf score words in bad prompts

Not working

In [None]:
# from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer, TfidfTransformer
# from emoji import demojize

In [None]:
# prompts = set([demojize(p['prompt']) for p in low_pairs])
# prompts = list(prompts)

# # model = CountVectorizer(ngram_range=(1, 2)).fit(prompts)
# # words = model.transform(prompts)

In [None]:
# all_prompts = []
# existing_prompts = set(prompts)

# for p in tqdm(set(metadata_df['prompt'])):
#     cur_p = demojize(p)
#     if cur_p not in existing_prompts:
#         all_prompts.append(cur_p)
        
# all_docs = list(all_prompts)
# bad_prompt_doc = ' '.join(prompts)
# all_docs.append(bad_prompt_doc)

# count_model = CountVectorizer(ngram_range=(1, 2))
# count = count_model.fit_transform(all_docs)

# cout_copy = count.copy()
# cout_copy[len(all_docs) - 1, :] = cout_copy[len(all_docs) - 1, :] / 10

# tfidf_model = TfidfTransformer(use_idf=True)
# tfidf = tfidf_model.fit_transform(cout_copy)

In [None]:
# bad_prompt_tfidf = tfidf[len(all_docs) - 1, :].todense()
# bad_words = []

# for i, name in enumerate(count_model.get_feature_names_out()):
#     if bad_prompt_tfidf[0, i] > 0:
#         bad_words.append([name, bad_prompt_tfidf[0, i]])

# bad_words.sort(key=lambda x: x[1], reverse=True)
# # bad_words[:200]


### Analyze the word count in bad prompts

In [None]:
unique_low_pairs = []
added_prompts = set()
for p in low_pairs:
    if p['prompt'] not in added_prompts:
        added_prompts.add(p['prompt'])
        unique_low_pairs.append(p)
        
len(unique_low_pairs)

In [None]:
bad_prompts = set([p['prompt'] for p in low_pairs])
bad_promtps_length = [len(p) for p in bad_prompts]

all_prompts = set(metadata_df['prompt'])
all_prompts_length = [len(p) for p in all_prompts]

In [None]:
bins = np.linspace(0, np.max(all_prompts_length), 100)

In [None]:
plt.figure(figsize=(12, 4))
plt.title('Prompt length distribution')
plt.hist(bad_promtps_length, bins, alpha=0.5, label='bad', density=True)
plt.hist(all_prompts_length, bins, alpha=0.5, label='all', density=True)
plt.legend(loc='upper right')
plt.show()

In [None]:
# Test the token length of bad prompts vs all prompts
prompt_token_lengths = pickle.load(open('./outputs/token_count_truncated.pkl', 'rb'))
prompt_token_lengths = prompt_token_lengths.numpy()

prompt_to_index = {}

for i in tqdm(range(len(prompts))):
    prompt_to_index[prompts[i]] = i
    
bad_token_lengths = []

for bp in tqdm(bad_prompts):
    cur_i = prompt_to_index[bp]
    bad_token_lengths.append(prompt_token_lengths[cur_i])
    
stats.ttest_ind(bad_token_lengths, prompt_token_lengths, equal_var=False, alternative='less')

### Analyze Lang Distribution in Bad Prompts

In [None]:
non_en_prompts = pickle.load(open("./outputs/non-en-prompts.pkl", "rb"))

lang_counter = Counter(non_en_prompts.values())

lang_counter_map = {'en': 0}
lang_counter_list = [(len(prompts) - len(non_en_prompts)) / len(prompts)]

cur_i = 1
for lang in lang_counter:
    lang_counter_map[lang] = cur_i
    lang_counter_list.append(lang_counter[lang] / len(prompts))
    cur_i += 1
    
lang_counter_list = np.array(lang_counter_list)
np.sum(lang_counter_list)

In [None]:
bad_lang_counter = {}

for bp in bad_prompts:
    if bp in non_en_prompts:
        lang = non_en_prompts[bp]
        if lang in bad_lang_counter:
            bad_lang_counter[lang] += 1
        else:
            bad_lang_counter[lang] = 1

bad_lang_counter_list = np.zeros(len(lang_counter_list))
bad_lang_counter_list[0] = (
    len(bad_prompts) - np.sum(list(bad_lang_counter.values()))
) / len(bad_prompts)

for lang in bad_lang_counter:
    cur_i = lang_counter_map[lang]
    bad_lang_counter_list[cur_i] = bad_lang_counter[lang] / len(bad_prompts)

np.sum(bad_lang_counter_list)


In [None]:
results = stats.chisquare(
    f_obs=[991, 160],
    f_exp=[1131.4214483066346, 19.578551693365398]
)
results

#### Check CFG Distribution

In [None]:
bad_cfgs = [p['cfg'] for p in unique_low_pairs if p['cfg']< 200]
plt.hist(bad_cfgs, bins=100)
plt.show()

In [None]:
# Analyzing the language in bad prompts
non_en_prompts = pickle.load(open("./outputs/non-en-prompts.pkl", "rb"))

bad_prompt_lang = {}

for p in bad_prompts:
    if p in non_en_prompts:
        bad_prompt_lang[p] = non_en_prompts[p]
        
len(bad_prompt_lang) / len(bad_prompts)

In [None]:
# Dsiplay these bad images
folders = glob("/project/diffusiondb/images/*")
existing_part_ids = set(
    [int(re.sub(r".*part-(\d+)", r"\1", f)) for f in folders if "json" not in f]
)

shutil.rmtree(join(WORKING_IMAGE_DIR, 'bad-images'))
os.makedirs(join(WORKING_IMAGE_DIR, 'bad-images'))

# Copy low distance images into one folder
random_indexes = np.random.choice(len(low_pairs), len(low_pairs), replace=False)
count_limit = 10
visited_prompts = set()

i = 0
with tqdm(total=count_limit) as pbar:
    while i < count_limit:
        p = low_pairs[random_indexes[i]]
        
        if p['cfg'] != 7 or len(p['prompt']) < 30:
            i += 1
            continue

        cur_path = get_image_path(name_to_part, p['name'], existing_part_ids)
        local_path = join(WORKING_IMAGE_DIR, 'bad-images', basename(cur_path))\

        if not exists(local_path):
            shutil.copyfile(cur_path, local_path)
            
        img = Image.open(local_path)
        img.thumbnail((300, 300), Image.Resampling.LANCZOS)
        
        prompt = p['prompt']
        
        if prompt in visited_prompts:
            continue
        visited_prompts.add(prompt)
        
        display(img)
        print(p)
        i += 1

        # try:
        #     canvas = ImageDraw.Draw(img)
        #     canvas.text((10, 5), prompt[:40], fill=(255, 0, 0))
            
        #     if len(prompt) > 40:
        #         canvas.text((10, 15), prompt[40:80], fill=(255, 0, 0))
                
        #     i += 1
        #     pbar.update(1)
                
        # except:
        #     pass
        
        # img.save(local_path)

## Entropy Analyais

In [None]:
# We want to compute many stats at once

from skimage.measure.entropy import shannon_entropy

img = Image.open(join(WORKING_IMAGE_DIR, 'bad-images', '415bf54a-ebe1-4e0a-ae5d-ef0b7c571b15.webp'))
img_mat = np.array(img)

shannon_entropy(img_mat)

In [None]:
entropy_pairs = []
for f in glob(join(WORKING_IMAGE_DIR, 'bad-images', '*.webp')):
    img = Image.open(f)
    img_mat = np.array(img)

    entropy_pairs.append([img, shannon_entropy(img_mat)])

In [None]:
entropy_pairs = sorted(entropy_pairs, key=lambda x: x[1])

In [None]:
# for p in entropy_pairs[:10]:
#     display(p[0])
#     print(p[1])

In [None]:
# pickle.load(open('./entropy-pickles/entropy-000001.pkl', 'rb'))

## Same Prompt Different Image

We try to identify prompts that generate very different images.

In [None]:
metadata_df = pd.read_parquet(PARQUET_PATH, columns=['image_name', 'part_id', 'prompt'])
print(metadata_df.shape)
metadata_df.head()

In [None]:
folders = glob("/project/diffusiondb/images/*")
existing_part_ids = set(
    [int(re.sub(r".*part-(\d+)", r"\1", f)) for f in folders if "json" not in f]
)

def get_image_path(part_id, name, existing_part_ids):
    """
    Get the path of an image by its name.
    """
    
    if part_id in existing_part_ids:
        image_path = join(REMOTE_IMAGE_DIR, f'part-{part_id:06}', name)
        return image_path
    
    # Need to download the image's zip file first
    else:
        cur_zip = join(WORK_DIR, f"part-{part_id:06}.zip")
        cur_img_dir = join(WORKING_IMAGE_DIR, f"part-{part_id:06}")

        if not exists(cur_img_dir):
            # Download and extract the zip file
            if part_id > 100000:
                shutil.copyfile(
                    join(ZIP_DIR2, f"part-{part_id:06}.zip"),
                    cur_zip,
                )
            else:
                shutil.copyfile(
                    join(ZIP_DIR1, f"part-{part_id:06}.zip"),
                    cur_zip,
                )

            shutil.unpack_archive(cur_zip, cur_img_dir)
            
        image_path = join(WORKING_IMAGE_DIR, f'part-{part_id:06}', name)
        
    
    return image_path

In [None]:
# Prompt -> [(image_name, part_id)]
prompt_to_images = {}

for row in tqdm(metadata_df.itertuples(), total=metadata_df.shape[0]):
    i = row[0]
    name = row[1]
    part_id = row[2]
    prompt = row[3].lower()
    
    if prompt in prompt_to_images:
        prompt_to_images[prompt].append((name, part_id))
    else:
        prompt_to_images[prompt] = [(name, part_id)]

In [None]:
# Only keep prompts with more than 3 occurance
filtered_prompt_to_images = {}

for p in tqdm(prompt_to_images):
    if len(prompt_to_images[p]) > 3:
        filtered_prompt_to_images[p] = prompt_to_images[p]
        
print(len(filtered_prompt_to_images))

In [None]:
pickle.dump(prompt_to_images, open('./outputs/prompt_to_images.pkl', 'wb'))

In [None]:
# Find the most-used prompts
prompt_to_images = pickle.load(open('./outputs/prompt_to_images.pkl', 'rb'))

In [None]:
prompt_image_count_pairs = []

for p in tqdm(prompt_to_images):
    prompt_image_count_pairs.append([p, len(prompt_to_images[p])])
    
prompt_image_count_pairs.sort(key=lambda x: x[1], reverse=True)

In [None]:
# prompt_to_images
plt.title('Prompt Count (top 2000)')
plt.bar(list(range(len(prompt_image_count_pairs[:2000]))), [p[1] for p in prompt_image_count_pairs[:2000]])
plt.show()

In [None]:
prompt_image_count_pairs[200]

In [None]:
# prompt_max_distance = {}

# last_part_id = -1
# last_image_names = None
# last_image_emb = None
# last_image_names_to_id = None

# for p in tqdm(filtered_prompt_to_images):
    
#     local_embs = []
#     max_distance = -np.inf
#     min_distance = np.inf
    
#     for name, part_id in filtered_prompt_to_images[p]:
        
#         if part_id != last_part_id:
#             image_emb_data = np.load(join(IMAGE_EMB_DIR, f'part-{part_id:06}-image-emb.npz'))

#             last_image_names = image_emb_data['images_name']
#             last_image_names_to_id = {}
#             for i, name in enumerate(last_image_names):
#                 last_image_names_to_id[name] = i

#             last_image_emb = image_emb_data['images_emb']
#             last_part_id = part_id
        
#         cur_i = last_image_names_to_id[name]
#         local_embs.append(last_image_emb[cur_i, :])
    
#     # Comptue pair-wise cosine distance
#     for i in range(len(local_embs)):
#         for j in range(i + 1, len(local_embs)):
#             cur_d = cosine(local_embs[i], local_embs[j])
            
#             if cur_d > max_distance:
#                 max_distance = cur_d
                
#             if cur_d < min_distance:
#                 min_distance = cur_d
                
#     emb_mean = np.mean(local_embs, axis=0)
#     emb_std = np.std(local_embs, axis=0)
    
#     record = {
#         'min_distance': min_distance,
#         'max_distance': max_distance,
#         'count': len(local_embs),
#         'images': filtered_prompt_to_images[p]
#     }

In [None]:
# Very different images from the same prompt

cur_distance = pickle.load(open('./outputs/distance-0-100000.pkl', 'rb'))

## Language Detector

In [None]:
import spacy
import spacy_fastlang

In [None]:
metadata_df = pd.read_parquet(
    PARQUET_PATH,
    columns=[
        "image_name",
        "part_id",
        "prompt",
        "cfg",
        "step",
        "sampler",
        "width",
        "height",
        "image_nsfw",
    ],
)
print(metadata_df.shape)
metadata_df.head()


In [None]:
unique_prompts = list(set(metadata_df['prompt']))

In [None]:
nlp = spacy.load("en_core_web_sm")
nlp.add_pipe("language_detector")

In [None]:
# non_en_prompts = {}
# nlp = spacy.load("en_core_web_sm")
# nlp.add_pipe("language_detector")

# for p in tqdm(unique_prompts):
#     doc = nlp(p)
#     if doc._.language != 'en':
#         non_en_prompts[p] = doc._.language

In [None]:
non_en_prompts = pickle.load(open("./outputs/non-en-prompts.pkl", "rb"))

In [None]:
langs = non_en_prompts.values()

counter = Counter(langs)

pairs = list(zip(counter.keys(), counter.values()))
pairs.sort(key=lambda x: x[1], reverse=True)

plt.figure(figsize=(18, 4))
plt.bar([p[0] for p in pairs], [p[1] for p in pairs])
plt.xticks(rotation='vertical')
plt.title('Non-English Prompt Language Distribution')
plt.show()

In [None]:
counter.most_common(15)

In [None]:
print(len(non_en_prompts))
print(len(counter))

lang_count = 0
for c in counter:
    if counter[c] > 100:
        lang_count += 1
lang_count

### Compare token length

In [None]:
from diffusers import StableDiffusionPipeline
import torch

In [None]:
# auth_token = os.environ["HFTOKEN"]
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")

In [None]:
prompts = np.array(list(set(metadata_df['prompt'])))
len(prompts)

In [None]:
def batch_tok_length(text_inputs):
    """Calculate average number of tokens in input"""
    n_tokens = text_inputs["attention_mask"].sum(-1) - 2 # remove BOS and EOS added tags
    return n_tokens


# !! Long running cell. Choose batch size that computer can handle easily
bs = 10000
i = 0
vocab_size = pipe.tokenizer.vocab_size
total_token_length = torch.zeros(len(prompts), dtype=torch.int16)
total_iter = len(prompts) // bs + 1
nprompts = len(prompts)
n = 0
with tqdm(total=total_iter) as pbar:
    while i < nprompts:
        n+= 1
        pbar.update(1)
        pidxs = slice(i, i+bs)
        p = prompts[pidxs].tolist()
        text_inputs = pipe.tokenizer(
            p,
            padding=True,
            max_length=pipe.tokenizer.model_max_length,
            truncation=False,
            return_tensors="pt",
        )

        length = batch_tok_length(text_inputs)
        total_token_length[pidxs] = length

        i += bs

        if n == total_iter:
            break

In [None]:
# Save the token counts
token_count_dict = {}
for i in range(len(total_token_length)):
    token_count_dict[prompts[i]] = total_token_length[i]

In [None]:
pickle.dump(total_token_length, open('./outputs/token_count_truncated.pkl', 'wb'))

In [None]:
truncated_total_token_length = pickle.load(open('./outputs/token_count_truncated.pkl', 'rb'))

In [None]:
plt.figure(figsize=(10, 5))
plt.grid(alpha=0.2)
n = n.astype("int")
n, bins, patches = plt.hist(np.array(truncated_total_token_length),
                            bins=37, edgecolor='white', linewidth=0.5, alpha=0.9)
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
plt.xlabel("Number of tokens in prompt", fontsize=16)
plt.title("Distribution of Prompt Length (# of Tokens)", fontsize=18)
# plt.savefig("plots/token_length_dist.pdf", bbox_inches='tight')

In [None]:
bins[-2] - bins[-3]

In [None]:
pickle.dump(total_token_length, open('./outputs/token_count_no_truncated.pkl', 'wb'))

In [None]:
total_token_length = pickle.load(open('./outputs/token_count_no_truncated.pkl', 'rb'))
total_token_length = total_token_length.numpy()

In [None]:

# plt.figure(figsize=(10, 5))
# plt.grid(alpha=0.2)

# cur_lengths = [l for l in total_token_length if l > 70 and l < 140]
# n, bins, patches = plt.hist(cur_lengths, bins=40, edgecolor='white', linewidth=0.5, alpha=0.9)
# plt.xticks(fontsize=20)
# plt.yticks(fontsize=20)
# plt.xlabel("Number of tokens in prompt", fontsize=16)
# plt.title("Distribution of Prompt Length (# of Tokens)", fontsize=18)
# plt.savefig("plots/token_length_dist_untrunc.pdf", bbox_inches='tight')

### Identify examples in the figure

In [None]:
metadata_df = pd.read_parquet(
    PARQUET_PATH,
)
print(metadata_df.shape)
metadata_df.head()


In [None]:
selected_rows = []

for row in tqdm(metadata_df.itertuples(), total=len(metadata_df)):
    # if 'old couple smiling' in row[2].lower():
    if 'watercolor painting' in row[2].lower() and row[6] > 7 and len(row[2]) < 120:
    # if 'fighting russian' in row[2].lower():
        selected_rows.append(row)

In [None]:
print(len(selected_rows))
selected_rows[1]

In [None]:
folders = glob("/project/diffusiondb/images/*")
existing_part_ids = set(
    [int(re.sub(r".*part-(\d+)", r"\1", f)) for f in folders if "json" not in f]
)

def get_image_path(part_id, name, existing_part_ids):
    """
    Get the path of an image by its name.
    """
    
    if part_id in existing_part_ids:
        image_path = join(REMOTE_IMAGE_DIR, f'part-{part_id:06}', name)
        return image_path
    
    # Need to download the image's zip file first
    else:
        cur_zip = join(WORK_DIR, f"part-{part_id:06}.zip")
        cur_img_dir = join(WORKING_IMAGE_DIR, f"part-{part_id:06}")

        if not exists(cur_img_dir):
            # Download and extract the zip file
            if part_id > 100000:
                shutil.copyfile(
                    join(ZIP_DIR2, f"part-{part_id:06}.zip"),
                    cur_zip,
                )
            else:
                shutil.copyfile(
                    join(ZIP_DIR1, f"part-{part_id:06}.zip"),
                    cur_zip,
                )

            shutil.unpack_archive(cur_zip, cur_img_dir)
            
        image_path = join(WORKING_IMAGE_DIR, f'part-{part_id:06}', name)
        
    
    return image_path

In [None]:
# Dsiplay these bad images
folders = glob("/project/diffusiondb/images/*")
existing_part_ids = set(
    [int(re.sub(r".*part-(\d+)", r"\1", f)) for f in folders if "json" not in f]
)

shutil.rmtree(join(WORKING_IMAGE_DIR, 'bad-images'))
os.makedirs(join(WORKING_IMAGE_DIR, 'bad-images'))

i = 165
count_limit = 175

with tqdm(total=count_limit) as pbar:
    while i < min(count_limit, len(selected_rows)):
        p = selected_rows[i]

        cur_path = get_image_path(p[3], p[1], existing_part_ids)
        local_path = join(WORKING_IMAGE_DIR, 'bad-images', basename(cur_path))\

        if not exists(local_path):
            shutil.copyfile(cur_path, local_path)
            
        img = Image.open(local_path)
        img.thumbnail((145, 145), Image.Resampling.LANCZOS)
        
        prompt = p[3]
        
        display(img)
        print(p)
        i += 1