In [None]:
from glob import glob
from os.path import exists, join, basename
from copy import deepcopy
from tqdm import tqdm
from json import load, dump
from sys import argv
from matplotlib import pyplot as plt
from scipy.spatial.distance import cosine
from scipy.stats import norm
from scipy import stats
from PIL import Image, ImageDraw
from collections import Counter

import re
import os
import shutil
import time
import random
import tarfile
import zipfile
import pickle

import pandas as pd
import numpy as np

# plt.style.use('ggplot')
# plt.rcParams['figure.dpi'] = 300

In [None]:
PROMPT_EMB_PATH = "/nvmescratch/diffusiondb/prompt-emb/prompt-emb.npz"
IMAGE_EMB_DIR = "/nvmescratch/diffusiondb/img-emb"
PARQUET_PATH = "/nvmescratch/diffusiondb/metadata.parquet"

ZIP_DIR1 = "/project/diffusiondb-hugging/diffusiondb-large-part-1/"
ZIP_DIR2 = "/project/diffusiondb-hugging/diffusiondb-large-part-2/"

WORK_DIR = "/nvmescratch/diffusiondb/"
WORKING_IMAGE_DIR = "/nvmescratch/diffusiondb/images/"
REMOTE_IMAGE_DIR = '/project/diffusiondb/images/'

## Entropy Analysis

In [None]:
metadata_df = pd.read_parquet(
    PARQUET_PATH,
    columns=[
        "image_name",
        "part_id",
        "prompt",
        "cfg",
        "step",
        "sampler",
        "width",
        "height",
        "image_nsfw",
    ],
)
print(metadata_df.shape)
metadata_df.head()


In [None]:
name_to_part = {}
name_to_index = {}
for i in tqdm(range(0, len(metadata_df))):
    name_to_part[metadata_df['image_name'][i]] = metadata_df['part_id'][i]
    name_to_index[metadata_df['image_name'][i]] = i

In [None]:
# # Read all entropies
# image_entropies = {}

# for f in tqdm(glob(join(WORK_DIR, 'entropy-pickles', '*.pkl'))):
#     cur_dict = pickle.load(open(f, 'rb'))
#     for name in cur_dict:
#         image_entropies[name] = cur_dict[name]
        
# pickle.dump(image_entropies, open('./outputs/image_entropies.pkl', 'wb'))

In [None]:
image_entropies = pickle.load(open('./outputs/image_entropies.pkl', 'rb'))

In [None]:
entropies = []

for k in image_entropies:
    entropies.append(image_entropies[k])


In [None]:

# plt.figure(figsize=(10, 4))
# plt.title('Image Entropy Distribution (14M)')
# plt.hist(entropies, bins=100)
# plt.show()

In [None]:
# Fit a gausian distribution
mean, std = norm.fit(entropies)

sorted_pairs = []
LOW_BAR = mean - 9 * std
print(mean, std, LOW_BAR)

for k in tqdm(image_entropies):
    if image_entropies[k] < LOW_BAR:
        sorted_pairs.append([k, image_entropies[k]])
        
sorted_pairs.sort(key=lambda x: x[1])

In [None]:
low_pairs = []

for p in tqdm(sorted_pairs):
    # Check parameters
    cur_i = name_to_index[p[0]]
    cfg = metadata_df['cfg'][cur_i]
    sampler = metadata_df['sampler'][cur_i]
    step = metadata_df['step'][cur_i]
    nsfw = metadata_df['image_nsfw'][cur_i]
    width = metadata_df['width'][cur_i]
    height = metadata_df['height'][cur_i]
    nsfw = metadata_df['image_nsfw'][cur_i]
    prompt = metadata_df['prompt'][cur_i]
    
    # if cfg > 1:
    if nsfw < 2 and abs(cfg - 7) < 5 and min(width, height) >= 512 and step > 10:
        new_p = {
            'name': p[0],
            'index': cur_i,
            'entropy': p[1],
            'cfg': cfg,
            'sampler': sampler,
            'step': step,
            'width': width,
            'height': height,
            'nsfw': nsfw,
            'prompt': prompt
        }
        low_pairs.append(new_p)
        
print(len(low_pairs))

In [None]:
def get_image_path(name_to_part, name, existing_part_ids):
    """
    Get the path of an image by its name.
    """
    
    # Find the part id of this image
    part_id = name_to_part[name]
    
    if part_id in existing_part_ids:
        image_path = join(REMOTE_IMAGE_DIR, f'part-{part_id:06}', name)
        return image_path
    
    # Need to download the image's zip file first
    else:
        cur_zip = join(WORK_DIR, f"part-{part_id:06}.zip")
        cur_img_dir = join(WORKING_IMAGE_DIR, f"part-{part_id:06}")

        if not exists(cur_img_dir):
            # Download and extract the zip file
            if part_id > 100000:
                shutil.copyfile(
                    join(ZIP_DIR2, f"part-{part_id:06}.zip"),
                    cur_zip,
                )
            else:
                shutil.copyfile(
                    join(ZIP_DIR1, f"part-{part_id:06}.zip"),
                    cur_zip,
                )

            shutil.unpack_archive(cur_zip, cur_img_dir)
            
        image_path = join(WORKING_IMAGE_DIR, f'part-{part_id:06}', name)
        
    
    return image_path

In [None]:
unique_low_pairs = []
added_prompts = set()
for p in low_pairs:
    if p['prompt'] not in added_prompts:
        added_prompts.add(p['prompt'])
        unique_low_pairs.append(p)
        
len(unique_low_pairs)

In [None]:
# Dsiplay these bad images
folders = glob("/project/diffusiondb/images/*")
existing_part_ids = set(
    [int(re.sub(r".*part-(\d+)", r"\1", f)) for f in folders if "json" not in f]
)

shutil.rmtree(join(WORKING_IMAGE_DIR, 'bad-images'))
os.makedirs(join(WORKING_IMAGE_DIR, 'bad-images'))

# Copy low distance images into one folder
random_indexes = np.random.choice(len(unique_low_pairs), len(unique_low_pairs), replace=False)
count_limit = 100
visited_prompts = set()

i = 0
with tqdm(total=count_limit) as pbar:
    while i < count_limit:
        p = unique_low_pairs[random_indexes[i]]

        cur_path = get_image_path(name_to_part, p['name'], existing_part_ids)
        local_path = join(WORKING_IMAGE_DIR, 'bad-images', basename(cur_path))\

        if not exists(local_path):
            shutil.copyfile(cur_path, local_path)
            
        img = Image.open(local_path)
        img.thumbnail((300, 300), Image.Resampling.LANCZOS)
        
        prompt = p['prompt']
        
        # display(img)
        # print(p)
        # i += 1

        try:
            canvas = ImageDraw.Draw(img)
            canvas.text((10, 5), prompt[:40], fill=(255, 0, 0))
            
            if len(prompt) > 40:
                canvas.text((10, 15), prompt[40:80], fill=(255, 0, 0))

        except:
            pass
        
        img.save(local_path)
        i += 1
        pbar.update(1)