In [None]:
#data
#https://www.kaggle.com/datasets/alexandreteles/diffusiondb-metadata
#https://www.kaggle.com/datasets/dschettler8845/diffusiondb-2m-part-0001-to-0100-of-2000
#https://www.kaggle.com/datasets/dschettler8845/diffusiondb-2m-part-0101-to-0200-of-2000
#https://www.kaggle.com/datasets/dschettler8845/diffusiondb-2m-part-0201-to-0300-of-2000
#https://www.kaggle.com/datasets/dschettler8845/diffusiondb-2m-part-0301-to-0400-of-2000
#https://www.kaggle.com/datasets/dschettler8845/diffusiondb-2m-part-0401-to-0500-of-2000
#https://www.kaggle.com/datasets/dschettler8845/diffusiondb-2m-part-0501-to-0600-of-2000
#https://www.kaggle.com/datasets/dschettler8845/diffusiondb-2m-part-0601-to-0700-of-2000
#https://www.kaggle.com/datasets/dschettler8845/diffusiondb-2m-part-0701-to-0800-of-2000
#https://www.kaggle.com/datasets/dschettler8845/diffusiondb-2m-part-0801-to-0900-of-2000
#https://www.kaggle.com/datasets/dschettler8845/diffusiondb-2m-part-0901-to-1000-of-2000

#https://www.kaggle.com/datasets/dschettler8845/diffusiondb-2m-part-1001-to-1100-of-2000
#https://www.kaggle.com/datasets/dschettler8845/diffusiondb-2m-part-1101-to-1200-of-2000
#https://www.kaggle.com/datasets/dschettler8845/diffusiondb-2m-part-1201-to-1300-of-2000
#https://www.kaggle.com/datasets/dschettler8845/diffusiondb-2m-part-1301-to-1400-of-2000
#https://www.kaggle.com/datasets/dschettler8845/diffusiondb-2m-part-1401-to-1500-of-2000
#https://www.kaggle.com/datasets/dschettler8845/diffusiondb-2m-part-1501-to-1600-of-2000
#https://www.kaggle.com/datasets/dschettler8845/diffusiondb-2m-part-1601-to-1700-of-2000
#https://www.kaggle.com/datasets/dschettler8845/diffusiondb-2m-part-1701-to-1800-of-2000
#https://www.kaggle.com/datasets/dschettler8845/diffusiondb-2m-part-1801-to-1900-of-2000
#https://www.kaggle.com/datasets/dschettler8845/diffusiondb-2m-part-1901-to-2000-of-2000
#https://www.kaggle.com/datasets/inversion/sentence-transformers-222

## Cleansing approach for Diffusion-DB using vector search

The data cleansing of DiffusionDB-2M has been disclosed in public notebooks and discussions.

https://www.kaggle.com/code/shoheiazuma/diffusiondb-data-cleansing/notebook
https://www.kaggle.com/competitions/stable-diffusion-image-to-prompts/discussion/398529

Through rule-based filtering and filtering based on the similarity of prompt vectors. For the evaluation of vector similarity, I used the faiss vector search library, which is used for recommendation and similar image search.

In [1]:
%pip install faiss-gpu

Collecting faiss-gpu
  Downloading faiss_gpu-1.7.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (85.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m85.5/85.5 MB[0m [31m13.1 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hInstalling collected packages: faiss-gpu
Successfully installed faiss-gpu-1.7.2
[0mNote: you may need to restart the kernel to use updated packages.


In [2]:
import sys
import re
import faiss
import torch
import numpy as np
import polars as pl
from pathlib import Path
import torch.nn.functional as F
from tqdm.notebook import tqdm
from sklearn.metrics.pairwise import cosine_similarity

sys.path.append("/kaggle/input/sentence-transformers-222/sentence-transformers")
from sentence_transformers import SentenceTransformer

## Rule-Based Filtering

In [3]:
def check_string(string: str) -> bool:
    # Checks if the given string contains any character other than alphanumeric characters, comma, dot, hyphen or whitespace
    return bool(re.search(r'[^A-Za-z0-9,.\\-\\s]', string))

In [4]:
# Load data from a Parquet file
# For the purpose of illustration, the amount of data will be reduced
pldf = pl.read_parquet("/kaggle/input/diffusiondb-metadata/metadata.parquet", columns=['image_name', 'prompt', 'width', 'height'])

# Select only those images whose width and height fall between 256 and 768 pixels
pldf = pldf.filter(pl.col("width").is_between(256, 768) & pl.col("height").is_between(256, 768))

# Select only those prompts that have five or more words 
pldf = pldf.filter(pl.col("prompt").str.split(" ").apply(lambda x: len(x)>=5))

# Select only those prompts that are not blank, NULL, null, or NaN
pldf = pldf.filter(~pl.col("prompt").str.contains('^(?:\s*|NULL|null|NaN)$'))


pldf = pldf.filter(pl.col("prompt").apply(check_string))
pldf.glimpse()

Rows: 1561961
Columns: 4
$ image_name <str> 2217ccbd-a1c6-47ac-9a2d-79649727c834.png, c78807b7-d55a-4a2d-a6b6-9192b18941ad.png, dc71658a-5e4b-4dca-861a-e1535510348b.png, 48eb7e17-a3cf-4eb8-96a9-d8e3e23fa1af.png, 601d9792-eccd-4850-97a7-edbe91d3464c.png, 3c586acb-14dc-43df-8900-954c336f01b3.png, a5ec307e-7e7b-4740-ad70-9bdb6f417bd1.png, 2919b048-6f68-4ac7-a6d5-060d827abb77.png, 986a21f0-2ad8-4f9f-8e49-7f7db6c80cdc.png, 3c835fdc-9047-4298-ac8a-7461f5490132.png
$ prompt     <str> a portrait of a female robot made from code, very intricate details, octane render, 8 k, trending on artstation , a portrait of a female robot made from a cloud of images being very grateful to the creator, very intricate details, futuristic steampunk, octane render, 8 k, trending on artstation , only memories remain, trending on artstation , dream swimming pool with nobody , a dog doing weights. epic oil painting. , a dog doing weights on fire. epic oil painting. , yoji shinkawa painting of a stylish sniper demo

In [5]:
#For the purpose of illustration, we will reduce the amount of data
pldf# = pldf[:10000]

image_name,prompt,width,height
str,str,u16,u16
"""2217ccbd-a1c6-…","""a portrait of …",512,512
"""c78807b7-d55a-…","""a portrait of …",512,512
"""dc71658a-5e4b-…","""only memories …",512,512
"""48eb7e17-a3cf-…","""dream swimming…",512,512
"""601d9792-eccd-…","""a dog doing we…",512,768
"""3c586acb-14dc-…","""a dog doing we…",512,768
"""a5ec307e-7e7b-…","""yoji shinkawa …",512,768
"""2919b048-6f68-…","""a beautiful pa…",512,512
"""986a21f0-2ad8-…","""male king arth…",512,704
"""3c835fdc-9047-…","""frontal portra…",512,512


## Vectorize using SentenceTransformers

In [6]:
model = SentenceTransformer("/kaggle/input/sentence-transformers-222/all-MiniLM-L6-v2")
vector = model.encode(pldf["prompt"].to_numpy(), batch_size=512, show_progress_bar=True, device="cuda", convert_to_tensor=True)

Batches:   0%|          | 0/3051 [00:00<?, ?it/s]

## Similarity filtering using vector search

In [7]:
threshold = 0.80  # Set the threshold for similarity.
n_neighbors = 1000  # Set the number of neighbors to consider.

# Perform batch processing because processing all data at once may cause resource shortage.
batch_size = 1000  # Set the batch size (i.e., the number of data items to be processed at once).
similar_vectors = []  # Create an empty list to store similar vectors.

In [8]:
# Create an IndexFlatIP index using the Faiss library
# The term 'IP' represents the Inner Product, 
# which is equivalent to cosine similarity as it involves taking the dot product of normalized vectors.
index = faiss.IndexFlatIP(384)

# Normalize the input vector and add it to the IndexFlatIP 
index.add(F.normalize(vector).cpu().numpy())

In [None]:
for i in tqdm(range(0, len(vector), batch_size)):
    # Get the target batch for processing.
    batch_data = vector.cpu().numpy()[i:i + batch_size]
    # Neighborhood search based on cosine similarity.
    similarities, indices = index.search(batch_data, n_neighbors)
    
    # Extract indexes and similarities of data to be deleted.
    for j in range(similarities.shape[0]):
        close_vectors = indices[j, similarities[j] >= threshold] 
        index_base = i
        # Get only the similar vectors that exclude itself
        close_vectors = close_vectors[close_vectors != index_base + j]  
        similar_vectors.append((index_base + j, close_vectors))


  0%|          | 0/1562 [00:00<?, ?it/s]

### Drop Similarity Data

In [None]:
pldf = pldf.with_columns(pl.Series(values=list(range(len(pldf))), name="index"))
pldf = pldf.filter(~pl.col("index").is_in(np.unique(np.concatenate([x for _, x in similar_vectors])).tolist()))

In [None]:
for i, _ in tqdm(enumerate(range(1, 2000, 100)), total=20):
    image_dir = Path("/kaggle/input/diffusiondb-2m-part-{:04d}-to-{:04d}-of-2000/".format(i * 100 + 1, (i + 1) * 100))
    pldf = pldf.with_columns(
        pl.when(pl.col("image_name").is_in([str(file_path.name) for file_path in image_dir.glob("*.png")]))
        .then(str(image_dir) + "/" + pl.col("image_name"))
        .otherwise(pl.col("image_name"))
        .alias("image_name")
    )

In [None]:
pldf.select(pl.col("image_name", "prompt")).write_csv("diffusiondb.csv")
pldf.select(pl.col("image_name", "prompt")).head()