In [1]:
#Installations
!pip install diffusers transformers gradio scipy ftfy "ipywidgets>=7,<8"
!pip install git+https://github.com/facebookresearch/segment-anything.git
!wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth

Collecting ftfy
  Downloading ftfy-6.3.1-py3-none-any.whl.metadata (7.3 kB)
Collecting jedi>=0.16 (from ipython>=4.0.0->ipywidgets<8,>=7)
  Downloading jedi-0.19.2-py2.py3-none-any.whl.metadata (22 kB)
Downloading ftfy-6.3.1-py3-none-any.whl (44 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.8/44.8 kB[0m [31m1.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading jedi-0.19.2-py2.py3-none-any.whl (1.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m17.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: jedi, ftfy
Successfully installed ftfy-6.3.1 jedi-0.19.2
Collecting git+https://github.com/facebookresearch/segment-anything.git
  Cloning https://github.com/facebookresearch/segment-anything.git to /tmp/pip-req-build-fqlycewl
  Running command git clone --filter=blob:none --quiet https://github.com/facebookresearch/segment-anything.git /tmp/pip-req-build-fqlycewl
  Resolved https://github.com/facebookresearch/segm

In [2]:
#Imports
import torch
import numpy as np
import cv2
import PIL
from segment_anything import SamPredictor, sam_model_registry
from PIL import Image, ImageOps
import matplotlib.pyplot as plt
import gradio as gr
from diffusers import StableDiffusionXLInpaintPipeline
import tensorflow as tf
import tensorflow_hub as hub
import os

# Set up environment
os.environ['TFHUB_MODEL_LOAD_FORMAT'] = 'COMPRESSED'
device = "cuda" if torch.cuda.is_available() else "cpu"


In [3]:
# Initialize SAM (GPU)
print("Loading SAM model...")
sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth")
sam.to("cuda")
predictor = SamPredictor(sam)

# Initialize SDXL (GPU)
print("Loading SDXL model...")
pipe = StableDiffusionXLInpaintPipeline.from_pretrained(
    "diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
    torch_dtype=torch.float16 if device == "cuda" else torch.float32,
    variant="fp16" if device == "cuda" else None,
    use_safetensors=True,
).to(device)

# Initialize Style Transfer (CPU- To balance load and RAM)
print("Loading Style Transfer model...")
with tf.device('/CPU:0'):
    hub_model = hub.load('https://tfhub.dev/google/magenta/arbitrary-image-stylization-v1-256/2')

print("All models loaded successfully!")

Loading SAM model...
Loading SDXL model...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


model_index.json:   0%|          | 0.00/690 [00:00<?, ?B/s]

Fetching 18 files:   0%|          | 0/18 [00:00<?, ?it/s]

scheduler_config.json:   0%|          | 0.00/479 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/746 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/758 [00:00<?, ?B/s]

text_encoder/model.fp16.safetensors:   0%|          | 0.00/246M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/472 [00:00<?, ?B/s]

text_encoder_2/model.fp16.safetensors:   0%|          | 0.00/1.39G [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/737 [00:00<?, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/460 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/725 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

unet/diffusion_pytorch_model.fp16.safete(…):   0%|          | 0.00/5.14G [00:00<?, ?B/s]

config.json:   0%|          | 0.00/659 [00:00<?, ?B/s]

vae/diffusion_pytorch_model.fp16.safeten(…):   0%|          | 0.00/167M [00:00<?, ?B/s]

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

The config attributes {'decay': 0.9999, 'inv_gamma': 1.0, 'min_decay': 0.0, 'optimization_step': 37000, 'power': 0.6666666666666666, 'update_after_step': 0, 'use_ema_warmup': False} were passed to UNet2DConditionModel, but are not expected and will be ignored. Please verify your config.json configuration file.


Loading Style Transfer model...
All models loaded successfully!


In [4]:
# ==================== SAM FUNCTIONS ====================

# Function for coordinate visualization
def visualize_coordinates(image, x0, y0, x1, y1):
    """Visualize the bounding box on the image"""
    if image is None:
        return None

    try:
        # Convert to numpy array
        image_np = np.array(image)

        # Create a copy for visualization
        vis_image = image_np.copy()

        # Draw rectangle
        cv2.rectangle(vis_image, (int(x0), int(y0)), (int(x1), int(y1)), (0, 255, 0), 3)

        # Add coordinate text
        cv2.putText(vis_image, f'({int(x0)},{int(y0)})', (int(x0), int(y0)-10),
                   cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
        cv2.putText(vis_image, f'({int(x1)},{int(y1)})', (int(x1), int(y1)+25),
                   cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)

        return Image.fromarray(vis_image)

    except Exception as e:
        print(f"Visualization error: {e}")
        return image

def segment_from_box(input_image, x0, y0, x1, y1):
    """Generate mask using SAM with bounding box coordinates"""
    if input_image is None:
        return None

    try:
        image_np = np.array(input_image)
        predictor.set_image(image_np)

        # Create bounding box
        input_box = np.array([[x0, y0, x1, y1]])
        masks, scores, _ = predictor.predict(box=input_box, multimask_output=True)

        # Pick best mask (highest score)
        best_mask = masks[np.argmax(scores)]
        mask_image = Image.fromarray((best_mask * 255).astype(np.uint8))

        return mask_image
    except Exception as e:
        print(f"SAM segmentation error: {e}")
        return None

def generate_background_mask(mask_image):
    """Generate background mask by inverting the object mask"""
    if mask_image is None:
        return None

    try:
        # Convert to PIL if needed
        if isinstance(mask_image, np.ndarray):
            mask_image = Image.fromarray(mask_image)

        # Invert the mask
        background_mask = ImageOps.invert(mask_image.convert('L'))
        return background_mask
    except Exception as e:
        print(f"Background mask generation error: {e}")
        return None


In [16]:
 #==================== SDXL FUNCTIONS ====================
def inpaint_with_prompt(input_dict, prompt, negative_prompt):
    """SDXL inpainting with prompt for object/background editing"""
    if input_dict is None:
        return [None] * 3

    try:
        # Extract image and mask
        image_pil = input_dict['background']
        mask_pil = input_dict['layers'][0]

        # Convert to RGB
        if image_pil.mode != 'RGB':
            image_pil = image_pil.convert('RGB')

        # Extract mask from alpha channel
        mask_array = np.array(mask_pil)[:, :, 3]

        if mask_array.max() == 0:
            return [image_pil] * 3

        # Create proper mask
        mask_pil_processed = Image.fromarray((mask_array > 0).astype(np.uint8) * 255, mode='L')

        # Resize to SDXL dimensions
        target_size = 1024
        aspect_ratio = image_pil.width / image_pil.height

        if aspect_ratio > 1:
            new_w = target_size
            new_h = int(target_size / aspect_ratio)
        else:
            new_h = target_size
            new_w = int(target_size * aspect_ratio)

        # Make dimensions multiples of 64
        new_w = (new_w // 64) * 64
        new_h = (new_h // 64) * 64

        image_pil = image_pil.resize((new_w, new_h), Image.LANCZOS)
        mask_pil_processed = mask_pil_processed.resize((new_w, new_h), Image.LANCZOS)

        # Generate variations
        results = []
        for i in range(3):
            gen = torch.Generator(device="cuda").manual_seed(42 + i)

            result = pipe(
                prompt=prompt,
                negative_prompt=negative_prompt,
                image=image_pil,
                mask_image=mask_pil_processed,
                generator=gen,
                guidance_scale=7.5,
                strength=0.95,
                num_inference_steps=30
            ).images[0]
            results.append(result)

        return results

    except Exception as e:
        print(f"SDXL inpainting error: {e}")
        return [None] * 3

def inpaint_with_uploaded_mask(image, mask, prompt, negative_prompt):
    """SDXL inpainting with uploaded mask"""
    if image is None or mask is None:
        return [None] * 3

    try:
        # Convert to RGB
        if image.mode != 'RGB':
            image = image.convert('RGB')

        # Convert mask to L
        if mask.mode != 'L':
            mask = mask.convert('L')

        # Resize to SDXL dimensions
        target_size = 1024
        aspect_ratio = image.width / image.height

        if aspect_ratio > 1:
            new_w = target_size
            new_h = int(target_size / aspect_ratio)
        else:
            new_h = target_size
            new_w = int(target_size * aspect_ratio)

        # Make dimensions multiples of 64
        new_w = (new_w // 64) * 64
        new_h = (new_h // 64) * 64

        image = image.resize((new_w, new_h), Image.LANCZOS)
        mask = mask.resize((new_w, new_h), Image.LANCZOS)

        # Generate variations
        results = []
        for i in range(3):
            gen = torch.Generator(device="cuda").manual_seed(42 + i)

            result = pipe(
                prompt=prompt,
                negative_prompt=negative_prompt,
                image=image,
                mask_image=mask,
                generator=gen,
                guidance_scale=7.5,
                strength=0.95,
                num_inference_steps=30
            ).images[0]
            results.append(result)

        return results

    except Exception as e:
        print(f"SDXL inpainting with uploaded mask error: {e}")
        return [None] * 3


In [6]:
# ==================== STYLE TRANSFER FUNCTIONS ====================
def tensor_to_image(tensor):
    """Convert tensor to PIL image"""
    tensor = tensor * 255
    tensor = np.array(tensor, dtype=np.uint8)
    if np.ndim(tensor) > 3:
        assert tensor.shape[0] == 1
        tensor = tensor[0]
    return PIL.Image.fromarray(tensor)

def load_img_for_style(image_pil):
    """Load and preprocess image for style transfer"""
    max_dim = 512

    # Convert PIL to tensor
    img_array = np.array(image_pil)
    img = tf.constant(img_array, dtype=tf.float32) / 255.0

    shape = tf.cast(tf.shape(img)[:-1], tf.float32)
    long_dim = max(shape)
    scale = max_dim / long_dim

    new_shape = tf.cast(shape * scale, tf.int32)

    img = tf.image.resize(img, new_shape)
    img = img[tf.newaxis, :]
    return img

def apply_style_transfer(content_image, style_image):
    """Apply style transfer using TensorFlow Hub model"""
    if content_image is None or style_image is None:
        return None

    try:
        with tf.device('/CPU:0'):  # Force CPU usage
            # Preprocess images
            content_tensor = load_img_for_style(content_image)
            style_tensor = load_img_for_style(style_image)

            # Apply style transfer
            stylized_tensor = hub_model(tf.constant(content_tensor), tf.constant(style_tensor))[0]

            # Convert back to PIL
            stylized_image = tensor_to_image(stylized_tensor)
            return stylized_image

    except Exception as e:
        print(f"Style transfer error: {e}")
        return None

In [13]:
# ==================== COLORING FUNCTIONS ====================
def color_region(image, mask, r, g, b, blend_strength):
    """Color specific region (object or background) with given RGB values"""
    if image is None or mask is None:
        return None

    try:
        # Convert to numpy arrays
        image_np = np.array(image)
        mask_np = np.array(mask)

        # Convert mask to grayscale if it's RGB
        if len(mask_np.shape) == 3:
            mask_np = cv2.cvtColor(mask_np, cv2.COLOR_RGB2GRAY)

        # Convert to BGR for OpenCV
        original_bgr = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)

        # Create colored overlay
        colored_overlay = original_bgr.copy()
        colored_overlay[:, :, 0] = b  # Blue
        colored_overlay[:, :, 1] = g  # Green
        colored_overlay[:, :, 2] = r  # Red

        # Convert mask to 3 channels and normalize
        mask_3c = cv2.cvtColor(mask_np, cv2.COLOR_GRAY2BGR) / 255.0

        # Apply blend strength (0.0 = original image, 1.0 = full color)
        blend_factor = blend_strength / 100.0

        # Blend: tint the masked region with the color
        # Instead of replacing, we blend the original with the colored overlay
        tinted_region = (original_bgr * (1 - blend_factor) + colored_overlay * blend_factor).astype(np.uint8)

        # Apply only to masked area
        final_image = (tinted_region * mask_3c + original_bgr * (1 - mask_3c)).astype(np.uint8)

        # Convert back to RGB
        final_rgb = cv2.cvtColor(final_image, cv2.COLOR_BGR2RGB)

        return Image.fromarray(final_rgb)

    except Exception as e:
        print(f"Coloring error: {e}")
        return None


In [8]:
# ==================== BACKGROUND REMOVAL FUNCTIONS ====================
def remove_background(image, mask):
    """Remove background using mask to create transparent PNG"""
    if image is None or mask is None:
        return None

    try:
        # Convert to numpy arrays
        image_np = np.array(image)
        mask_np = np.array(mask)

        # Convert mask to grayscale if it's RGB
        if len(mask_np.shape) == 3:
            mask_np = cv2.cvtColor(mask_np, cv2.COLOR_RGB2GRAY)

        # Convert to BGR for OpenCV
        original_bgr = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)

        # Ensure mask is binary and invert it (to get foreground mask)
        _, mask_binary = cv2.threshold(mask_np, 127, 255, cv2.THRESH_BINARY)
        mask_inv = cv2.bitwise_not(mask_binary)

        # Convert original to BGRA (adds alpha channel)
        image_bgra = cv2.cvtColor(original_bgr, cv2.COLOR_BGR2BGRA)

        # Set alpha channel: foreground opaque (255), background transparent (0)
        image_bgra[:, :, 3] = mask_inv

        # Convert back to RGBA PIL image
        final_rgba = cv2.cvtColor(image_bgra, cv2.COLOR_BGRA2RGBA)

        return Image.fromarray(final_rgba)

    except Exception as e:
        print(f"Background removal error: {e}")
        return None


In [None]:
# ==================== GRADIO INTERFACE ====================
with gr.Blocks(theme=gr.themes.Soft(), title="Complete Image Editor") as demo:
    gr.Markdown("# 🎨 Complete Image Editor Suite")
    gr.Markdown("### SAM Segmentation | SDXL Inpainting | Style Transfer | Coloring | Background Removal")

    with gr.Tabs():
        # ==================== SAM TAB ====================
        with gr.Tab("SAM Segmentation"):
          gr.Markdown("### Generate precise masks using Segment Anything Model")

          with gr.Row():
              with gr.Column():
                  sam_input_image = gr.Image(type="pil", label="Upload Image")

                  with gr.Row():
                      sam_x0 = gr.Number(label="x0 (Left)", value=100)
                      sam_y0 = gr.Number(label="y0 (Top)", value=100)
                      sam_x1 = gr.Number(label="x1 (Right)", value=300)
                      sam_y1 = gr.Number(label="y1 (Bottom)", value=300)

                  with gr.Row():
                      sam_preview_btn = gr.Button("Preview Box", variant="secondary")
                      sam_segment_btn = gr.Button("Generate SAM Mask", variant="primary")

              with gr.Column():
                  sam_preview_image = gr.Image(label="Preview with Bounding Box")
                  sam_output_mask = gr.Image(label="Generated Mask")


          # Connect the functions
          sam_preview_btn.click(
              fn=visualize_coordinates,
              inputs=[sam_input_image, sam_x0, sam_y0, sam_x1, sam_y1],
              outputs=sam_preview_image
          )

          sam_segment_btn.click(
              fn=segment_from_box,
              inputs=[sam_input_image, sam_x0, sam_y0, sam_x1, sam_y1],
              outputs=sam_output_mask
          )

          # Auto-update preview when coordinates change
          for coord_input in [sam_x0, sam_y0, sam_x1, sam_y1]:
              coord_input.change(
                  fn=visualize_coordinates,
                  inputs=[sam_input_image, sam_x0, sam_y0, sam_x1, sam_y1],
                  outputs=sam_preview_image
              )

          # Update preview when new image is uploaded
          sam_input_image.change(
              fn=visualize_coordinates,
              inputs=[sam_input_image, sam_x0, sam_y0, sam_x1, sam_y1],
              outputs=sam_preview_image
          )


        # ==================== BACKGROUND MASK TAB ====================
        with gr.Tab("Background Mask"):
            gr.Markdown("### Generate background mask by inverting object mask")

            with gr.Row():
                with gr.Column():
                    bg_input_mask = gr.Image(type="pil", label="Upload Object Mask")
                    bg_generate_btn = gr.Button("Generate Background Mask", variant="primary")

                with gr.Column():
                    bg_output_mask = gr.Image(label="Background Mask")


            bg_generate_btn.click(
                fn=generate_background_mask,
                inputs=bg_input_mask,
                outputs=bg_output_mask
            )

        # ==================== SDXL TAB ====================
        with gr.Tab("SDXL Inpainting"):
            gr.Markdown("### AI-powered image editing with Stable Diffusion XL")

            with gr.Tabs():
                # Manual masking subtab
                with gr.Tab("Manual Masking"):
                    with gr.Row():
                        with gr.Column():
                            sdxl_input_editor = gr.ImageEditor(
                                label="Draw White over area to edit",
                                type="pil",
                                brush=gr.Brush(default_size=20, colors=["#FFFFFF"])
                            )

                            sdxl_prompt = gr.Textbox(
                                label="Prompt (leave empty to remove object)",
                                placeholder="e.g., 'a red sports car' or leave empty for removal"
                            )

                            sdxl_negative_prompt = gr.Textbox(
                                label="Negative Prompt",
                                value="blurry, distorted, painting, cartoon, watermark"
                            )



                            sdxl_manual_btn = gr.Button("Apply SDXL Edit", variant="primary")

                        with gr.Column():
                            sdxl_manual_outputs = [gr.Image(label=f"Result {i+1}") for i in range(3)]

                # Upload mask subtab
                with gr.Tab("Upload Mask"):
                    with gr.Row():
                        with gr.Column():
                            sdxl_upload_image = gr.Image(type="pil", label="Upload Image")
                            sdxl_upload_mask = gr.Image(type="pil", label="Upload Mask")

                            sdxl_upload_prompt = gr.Textbox(
                                label="Prompt",
                                placeholder="e.g., 'a red sports car' or leave empty for removal"
                            )

                            sdxl_upload_negative = gr.Textbox(
                                label="Negative Prompt",
                                value="blurry, distorted, painting, cartoon, watermark"
                            )

                            sdxl_upload_btn = gr.Button("Apply SDXL with Uploaded Mask", variant="primary")

                        with gr.Column():
                            sdxl_upload_outputs = [gr.Image(label=f"Result {i+1}") for i in range(3)]

            # Connect SDXL functions
            sdxl_manual_btn.click(
                fn=inpaint_with_prompt,
                inputs=[sdxl_input_editor, sdxl_prompt, sdxl_negative_prompt],
                outputs=sdxl_manual_outputs
            )

            sdxl_upload_btn.click(
                fn=inpaint_with_uploaded_mask,
                inputs=[sdxl_upload_image, sdxl_upload_mask, sdxl_upload_prompt, sdxl_upload_negative],
                outputs=sdxl_upload_outputs
            )

        # ==================== STYLE TRANSFER TAB ====================
        with gr.Tab("Style Transfer"):
            gr.Markdown("### Apply artistic styles to your images")

            with gr.Row():
                with gr.Column():
                    style_content_image = gr.Image(type="pil", label="Content Image")
                    style_style_image = gr.Image(type="pil", label="Style Image")
                    style_transfer_btn = gr.Button("Apply Style Transfer", variant="primary")

                with gr.Column():
                    style_output = gr.Image(label="Stylized Result")

            style_transfer_btn.click(
                fn=apply_style_transfer,
                inputs=[style_content_image, style_style_image],
                outputs=style_output
            )

        # ==================== COLORING TAB ====================
        with gr.Tab("Coloring"):
                gr.Markdown("### Color specific regions with custom RGB values")

                with gr.Row():
                    with gr.Column():
                        color_input_image = gr.Image(type="pil", label="Upload Image")
                        color_input_mask = gr.Image(type="pil", label="Upload Mask")

                        with gr.Row():
                            color_r = gr.Slider(0, 255, value=255, label="Red")
                            color_g = gr.Slider(0, 255, value=0, label="Green")
                            color_b = gr.Slider(0, 255, value=0, label="Blue")

                        color_blend_strength = gr.Slider(
                            0, 100, value=30,
                            label="Blend Strength (%)",
                            info="0% = no tint, 100% = full color"
                        )



                        color_apply_btn = gr.Button("Apply Coloring", variant="primary")

                    with gr.Column():
                        color_output = gr.Image(label="Colored Result")

                color_apply_btn.click(
                    fn=color_region,
                    inputs=[color_input_image, color_input_mask, color_r, color_g, color_b, color_blend_strength],
                    outputs=color_output
                )

        # ==================== BACKGROUND REMOVAL TAB ====================
        with gr.Tab("Background Removal"):
            gr.Markdown("### Remove background to create transparent PNG")

            with gr.Row():
                with gr.Column():
                    bg_remove_image = gr.Image(type="pil", label="Upload Image")
                    bg_remove_mask = gr.Image(type="pil", label="Upload Background Mask")
                    bg_remove_btn = gr.Button("Remove Background", variant="primary")

                with gr.Column():
                    bg_remove_output = gr.Image(label="Result (Transparent PNG)")

            bg_remove_btn.click(
                fn=remove_background,
                inputs=[bg_remove_image, bg_remove_mask],
                outputs=bg_remove_output
            )

    # ==================== INSTRUCTIONS ====================
    gr.Markdown("""
    ## 📋 Instructions

    **SAM Segmentation**: Upload image and specify bounding box coordinates to generate precise masks

    **Background Mask**: Upload an object mask to automatically generate its inverse (background mask)

    **SDXL Inpainting**:
    - Manual: Draw on image editor and enter prompt for AI editing
    - Upload: Use pre-generated SAM masks for precise editing

    **Style Transfer**: Upload content and style images for artistic transformation (runs on CPU)

    **Coloring**: Upload image and mask, set RGB values to color specific regions

    **Background Removal**: Upload image and background mask to create transparent PNG

    ### 💡 Workflow Tips:
    1. Use SAM to generate precise masks
    2. Create background mask if needed
    3. Use masks in SDXL for AI editing
    4. Apply style transfer for artistic effects
    5. Use coloring for specific color changes
    6. Remove backgrounds for transparent images
    """)

demo.launch(share=True, debug=True)


Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
* Running on public URL: https://6cc982dcf4e19aa3d4.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

Keyboard interruption in main thread... closing server.
Killing tunnel 127.0.0.1:7860 <> https://6cc982dcf4e19aa3d4.gradio.live


