In [1]:
import gradio as gr
import numpy as np
from PIL import Image
import nest_asyncio

nest_asyncio.apply()

def create_mask(mask_data):
    # Convert the mask data from gradio sketch to binary mask
    mask = mask_data['mask'].astype(np.uint8)
    mask = mask[:,:,0] # Take first channel since sketch returns RGB
    mask = (mask > 0).astype(np.uint8) * 255
    
    # Convert to PIL Image for saving
    mask_img = Image.fromarray(mask)
    
    return mask_img

# Create Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# Image Mask Creation Tool")
    gr.Markdown("Draw on the image to create a binary mask")
    
    # Load and display the rhino image with mask overlay
    mask_input = gr.ImageMask(
        value="rhino-suit.png",
        label="Draw Mask", 
        source="upload",  # Changed from "value" to "upload" since "value" is not a valid source
        tool="sketch"
    )
    
    # Output mask display
    output_mask = gr.Image(label="Generated Mask")

    # Button to generate mask
    submit_btn = gr.Button("Generate Mask")
    submit_btn.click(
        fn=create_mask,
        inputs=[mask_input],
        outputs=output_mask
    )

demo.launch()


  from .autonotebook import tqdm as notebook_tqdm


Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.




IMPORTANT: You are using gradio version 3.50.2, however version 4.44.1 is available, please upgrade.
--------


In [3]:
import gradio as gr
import numpy as np
from PIL import Image
import nest_asyncio
import fal_client
import requests
from io import BytesIO

nest_asyncio.apply()

def on_queue_update(update):
    if isinstance(update, fal_client.InProgress):
        for log in update.logs:
           print(log["message"])

def process_image(image_mask_dict, mask_data, prompt):
    # Extract image from the ImageMask component dictionary
    image = Image.fromarray(image_mask_dict['image'])
    
    # Create binary mask from sketch
    mask = mask_data['mask'].astype(np.uint8)
    mask = mask[:,:,0] # Take first channel since sketch returns RGB
    mask = (mask > 0).astype(np.uint8) * 255
    mask_img = Image.fromarray(mask)
    
    # Save temporary files
    image.save("temp_image.png")
    mask_img.save("temp_mask.png")
    
    # Upload to fal.ai
    image_url = fal_client.upload_file("temp_image.png")
    mask_url = fal_client.upload_file("temp_mask.png")
    
    # Run inpainting
    result = fal_client.subscribe(
        "fal-ai/flux-general/inpainting",
        arguments={
            "image_url": image_url,
            "mask_url": mask_url,
            "prompt": prompt
        },
        with_logs=True,
        on_queue_update=on_queue_update,
    )
    
    # Get result image
    result_url = result['images'][0]['url']
    response = requests.get(result_url)
    result_img = Image.open(BytesIO(response.content))
    
    return result_img

# Create Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# Image Inpainting with Flux")
    gr.Markdown("1. Upload an image\n2. Draw mask on areas to change\n3. Enter prompt\n4. Generate new image")
    
    with gr.Row():
        # Input image with mask
        mask_input = gr.ImageMask(
            label="Upload Image & Draw Mask",
            source="upload",
            tool="sketch"
        )
        # Output image
        output_image = gr.Image(label="Generated Image")
        
    # Prompt input
    prompt_input = gr.Textbox(
        label="Prompt",
        placeholder="Describe what should replace the masked area..."
    )

    # Generate button
    submit_btn = gr.Button("Generate New Image")
    submit_btn.click(
        fn=process_image,
        inputs=[
            mask_input,  # Pass the entire ImageMask component instead of trying to access .image
            mask_input,
            prompt_input
        ],
        outputs=output_image
    )

demo.launch()


Running on local URL:  http://127.0.0.1:7862

To create a public link, set `share=True` in `launch()`.




IMPORTANT: You are using gradio version 3.50.2, however version 4.44.1 is available, please upgrade.
--------
Generating 1 images...
Generating 1 images...
Generating 1 images...
Generating 1 images...
  0%|          | 0/24 [00:00<?, ?it/s]
  4%|▍         | 1/24 [00:00<00:05,  4.39it/s]
  8%|▊         | 2/24 [00:00<00:04,  4.89it/s]
 12%|█▎        | 3/24 [00:00<00:04,  5.08it/s]
 17%|█▋        | 4/24 [00:00<00:03,  5.17it/s]
 21%|██        | 5/24 [00:00<00:03,  5.23it/s]
 25%|██▌       | 6/24 [00:01<00:03,  5.26it/s]
Generating 1 images...
  0%|          | 0/24 [00:00<?, ?it/s]
  4%|▍         | 1/24 [00:00<00:05,  4.39it/s]
  8%|▊         | 2/24 [00:00<00:04,  4.89it/s]
 12%|█▎        | 3/24 [00:00<00:04,  5.08it/s]
 17%|█▋        | 4/24 [00:00<00:03,  5.17it/s]
 21%|██        | 5/24 [00:00<00:03,  5.23it/s]
 25%|██▌       | 6/24 [00:01<00:03,  5.26it/s]
Generating 1 images...
  0%|          | 0/24 [00:00<?, ?it/s]
  4%|▍         | 1/24 [00:00<00:05,  4.39it/s]
  8%|▊         | 2/24 [0