In [1]:
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.messages import HumanMessage
from langchain_core.prompts.image import ImagePromptTemplate
from langchain.prompts import PromptTemplate
from langchain_core.prompt_values import ImageURL
from langchain_core.pydantic_v1 import BaseModel, Field, validator
from langchain.output_parsers import PydanticOutputParser
from google.cloud import vision
import requests
from io import BytesIO

GOOGLE_GEMINI_API_KEY = ""
GOOGLE_VISION_API_KEY = ""
PID = "ai-solution-genai-hack2skill"

urls = ['https://img.maisonkorea.com/2020/03/msk_5e65a1179ab47.jpg',
       'https://upload.wikimedia.org/wikipedia/commons/thumb/c/cf/Wohnzimmer_2007.jpg/1200px-Wohnzimmer_2007.jpg',
       'https://static.hyundailivart.co.kr/upload_mall/board/ME00000044/B200042482/tplt/0000215270_20220224105746834.jpg']

In [2]:
class TagsForList(BaseModel):
    
    furniture_list: list[str] = Field(..., description = "List of furnitures");

class Localizer:

    def __init__(self):

        self.GOOGLE_GEMINI_API_KEY = GOOGLE_GEMINI_API_KEY
        self.GOOGLE_VISION_API_KEY = GOOGLE_VISION_API_KEY
        self.llm = ChatGoogleGenerativeAI(model = 'gemini-pro', google_api_key = self.GOOGLE_GEMINI_API_KEY,
                                         temperature = 0)

        self.parser = PydanticOutputParser(pydantic_object = TagsForList)
        self.prompt = PromptTemplate(
            template = """Answer the user query. \n {format_instructions}\n{query}\n
            You select only the items from the list that can be categorized as furniture or appliances.
            """,
            input_variables = ["query"],
            partial_variables = {"format_instructions" : self.parser.get_format_instructions()}
        )

    def localize_objects(self, url, api_key = GOOGLE_API_KEY, pid = PID):

        client = vision.ImageAnnotatorClient(\
            client_options = {"api_key": self.GOOGLE_VISION_API_KEY, "quota_project_id": pid})
    
        res = requests.get(url)
        img = vision.Image(content = res.content)
        
        objects = client.object_localization(image = img).localized_object_annotations
    
        obj_list = []
        upper_left_axis_list = []
        bottom_right_axis_list = []
        
        for object_ in objects:
            
            obj_list.append(object_.name)
    
            for i, vertex in enumerate(object_.bounding_poly.normalized_vertices):
                
                if i == 0:
                    upper_left_axis_list.append((vertex.x, vertex.y))
    
                if i == 2:
                    bottom_right_axis_list.append((vertex.x, vertex.y))
    
        return dict(zip(obj_list, upper_left_axis_list)), dict(zip(obj_list, bottom_right_axis_list))
    
    def query(self, url):

        upper_left_axis, bottom_right_axis = self.localize_objects(url = url)

        query_sentence = str(list(upper_left_axis.keys()))

        chain = self.prompt | self.llm | self.parser
        llm_output = chain.invoke({"query" : query_sentence})
        furnitures = llm_output.furniture_list

        upper_left_axis = {key: upper_left_axis[key] for key in furnitures}
        bottom_right_axis = {key: bottom_right_axis[key] for key in furnitures}

        return upper_left_axis


NameError: name 'GOOGLE_API_KEY' is not defined

In [None]:
output_bbgs = []
loc = Localizer()
for url in urls:
    output_bbg = loc.query(url)
    output_bbgs.append(output_bbg)

In [None]:
output_bbgs[2]

In [None]:
from PIL import Image, ImageDraw, ImageFont
import requests
from io import BytesIO

def draw_rectangles(url:str, upper_left_dict:dict, bottom_right_dict:dict):
    
    res = requests.get(url)
    img = Image.open(BytesIO(res.content))
    draw = ImageDraw.Draw(img)

    width = img.width; height = img.height
    font = ImageFont.load_default()    

    for item in upper_left_dict.keys():

        outline_color = 'red'; text_color = 'blue'
        
        upper_left_axis = (int(upper_left_dict[item][0] * width), int(upper_left_dict[item][1] * height))
        bottom_right_axis = (int(bottom_right_dict[item][0] * width), int(bottom_right_dict[item][1] * height))
        draw.rectangle([upper_left_axis, bottom_right_axis], outline = outline_color, width = 2)
        draw.text(upper_left_axis, item, fill = text_color, font = font)

    return img

In [None]:
for url, output_bbg in zip(urls, output_bbgs):
    img = draw_rectangles(url = url,
                   upper_left_dict = output_bbg[0], bottom_right_dict = output_bbg[1])
    img.save('vai_'+url.split('/')[-1])