# 3. Creating functions for handling retail user interactions

This notebook walks you through implementing RAG with the database that we have created using Claude 3 models. 

Firstly, we create functions to deal with product images and use an LLM to select the best product image for display based on the product title. This helps to address scenarios where products come with images of material closeups and/or size charts, along with images of models wearing the product, and we want to pick the images from which users can get the most insights on the product at one glance as the thumbnail image.

Secondly, we create a function that uses an LLM to construct a search query for use with Bedrock knowledge base, using an LLM to refine the search query based on user input and feedback, as well as selected product information. This allows for users to give input based on visual cues e.g. "I want something like this shirt but in black". 

3.0. [Set up](#3.0)

3.1. [Test and use the Bedrock Knowledge Base Retrieve API](#3.1)

3.2. [Create functions to handle product images](#3.2)

3.3. [Translate user input into more relevant product search results](#3.3)

## <a id="3.0">Set up<a>

In [None]:
# run this cell to upgrade to the latest version of boto3 if required, and restart the kernel
!pip install --upgrade --force --quiet botocore boto3

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import boto3
import sagemaker

import pandas as pd
import ast
import json

<div class="alert alert-block alert-warning">

IMPORTANT! Please copy and paste the required information for your <b>RDS Aurora PostgreSQL database</b> in the cell below.
    
</div>

In [None]:
sess = sagemaker.Session()
bucket = sess.default_bucket()
region = sess.boto_region_name
accountid = sess.account_id()
product_db_data_path = 'amazon-reviews-fashion-metadata'
bedrock_kb_data_path = 'bedrock-kb-data'
bedrock_kb_datasource_uri = f's3://{bucket}/{bedrock_kb_data_path}/'

database_identifier='<TODO>'
database_arn='<TODO>'
database_secret_arn='<TODO>'
database_name='<TODO>'

In [None]:
%mkdir -p util

In [None]:
# OPTIONAL
items = pd.read_csv('items.txt', sep='|', index_col=False, 
                    names=['asin',	'title', 'brand', 'price', 'description', 'image'])
items

## 3.1 <a id="3.1">Test and use the Bedrock Knowledge Base Retrieve API<a>

### Run a test query

<div class="alert alert-block alert-warning">

IMPORTANT! Please copy and paste the <b>Bedrock Knowledge Base ID</b> for the knowledge base that you are using in the cell below.
    
</div>

In [None]:
bedrock_kb_id = '<TODO>'

In [None]:
from util.bedrockkb import bedrock_kb_retrieve

search_query = 'shoes for women'
no_kb_results = 3
search_list = bedrock_kb_retrieve(bedrock_kb_id, search_query, no_kb_results)

In [None]:
search_list

### Create a wrapper to extract a list of asin from the output of the Bedrock knowledge base Retrieve API call

In [None]:
%%writefile util/getasinlist.py

def get_asin_list(search_list): 
    if type(search_list) == list:
        asin_list=[]
        for search_item in search_list:
            asin_list.append(search_item['location']['s3Location']['uri'].split('/')[-1].split('.')[0])
        return asin_list
    elif type(search_list) == str:
        search_item = search_list['location']['s3Location']['uri'].split('/')[-1].split('.')[0]
        return ('['+str(search_item)+']').tolist()
    else:
        return None

In [None]:
# For local notebook processing only

# Get product image URL from dataframe using asin
def get_img_url_from_df(asin, df):
    try: 
        return ast.literal_eval(str("['"+df[df['asin']==asin].to_dict('records')[0]['image']+"']"))
    except:
        print('No image found.')

# Get product title from dataframe using asin
def get_title_from_df(asin, df):
    return str("['"+df[df['asin']==asin].to_dict('records')[0]['title']+"']")

In [None]:
from util.getasinlist import get_asin_list
asin_list=get_asin_list(search_list)
asin_list

## 3.2 <a id ="3.2">Create functions to handle product images<a>

### Create a function to display images returned from product image URLs

In [None]:
%%writefile util/gallery.py
from IPython.display import HTML, Image

def gallery(images, row_height='200px'):
    """Shows a set of images in a gallery that flexes with the width of the notebook.
    
    Parameters
    ----------
    images: list of str or bytes
        URLs or bytes of images to display

    row_height: str
        CSS height value to assign to all images. Set to 'auto' by default to show images
        with their native dimensions. Set to a value like '250px' to make all rows
        in the gallery equal height.
    """
    if images:
        figures = []
        for image in images:
            src = image
            caption = f'<figcaption style="font-size: 0.6em">{image}</figcaption>'
            figures.append(f'''
                <figure style="margin: 5px !important;">
                  <img src="{src}" style="height: {row_height}">
                  {caption}
                </figure>
            ''')
        return HTML(data=f'''
            <div style="display: flex; flex-flow: row wrap; text-align: center;">
            {''.join(figures)}
            </div>
        ''')
    else:
        return None

In [None]:
asin = asin_list[0]
url=get_img_url_from_df(asin, items)
url

In [None]:
from util.gallery import gallery

gallery(url)

### Add helpers to process images from URL and/or handle bytes to string conversion

In [None]:
%%writefile util/imagehelpers.py

import requests
from PIL import Image
from io import BytesIO
import base64

#screen through img URL list for blank or invalid URLs
def filter_image_url(img_list):
    images = []
    for url in img_list:
        try:
            image = Image.open(requests.get(url, stream=True).raw)
            images.append(url)
        except:
            pass
    if len(images) == 0:
        return None
    else:
        return images

#check if the image in the URL exists
def url_image_processing(imgurl):
    imgurl_response = requests.get(imgurl)
    imgurl_bytes = BytesIO(imgurl_response.content)
    imgurl_bytes.seek(0)
    return imgurl_bytes.read()

#get a BytesIO object from file bytes
def get_bytesio_from_bytes(image_bytes):
    image_io = BytesIO(image_bytes)
    return image_io

#get a base64-encoded string from file bytes
def get_base64_from_bytes(image_bytes):
    resized_io = get_bytesio_from_bytes(image_bytes)
    img_str = base64.b64encode(resized_io.getvalue()).decode("utf-8")
    return img_str

### Create an image picker to pick the most suitable image based on the product title

This is to address situations where the merchants upload images of size charts instead of attractive product images. The recommended approach for doing this in practice would be to perform a batch inference for each product and store the best image in a database table rather than doing it using real time inference with an API call.

In [None]:
%%writefile util/pickimg.py

import boto3
import ast, json
from util.imagehelpers import *

def pick_img(text_input, img_list):
    model_id = "anthropic.claude-3-sonnet-20240229-v1:0"
    bedrock = boto3.client(service_name='bedrock-runtime')

    accept = "application/json"
    content_type = "application/json"
    
    def set_default(obj):
        if isinstance(obj, set):
            return list(obj)
        raise TypeError
    
    system_prompt = """
    Based on the user's text and images, select and output the identifier for the image that best matches the text and is most visually appealing to shoppers.
    Please output only the image number enclosed in a python array and nothing else."""
    
    content = [
                    {
                        "type": "text",
                        "text": text_input
                    }
                ]
    
    img_list = filter_image_url(img_list)
    
    if img_list == None:
        return None
    
    for img in img_list:
        image_data = get_base64_from_bytes(url_image_processing(img))
        content.append({
                        "type": "image",
                        "source": {
                            "type": "base64",
                            "media_type": "image/jpeg",
                            "data": image_data,
                        }
        })
        
    messages = [{
                "role": "user",
                "content": content
    }]
        
    body = json.dumps({
        "anthropic_version": "bedrock-2023-05-31",
        "max_tokens": 10,
        "temperature":0,
        "system": system_prompt,
        "messages": messages
    }, default=set_default)
    
    
    response = bedrock.invoke_model(body=body, modelId=model_id)
    response_body = json.loads(response.get('body').read())['content']
    answer = response_body[0]['text']
    try:
        answer = ast.literal_eval(answer.strip('[]'))
    except:
        pass
    
    if (isinstance(answer, int)) and (answer<len(img_list)):
        return img_list[answer-1]
    else:
        return None

In [None]:
from util.pickimg import pick_img

#pick the image that best matches the product title
title = get_title_from_df(asin, items)
ans = pick_img(title, url)
product = json.loads(json.dumps({ "asin" : asin, "title" : title, "image" : ans }))
print(product)
gallery([ans])

## 3.3 <a id="3.3">Translate user input into more relevant product search results<a>

### Create a query refiner

The query refiner can help translate user input and/or feedback into a query for interacting with a Bedrock knowledge base. 
- It can interpret user requests and construct a search query
- It can factor in user feedback based on a selected product (e.g. a user can refer to an item that was suggested and ask for something similar in brighter colors), by taking the image URL provided to the user as input.
- It can consider the recent search queries (query history) for constructing the new query.

In [None]:
%%writefile util/refinequery.py

import boto3
import json, logging, re
from util.imagehelpers import *

def refine_query(input_text,query_history=None,product=None,log_level='ERROR'):
    
    model_id = "anthropic.claude-3-sonnet-20240229-v1:0"

    bedrock = boto3.client(service_name='bedrock-runtime')

    accept = "application/json"
    content_type = "application/json"
    
    content = [
                    {
                        "type": "text",
                        "text": f"CURRENT USER QUERY: {input_text}"
                    }
                ]
    if query_history:
        for i, query in enumerate(query_history):
                content.append({
                            "type": "text",
                            "text": f"PAST USER QUERY: {query}"
                            })
                if i==3:
                    break

    if product:
        asin = product['asin']
        title = product['title']
        img = product['image']
        image_data = get_base64_from_bytes(url_image_processing(img))
        content.append({
                        "type": "text",
                        "text": f"IMAGE title: {title}"
                        })
        content.append({
                        "type": "image",
                        "source": {
                            "type": "base64",
                            "media_type": "image/jpeg",
                            "data": image_data,
                        }
        })

    messages = [{
                "role": "user",
                "content": content
    }]
    

    system_prompt = """
    Your task is to help the user find fashion apparel that they like.
    Please only construct and output a descriptive text search query optimized for search with pgvector. Do not output anything else.
    Be as descriptive as reasonable and use information from the user input to identify features like age group, gender, color, material etc to help find fashion apparel. 
    Do not assume or hallucinate.
    Always refine the search query by using the CURRENT USER QUERY to update your output, factoring in PAST USER QUERY and IMAGE."""
    
    level = logging.getLevelName(log_level)
    logging.basicConfig()
    logging.getLogger().setLevel(level)
    logger = logging.getLogger(__name__)
    logger.info(f"Messages {messages}")
    
    def set_default(obj):
        if isinstance(obj, set):
            return list(obj)
        raise TypeError
    
    # result = json.dumps(yourdata, default=set_default)
    body = json.dumps({
        "anthropic_version": "bedrock-2023-05-31",
        "max_tokens": 4000,
        "temperature":0,
        "system": system_prompt,
        "messages": messages
    }, default=set_default)


    response = bedrock.invoke_model(body=body, modelId=model_id)
    response_body = json.loads(response.get('body').read())['content']
    output_query = str(response_body[0]['text'])
    
    def clean_text(text):
        html_pattern = re.compile('<.*?>')
        clean_text = re.sub(html_pattern, '', text)
        output_text = str(clean_text).replace(r'/[^a-zA-Z0-9 ]', '').strip()
        return output_text
    
    output_query = clean_text(output_query)
    
    if query_history:
        query_list = [output_query] + query_history
        query_list = query_list[:3]
    else:
        query_list = [output_query]
    
    return output_query, query_list

### Testing the query refiner for creating an initial search query

In [None]:
from util.refinequery import refine_query

search_query="I want to buy stretchy breathable sports wear for exercise."

query, query_history=refine_query(search_query)
query

In [None]:
search_asin = get_asin_list(bedrock_kb_retrieve(bedrock_kb_id, query, no_kb_results))[0]
search_img = get_img_url_from_df(search_asin, items)
search_title = get_title_from_df(search_asin, items)
print(search_title)
gallery(search_img)

In [None]:
best_img = pick_img(search_title, search_img)
gallery([best_img])

### Incorporating user feedback into the search history and search results 

In [None]:
search_product = json.loads(json.dumps({ "asin" : search_asin, "title" : search_title, "image" : best_img}))

search_query1="I want shirts for men."
query1, query_history1=refine_query(search_query1, query_history, search_product)
query1

### Creating a function to get product information from RDS using the RDS Data API

In [None]:
%%writefile util/getinfo.py

import boto3
import ast, json, re

def get_info_from_db(asin, database_arn, database_secret_arn, database_name):
    
    asin = re.sub(r'\W+', '', asin)
    
    query=(
        f"SELECT asin, title, image FROM products WHERE asin='{asin}';")

    rdsdata = boto3.client('rds-data')

    response = rdsdata.execute_statement(
        resourceArn=database_arn,
        secretArn=database_secret_arn,
        sql=query,
        database=database_name,
    )
    
    info = response['records'][0]
    info_image = ast.literal_eval("['"+json.dumps(info[2]['stringValue']).strip('\\\'"')+"']")
    
    info_dict = {'asin':json.dumps(info[0]['stringValue']),
                 'title':json.dumps(info[1]['stringValue']),
                 'image':info_image}
        
    return info_dict

In [None]:
from util.getinfo import get_info_from_db

asin_list = get_asin_list(bedrock_kb_retrieve(bedrock_kb_id, query1, no_kb_results))
asin = list(filter(lambda x :x!=search_asin, asin_list))[0]
asin_result = get_info_from_db(asin, database_arn, database_secret_arn, database_name)
asin_result['image'] = pick_img(asin_result['title'], asin_result['image'])
print(asin_result)
gallery([asin_result['image']])