<a href="https://colab.research.google.com/github/Troyanovsky/stable_diffusion_prompt_optimizer/blob/main/Prompt_Optimizer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Stable Diffusion Prompt Optimizer

This colab notebook is an attempt to automatically optimize stable diffusion prompts.

## Rationale

The rationale is that for a good prompt, you need to have the objects, art movement, art medium, artists, and flavor modifiers.

For all these tags, the closer they are, the most harmonious they are and thus produces higher quality images from stable diffusion. For example, for an impressionism painting (an art movement), it pairs well with Claude Monet, Edouard Manet, Pierre-Auguste Renoir (artists), oil paiting (medium), and unblended color and natural lights (flavors).

## Process
To do this process automatically, I did the following:
1. Collected lists of art movements, artists, art medium, and flavors tags that work in stable diffusion.
2. Turn them into vector embeddings with pre-trained embeddings.
3. Compare them with the user's existing tags to find art movements, artists, art medium, and flavors to add to the prompt.

## How to use
To use the notebook, if you just need to optimize your prompt, you can run the second section.

If you want to create your own lists and embeddings, you can refer to the first section.

# Pre-req

In [1]:
!pip install spacy

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [2]:
!python -m spacy download en_core_web_md

2023-04-25 16:20:53.233388: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting en-core-web-md==3.5.0
  Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_md-3.5.0/en_core_web_md-3.5.0-py3-none-any.whl (42.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.8/42.8 MB[0m [31m11.5 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: en-core-web-md
Successfully installed en-core-web-md-3.5.0
[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('en_core_web_md')


In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
# Define the path to your Google Drive folder
# Download the files from Github in the tags folder
# put it into a Google Drive folder
# replace the path below with your folder
folder_path = "/content/drive/MyDrive/stable_diffusion"

In [5]:
import spacy
import pickle

# Load the spaCy model
nlp = spacy.load("en_core_web_md")

In [6]:
import os
import glob

def load_directory(directory_path):
    """
    Loads all the text files in a directory and returns their contents as a dictionary with the key being the
    file name without extension and the value being the file content.

    Parameters:
        directory_path (str): The path to the directory containing the text files.

    Returns:
        A dictionary where each key is the file name without extension and each value is the contents of a text file.
    """
    # Navigate to the specified directory
    %cd {directory_path}
    
    # Use glob to get a list of all the text files in the directory
    text_files = glob.glob("*.txt")
    
    # Read the contents of each file into memory and store them in a dictionary
    data = {}
    for filename in text_files:
        # Get the file name without extension
        file_name = os.path.splitext(filename)[0]
        
        # Read the contents of the file
        with open(filename) as file:
            file_content = file.read()
        
        # Add the file content to the dictionary with the file name as the key
        data[file_name] = file_content
    
    return data


In [7]:
tags_data = load_directory(folder_path)

/content/drive/MyDrive/stable_diffusion


In [8]:
# Define a function to load the tags data and split into lists
def load_and_split_tags(tags_data):
    tag_lists = {}
    for key, value in tags_data.items():
        # Split the string of potential tags into a list
        tag_list = [tag.strip() for tag in value.split("\n") if tag.strip()]
        # If the resulting tag list is not empty, add it to the dict
        if tag_list:
            tag_lists[key] = tag_list
    return tag_lists

tag_lists = load_and_split_tags(tags_data)

## Section 1: Setting Up Embeddings

In [None]:
# ONLY RUN ONCE THE FIRST TIME TO CREATE EMBEDDINGS
# Loop over each tag list, convert tags to vectors, and save with pickle
for key, tag_list in tag_lists.items():
    # Convert each tag to its vector representation using spaCy
    tag_vectors = []
    for tag in tag_list:
        doc = nlp(tag)
        tag_vectors.append(doc.vector)
    
    # Save the list of tag vectors as a binary file with pickle
    file_path = f"{folder_path}{key}_vector.pkl"
    with open(file_path, "wb") as f:
        pickle.dump(tag_vectors, f)

## Section 2: Using the Optimizer

In [9]:
# Load artists_vector.pkl into variable artists_vector
with open(os.path.join(folder_path, "artists_vector.pkl"), "rb") as f:
    artists_vector = pickle.load(f)

# Load flavors_vector.pkl into variable flavors_vector
with open(os.path.join(folder_path, "flavors_vector.pkl"), "rb") as f:
    flavors_vector = pickle.load(f)

# Load medium_vectors.pkl into variable medium_vector
with open(os.path.join(folder_path, "mediums_vector.pkl"), "rb") as f:
    mediums_vector = pickle.load(f)

# Load movements_vector.pkl into variable movements_vector
with open(os.path.join(folder_path, "movements_vector.pkl"), "rb") as f:
    movements_vector = pickle.load(f)

In [10]:
def convert_to_tags(user_prompt):
    # Split the prompt by commas and strip any leading/trailing whitespace from each tag
    tags = [tag.strip() for tag in user_prompt.split(",")]
    return tags

In [11]:
from sklearn.metrics.pairwise import cosine_similarity
import random

def find_similar_artists(input_vector):
    most_similar = []
    
    # Find two most similar tags from top 5 artists_vector
    sims_artists = cosine_similarity([input_vector], artists_vector)
    sims_artists = sims_artists[0]
    # Randomly choose two tags from the top 5 artists
    top_artists = [tag_lists['artists'][i] for i in sims_artists.argsort()[-5:][::-1]]
    chosen_artists = random.sample(top_artists, k=2)
    most_similar.extend(chosen_artists)

    return most_similar

def find_similar_flavors(input_vector):
    most_similar = []

    # Find two most similar tags from top 5 flavors_vector
    sims_flavors = cosine_similarity([input_vector], flavors_vector)
    sims_flavors = sims_flavors[0]
    # Randomly choose two tags from the top 5 flavors
    top_flavors = [tag_lists['flavors'][i] for i in sims_flavors.argsort()[-5:][::-1]]
    chosen_flavors = random.sample(top_flavors, k=2)
    most_similar.extend(chosen_flavors)

    return most_similar
    
def find_similar_mediums(input_vector):
    most_similar = []

    # Find one most similar tag from top 3 medium_vector
    sims_medium = cosine_similarity([input_vector], mediums_vector)
    sims_medium = sims_medium[0]
    # Randomly choose one tag from the top 3 medium
    top_mediums = [tag_lists['mediums'][i] for i in sims_medium.argsort()[-3:][::-1]]
    chosen_medium = random.choice(top_mediums)
    most_similar.append(chosen_medium)

    return most_similar
    
def find_similar_movements(input_vector):
    most_similar = []

    # Find one most similar tags from top 3 movements_vector
    sims_movements = cosine_similarity([input_vector], movements_vector)
    sims_movements = sims_movements[0]
    # Randomly choose one tag from the top 3 movements
    top_movements = [tag_lists['movements'][i] for i in sims_movements.argsort()[-3:][::-1]]
    chosen_movement = random.choice(top_movements)
    most_similar.append(chosen_movement)
    
    return most_similar

In [12]:
def categorizeUserTags(user_tags):
    user_artists = []
    user_flavors = []
    user_mediums = []
    user_movements = []
    other_tags = []
    artists_list = [tag.lower() for tag in tag_lists["artists"]]
    flavors_list = [tag.lower() for tag in tag_lists["flavors"]]
    mediums_list = [tag.lower() for tag in tag_lists["mediums"]]
    movements_list = [tag.lower() for tag in tag_lists["movements"]]

    for tag in user_tags:
        if tag in artists_list:
            user_artists.append(tag)
        elif tag in flavors_list:
            user_flavors.append(tag)
        elif tag in mediums_list:
            user_mediums.append(tag)
        elif tag in movements_list:
            user_movements.append(tag)
        else:
            other_tags.append(tag)
    
    return user_artists, user_flavors, user_mediums, user_movements, other_tags

In [13]:
def getAvgTagsVector(tags_list):
    tag_vectors = [nlp(tag).vector for tag in tags_list]

    avg_vector = sum(tag_vectors) / len(tag_vectors)

    return avg_vector

In [14]:
def process_tags(user_prompt):
    original_tags = convert_to_tags(user_prompt)
    if len(original_tags) < 1:
        return "Prompt too short, please consider addding more details like art style, art medium, lighting, mood, etc."

    avg_vector = getAvgTagsVector(original_tags)

    tags_to_add = []

    user_tags = original_tags.copy()[1:]

    user_artists, user_flavors, user_mediums, user_movements, other_tags = categorizeUserTags(original_tags.copy())

    if len(user_artists) < 2:
        most_similar_artists = find_similar_artists(avg_vector)
        tags_to_add.extend(most_similar_artists)
    if len(user_flavors) < 2:
        most_similar_flavors = find_similar_flavors(avg_vector)
        tags_to_add.extend(most_similar_flavors)    
    if len(user_mediums) < 1:
        most_similar_mediums = find_similar_mediums(avg_vector)
        tags_to_add.extend(most_similar_mediums)    
    if len(user_movements) < 1:
        most_similar_movements = find_similar_movements(avg_vector)
        tags_to_add.extend(most_similar_movements)


    processed_tags = original_tags[:1] + user_artists + user_flavors + user_mediums + user_movements + other_tags + tags_to_add
    
    # Deduplicate the new tags found
    seen = set()
    deduplicated_tags = []
    for tag in processed_tags:
        if tag not in seen:
            deduplicated_tags.append(tag)
            seen.add(tag)
    
    # Construct the new prompt
    new_prompt =','.join(deduplicated_tags)

    return new_prompt

## Inference

In [19]:
user_prompt = "mountains and rivers, traditional chinese painting"

In [20]:
new_prompt = process_tags(user_prompt)
print(new_prompt)

mountains and rivers,traditional chinese painting,national geographic,Theophanes the Greek,a renaissance painting,arts and crafts movement
