In [None]:
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Generate product attributes and descriptions from images

## Overview

This notebook shows how to generate attributes and descriptions of products based on product images in a GCS bucket.  
It uses the [Stanford Online Products dataset](https://cvgl.stanford.edu/projects/lifted_struct/) and uses the Vertex AI Imagen for [Captioning](https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/image-captioning) & [VQA](https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/visual-question-answering) model to generate product attributes.  
It uses the [Gemini](https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini) to generate product sales descriptions, using Spark UDFs to parallelize processing.

#### **Steps**
Using Spark,
1) It reads the table of the [Stanford Online Products dataset](https://cvgl.stanford.edu/projects/lifted_struct/) dataset located in [gs://dataproc-metastore-public-binaries/stanford_online_products](https://console.cloud.google.com/storage/browser/dataproc-metastore-public-binaries/stanford_online_products)    
We will create a metadata table poiting to the paths of the image files in the bucket.  
2) It calls Vertex AI Imagen for Captioning and VQA to get product attributes for each image.
3) It calls Vertex AI Gemini API to get product sales descriptions based on the image.

### Setup

Make sure the service account running this notebook has the required permissions:

- **Run the notebook**
  - AI Platform Notebooks Service Agent
  - Notebooks Admin
  - Vertex AI Administrator
- **Read files from bucket**
  - Storage Object Viewer
- **Run Dataproc jobs**
  - Dataproc Service Agent
  - Dataproc Worker
- **Call Google APIs**
  - Service Usage Consumer
- **BigQuery**
  - BigQuery Data Editor

#### Imports

In [None]:
import time

from pyspark.sql.functions import regexp_replace, concat
from pyspark.sql.functions import udf, col, lit

import google.auth
import google.auth.transport.requests
import requests

import pandas as pd
pd.set_option('display.max_colwidth', None)
pd.set_option('display.min_rows', 20)

In [None]:
# When using Dataproc Serverless, installed packages are automatically available on all nodes
!pip install --upgrade google-cloud-aiplatform
# When using a Dataproc cluster, you will need to install these packages during cluster creation: https://cloud.google.com/dataproc/docs/tutorials/python-configuration

#### Authentication

In [None]:
credentials, project_id = google.auth.default()
auth_req = google.auth.transport.requests.Request()
credentials.refresh(auth_req)

#### Setup Spark Session

In [None]:
from pyspark.sql import SparkSession

In [None]:
spark = SparkSession.builder \
  .appName("Image attributes and descriptions generation") \
  .enableHiveSupport() \
  .getOrCreate()

#### Read dataset

In [None]:
BINARIES_BUCKET_PATH = "gs://dataproc-metastore-public-binaries/stanford_online_products/"
binaries_df = spark.read.format("binaryFile").option("recursiveFileLookup", "true").load(BINARIES_BUCKET_PATH)

In [None]:
# Let's select the paths of the first 10 product images
paths_df = binaries_df.select("path").limit(10)
paths_df.cache()

|                                                                                          path|
|----------------------------------------------------------------------------------------------|
|gs://dataproc-metastore-public-binaries/stanford_online_products/sofa_final/181714736872_0.JPG|
|gs://dataproc-metastore-public-binaries/stanford_online_products/sofa_final/181661485577_1.JPG|
|gs://dataproc-metastore-public-binaries/stanford_online_products/sofa_final/171860974117_1.JPG|
|gs://dataproc-metastore-public-binaries/stanford_online_products/sofa_final/171860974117_2.JPG|
|gs://dataproc-metastore-public-binaries/stanford_online_products/sofa_final/181661485577_0.JPG|

#### Define prompts to get image attributes

In [None]:
prompt_color = "What is the product colors?"
prompt_gender = "The product shown in the image is most appropriate to be used by men, woman or both?"
prompt_brand = "What is the brand of the product shown in the image? reply unanswerable if you do not know for sure"
prompt_style = "What is the style of the product shown in the image? ex: modern, casual, tech"
prompt_material = "What is the material of the product shown in the image? ex: steel, wood, rubber"
prompt_purpose = "What is the purpose or usage of this product?"
prompt_year = "What is the year of the product? reply unanswerable if you do not know for sure"

#### Define UDF and call Image Captioning and VQA APIs to generate product attributes

In [None]:
import vertexai
from vertexai.vision_models import Image, ImageTextModel

vertexai.init(project=project_id, location="us-central1")

def visual_qa(prompt, gcs_uri):
    
    model = ImageTextModel.from_pretrained("imagetext@001")
    source_img = Image(gcs_uri=gcs_uri)

    if prompt == "":
        captions = model.get_captions(
            image=source_img,
            language="en",
            number_of_results=1,
        )
    else:
        captions = model.ask_question(
            image=source_img,
            question=prompt,
            number_of_results=1,
        )
    
    return captions[0]

In [None]:
visual_qa_udf = udf(visual_qa)

In [None]:
image_metadata_df = paths_df.withColumn("description", visual_qa_udf(lit(""), col("path"))) \
  .withColumn("color", visual_qa_udf(lit(prompt_color), col("path"))) \
  .withColumn("gender", visual_qa_udf(lit(prompt_gender), col("path"))) \
  .withColumn("brand", visual_qa_udf(lit(prompt_brand), col("path"))) \
  .withColumn("style", visual_qa_udf(lit(prompt_style), col("path"))) \
  .withColumn("material", visual_qa_udf(lit(prompt_material), col("path"))) \
  .withColumn("purpose", visual_qa_udf(lit(prompt_purpose), col("path"))) \
  .withColumn("year", visual_qa_udf(lit(prompt_year), col("path")))

In [None]:
image_metadata_df.show(5, 10)

In [None]:
image_metadata_df.cache()

#### Define UDF and call Gemini API to generate product sales descriptions

In [None]:
import vertexai
from vertexai.generative_models import GenerativeModel, Part , HarmCategory, HarmBlockThreshold

vertexai.init(project=project_id, location="us-central1")

def gemini_predict(gcs_uri, prompt):
      
    gemini_pro_vision_model = GenerativeModel("gemini-1.0-pro-vision")
    config = {"max_output_tokens": 2048, "temperature": 0.4, "top_p": 1, "top_k": 32}
    safety_config = {
        HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
        HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
    }
    
    prediction = gemini_pro_vision_model.generate_content([
          prompt,
          Part.from_uri(gcs_uri, mime_type="image/jpeg")
        ],
        generation_config=config,
        safety_settings=safety_config,
        stream=True
    )
    
    text_responses = []
    for response in prediction:
        text_responses.append(response.text)
    return "".join(text_responses)

In [None]:
def generate_descriptions(gcs_uri, description, color, gender, brand, style, material, purpose, year):
    
    prompt = f"""
        You are a retail expert and knows how to write beatiful, elegant and concise product descriptions, based on data about the product.
        Based on the PRODUCT DATA, and the image of the product, you are able to provide the PRODUCT SALES DESCRIPTION.

        Here is one EXAMPLE:

        START____________________________________
        PRODUCT DATA:
        Product description: Brown Fashion Sneakers
        Color: brown
        Gender: Women
        Brand: NONE
        Style: Fashion Flat heel
        Material: Polyurethane

        PRODUCT SALES DESCRIPTION:
        A pair of pink sneakers with white soles on a white background is a stylish and comfortable choice for women who want to add a touch of color to their wardrobe. These sneakers are made of polyurethane, which is a durable and lightweight material that will keep your feet comfortable all day long. The flat heel makes them easy to wear for all-day activities, and the lace-up closure ensures a secure fit.
        These sneakers are perfect for a variety of occasions, from running errands to running errands. They can be dressed up or down, depending on your personal style. Pair them with a casual dress or jeans for a relaxed look, or dress them up with a skirt or pants for a more formal look.
        If you are looking for a stylish and comfortable pair of sneakers, these pink sneakers with white soles are a great option. They are made of durable materials, are easy to wear, and can be dressed up or down.

        You are a retail expert and knows how to write beatiful, elegant and concise product descriptions, based on data about the product.
        Based on the PRODUCT DATA, and the image of the product, you are able to provide the PRODUCT SALES DESCRIPTION.
        
        Generate a PRODUCT SALES DESCRIPTION for this product:

        START____________________________________
        PRODUCT DATA:
        Product description: {description}
        Color: {color}
        Gender: {gender}
        Brand: {brand}
        Style: {style}
        Material: {material}
        Purpose: {purpose}
        Year: {year}

        PRODUCT SALES DESCRIPTION:
    """

    descriptions = gemini_predict(gcs_uri, prompt)
    return descriptions
    
generate_descriptions_udf = udf(generate_descriptions)

In [None]:
image_descriptions_df = image_metadata_df.withColumn("sales_description", generate_descriptions_udf("path", "description", "color", "gender", "brand", "style", "material", "purpose", "year"))

In [None]:
image_descriptions_df.sort(image_descriptions_df.path.asc()).withColumn("url", regexp_replace(concat(lit("https://storage.mtls.cloud.google.com/"),col("path")), "gs://", "")).toPandas()

|                          path|                   description|color|gender|       brand| style|material|     purpose|        year|             sales_description|                           url|
|------------------------------|------------------------------|-----|------|------------|------|--------|------------|------------|------------------------------|------------------------------|
|gs://dataproc-metastore-pub...|a kitchen with wooden cabin...|brown| women|unanswerable|modern|    wood|     kitchen|unanswerable| This beautiful kitchen is ...|https://storage.mtls.cloud....|
|gs://dataproc-metastore-pub...|a close up of a wooden pill...|brown| women|unanswerable|modern|    wood|     cabinet|unanswerable| This beautiful wooden cabi...|https://storage.mtls.cloud....|
|gs://dataproc-metastore-pub...|a crystal chandelier is han...|clear| women|unanswerable|modern|   metal|  decoration|unanswerable| This stunning crystal chan...|https://storage.mtls.cloud....|
|gs://dataproc-metastore-pub...|a lamp with a giraffe shape...|white| women|unanswerable|modern|    wood|        lamp|unanswerable| This elegant lamp is sure ...|https://storage.mtls.cloud....|
|gs://dataproc-metastore-pub...|a brown and white crocheted...|white| women|unanswerable|modern|  rubber|unanswerable|unanswerable| This crocheted brown and w...|https://storage.mtls.cloud....|