In [None]:
import math
import torch

from tqdm.notebook import tqdm
from PIL import Image, ImageOps

from diffusers import DDIMScheduler, DDIMInverseScheduler
from pipeline_stable_diffusion_grounded_instruct_pix2pix import StableDiffusionInstructPix2PixPipeline

from transformers import logging
logging.set_verbosity_error()

import os 
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"

In [None]:
def load_pil_image(image_path, resolution=512):
    image = Image.open(image_path).convert("RGB")
    width, height = image.size
    factor = resolution / max(width, height)
    factor = math.ceil(min(width, height) * factor / 64) * 64 / min(width, height)
    width = int((width * factor) // 64) * 64
    height = int((height * factor) // 64) * 64
    image = ImageOps.fit(image, (width, height), method=Image.Resampling.LANCZOS)
    return image

In [None]:
# mask extractor
device = 'cuda:0'

# pipeline
num_timesteps = 100
pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained("timbrooks/instruct-pix2pix",
                                                                  torch_dtype=torch.float16,
                                                                  safety_checker=None).to(device)
pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config, set_alpha_to_zero=False)

pipeline.scheduler.set_timesteps(num_timesteps)
pipeline.inverse_scheduler.set_timesteps(num_timesteps)

In [None]:
def inference(pipeline, image_pil, instruction, 
              image_guidance_scale, text_guidance_scale, seed, blending_range):
    external_mask_pil, chosen_noun_phrase = mask_extractor.get_external_mask(image_pil, instruction, verbose=verbose)

    inv_results = pipeline.invert(instruction, image_pil, num_inference_steps=num_timesteps, inv_range=blending_range)

    generator = torch.Generator(device).manual_seed(seed) if seed is not None else torch.Generator(device)
    edited_image = pipeline(instruction, src_mask=external_mask_pil, image=image_pil,
                            guidance_scale=text_guidance_scale, image_guidance_scale=image_guidance_scale,
                            num_inference_steps=num_timesteps, generator=generator).images[0]
    return edited_image

In [None]:
from external_mask_extractor import ExternalMaskExtractor  

mask_extractor = ExternalMaskExtractor(device=device)

In [None]:
# the default values of Instruct-Pix2Pix are a good starting point,
# but you can often get better results with higher guidance_scale

##! use our method

verbose = True
image_path = './test_img/test1.jpg' 


# edit_instruction = 'turn apples into watermelon'  
# edit_instruction = 'turn the third apple into watermelon'   # position
# edit_instruction = 'turn the red apple into watermelon'   # color
# edit_instruction = 'turn the second largest apple into watermelon'   # size
# edit_instruction = 'turn the second red apple into watermelon'   # color + position
edit_instruction = 'turn the second largest red apple into egg'   # color + size


image_guidance_scale = 1.5
guidance_scale = 7.5  
# here, steps are defined w.r.t. num_train_steps(=1000)
start_blending_at_tstep = 100
end_blending_at_tstep   = 1
blending_range = [start_blending_at_tstep, end_blending_at_tstep]
seed = 42


image = load_pil_image(image_path)
image.show()

edited_image = inference(pipeline, image, edit_instruction, image_guidance_scale, 
                         guidance_scale, seed, blending_range)
print(type(edited_image))
edited_image.show()


In [None]:
# the default values of Instruct-Pix2Pix are a good starting point,
# but you can often get better results with higher guidance_scale

##! use Instruct-Pix2Pix model

image_guidance_scale = 1.5
guidance_scale = 7.5  
edit_instruction = 'make this picture look like a sketch'   # color + size

image.show()
edited_image = inference(pipeline, image, edit_instruction, image_guidance_scale, 
                         guidance_scale, seed, blending_range)
edited_image.show()
