<a href="https://colab.research.google.com/github/Goderr/Background-generation/blob/main/Img_gen.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install git+https://github.com/facebookresearch/segment-anything.git
!pip install gradio
!pip install diffusers["torch"] transformers

In [None]:
import gradio as gr
import numpy as np
import torch
from diffusers import StableDiffusionInpaintPipeline
from PIL import Image
from segment_anything import SamPredictor, sam_model_registry


checkpoint = "/content/drive/MyDrive/sam_vit_h_4b8939.pth"

model_type = "vit_h"

device = "cuda"
low_cpu_mem_usage=True

sam = sam_model_registry[model_type](checkpoint=checkpoint)
sam.to(device)
predictor = SamPredictor(sam)

pipe = StableDiffusionInpaintPipeline.from_pretrained(
    "stabilityai/stable-diffusion-2-inpainting",
    torch_dtype=torch.float16,
)
pipe.to("cuda")

with gr.Blocks() as demo:
    with gr.Row():
        input_img = gr.Image(label="Input")
        mask_img = gr.Image(label="Mask")
        output_img = gr.Image(label="Output")

    with gr.Blocks():
        prompt_text = gr.Textbox(lines=1, label="Prompt")

    with gr.Row():
        submit = gr.Button("Submit")

    selected_pixels = [] #List to store the pixel of the mask region of image
    def generate_mask(image, evt:gr.SelectData):
        selected_pixels.append(evt.index)

        predictor.set_image(image)
        input_points = np.array(selected_pixels)
        input_label = np.ones(input_points.shape[0])
        mask,_,_ = predictor.predict(
            point_coords=input_points,
            point_labels=input_label,
            multimask_output=False #Only one mask on the product
        )

        mask = np.logical_not(mask)
        mask = Image.fromarray(mask[0,:,:])
        return mask

    def inpaint(img, mask, prompt):
        img = Image.fromarray(img)
        mask = Image.fromarray(mask)

        img = img.resize((512,512))
        mask = mask.resize((512,512))

        output = pipe(
            prompt= prompt,
            image=img,
            mask_image=mask,
        ).images[0]

        return output

    input_img.select(generate_mask, [input_img], [mask_img])
    submit.click(inpaint, inputs=[input_img, mask_img, prompt_text], outputs=[output_img])

if __name__ == "__main__":
    demo.launch()