In [None]:
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Generate product attributes and descriptions from images

## Overview

This notebook shows how to generate attributes and descriptions of products based on product images in a GCS bucket.  
It uses the [Stanford Online Products dataset](https://cvgl.stanford.edu/projects/lifted_struct/) and uses the Vertex AI Imagen for [Captioning](https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/image-captioning) & [VQA](https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/visual-question-answering) model to generate product attributes.  
It uses the [Gemini](https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini) to generate product sales descriptions, using Spark UDFs to parallelize processing.

#### **Steps**
Using Spark,
1) It reads a metadata table of the [Stanford Online Products dataset](https://cvgl.stanford.edu/projects/lifted_struct/) from the **public_datasets** dataset located in the [metastore](../../public_datasets/dataproc_metastore/metastore_public_datasets_quickstart.ipynb) (notebook should be connected with the public metastore if using this specific dataset).
This metadata table contains the paths of the image files in the bucket.
If you want to apply this to a different dataset, you can read the pdf files in your bucket with spark.read.format("binaryFile") (no need of the metastore) - more details [here](../../public_datasets/dataproc_metastore/metastore_public_datasets_quickstart.ipynb).
2) It calls Vertex AI Imagen for Captioning and VQA to get product attributes for each image.
3) It calls Vertex AI Gemini API to get product sales descriptions based on the image.

### Setup

Make sure the service account running this notebook has the required permissions:

- **Run the notebook**
  - AI Platform Notebooks Service Agent
  - Notebooks Admin
  - Vertex AI Administrator
- **Read tables from Dataproc Metastore**
  - Dataproc Metastore Editor
  - Dataproc Metastore Metadata Editor
  - Dataproc Metastore Metadata User
  - Dataproc Metastore Service Agent
- **Read files from bucket**
  - Storage Object Viewer
- **Run Dataproc jobs**
  - Dataproc Service Agent
  - Dataproc Worker
- **Call Google APIs**
  - Service Usage Consumer
- **BigQuery**
  - BigQuery Data Editor

#### Imports

In [None]:
import time

from pyspark.sql.functions import regexp_replace, concat
from pyspark.sql.functions import udf, col, lit

import google.auth
import google.auth.transport.requests
import requests

import pandas as pd
pd.set_option('display.max_colwidth', None)
pd.set_option('display.min_rows', 20)

#### Authentication

In [None]:
credentials, project_id = google.auth.default()
auth_req = google.auth.transport.requests.Request()
credentials.refresh(auth_req)

#### Setup Spark Session

In [None]:
from pyspark.sql import SparkSession

In [None]:
spark = SparkSession.builder \
  .appName("Image attributes and descriptions generation") \
  .enableHiveSupport() \
  .getOrCreate()

#### Read dataset

In [None]:
### Read the dataset from the public Dataproc Metastore connected
binaries_df = spark.read.table("public_datasets.stanford_online_products")

In [None]:
### Another option is to read from the bucket directly
# BINARIES_BUCKET_PATH = "gs://dataproc-metastore-public-binaries/stanford_online_products/"
# binaries_df = spark.read.format("binaryFile").option("recursiveFileLookup", "true").load(BINARIES_BUCKET_PATH)

In [None]:
# Let's select the paths of the first 100 product images
paths_df = binaries_df.select("path").limit(5)

#### Define prompts to get image attributes

In [None]:
prompt_color = "What is the product colors?"
prompt_gender = "The product shown in the image is most appropriate to be used by men, woman or both?"
prompt_brand = "What is the brand of the product shown in the image? reply unanswerable if you do not know for sure"
prompt_style = "What is the style of the product shown in the image? ex: modern, casual, tech"
prompt_material = "What is the material of the product shown in the image? ex: steel, wood, rubber"
prompt_purpose = "What is the purpose or usage of this product?"
prompt_year = "What is the year of the product? reply unanswerable if you do not know for sure"

#### Define UDF and call Image Captioning and VQA APIs to generate product attributes

In [None]:
def visual_qa(prompt, gcs_uri):

  model_url = f"https://us-central1-aiplatform.googleapis.com/v1/projects/{project_id}/locations/us-central1/publishers/google/models/imagetext:predict"

  request = {
      "instances": [
        {  "prompt": prompt,
            "image": {
                 "gcsUri": gcs_uri
            }
        }
      ],
      "parameters": {
        "sampleCount": 1
      }
  }
    
  if prompt == "": # passing no prompt will trigger the image-captioning to get image description instead of visual-question-answering
    del request["instances"][0]["prompt"] 
      
  prediction = requests.post( model_url,
    headers={'Authorization': 'Bearer %s' % credentials.token,
             'x-goog-user-project': project_id,
             'Content-Type': 'application/json; charset=utf-8'},
    json=request
  ).json()
    
  if "predictions" in prediction:
    return prediction["predictions"][0]
  else:
    if "error" in prediction:
      if prediction["error"]["code"] == 429:  # Quota exceeded
        time.sleep(5)
        return visual_qa(prompt, gcs_uri)
      else:
        return f"Error getting prediction: {prediction['error']}"
    return f"Error getting predictions"

In [None]:
visual_qa_udf = udf(visual_qa)

In [None]:
image_metadata_df = paths_df.withColumn("description", visual_qa_udf(lit(""), col("path"))) \
  .withColumn("color", visual_qa_udf(lit(prompt_color), col("path"))) \
  .withColumn("gender", visual_qa_udf(lit(prompt_gender), col("path"))) \
  .withColumn("brand", visual_qa_udf(lit(prompt_brand), col("path"))) \
  .withColumn("style", visual_qa_udf(lit(prompt_style), col("path"))) \
  .withColumn("material", visual_qa_udf(lit(prompt_material), col("path"))) \
  .withColumn("purpose", visual_qa_udf(lit(prompt_purpose), col("path"))) \
  .withColumn("year", visual_qa_udf(lit(prompt_year), col("path")))

In [None]:
image_metadata_df.show(5, 10)

In [None]:
image_metadata_df.cache()

#### Define UDF and call Gemini API to generate product sales descriptions

In [None]:
def generate_descriptions(gcs_uri, description, color, gender, brand, style, material, purpose, year):

  def gemini_predict(gcs_uri, prompt):
      
    model_url = f"https://us-central1-aiplatform.googleapis.com/v1/projects/{project_id}/locations/us-central1/publishers/google/models/gemini-pro-vision:streamGenerateContent"
    request_body = {
      "contents": {
        "role": "user",
        "parts": [
          {
            "fileData": {
              "mimeType": "image/jpeg",
              "fileUri": gcs_uri
            }
          },
          {
            "text": prompt
          }
        ]
      },
      "safety_settings": {
        "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
        "threshold": "BLOCK_LOW_AND_ABOVE"
      },
      "generation_config": {
        "temperature": 0.4,
        "topP": 1.0,
        "topK": 32,
        "maxOutputTokens": 2048
      }
    }
      
    prediction = requests.post(
      model_url,
      headers={'Authorization': 'Bearer %s' % credentials.token,
               'Content-Type': 'application/json'},
      json = request_body
    ).json()


    full_prediction = ""
    for pred in prediction:
      if "candidates" in pred:
        content = pred["candidates"][0]["content"]["parts"][0]["text"]
        full_prediction += content
    return full_prediction

  prompt = f"""
        You are a retail expert and knows how to write beatiful, elegant and concise product descriptions, based on data about the product.
        Based on the PRODUCT DATA, and the image of the product, you are able to provide the PRODUCT SALES DESCRIPTION.

        Here is one EXAMPLE:

        START____________________________________
        PRODUCT DATA:
        Product description: Brown Fashion Sneakers
        Color: brown
        Gender: Women
        Brand: NONE
        Style: Fashion Flat heel
        Material: Polyurethane

        PRODUCT SALES DESCRIPTION:
        A pair of pink sneakers with white soles on a white background is a stylish and comfortable choice for women who want to add a touch of color to their wardrobe. These sneakers are made of polyurethane, which is a durable and lightweight material that will keep your feet comfortable all day long. The flat heel makes them easy to wear for all-day activities, and the lace-up closure ensures a secure fit.
        These sneakers are perfect for a variety of occasions, from running errands to running errands. They can be dressed up or down, depending on your personal style. Pair them with a casual dress or jeans for a relaxed look, or dress them up with a skirt or pants for a more formal look.
        If you are looking for a stylish and comfortable pair of sneakers, these pink sneakers with white soles are a great option. They are made of durable materials, are easy to wear, and can be dressed up or down.

        You are a retail expert and knows how to write beatiful, elegant and concise product descriptions, based on data about the product.
        Based on the PRODUCT DATA, and the image of the product, you are able to provide the PRODUCT SALES DESCRIPTION.
        
        Generate a PRODUCT SALES DESCRIPTION for this product:

        START____________________________________
        PRODUCT DATA:
        Product description: {description}
        Color: {color}
        Gender: {gender}
        Brand: {brand}
        Style: {style}
        Material: {material}
        Purpose: {purpose}
        Year: {year}

        PRODUCT SALES DESCRIPTION:
    """

  descriptions = gemini_predict(gcs_uri, prompt)

  return descriptions

In [None]:
generate_descriptions_udf = udf(generate_descriptions)

In [None]:
image_descriptions_df = image_metadata_df.withColumn("sales_description", generate_descriptions_udf("path", "description", "color", "gender", "brand", "style", "material", "purpose", "year"))

In [None]:
image_descriptions_df.sort(image_descriptions_df.path.asc()).withColumn("url", regexp_replace(concat(lit("https://storage.mtls.cloud.google.com/"),col("path")), "gs://", "")).toPandas()

|                          path|                   description|color|gender|       brand| style|material|     purpose|        year|             sales_description|                           url|
|------------------------------|------------------------------|-----|------|------------|------|--------|------------|------------|------------------------------|------------------------------|
|gs://dataproc-metastore-pub...|a kitchen with wooden cabin...|brown| women|unanswerable|modern|    wood|     kitchen|unanswerable| This beautiful kitchen is ...|https://storage.mtls.cloud....|
|gs://dataproc-metastore-pub...|a close up of a wooden pill...|brown| women|unanswerable|modern|    wood|     cabinet|unanswerable| This beautiful wooden cabi...|https://storage.mtls.cloud....|
|gs://dataproc-metastore-pub...|a crystal chandelier is han...|clear| women|unanswerable|modern|   metal|  decoration|unanswerable| This stunning crystal chan...|https://storage.mtls.cloud....|
|gs://dataproc-metastore-pub...|a lamp with a giraffe shape...|white| women|unanswerable|modern|    wood|        lamp|unanswerable| This elegant lamp is sure ...|https://storage.mtls.cloud....|
|gs://dataproc-metastore-pub...|a brown and white crocheted...|white| women|unanswerable|modern|  rubber|unanswerable|unanswerable| This crocheted brown and w...|https://storage.mtls.cloud....|