# Selecting nice images generated by Stable Diffusion using the CLIP Score
This notebook shows a process to **select the best images generated by stable diffusion** in a text to image setting. There are two potential settings where this might be useful:
1. You have a prompt dataset and just want to **explore** the most promising images and prompts.
2. You have a concrete task and want to find out **which possible promps** yield the best results for your task.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
!pip install -U sliceguard diffusers[torch] datasets invisible_watermark transformers accelerate safetensors torchmetrics Pillow kaleido

# Step 1: Generate Stable Diffusion Images
This step is simply **generating images** from text using stable diffusion. We use a prompt dataset on the huggingface hub for that. Later we want to filter especially nice images using the CLIP Score metric.

In [None]:
# Some imports you need
from pathlib import Path
import shutil
import uuid
from diffusers import DiffusionPipeline
import torch
import datasets
import pandas as pd

In [None]:
# I just chose the currently most trending prompt dataset on the huggingface hub.
# Replace that with anything that suits your need better or potentially your own
# list of potential prompts.
prompt_dataset = datasets.load_dataset("Gustavosta/Stable-Diffusion-Prompts")

In [None]:
# Instantiate the stable diffusion model as well as the refiner.
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-0.9", torch_dtype=torch.float16)
pipe.enable_model_cpu_offload()

rf_pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-refiner-0.9", torch_dtype=torch.float16, use_safetensors=True, variant="fp16")
rf_pipe.enable_model_cpu_offload()

In [None]:
# Generate a bunch of images in the directory "images"
# The prompt dataset is relatively large so it could make sense to stop early.
target_dir = Path("images")
if not target_dir.is_dir():
    target_dir.mkdir()
else:
    shutil.rmtree(target_dir)
    target_dir.mkdir()

prompts = []
generated_images = []
for prompt in prompt_dataset["train"]:
    try:
        prompt = prompt["Prompt"]
        image = pipe(prompt, output_type="latent").images

        image = rf_pipe(prompt=prompt, image=image).images[0]

        image_name = f"{str(uuid.uuid4())}.png"
        image_path = target_dir / image_name
        image.save(image_path)
        prompts.append(prompt)
        generated_images.append(str(image_path))

        df = pd.DataFrame(data={"image": generated_images, "prompt": prompts})
        df.to_json("sd_dataset.json", orient="records") # save this after every generation to not loose progress in case of crashing
    except:
        print("An error occured while generating image.")

# Step 2: Generate CLIP Scores for all the examples
The CLIP Score is basically a **correlation between a text and the contents of an image**. It supposedly is highly correlated with human judgement and could thus be used to **filter for promising image generations**.

In [None]:
# Some imports you need for this step
import pandas as pd
import torch
from torchmetrics.multimodal.clip_score import CLIPScore
import PIL
import numpy as np
import torch
from tqdm import tqdm

In [None]:
# Instantiate the CLIP Score metric
device = "cuda" if torch.cuda.is_available() else "cpu"
metric = CLIPScore(model_name_or_path="openai/clip-vit-large-patch14") # openai/clip-vit-base-patch16
metric = metric.to(device)

In [None]:
# Read the dataset generated in the previous step
df = pd.read_json("sd_dataset.json")

In [None]:
# Generate the clip scores for each image
clip_scores = []
for idx, row in tqdm(df.iterrows(), total=len(df)):
    try:
        img = PIL.Image.open(row["image"])
        img = img.convert('RGB')
        np_img = np.array(img)
        clip_score = metric(torch.Tensor(np_img).to(device), row["prompt"]).detach().cpu().numpy()
        img.close()
        clip_scores.append(clip_score)
    except Exception as e:
        print(e)
        clip_scores.append(np.nan)

In [None]:
# Store the clip scores in a new dataset, removing prompts that were too long for the metric
scored_df = pd.concat((df, pd.DataFrame(data={"clip_score": clip_scores})), axis=1)
scored_df = scored_df.dropna()
scored_df["prompt"] = scored_df["prompt"].astype("str")
scored_df.to_json("sd_dataset_scored.json", orient="records")

# Step 3: Precompute CLIP embeddings for the image text pairs
We want to identify clusters of images/prompts that are especially well scored (CLIP score). Therefore we have to generate a **text and image representation to cluster on**. We simply also precompute the **CLIP embeddings** for text and images and add them to the dataset.

In [None]:
# Some imports you need
from tqdm import tqdm
import pandas as pd
import torch
from transformers import CLIPProcessor, CLIPModel
from PIL import Image

In [None]:
# Load the dataset
df = pd.read_json("sd_dataset_scored.json")

In [None]:
# Instantiate the CLIP model from the huggingface hub
device = "cuda" if torch.cuda.is_available() else "cpu"
model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14", output_hidden_states=True).to(device)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")

In [None]:
# Generate clip embeddings for images and texts
clip_image_embeddings = []
clip_text_embeddings = []
for idx, row in tqdm(df.iterrows(), total=len(df)):
    with Image.open(row["image"]) as img:
        inputs = processor(text=[row["prompt"]], images=[img], return_tensors="pt", padding=True).to(device)
        with torch.no_grad():
            outputs = model(**inputs)
            image_embedding = outputs["image_embeds"].detach().cpu().numpy()[0]
            text_embedding = outputs["text_embeds"].detach().cpu().numpy()[0]
    clip_image_embeddings.append(image_embedding)
    clip_text_embeddings.append(text_embedding)

In [None]:
# Save the new dataset
df["clip_text_embedding"] = [e.tolist() for e in clip_text_embeddingsxt_embeddings]
df["clip_image_embedding"] = [e.tolist() for e in clip_image_embeddings]
df.to_json("sd_dataset_scored_embedded.json")

# Step 4: Selection of especially appealing images according to CLIP
We now want to select the most appealing images according to clip score. We here explore two strategies:
1. **Globally** select those clusters with higher than average CLIP score.
2. First detect larger clusters in the data in which we then search for promising images **cluster-by-cluster**.

Note: More on why the second strategy makes sense later.

In [2]:
# Some imports you need
import pandas as pd
import numpy as np
from sliceguard import SliceGuard
from renumics import spotlight
from renumics.spotlight import Image, Embedding

In [3]:
# Load the dataset
df = pd.read_json("sd_dataset_scored_embedded.json")

In [5]:
# Get the text and image embeddings from the dataframe
clip_text_embeddings = np.vstack(df["clip_text_embedding"])
clip_image_embeddings = np.vstack(df["clip_image_embedding"])

## Global selection

In [6]:
# Define a metric that simply returns the precomputed metric
def return_precomputed_metric(y, y_pred):
    return y.mean(0)

In [None]:
# Let sliceguard search for clusters in the image embedding that at least contain 2 images
# and that have a CLIP score that is at least 4.5 OVER the average CLIP score of the whole dataset.
sg = SliceGuard()
sg.find_issues(df, ["clip_image_embedding"],
               "clip_score",
               "clip_score",
               return_precomputed_metric,
               metric_mode="min",
               min_support=2,
               min_drop=4.5,
              precomputed_embeddings={"clip_image_embedding": clip_image_embeddings})

# Note: There is no explicit interface for using a precomputed metric, however just supply the metric column
# for y and y_pred and return y mean in your metric function

# Note 2: Metric mode is set to "min" here. That is because we want to search for especially good images.
# Normally the natural thing for the metric would be to set it to "max" and get images that are especially bad.

In [None]:
# Show the results in renumics Spotlight
sg.report(spotlight_dtype={"image": Image, "clip_image_embedding": Embedding})

In [None]:
# We can do the same thing using the text embeddings for clustering.
sg = SliceGuard()
sg.find_issues(df, ["clip_text_embedding"],
               "clip_score",
               "clip_score",
               return_precomputed_metric,
               metric_mode="min",
               min_support=2,
               min_drop=4.5,
              precomputed_embeddings={"clip_text_embedding": clip_text_embeddings})

In [None]:
# Show the results in renumics Spotlight
sg.report(spotlight_dtype={"image": Image, "clip_text_embedding": Embedding})

**So, does this work well? No! The CLIP Score metric is extremely biased towards people portraits and especially well known concepts like the faces of prominent personalities.**

**What can we do about it? We can probably first compute clusters in the data to detect some sort of "categories" and then apply the search for significantly better clusters for each category (could be such categories as landscapes, people portraits, things, ...)**

**We implemented that below!**

**Note also, that the selection based on text embeddings seems to yield slightly more consistent results in our case. Meaning, the clusters are less prone to contain outliers in the CLIP score metric** 

## Category-wise, adaptive selection

In [7]:
# Perform this only on one type of embedding which you can select here
EMBEDDING_TYPE = "clip_image_embeddings" # clip_text_embeddings
if EMBEDDING_TYPE == "clip_image_embeddings":
    embeddings = clip_image_embeddings
elif EMBEDDING_TYPE == "clip_text_embeddings":
    embeddings = clip_text_embeddings
else:
    raise RuntimeError("No valid choice for embedding type.")

In [8]:
# An additional import. Note that you have to run the above section as well.
from hnne import HNNE

In [9]:
# Detect clusters in the data. Note that here the granularity is chosen by findind the first clustering that
# contains over 3 and less than 25 clusters. If that does not apply to your data you will get an error.
# Just shift the limits then.
hnne = HNNE(metric="euclidean")
projection = hnne.fit_transform(embeddings)
df["projection_x"] = projection[:, 0]
df["projection_y"] = projection[:, 1]
partitions = hnne.hierarchy_parameters.partitions
partitions = np.flip(partitions, axis=1) 
partition_sizes = np.flip(np.array(hnne.hierarchy_parameters.partition_sizes))
print(partition_sizes)
for partition_idx in range(partitions.shape[1]):
    df[f"clustering_{partition_idx}"] = partitions[:, partition_idx]
chosen_partition_level = None
for partition_level, partition_size in enumerate(partition_sizes):
    if partition_size > 3 and partition_size < 25:
        chosen_partition_level = partition_level
        break
assert chosen_partition_level is not None
clustering_partition = partitions[:, chosen_partition_level]

[  3  13 160]


In [10]:
# Apply sliceguard cluster-by-cluster to find the most promising images per "category".
# The intuition is that this mitigates some bias present in the CLIP Score metric.
df["selection_group"] = -1
df["selection"] = -1
current_issue_idx = 0
for cluster_idx in np.unique(clustering_partition):
    cluster_embeddings = embeddings[clustering_partition==cluster_idx]
    cluster_df = df[clustering_partition==cluster_idx]
    
    sg = SliceGuard()
    issue_df = sg.find_issues(cluster_df, ["clip_embedding"],
                   "clip_score",
                   "clip_score",
                   return_precomputed_metric,
                   metric_mode="min",
                   min_support=2,
                   min_drop=cluster_df["clip_score"].std(),
                  precomputed_embeddings={"clip_embedding": cluster_embeddings})
    issue_identifiers = np.setdiff1d(issue_df["issue"].unique(), [-1])
    for issue_identifier in issue_identifiers:
        issue_indices = issue_df[issue_df["issue"] == issue_identifier].index
        df.loc[issue_indices, "selection_group"] = cluster_idx
        df.loc[issue_indices, "selection"] = current_issue_idx
        current_issue_idx += 1


The overall metric value is 29.57892399299634
Using 2 as minimum support for determining problematic clusters.
Using 3.7496109339632384 as minimum drop for determining problematic clusters.
Identified 2 problematic slices.
The overall metric value is 30.998585888159816
Using 2 as minimum support for determining problematic clusters.
Using 3.659101791567243 as minimum drop for determining problematic clusters.
Identified 1 problematic slices.
The overall metric value is 31.632761600405814
Using 2 as minimum support for determining problematic clusters.
Using 3.702847682845617 as minimum drop for determining problematic clusters.
Identified 6 problematic slices.
The overall metric value is 30.551769518976187
Using 2 as minimum support for determining problematic clusters.
Using 3.6821321892904457 as minimum drop for determining problematic clusters.
Identified 3 problematic slices.
The overall metric value is 30.245432414699994
Using 2 as minimum support for determining problematic clust

In [None]:
# Show the results in Renumics Spotlight. The column "selection" contains the found clusters,
# So just use it for browsing the results.
spotlight.show(df, dtype={"image": Image, "clip_embedding": Embedding})

In [11]:
# Visualize the image clusters in a scatterplot on the 2D projection
# This was used to create a visual for the blogpost. You don't necessarily need it.
import math
import plotly.express as px
import plotly.graph_objects as go
from PIL import Image
groups = np.setdiff1d(df["selection_group"].unique(), [-1])

cluster_idx = 0
for group in groups:
    group_samples = df[df["selection_group"] == group]
    if len(group_samples) == 0:
        continue
    
    clusters = np.setdiff1d(group_samples["selection"].unique(), [-1])
    for cluster in clusters:
        
        
        cluster_samples = df[df["selection"] == cluster]
        
        fig = px.scatter(df, x="projection_x", y="projection_y", color="clustering_1")
      
        fig.update_layout(
            showlegend=False,
            coloraxis_showscale=False,
            xaxis=dict(visible=False),
            yaxis=dict(visible=False),
        )
        cluster_sample = 1
        x_center = cluster_samples["projection_x"].mean()
        y_center = cluster_samples["projection_y"].mean()
        for _, row in cluster_samples.iterrows():
            img = Image.open(row["image"])
            fig.add_layout_image(
            x=x_center + 1.6 * math.cos(((2*math.pi) / len(cluster_samples)) * cluster_sample),
            y=y_center + 1.6 * math.sin(((2*math.pi) / len(cluster_samples)) * cluster_sample),
            source=img,
            xref="x",
            yref="y",
            sizex=3,
            sizey=3,
            xanchor="center",
            yanchor="middle",
            )
            cluster_sample += 1

        fig.write_image(f"slice_{cluster_idx:0>2}.png", scale=2)
#         fig.show()
        cluster_idx += 1