<a href="https://colab.research.google.com/github/jack-of-all-trades-22/ARQe/blob/main/AI_Enhanced_Product_Photoshoot_Visuals.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Problem Statement "AI-Enhanced Product Photoshoot Visuals and Filter"

In [None]:
!pip install keras_cv

In [None]:
import tensorflow as tf
from keras_cv.models import StableDiffusion
import matplotlib.pyplot as plt
import numpy as np

In [None]:
# Initialize the Stable Diffusion model
model_diffusion = StableDiffusion(img_width=512, img_height=512)

In [None]:
# Load a pre-trained MobileNetV2 model for image classification
model_classification = tf.keras.applications.MobileNetV2(weights='imagenet')

In [None]:
def generate_images(prompt, batch_size=3):
    # Generate images based on the prompt
    images = model_diffusion.text_to_image(prompt, batch_size=batch_size)
    return images


In [None]:
def plot_images(images):
    # Plot the generated images
    plt.figure(figsize=(20, 20))
    for i, image in enumerate(images):
        ax = plt.subplot(1, len(images), i + 1)
        plt.imshow(image)
        plt.axis("off")
    plt.show()

In [None]:
def preprocess_images(images):
    # Resize images to match the input shape expected by MobileNetV2 (224x224)
    resized_images = [tf.image.resize(img, (224, 224)) for img in images]
    # Preprocess images for MobileNetV2
    preprocessed_images = [tf.keras.applications.mobilenet_v2.preprocess_input(img) for img in resized_images]
    preprocessed_images = np.array(preprocessed_images)
    return preprocessed_images

In [None]:
def classify_images(images, model):
    # Preprocess and resize images
    preprocessed_images = preprocess_images(images)

    # Make predictions
    predictions = model.predict(preprocessed_images)

    # Decode predictions
    decoded_predictions = tf.keras.applications.mobilenet_v2.decode_predictions(predictions)

    return decoded_predictions

In [None]:
def filter_non_relevant_images(images, model):
    # Implement a binary classifier for demonstration

    relevant_images = []
    for image in images:
        # Use the MobileNetV2 model for binary classification
        preprocessed_image = preprocess_images([image])
        prediction = model.predict(preprocessed_image)

        # Modify this condition based on your binary classification model
        if prediction[0][0] > 0.5:
            relevant_images.append(image)

    return relevant_images

In [None]:
# Example usage
prompt = "Car"
batch_size = 3

In [None]:
# Generate and plot images
generated_images = generate_images(prompt, batch_size=batch_size)
relevant_images = filter_non_relevant_images(generated_images, model_classification)
plot_images(generated_images)

In [None]:
# Classify generated images
classifications = classify_images(generated_images, model_classification)

In [None]:
# Display the classifications
for i, img_class in enumerate(classifications):
    print(f"Image {i + 1} classification:")
    for j, (imagenet_id, label, score) in enumerate(img_class):
        print(f"{j + 1}: {label} ({score:.2f})")
    print()