In [1]:
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Generate product attributes and descriptions from images

<table align="left">

<a href="https://github.com/GoogleCloudPlatform/ai-ml-recipes/blob/main/notebooks/generative_ai/content_generation/product_attributes_from_image.ipynb">
<img src="https://cloud.google.com/ml-engine/images/github-logo-32px.png" alt="GitHub logo">
View on GitHub
</a>
</td>
<td>
<a href="https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/GoogleCloudPlatform/ai-ml-recipes/main/notebooks/generative_ai/content_generation/product_attributes_from_image.ipynb">
<img src="https://lh3.googleusercontent.com/UiNooY4LUgW_oTvpsNhPpQzsstV5W8F7rYgxgGBD85cWJoLmrOzhVs_ksK_vgx40SHs7jCqkTkCk=e14-rj-sc0xffffff-h130-w32" alt="Vertex AI logo">
Open in Vertex AI Workbench
</a>
</td>
</table>

## Overview

This notebook shows how to generate attributes and descriptions of products based on product images in a GCS bucket.  
It uses the [Stanford Online Products dataset](https://cvgl.stanford.edu/projects/lifted_struct/).
It uses the [Gemini](https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini) to generate product attributes and sales descriptions, using Spark UDFs to parallelize processing.

#### **Steps**
Using Spark,
1) It reads the table of the [Stanford Online Products dataset](https://cvgl.stanford.edu/projects/lifted_struct/) dataset located in [gs://dataproc-metastore-public-binaries/stanford_online_products](https://console.cloud.google.com/storage/browser/dataproc-metastore-public-binaries/stanford_online_products)    
We will create a metadata table poiting to the paths of the image files in the bucket.   
2) It calls Vertex AI Gemini API to get product attributes and product sales descriptions based on the image.

### Setup

Make sure the service account running this notebook has the required permissions:

- **Run the notebook**
  - AI Platform Notebooks Service Agent
  - Notebooks Admin
  - Vertex AI Administrator
- **Read files from bucket**
  - Storage Object Viewer
- **Run Dataproc jobs**
  - Dataproc Service Agent
  - Dataproc Worker
- **Call Google APIs**
  - Service Usage Consumer
- **BigQuery**
  - BigQuery Data Editor

#### Imports

In [2]:
import time

from pyspark.sql.functions import regexp_replace, concat
from pyspark.sql.functions import udf, col, lit

import google.auth
import google.auth.transport.requests
import requests

import pandas as pd
pd.set_option('display.max_colwidth', None)
pd.set_option('display.min_rows', 20)

ModuleNotFoundError: No module named 'pyspark'

In [None]:
# When using Dataproc Serverless, installed packages are automatically available on all nodes
!pip install --upgrade google-cloud-aiplatform -q
# When using a Dataproc cluster, you will need to install these packages during cluster creation: https://cloud.google.com/dataproc/docs/tutorials/python-configuration

#### Authentication

In [None]:
credentials, project_id = google.auth.default()
auth_req = google.auth.transport.requests.Request()
credentials.refresh(auth_req)

#### Setup Spark Session

In [None]:
from pyspark.sql import SparkSession

In [None]:
spark = SparkSession.builder \
  .appName("Image attributes and descriptions generation") \
  .enableHiveSupport() \
  .getOrCreate()

#### Read dataset

In [None]:
BINARIES_BUCKET_PATH = "gs://dataproc-metastore-public-binaries/stanford_online_products/"
binaries_df = spark.read.format("binaryFile").option("recursiveFileLookup", "true").load(BINARIES_BUCKET_PATH)

In [None]:
# Let's select the paths of the first 10 product images
paths_df = binaries_df.select("path").limit(10)
paths_df.cache()

|                                                                                          path|
|----------------------------------------------------------------------------------------------|
|gs://dataproc-metastore-public-binaries/stanford_online_products/sofa_final/181714736872_0.JPG|
|gs://dataproc-metastore-public-binaries/stanford_online_products/sofa_final/181661485577_1.JPG|
|gs://dataproc-metastore-public-binaries/stanford_online_products/sofa_final/171860974117_1.JPG|
|gs://dataproc-metastore-public-binaries/stanford_online_products/sofa_final/171860974117_2.JPG|
|gs://dataproc-metastore-public-binaries/stanford_online_products/sofa_final/181661485577_0.JPG|

#### Define prompt to get image attributes

In [None]:
system_instructions = [
        """You are a retail expert and your job is to generate structured information about products based on the images of these products.""",  
        """You also know how to write beatiful, elegant and concise product descriptions, based on data about a product.""",
        """Respond in the JSON format."""
]

In [None]:
def attributes_prompt():
  return f"""
<h5>Instructions</h5>
Analyze the content and generate the following attributes of these products based on the following questions:

product: "What product is this?"
color: "What is the product colors?"
gender: "The product shown in the image is most appropriate to be used by men, woman, all or other?"
brand: "What is the brand of the product shown in the image? reply unanswerable if you do not know for sure"
style: "What is the style of the product shown in the image? ex: modern, casual, tech"
material: "What is the material of the product shown in the image? ex: steel, wood, rubber"
purpose: "What is the purpose or usage of this product?"
year: "What is the year of the product? reply unanswerable if you do not know for sure"
sales_description: "Beatiful, elegant and concise product description"

<h4>Example</h4>
{{
product: "Brown Fashion Sneakers"
color: "Brown"
gender: "Woman"
brand: "unanswerable"
style: "Fashion Flat heel"
material: "Polyurethane"
purpose: "unanswerable"
year: "unanswerable"
sales_description: "A pair of brown sneakers with white soles on a white background is a stylish and comfortable choice for women who want to add a touch of color to their wardrobe. These sneakers are made of polyurethane, which is a durable and lightweight material that will keep your feet comfortable all day long. The flat heel makes them easy to wear for all-day activities, and the lace-up closure ensures a secure fit. These sneakers are perfect for a variety of occasions, from running errands to running errands. They can be dressed up or down, depending on your personal style. Pair them with a casual dress or jeans for a relaxed look, or dress them up with a skirt or pants for a more formal look."
}}
        
<h4>Response</h4>
"""

In [None]:
response_schema = {
    "type": "object",
    "properties": {
        "product": {"type": "string"},
        "color": {"type": "string"},
        "gender": {"type": "string"},
        "brand": {"type": "string"},
        "style": {"type": "string"},
        "material": {"type": "string"},
        "purpose": {"type": "string"},
        "year": {"type": "string"},
        "sales_description": {"type": "string"},
    },
    "required": ["product","color","gender","brand","style","material","purpose","year","sales_description"],
}

#### Define UDF and call Gemini API to generate product attributes

In [None]:
from vertexai.generative_models import GenerativeModel, GenerationConfig, Part, Image, Content, HarmCategory, HarmBlockThreshold

def predict(uri, prompt, system_instructions=system_instructions, response_schema=response_schema, content_type="image/jpg", temperature=1, model_name="gemini-1.5-pro"):

    model = GenerativeModel(model_name=model_name, system_instruction=system_instructions)
    
    prompt_content = Content(
        role="user",
        parts=[
            Part.from_uri(uri, content_type),
            Part.from_text(prompt)
        ]
    )

    response = model.generate_content(
        prompt_content,
        generation_config = GenerationConfig(
            temperature=temperature, response_mime_type="application/json", response_schema=response_schema
        ),
        safety_settings={
                HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmBlockThreshold.BLOCK_ONLY_HIGH
        }
    )
    
    return response.text

In [None]:
predict_udf = udf(predict)

In [None]:
image_metadata_df = paths_df.withColumn("gemini_analysis", predict_udf(col("path"), lit(attributes_prompt())))

In [None]:
image_metadata_df.show(5, 10)

In [None]:
image_metadata_df.cache()

In [None]:
from pyspark.sql.functions import from_json, col
from pyspark.sql.types import StructType, StructField, StringType
schema = StructType(
    [
        StructField('product', StringType(), True),
        StructField('color', StringType(), True),
        StructField('gender', StringType(), True),
        StructField('brand', StringType(), True),
        StructField('style', StringType(), True),
        StructField('material', StringType(), True),
        StructField('purpose', StringType(), True),
        StructField('year', StringType(), True),
        StructField('sales_description', StringType(), True)
    ]
)

In [None]:
df_final = image_metadata_df.withColumn("exploded_data", from_json(regexp_replace(regexp_replace(col("gemini_analysis"),"json", ""),"```",""), schema))\
    .select(col('path'),col('exploded_data.*'))

In [None]:
df_final = df_final.withColumn("url", regexp_replace(concat(lit("https://storage.mtls.cloud.google.com/"),col("path")), "gs://", ""))

In [None]:
df_final.toPandas()