In [1]:
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.messages import HumanMessage
from langchain_core.prompts.image import ImagePromptTemplate
from langchain.prompts import PromptTemplate
from langchain_core.prompt_values import ImageURL
from langchain_core.pydantic_v1 import BaseModel, Field, validator
from langchain.output_parsers import PydanticOutputParser
from PIL import Image, ImageDraw
import requests
from io import BytesIO

GOOGLE_API_KEY = "AIzaSyCVFmbgbdktlHrE9_w3c9CeWo3Dchf6Of0"

In [2]:
class TagsForAxisAspects(BaseModel):
    
    bounding_box_axis_top_left: list[tuple[float, float]] = Field(..., description = "axis of top-left bounding box via query");
    bounding_box_axis_bottom_right: list[tuple[float, float]] = Field(..., description = "axis of bottom-right bounding box via query");
    
class SentenceToAspect:

    def __init__(self):

        self.GOOGLE_API_KEY = GOOGLE_API_KEY
        self.llm = ChatGoogleGenerativeAI(model = 'gemini-pro', google_api_key = self.GOOGLE_API_KEY,
                                         temperature = 0)

        self.parser = PydanticOutputParser(pydantic_object = TagsForAxisAspects)
        self.prompt = PromptTemplate(
            template = """Answer the user query. \n {format_instructions}\n{query}\n
            
            You must convey the input axis of bounding box, which is presented to [upper_left_x, upper_left_y, lower_right_x, lower_right_y] to aspect outputs.
            """,
            input_variables = ["query"],
            partial_variables = {"format_instructions" : self.parser.get_format_instructions()}
        )

    def query(self, query_sentence: str):

        chain = self.prompt | self.llm | self.parser
        return chain.invoke({"query" : query_sentence})


class ImageToAspect(SentenceToAspect):

    def __init__(self):
        SentenceToAspect.__init__(self)
        
        self.GOOGLE_API_KEY = GOOGLE_API_KEY
        self.lmm = ChatGoogleGenerativeAI(model = 'gemini-pro-vision', google_api_key = self.GOOGLE_API_KEY, temporary = 0, top_p = 0.2)

    def query_room_img(self, url:str = None):

        self.message = HumanMessage(
            content = [
                {'type': 'text',
                 'text': """이것은 빈 방의 사진이야. 너는 여기서 침대 크기의 가구를 배치할 적절한 방 영역의 영역을 masking해야 해.

let's think step by step.

해당 masking에 대한 bounding box의 top-left 좌표와 bottom-right 좌표를 수치로 제시하고, 원본 그림 위에 해당 bounding box를 그려줘.

주변의 사물은 box에서 최대한 exclude 하고, wall과 floor는 bounding box에 최대한 include 시켜줘.

단, 좌표의 수치는 이미지의 width나 height에 대한 0에서 1 사이의 상대값으로 제시해야 돼.

또, 문이나 장농이 있는 영역은 masking하면 안돼.

일단 하나를 제시한 다음, 그 bounding box의 바깥에서 또 다른 추천 box가 있다면 계속해서 제시해줘.
"""}
                ,
                {'type' : 'image_url',
                 'image_url' : url}
            ]           
        )
        
        self.return_sentence_img = self.lmm.invoke([self.message])
        return self.query(query_sentence = self.return_sentence_img.content)

In [9]:
def gen_masks(url:str, save_fn = None):

    if save_fn is None:
        save_fn = url.split('/')[-1] + '.jpg'
    
    res = requests.get(url)
    img = Image.open(BytesIO(res.content))
    img.save(save_fn + '.jpg')

    # decl ITA object
    ita = ImageToAspect()
    axis_aspect = ita.query_room_img(url)

    # get bounding box from aspect objects
    bounding_box_axis_top_left = axis_aspect.bounding_box_axis_top_left
    bounding_box_axis_bottom_right = axis_aspect.bounding_box_axis_bottom_right

    # draw a black box
    width = img.width; height = img.height
    img_mask = Image.new("RGB", (width, height), "black")
    draw = ImageDraw.Draw(img_mask)

    rectangles = []
    for top_left_tup, bottom_right_tup in zip(bounding_box_axis_top_left, bounding_box_axis_bottom_right):
        rectangles.append((int(top_left_tup[0]*width),\
                           int(top_left_tup[1]*height),\
                           int(bottom_right_tup[0]*width),\
                           int(bottom_right_tup[1]*height)))

    for rect in rectangles:
        draw.rectangle(rect, fill = "white")

    img_mask.save(save_fn + '_mask.jpg')
    
    return None


In [10]:
urls = ['https://img.freepik.com/premium-photo/clean-empty-room-background_28504-220.jpg',
        'https://as2.ftcdn.net/v2/jpg/04/48/91/67/1000_F_448916762_6Ld63L05E9VqKaT78irG8GjHASXSSwff.jpg',
        'https://previews.123rf.com/images/photoauris/photoauris1302/photoauris130200015/17922352-3-%EC%B0%A8%EC%9B%90-%EB%B9%88-%EB%B0%A9%EC%9D%84-%EB%A0%8C%EB%8D%94%EB%A7%81.jpg',
        'https://media.istockphoto.com/id/612028082/ko/%EC%82%AC%EC%A7%84/%ED%9D%B0%EC%83%89-%EC%B0%BD%EB%AC%B8%EC%9D%B4-%EC%9E%88%EB%8A%94-%EB%B9%88-%EB%B0%A9.jpg?s=1024x1024&w=is&k=20&c=lloKw1IMYB2ztgVvhBxfSo8OIrLYI-4Xq4EnctTIG1s=']

for j, url in enumerate(urls):
    gen_masks(url = url, save_fn = f'sample_{j:02d}')