In [None]:
%%capture
!pip install git+https://github.com/facebookresearch/segment-anything.git
!pip install opencv-python pycocotools matplotlib onnxruntime onnx transformers diffusers accelerate

In [None]:
%%capture
!wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

In [None]:
import cv2
from diffusers import StableDiffusionInpaintPipeline
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
from PIL import Image
from segment_anything import sam_model_registry, SamPredictor
import torch

In [None]:
# Loading the image
image_path = Path("car.jpg")
original_image = cv2.imread(str(image_path))
original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)
print(f"The shape of the image is ({original_image.shape[0]}, {original_image.shape[1]})")

# Defining some points of the image that will help SAM in the segmentation process
center = (original_image.shape[1] // 2, original_image.shape[0] // 2)
first_vertex = (10, 10)
second_vertex = (original_image.shape[1] - 10, 10)
third_vertex = (original_image.shape[1] - 10, original_image.shape[0] - 10)
fourth_vertex = (10, original_image.shape[0] - 10)

# Showing the image
plt.figure()
plt.imshow(original_image)
plt.scatter(*zip(*[center, first_vertex, second_vertex, third_vertex, fourth_vertex]),
            c="Red", s=40, label="Points")
plt.axis('off')
plt.show()

In [None]:
# Loading SAM model
sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"

device = "cuda" if torch.cuda.is_available() else "cpu"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device)

# Creating the predictor
predictor = SamPredictor(sam)

In [None]:
# Setting the image to the predictor
predictor.set_image(original_image)

# Choosing the point that will guide the segmentation
input_point = np.array([
    [center[0], center[1]],
    [first_vertex[0], first_vertex[1]],
    [second_vertex[0], second_vertex[1]],
    [third_vertex[0], third_vertex[1]],
    [fourth_vertex[0], fourth_vertex[1]]
])
input_label = np.array([1, 0, 0, 0, 0])

# Predicting the mask (multimask_output=True returns 3 masks)
masks, scores, logits = predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    multimask_output=True,
)

# Extracting the best mask
best_index = np.argmax(scores)
image_mask = masks[best_index]

# Storing the mask
file_name = image_path.stem + "_mask.png"
mask_path = image_path.parent / file_name
Image.fromarray(image_mask.astype(np.uint8), mode="L").save(mask_path)
# mode="L" allows to create binary images (masks)

In [None]:
# Visualizing the results
for i, mask in enumerate(masks):
    plt.figure()
    plt.imshow(original_image)
    plt.imshow(mask, alpha=0.5)
    plt.title(f"Mask {i+1} - Score: {scores[i]:.3f}")
    plt.axis("off")
plt.show()

In [None]:
# Deleting the model
del sam, predictor

In [None]:
def resize_and_pad(image, target_size=(512, 512), fill_color=(0, 0, 0)):
    if isinstance(image, np.ndarray):
        image = Image.fromarray(image)

    image = image.copy()  # Avoid modifying the original image

    # Resize while preserving aspect ratio
    image.thumbnail(target_size, Image.LANCZOS)
    pasted_image_size = image.size

    # Create padded canvas and paste resized image
    new_image = Image.new("RGB", target_size, fill_color)
    paste_position = (
        (target_size[0] - pasted_image_size[0]) // 2,
        (target_size[1] - pasted_image_size[1]) // 2,
    )
    new_image.paste(image, paste_position)

    return new_image, paste_position, pasted_image_size

def remove_padding(padded_image, paste_position, content_size):
    left, top = paste_position
    right = left + content_size[0]
    bottom = top + content_size[1]
    # Cropping the image from the padded one
    return padded_image.crop((left, top, right, bottom))

In [None]:
# Loading the mask
mask_path = mask_path # Edit if you have the mask already stored
image_mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)

# Formatting the mask as the diffusion model wants
diffusion_mask = 1 - image_mask
diffusion_mask = Image.fromarray((diffusion_mask * 255).astype(np.uint8))

# Resizing the image and the mask
target_size = (512, 512)
prepared_image, paste_position, pasted_image_size  = resize_and_pad(original_image, target_size)
prepared_mask, _, _ = resize_and_pad(diffusion_mask, fill_color=(255, 255, 255))

# Showing the padded image and mask
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(prepared_image)
ax[0].set_title("Padded Image")
ax[0].axis("off")
ax[1].imshow(prepared_mask)
ax[1].set_title("Padded Mask")
ax[1].axis("off")

In [None]:
# Creating the inpainting diffusion pipeline
inpainting_models = (
    "diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
    "stabilityai/stable-diffusion-2-inpainting",
    "runwayml/stable-diffusion-inpainting"
)

pipe = StableDiffusionInpaintPipeline.from_pretrained(
    inpainting_models[1], torch_dtype=torch.float16
).to("cuda")
#pipe.enable_attention_slicing()
#pipe.enable_xformers_memory_efficient_attention()

In [None]:
inpainted_extended_image = pipe(
    prompt="A clean white studio background with soft natural shadow under the car",
    image=prepared_image,
    mask_image=prepared_mask,
    num_images_per_prompt=3
).images

In [None]:
trials = 6
rows = int(np.ceil(trials / 2))
fig, axs = plt.subplots(rows, 2, figsize=(10, 5*rows))
axs = axs.flatten()

# Applying the pipeline to the image
inpainted_extended_images = pipe(
    prompt="A clean white studio background with natural shadow under the car",
    image=prepared_image,
    mask_image=prepared_mask,
    num_images_per_prompt=trials,
    #guidance_scale=10
).images

for i in range(trials):
    inpainted_extended_image = inpainted_extended_images[i]
    # Removing the padding
    inpainted_image = remove_padding(inpainted_extended_image, paste_position, pasted_image_size)

    # Storing the inpainted image with padding
    file_name = image_path.stem + f"_inpainted_extended_{i}.png"
    inpainted_extended_path = image_path.parent / file_name
    inpainted_extended_image.save(inpainted_extended_path)

    # Storing the inpainted image
    file_name = image_path.stem + f"_inpainted_{i}.png"
    inpainted_path = image_path.parent / file_name
    inpainted_image.save(inpainted_path)

    # Showing the created image
    axs[i].imshow(inpainted_image)
    axs[i].set_title(f"Inpainted Image {i}")
    axs[i].axis("off")

plt.tight_layout()
plt.show()