In [3]:
import argparse
import os
import copy
import gradio as gr

import numpy as np
import torch
from PIL import Image, ImageDraw, ImageFont


# Grounding DINO
import GroundingDINO.groundingdino.datasets.transforms as T
from GroundingDINO.groundingdino.models import build_model
from GroundingDINO.groundingdino.util import box_ops
from GroundingDINO.groundingdino.util.slconfig import SLConfig
from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap

# segment anything
from segment_anything import build_sam, SamPredictor 
import cv2
import numpy as np
import matplotlib.pyplot as plt

# diffusers
import PIL
import requests
import torch
from io import BytesIO
from stable_diffusion_masked_diffedit import StableDiffusionMaskedDiffeditPipeline
from diffusers import DDIMScheduler

# chatgpt
from chatgpt import call_chatgpt

# blip2
from transformers import AutoProcessor, Blip2ForConditionalGeneration

from operator import itemgetter

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
def load_model(model_config_path, model_checkpoint_path, device):
    args = SLConfig.fromfile(model_config_path)
    args.device = device
    model = build_model(args)
    checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
    load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
    print(load_res)
    _ = model.eval()
    return model

In [5]:
# cfg
config_file = 'GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py'  # change the path of the model config file
grounded_checkpoint = 'groundingdino_swint_ogc.pth'  # change the path of the model
sam_checkpoint = 'sam_vit_h_4b8939.pth'
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# load model
model = load_model(config_file, grounded_checkpoint, device=device)
blip_processor = AutoProcessor.from_pretrained('Salesforce/blip2-flan-t5-xl')
blip_model = Blip2ForConditionalGeneration.from_pretrained('Salesforce/blip2-flan-t5-xl', torch_dtype=torch.float16).to(device) 

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


final text_encoder_type: bert-base-uncased


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


_IncompatibleKeys(missing_keys=[], unexpected_keys=['label_enc.weight'])


Loading checkpoint shards: 100%|██████████| 2/2 [00:11<00:00,  5.60s/it]


In [6]:
def load_image(image_path, height=512, width=512,):
    # load image
    image_pil = Image.open(image_path).convert("RGB")  # load image
    image_pil = preprocess_image(image_pil, height, width)

    transform = T.Compose(
        [
            T.RandomResize([800], max_size=1333),
            T.ToTensor(),
            T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    )
    image, _ = transform(image_pil, None)  # 3, h, w
    return image_pil, image


def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, device="cpu", return_max=False):
    caption = caption.lower()
    caption = caption.strip()
    if not caption.endswith("."):
        caption = caption + "."
    model = model.to(device)
    image = image.to(device)
    with torch.no_grad():
        outputs = model(image[None], captions=[caption])
    logits = outputs["pred_logits"].cpu().sigmoid()[0]  # (nq, 256)
    boxes = outputs["pred_boxes"].cpu()[0]  # (nq, 4)
    logits.shape[0]

    # filter output
    logits_filt = logits.clone()
    boxes_filt = boxes.clone()
    filt_mask = logits_filt.max(dim=1)[0] > box_threshold
    logits_filt = logits_filt[filt_mask]  # num_filt, 256
    boxes_filt = boxes_filt[filt_mask]  # num_filt, 4
    logits_filt.shape[0]

    # get phrase
    tokenlizer = model.tokenizer
    tokenized = tokenlizer(caption)
    # build pred
    pred_phrases = []
    
    logits_filt_item = []
    
    for logit, box in zip(logits_filt, boxes_filt):
        pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
        logits_filt_item.append(logit.max().item())
        if with_logits:
            pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
        else:
            pred_phrases.append(pred_phrase)
    
    if return_max:
        max_logit_index, max_logit = max(enumerate(logits_filt_item), key=itemgetter(1))
        return boxes_filt[max_logit_index].unsqueeze(0), [pred_phrases[max_logit_index]]

    return boxes_filt, pred_phrases

def show_mask(mask, ax, random_color=False, opacity=0.6):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, opacity])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)


def show_box(box, ax, label):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2)) 
    ax.text(x0, y0, label)
    
def adjust_encoding_ratio(instructions, feedbacks, cur_encoding_ratio):
    full_instructions = []
    for instruction in instructions:
        full_instruction = '{}. The current encoding ratio is {}.'.format(instruction, cur_encoding_ratio)
        full_instructions.append(full_instruction)
        
    encoding_ratio = call_chatgpt(full_instructions, feedbacks=feedbacks)
    
    return encoding_ratio

def preprocess_image(image, height=512, width=512, left=0, right=0, top=0, bottom=0):
    if isinstance(image, str):
        image = np.array(Image.open(image))
    elif isinstance(image, np.ndarray):
        pass
    else:
        image = np.array(image)
        
    if image.ndim == 3:
        image = image[:, :, :3]
        h, w, _ = image.shape
    else:
        h, w = image.shape
        
    left = min(left, w-1)
    right = min(right, w - left - 1)
    top = min(top, h - left - 1)
    bottom = min(bottom, h - top - 1)
    image = image[top:h-bottom, left:w-right]
    
    if image.ndim == 3:
        h, w, _ = image.shape
    else:
        h, w = image.shape
    if h < w:
        offset = (w - h) // 2
        image = image[:, offset:offset + h]
    elif w < h:
        offset = (h - w) // 2
        image = image[offset:offset + w]
    image = Image.fromarray(image).resize((height, width))
    return image

In [8]:
@torch.no_grad()
def maskedit(image_path, user_instructions, is_blip2_description, encoding_ratio, mask_mode, 
             return_max=True, box_threshold=0.3, text_threshold=0.25, is_show_box=True):
    # load image
    image_pil, image = load_image(image_path, height=512, width=512,)    

    if is_blip2_description:
        
        blip_prompt_1 = "Is this a photo, a painting, a drawing, or other kind of arts?"
        blip_inputs = blip_processor(image_pil, text=blip_prompt_1, return_tensors="pt").to(device, torch.float16)
        generated_ids = blip_model.generate(**blip_inputs, max_new_tokens=20)
        generated_text = blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
        blip_prompt_2 = "{} of".format(generated_text.capitalize())
        blip_inputs = blip_processor(image_pil, text=blip_prompt_2, return_tensors="pt").to(device, torch.float16)
        generated_ids = blip_model.generate(**blip_inputs, max_new_tokens=20)
        description = blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
        description = "{} {}.".format(blip_prompt_2, description)
        descriptions = [description]    
        print('Blip2 description: {}'.format(description))
    else:
        descriptions = None
    instructions = [user_instructions]
    det_prompt, prompt_inversion, prompt, _ = call_chatgpt(instructions, descriptions=descriptions)
    print('Segmentation prompt: {}'.format(det_prompt))
    print('Inverted prompt: {}'.format(prompt_inversion))
    print('Editing prompt: {}'.format(prompt))
    
    prompts = [prompt_inversion, prompt]

    # load model
    model = load_model(config_file, grounded_checkpoint, device=device)

    # run grounding dino model
    boxes_filt, pred_phrases = get_grounding_output(
        model, image, det_prompt, box_threshold, text_threshold, device=device, return_max=return_max,
    )

    # initialize SAM
    predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint).to(device))
    image = cv2.imread(image_path)
    image = np.asarray(preprocess_image(image, height=512, width=512))
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    predictor.set_image(image)

    size = image_pil.size
    H, W = size[1], size[0]
    for i in range(boxes_filt.size(0)):
        boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
        boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
        boxes_filt[i][2:] += boxes_filt[i][:2]

    boxes_filt = boxes_filt.cpu()
    transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2]).to(device)

    masks, _, _ = predictor.predict_torch(
        point_coords = None,
        point_labels = None,
        boxes = transformed_boxes.to(device),
        multimask_output = False,
    )

    if is_show_mask:
        # draw output image
        plt.figure(figsize=(10, 10))
        plt.imshow(image)
        for mask in masks:
            show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
        if is_show_box:
            for box, label in zip(boxes_filt, pred_phrases):
                show_box(box.numpy(), plt.gca(), label)

        plt.axis('off')
        
    # Masked diffedit pipeline
    if mask_mode == 'merge':
        masks = torch.sum(masks, dim=0).unsqueeze(0)
        masks = torch.where(masks > 0, True, False)
    mask = masks[0][0].cpu().numpy() # simply choose the first mask, which will be refine in the future release
    mask_pil = Image.fromarray(mask)
    image_pil = Image.fromarray(image)

    scheduler =DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False,
                                        set_alpha_to_one=False)
    pipe = StableDiffusionMaskedDiffeditPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", scheduler=scheduler)
    pipe.to(device)

    image = pipe(prompt, image=image_pil, mask_image=mask_pil, prompt_inversion=prompt_inversion, encoding_ratio=encoding_ratio,
                 height=512, width=512, ).images[0]
    image = image.resize((512, 512))

    return image, mask_pil, prompts, image_pil

In [28]:
image_name = "cat_dog"
image_path = './example_input/{}.jpg'.format(image_name)
user_instructions = "Change the cat to a dog"
is_blip2_description = True
is_show_mask = True
encoding_ratio = 0.5
box_threshold = 0.3
text_threshold = 0.25
output_image_path = "./output_image/{}.jpg".format(image_name)
output_mask_path = './output_mask/{}_mask.jpg'.format(image_name)
mask_mode = "max"
if mask_mode == "max":
    return_max = True
else:
    return_max = False


In [None]:
output_image, output_mask, prompts, input_image = maskedit(image_path, user_instructions, 
                                                           is_blip2_description, encoding_ratio, 
                                                           mask_mode, return_max=return_max, box_threshold=box_threshold, text_threshold=text_threshold)