In [1]:
from omegaconf import OmegaConf
import torch, torch.nn.functional as F
from torchvision.utils import make_grid, save_image
from pytorch_lightning import seed_everything
from PIL import Image
from torchvision import transforms
from tqdm import tqdm
import kornia
import os, sys
sys.path.append(os.getcwd()),
sys.path.append('./unipaint/')
sys.path.append('./unipaint/src/clip')
sys.path.append('./unipaint/src/taming-transformers')
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.util import instantiate_from_config, load_model_from_config, ExemplarAugmentor

from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.messages import HumanMessage
from langchain_core.prompts.image import ImagePromptTemplate
from langchain.prompts import PromptTemplate
from langchain_core.prompt_values import ImageURL
from langchain_core.pydantic_v1 import BaseModel, Field, validator
from langchain.output_parsers import PydanticOutputParser
from google.cloud import vision
import requests
from io import BytesIO

from PIL import Image, ImageDraw, ImageFont
import requests
from io import BytesIO
import json

#### Setup

In [2]:
num_iter = 100  # num of fine-tuning iterations
lr = 1e-5
config_path = "unipaint/configs/stable-diffusion/v1-inference.yaml"
ckpt_path = "unipaint/ckpt/sd-v1-4-full-ema.ckpt"  # path to SD checkpoint
h = w = 512
scale=24  # cfg scale
ddim_steps= 50
ddim_eta=0.0
# seed_everything(42)
n_samples = 2
out_path = "outputs/" 
gpu_id = '0'
GOOGLE_GEMINI_API_KEY = ""
GOOGLE_VISION_API_KEY = ""
PID = ""

#### Define some useful functions

In [3]:
# save tensor as image file
def tsave(tensor, save_path, **kwargs):
    save_image(tensor, save_path, normalize=True, scale_each=True, value_range=(-1, 1), **kwargs)

#### Get bounding box and generating mask for sequential inpainting

In [4]:
class TagsForList(BaseModel):
    
    furniture_list: list[str] = Field(..., description = "List of furnitures");

class Localizer:

    def __init__(self):

        self.GOOGLE_GEMINI_API_KEY = GOOGLE_GEMINI_API_KEY
        self.GOOGLE_VISION_API_KEY = GOOGLE_VISION_API_KEY
        self.PID = PID
        self.llm = ChatGoogleGenerativeAI(model = 'gemini-pro', google_api_key = self.GOOGLE_GEMINI_API_KEY,
                                         temperature = 0)

        self.parser = PydanticOutputParser(pydantic_object = TagsForList)
        self.prompt = PromptTemplate(
            template = """Answer the user query. \n {format_instructions}\n{query}\n
            You select only the items from the list that can be categorized as furniture or appliances.
            """,
            input_variables = ["query"],
            partial_variables = {"format_instructions" : self.parser.get_format_instructions()}
        )

        self.furnitures = ""

    def localize_objects(self, url:str, api_key:str = None, pid:str = None):

        if api_key is None:
            api_key = self.GOOGLE_VISION_API_KEY

        if pid is None:
            pid = self.PID

        client = vision.ImageAnnotatorClient(\
            client_options = {"api_key": self.GOOGLE_VISION_API_KEY, "quota_project_id": pid})
    
        res = requests.get(url)
        img = vision.Image(content = res.content)
        
        objects = client.object_localization(image = img).localized_object_annotations
    
        obj_list = []
        upper_left_axis_list = []
        bottom_right_axis_list = []
        
        for object_ in objects:
            
            obj_list.append(object_.name)
    
            for i, vertex in enumerate(object_.bounding_poly.normalized_vertices):
                
                if i == 0:
                    upper_left_axis_list.append((vertex.x, vertex.y))
    
                if i == 2:
                    bottom_right_axis_list.append((vertex.x, vertex.y))
    
        return dict(zip(obj_list, upper_left_axis_list)), dict(zip(obj_list, bottom_right_axis_list))
    
    def query(self, url:str):

        upper_left_axis, bottom_right_axis = self.localize_objects(url = url)

        query_sentence = str(list(upper_left_axis.keys()))

        chain = self.prompt | self.llm | self.parser
        llm_output = chain.invoke({"query" : query_sentence})
        self.furnitures = llm_output.furniture_list

        upper_left_axis = {key: upper_left_axis[key] for key in self.furnitures}
        bottom_right_axis = {key: bottom_right_axis[key] for key in self.furnitures}

        return upper_left_axis, bottom_right_axis


def gen_masks(url:str, bounding_box_axis_top_left:dict, bounding_box_axis_bottom_right:dict, save_fn = None, expanding_factor:float = 1.05):

    if save_fn is None:
        save_fn = url.split('/')[-1] + '.jpg'
    
    res = requests.get(url)
    img = Image.open(BytesIO(res.content))
    img.save(save_fn + '.jpg')

    rect_axises = []
    dict_json = {}
    
    for obj_nm, top_left_tup, bottom_right_tup in zip(bounding_box_axis_top_left.keys(), bounding_box_axis_top_left.values(), bounding_box_axis_bottom_right.values()):

        # draw a black box
        width = img.width; height = img.height
        img_mask = Image.new("RGB", (width, height), "black")
        draw = ImageDraw.Draw(img_mask)
    
        rect = (int(top_left_tup[0]*width/expanding_factor),\
                int(top_left_tup[1]*height/expanding_factor),\
                int(bottom_right_tup[0]*width*expanding_factor),\
                int(bottom_right_tup[1]*height*expanding_factor))

        if rect not in rect_axises:
        
            rect_axises.append(rect)
            draw.rectangle(rect, fill = "white")
            img_mask.save(save_fn + f'_{obj_nm}' + '_mask.jpg')
            dict_json[obj_nm] = save_fn + f'_{obj_nm}' + '_mask.jpg'

    with open(save_fn + ".json", "w") as json_file:
        json.dump(dict_json, json_file)
    
    return None


In [5]:
urls = ['https://img.maisonkorea.com/2020/03/msk_5e65a1179ab47.jpg',
        'https://cdn.ggumim.co.kr/cache/star/600/20201222151726uFqauJF8wD.jpg',
        'https://img.maisonkorea.com/2020/03/msk_5e659f7c6816a.jpg']

GOOGLE_GEMINI_API_KEY = ""
GOOGLE_VISION_API_KEY = ""
PID = "104910716278680354156"

lc = Localizer()

for j, url in enumerate(urls):
    bb_dict_top_left, bb_dict_bottom_right = lc.query(url)
    gen_masks(url = url,\
              bounding_box_axis_top_left = bb_dict_top_left,\
              bounding_box_axis_bottom_right = bb_dict_bottom_right,\
              save_fn = f'./temp/sample_{j:02d}')

#### Exemplar inpainting

In [6]:
# os.environ['PYTORCH_CUDA_ALLOC_CONF']='expandable_segments'

def inpaint(config_path, ckpt_path, item_type, org_image_path:str, mask_path:str,
            ref_path:str, out_path:str, item_desc:str = "", save_fn:str = "", tuning = False):
 
    # set config
    
    config = OmegaConf.load(config_path)
    config.model.params.personalization_config.params.initializer_words = [item_type] 
    config.model.params.personalization_config.params.initializer_images = [ref_path] 

    device = torch.device(f"cuda:{gpu_id}") if torch.cuda.is_available() else torch.device("cpu")

    # load model
    
    model = load_model_from_config(config, ckpt_path, device)
    sampler = DDIMSampler(model)
    params_to_be_optimized = list(model.model.parameters())
    optimizer = torch.optim.Adam(params_to_be_optimized, lr=lr)
    os.makedirs(out_path, exist_ok=True)
    
    # load imgs from path
    
    image = Image.open(org_image_path).convert('RGB').resize((h,w), Image.Resampling.BILINEAR)
    mask = Image.open(mask_path).convert('L').resize((h,w), Image.Resampling.BILINEAR)
    image_ref = Image.open(ref_path).convert('RGB').resize((h,w), Image.Resampling.BILINEAR)

    # define enc/dec

    D = lambda _x: torch.clamp(model.decode_first_stage(_x), min=-1, max=1).detach() # vae decode
    E = lambda _x: model.get_first_stage_encoding(model.encode_first_stage(_x))  # # vae encode
    img_transforms = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: x.unsqueeze(0) * 2. - 1)])
    mask_transforms = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: (x.unsqueeze(0) > 0).float())])
    
    # define text encoder
    def C(_txt, enable_emb_manager=False):
        _txt = [_txt] if isinstance(_txt,str) else _txt
        with torch.enable_grad() if enable_emb_manager else torch.no_grad(): # # disable grad flow unless we want textual inv
            c = model.get_learned_conditioning(_txt, enable_emb_manager)
            return c

    # transform imgs
    
    x = img_transforms(image).to(device)
    m = mask_transforms(mask).to(device)
    x_ref = img_transforms(image_ref).to(device)

    x_in = x * (1 - m)
    z_xm = E(x_in)
    z_ref = E(x_ref)
    z_m = F.interpolate(m, size=(h // 8, w // 8))
    z_m = kornia.morphology.dilation(z_m, torch.ones((3,3),device=device))

    attn_mask = {}
    for attn_size in [64,32,16,8]:  # create attention masks for multi-scale layers in unet
        attn_mask[str(attn_size**2)]= (F.interpolate(m, (attn_size,attn_size), mode='bilinear'))[0,0,...]

    # fine tuning
    if tuning:

        exemplar_augmentor = ExemplarAugmentor(mask=mask)
        
        model.train()
        pbar = tqdm(range(70), desc='Fine-tune the model')
        for i in pbar:
            optimizer.zero_grad()
    
            x_reff, x_reff_mask = exemplar_augmentor(x_ref)
            z_reff = E(x_reff)
            z_reff_mask = F.interpolate(x_reff_mask, size=(h // 8, w // 8),mode='bilinear')
        
            t_emb = torch.randint(model.num_timesteps, (1,), device=device)
            c_ref =  C(item_desc,True).detach()
            uc = C("").detach()
            noise1 = torch.randn_like(z_xm)
            z_ref_t = model.q_sample(z_reff, t_emb, noise=noise1)
            pred_noise_ref = model.apply_model(z_ref_t, t_emb, c_ref)
            loss_ref = F.mse_loss(pred_noise_ref * z_reff_mask, noise1 * z_reff_mask)
        
            t_emb2 = torch.randint(model.num_timesteps, (1,), device=device)
            noise2 = torch.randn_like(z_xm)
            z_bg_t = model.q_sample(z_xm, t_emb2, noise=noise2)
            pred_noise_bg = model.apply_model(z_bg_t, t_emb2, uc)
            loss_bg = F.mse_loss(pred_noise_bg * (1 - z_m), noise2 * (1 - z_m))
        
            loss = loss_bg + loss_ref
            loss.backward()
            optimizer.step()
            
            losses_dict = {"loss": loss,  "loss_bg": loss_bg, "loss_ref":loss_ref}
            pbar.set_postfix({k: v.item() for k,v in losses_dict.items()})
        
    # gen imgs
    
    with torch.no_grad(), torch.autocast(device.type):
        tmp, _ = sampler.sample(S=ddim_steps, batch_size=n_samples, shape=[4, h // 8, w // 8],
                            conditioning=C(item_desc, True).repeat(n_samples,1,1), 
                            blend_interval=[0, 1], 
                            x0=z_xm.repeat(n_samples,1,1,1), 
                            mask=z_m.repeat(n_samples,1,1,1), 
                            attn_mask=attn_mask,
                            x_T=None, 
                            unconditional_guidance_scale=scale, 
                            eta=ddim_eta,
                            verbose=False)

    tsave(D(tmp), os.path.join(out_path, save_fn), nrow=n_samples)

    return model

In [7]:
fine_tuned_model = inpaint(item_type = 'desk',
        config_path = config_path,
        ckpt_path = ckpt_path,
        org_image_path = './temp/sample_00.jpg',
        mask_path = './temp/sample_00_Desk_mask.jpg',
        ref_path = './data/desk/thumbnails_nobg/50.jpg',
        item_desc = 'modern black desk in cozy room',
        out_path = './outputs/',
       save_fn = 'sample_00_st_01_D.jpg',
       tuning = True)

LatentDiffusion: Running in eps-prediction mode
DiffusionWrapper has 859.52 M params.
making attention of type 'vanilla' with 512 in_channels
Working with z of shape (1, 4, 32, 32) = 4096 dimensions.
making attention of type 'vanilla' with 512 in_channels
initializer_words is desk


Fine-tune the model: 100%|█████████| 150/150 [2:22:24<00:00, 56.96s/it, loss=0.0104, loss_bg=0.00782, loss_ref=0.00263]


Data shape for DDIM sampling is (2, 4, 64, 64), eta 0.0
Running DDIM Sampling with 50 timesteps


DDIM Sampler: 100%|████████████████████████████████████████████████████████████████████| 50/50 [04:50<00:00,  5.82s/it]


 blend_interval = [0, 1] 



In [9]:
fine_tuned_model = inpaint(item_type = 'desk',
        config_path = config_path,
        ckpt_path = ckpt_path,
        org_image_path = './temp/sample_00.jpg',
        mask_path = './temp/sample_00_Desk_mask.jpg',
        ref_path = './data/desk/thumbnails_nobg/50.jpg',
        item_desc = '#',
        out_path = './outputs/',
       save_fn = 'sample_00_st_01_E.jpg',
       tuning = False)

LatentDiffusion: Running in eps-prediction mode
DiffusionWrapper has 859.52 M params.
making attention of type 'vanilla' with 512 in_channels
Working with z of shape (1, 4, 32, 32) = 4096 dimensions.
making attention of type 'vanilla' with 512 in_channels
initializer_words is desk
Data shape for DDIM sampling is (2, 4, 64, 64), eta 0.0
Running DDIM Sampling with 50 timesteps


DDIM Sampler: 100%|████████████████████████████████████████████████████████████████████| 50/50 [03:16<00:00,  3.92s/it]


 blend_interval = [0, 1] 

