In [None]:
!pip install git+https://github.com/huggingface/transformers -q
!pip install git+https://github.com/huggingface/diffusers.git -q
!pip install gradio -q

In [None]:
from huggingface_hub import notebook_login

notebook_login()

In [None]:
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
from diffusers import StableDiffusionInpaintPipeline
from PIL import Image
import requests

import cv2
import torch
import matplotlib.pyplot as plt

In [None]:
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")

In [None]:
device = "cuda"
model_path = "runwayml/stable-diffusion-inpainting"

pipe = StableDiffusionInpaintPipeline.from_pretrained(
    model_path,
    revision="fp16", 
    torch_dtype=torch.float16,
    use_auth_token=True
).to(device)

In [None]:
def create_mask(image, prompt):
  inputs = processor(text=[prompt], images=[image], padding="max_length", return_tensors="pt")
  # predict
  with torch.no_grad():
    outputs = model(**inputs)

  preds = outputs.logits
  
  filename = f"mask.png"
  plt.imsave(filename,torch.sigmoid(preds))

  gray_image = cv2.cvtColor(cv2.imread(filename), cv2.COLOR_BGR2GRAY)

  (thresh, bw_image) = cv2.threshold(gray_image, 100, 255, cv2.THRESH_BINARY)

  # For debugging only:
  # cv2.imwrite(filename,bw_image)

  # fix color format
  cv2.cvtColor(bw_image, cv2.COLOR_BGR2RGB)

  mask = cv2.bitwise_not(bw_image)
  cv2.imwrite(filename, mask)

  return Image.open('mask.png')


In [24]:
def generate_image(image, product_name, target_name):
  mask = create_mask(image, product_name)
  image = image.resize((512, 512))
  mask = mask.resize((512,512))
  guidance_scale=16
  num_samples = 1

  prompt = 'a photo of a ' + product_name + ' with ' + target_name + ' product photograpy'
  generator = torch.Generator(device="cuda").manual_seed(22) # change the seed to get different results

  im = pipe(
      prompt=prompt,
      image=image,
      mask_image=mask,
      guidance_scale=guidance_scale,
      generator=generator,
      num_images_per_prompt=num_samples,
  ).images[0]

  return im
  

In [10]:
import gradio as gr

In [None]:
with gr.Blocks() as demo:
    gr.Markdown("# Advertise better with AI")
    # with gr.Tab("Prompt Paint - Basic"):
    with gr.Row():

      with gr.Column():
        input_image = gr.Image(label = "Upload your product's photo", type = 'pil')

        product_name = gr.Textbox(label="Describe your product")
        target_name = gr.Textbox(label="Where do you want to put your product?")
        # result_prompt = product_name + ' in ' + target_name + 'product photograpy ultrarealist'

        image_button = gr.Button("Generate")
      
      with gr.Column():
        image_output = gr.Image()
    
    image_button.click(generate_image, inputs=[input_image, product_name, target_name ], outputs=image_output)


demo.launch(debug = True)