Setup the environment, if needed

In [None]:
## Update the following with your specific version of CUDA, if any. 
# !pip install torch --extra-index-url https://download.pytorch.org/whl/cu113
# !pip install h5py pandas numpy matplotlib diffusers transformers scipy ftfy pyarrow regex wordcloud


In [None]:
from PIL import Image
from pathlib import Path
import os
import json
from diffusers import StableDiffusionPipeline
import regex as re
import pandas as pd
import torch
from tqdm.auto import tqdm
import numpy as np
import matplotlib.pyplot as plt
import wordcloud as wc
import requests

Load the pipeline to get the same tokenizer used as Stable Diffusion

In [None]:
# auth_token = os.environ["HFTOKEN"]
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")

In [None]:
prompts = pd.read_parquet(
    './indexes/metadata-large.parquet',
    columns=['prompt']
)['prompt']
print("Length of prompts: ", len(prompts))

In [None]:
prompts = list(set(prompts))
len(prompts)

## Prompt uniqueness?

In [None]:
sprompts = set(list(prompts.prompt))

# Get count of each prompt
ct_dict = {k:0 for k in sprompts}
for k in prompts.prompt:
    ct_dict[k] += 1

In [None]:
x = np.array([v for v in ct_dict.values()])
cts, bins = np.histogram(x, bins=np.unique(x))

plt.bar(bins[:-1], cts)
plt.yscale("log")
plt.xscale("log")
plt.xlabel("N (number of times a prompt appears in the dataset)")
plt.ylabel("Number of prompts that appear N times")
plt.title("How unique are the prompts?")

# Prompt lengths?

By **specifier clauses** and **token length**

In [None]:
# Choose separators, find token ids of that token
sep_ids = [",", ";", "|"]
for s in sep_ids:
    print(f"sep_id: '{s}': ", pipe.tokenizer.encode(f"know{s}nothing")[2:-2]) # ids of separator
    

In [None]:
def batch_tok_length(text_inputs):
    """Calculate average number of tokens in input"""
    n_tokens = text_inputs["attention_mask"].sum(-1) - 2 # remove BOS and EOS added tags
    return n_tokens

def batch_concepts(prompt:str):
    """Return the concepts in each prompt as strings"""
    concepts = re.split(';|,|\|', prompt)
    return concepts

def batch_num_concepts(text_inputs):
    """Calculate how many concepts are in the prompt, from the tokens"""
    split_ids = [267, 282, 347] # comma, semicolon, pipe
    iids = text_inputs["input_ids"]
    mask = torch.zeros_like(iids)
    seps = [torch.eq(iids, sid) for sid in split_ids]
    for s in seps:
        mask = torch.logical_or(mask, s)
    out = mask.sum(-1) + 1 # Number of concepts = number of separators + 1
    return out

def tok_frequencies(text_inputs):
    """Calculate the frequency each token appears in a batch of tokenized inputs"""
    iids = text_inputs["input_ids"]
    ids, counts = torch.unique(iids, return_counts=True)
    return ids, counts

In [None]:
prompts = np.array(prompts)

In [None]:
# !! Long running cell. Choose batch size that computer can handle easily
bs = 10000
i = 0
vocab_size = pipe.tokenizer.vocab_size
tokfreqs = torch.zeros(vocab_size, dtype=torch.int64)
total_nconcepts = torch.zeros(len(prompts), dtype=torch.int16)
total_token_length = torch.zeros(len(prompts), dtype=torch.int16)
total_iter = len(prompts) // bs + 1
nprompts = len(prompts)
n = 0
with tqdm(total=total_iter) as pbar:
    while i < nprompts:
        n+= 1
        pbar.update(1)
        pidxs = slice(i, i+bs)
        p = prompts[pidxs].tolist()
        text_inputs = pipe.tokenizer(
            p,
            padding="max_length",
            max_length=pipe.tokenizer.model_max_length,
            truncation=True,
            return_tensors="pt",
        )
        nconcepts = batch_num_concepts(text_inputs)
        total_nconcepts[pidxs] = nconcepts
        length = batch_tok_length(text_inputs)
        total_token_length[pidxs] = length

        ids, counts = tok_frequencies(text_inputs)
        tokfreqs[ids] += counts
        i += bs

        if n == total_iter:
            break

In [None]:

plt.figure(figsize=(10, 5))
plt.grid(alpha=0.2)
n = n.astype("int")
n, bins, patches = plt.hist(np.array(total_token_length), bins=37, edgecolor='white', linewidth=0.5, alpha=0.9)
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
plt.xlabel("Number of tokens in prompt", fontsize=16)
plt.title("Distribution of Prompt Length (# of Tokens)", fontsize=18)
plt.savefig("plots/token_length_dist.pdf", bbox_inches='tight')

In [None]:
plt.figure(figsize=(10, 5))
plt.grid(alpha=0.2)
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
n, bins, patches = plt.hist(np.array(total_nconcepts), bins=range(0, np.unique(total_nconcepts).max().item()+1), edgecolor='#e0e0e0', linewidth=0.5, alpha=0.9)
plt.xlabel("Number of specifier clauses in prompt", fontsize=16)
plt.yscale("log")
ticks = list(range(0, total_nconcepts.max().item(), 5)); ticks[0]=1
plt.xticks(ticks=ticks)
plt.title("Distribution of Prompt Length by Specifier Clause", fontsize=18)
plt.savefig("plots/spec_clause_length.pdf", bbox_inches='tight')

## Concept Frequency

A qualitative analysis of the concepts present in DiffusionDB. We manually filter the top tokens for stop words, combining subtoken representations into meaningful concepts, before displaying in a WordCloud.

In [None]:
# Show top K tokens in the corpus, visually filter as needed
cts, idxs = tokfreqs.topk(k=100)
print("\n".join([" :: ".join((str(pipe.tokenizer._convert_id_to_token(idx.item())), str(cts[i].item()))) for i, idx in enumerate(idxs)]))

In [None]:
# Filtered and combined tokens
words = {
  "art": 784355,
  "detailed": 714959,
  "artstation": 476150,
  "painting": 438349,
  "portrait": 399555,
  "realistic": 365993,
  "8k": 323039,
  "highly": 319087,
  "lighting": 310602,
  "digital": 295669,
  "intricate": 276934,
  "beautiful": 276268,
  "concept": 256254,
  "trending": 245511,
  "style": 235599,
  "4k": 235164,
  "cinematic": 229357,
  "sharp": 228603,
  "greg rutkowski": 222008,
  "render": 221661,
  "illustration": 221422,
  "focus": 210662,
  "high": 188288,
  "fantasy": 177511,
  "octane": 176801,
  "face":162641,
  "photo":161967,
  "light": 155787,
  "black": 131100,
  "wearing": 130106,
  "dark": 124368,
  "smooth": 120759,
  "white": 119682,
  "hyper": 117479,
  "unreal engine": 114896,
  "background": 114650,
  "elegant": 111326,
  "hair": 110355,
  "full": 109023,
  "mucha": 105940,
  "hyper": 107780,
}

print(len(words))

In [None]:
cloud = wc.WordCloud(width=500, height=300, background_color="white", min_font_size=10, relative_scaling=0.0001, colormap="Dark2").fit_words(words)
im = cloud.to_svg(True)

with open('./plots/wordcloud.svg', 'w') as fp:
    fp.write(im)

In [None]:
im.save("plots/wordcloud_freqs.pdf")