In [1]:
!pip install keras_cv

Collecting keras_cv
  Downloading keras_cv-0.8.2-py3-none-any.whl (613 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/613.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.4/613.1 kB[0m [31m1.7 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m [32m604.2/613.1 kB[0m [31m9.5 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m613.1/613.1 kB[0m [31m7.5 MB/s[0m eta [36m0:00:00[0m
Collecting keras-core (from keras_cv)
  Downloading keras_core-0.1.7-py3-none-any.whl (950 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m950.8/950.8 kB[0m [31m33.5 MB/s[0m eta [36m0:00:00[0m
Collecting namex (from keras-core->keras_cv)
  Downloading namex-0.0.7-py3-none-any.whl (5.8 kB)
Installing collected packages: namex, keras-core, keras_cv
Successfully installed keras-core-0.1.7 kera

In [2]:
import tensorflow as tf
from keras_cv.models import StableDiffusion
import matplotlib.pyplot as plt
import numpy as np

Using TensorFlow backend


In [3]:
# Initialize the Stable Diffusion model
model_diffusion = StableDiffusion(img_width=512, img_height=512)

By using this model checkpoint, you acknowledge that its usage is subject to the terms of the CreativeML Open RAIL-M license at https://raw.githubusercontent.com/CompVis/stable-diffusion/main/LICENSE


In [4]:
# Load a pre-trained MobileNetV2 model for image classification
model_classification = tf.keras.applications.MobileNetV2(weights='imagenet')

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_224.h5


In [5]:
def generate_images(prompt, batch_size=3):
    # Generate images based on the prompt
    images = model_diffusion.text_to_image(prompt, batch_size=batch_size)
    return images


In [6]:
def plot_images(images):
    # Plot the generated images
    plt.figure(figsize=(20, 20))
    for i, image in enumerate(images):
        ax = plt.subplot(1, len(images), i + 1)
        plt.imshow(image)
        plt.axis("off")
    plt.show()

In [7]:
def preprocess_images(images):
    # Resize images to match the input shape expected by MobileNetV2 (224x224)
    resized_images = [tf.image.resize(img, (224, 224)) for img in images]
    # Preprocess images for MobileNetV2
    preprocessed_images = [tf.keras.applications.mobilenet_v2.preprocess_input(img) for img in resized_images]
    preprocessed_images = np.array(preprocessed_images)
    return preprocessed_images

In [8]:
def classify_images(images, model):
    # Preprocess and resize images
    preprocessed_images = preprocess_images(images)

    # Make predictions
    predictions = model.predict(preprocessed_images)

    # Decode predictions
    decoded_predictions = tf.keras.applications.mobilenet_v2.decode_predictions(predictions)

    return decoded_predictions

In [9]:
def filter_non_relevant_images(images, model):
    # Implement a binary classifier for demonstration

    relevant_images = []
    for image in images:
        # Use the MobileNetV2 model for binary classification
        preprocessed_image = preprocess_images([image])
        prediction = model.predict(preprocessed_image)

        # Modify this condition based on your binary classification model
        if prediction[0][0] > 0.5:
            relevant_images.append(image)

    return relevant_images

In [10]:
# Example usage
prompt = "Headphones"
batch_size = 3

In [None]:
# Generate and plot images
generated_images = generate_images(prompt, batch_size=batch_size)
relevant_images = filter_non_relevant_images(generated_images, model_classification)
plot_images(generated_images)

Downloading data from https://github.com/openai/CLIP/blob/main/clip/bpe_simple_vocab_16e6.txt.gz?raw=true
Downloading data from https://huggingface.co/fchollet/stable-diffusion/resolve/main/kcv_encoder.h5
Downloading data from https://huggingface.co/fchollet/stable-diffusion/resolve/main/kcv_diffusion_model.h5

In [None]:
# Classify generated images
classifications = classify_images(generated_images, model_classification)

In [None]:
# Display the classifications
for i, img_class in enumerate(classifications):
    print(f"Image {i + 1} classification:")
    for j, (imagenet_id, label, score) in enumerate(img_class):
        print(f"{j + 1}: {label} ({score:.2f})")
    print()