In [None]:
!pip install regex tqdm
!pip install diffusers transformers accelerate scipy
!pip install -U xformers
!pip install opencv-python

In [None]:
!pip install git+https://github.com/facebookresearch/segment-anything.git
!pip install pycocotools matplotlib onnxruntime onnx

## SAM with stable diffusion 2 inpainting model

In [None]:
import torch
from torchvision import transforms
from torchvision.transforms.functional import to_pil_image, to_tensor

import PIL, cv2
from PIL import Image

from io import BytesIO
from IPython.display import display
import base64, json, requests
from matplotlib import pyplot as plt

import numpy as np
import copy

import sys

from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

In [None]:
!wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

## stable diffusion model

In [None]:
from diffusers import StableDiffusionInpaintPipeline, EulerDiscreteScheduler

model_dir = "stabilityai/stable-diffusion-2-inpainting"

scheduler = EulerDiscreteScheduler.from_pretrained(model_dir, subfolder="scheduler")

pipe = StableDiffusionInpaintPipeline.from_pretrained(model_dir,
                                                      scheduler=scheduler,
                                                      revision="fp16",
                                                      torch_dtype=torch.float16)

pipe = pipe.to("cuda")
pipe.enable_xformers_memory_efficient_attention()

In [None]:
target_width, target_height = 512, 512
source_image = Image.open('mix909-AsJirOOLN_s-unsplash.jpg')

width, height = source_image.size

source_image = source_image.crop((0, height-width, width, height))

source_image = source_image.resize((target_width, target_height), Image.LANCZOS)

segmentation_image = np.asarray(source_image)
display(source_image)

In [None]:
sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"

device = "cuda"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

mask_generator = SamAutomaticMaskGenerator(model=sam,
                                           points_per_side=32,
                                           pred_iou_thresh=0.95,
                                           crop_n_layers=1,
                                           crop_n_points_downscale_factor=2,
                                           min_mask_region_area=100)

In [None]:
masks = mask_generator.generate(segmentation_image)

print(f"Number of masks generated: {len(masks)}")
print(masks[0].keys())

In [None]:
def show_anns(anns):
  if len(anns) == 0:
    return

  sorted_anns = sorted(enumerate(anns), key=(lambda x: x[1]['area']), reverse=True)
  ax = plt.gca()

  ax.set_autoscale_on(False)

  for original_idx, ann in sorted_anns:
    m = ann['segmentation']
    img = np.ones((m.shape[0], m.shape[1], 3))

    color_mask = np.random.random((1, 3)).tolist()[0]

    for i in range(3):
      img[:,:,i] = color_mask[i]

    ax.imshow(np.dstack((img, m*0.35)))

    contours, _ = cv2.findContours(m.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    if contours:
      cnt = contours[0]
      M = cv2.moments(cnt)

      if M["m00"] != 0:
        cx = int(M["m10"] / M["m00"])
        cy = int(M["m01"] / M["m00"])

        ax.text(cx, cy, str(original_idx), color='white', fontsize=16, ha='center', va='center', fontweight='bold')

In [None]:
plt.figure(figsize=(10,10))
plt.imshow(source_image)
show_anns(masks[:])
plt.axis('off')
plt.show()

In [None]:
for i, values in enumerate(masks):
  print(f"{i}: Area: {values['area']} | IoU: {values['predicted_iou']}")

In [None]:
mask_index = 9

segmentation_mask = masks[mask_index]['segmentation']
stable_diffusion_mask = PIL.Image.fromarray(segmentation_mask)
display(stable_diffusion_mask)


In [None]:
num_images_per_prompt = 2
# inpainting_prompts = ["A yellow flowery skirt", "A skirt with egyptian hieroglyphs", "A skirt with green turtles", "A skirt with beutiful scenery of sunset"]
inpainting_prompts = ['A fire style skirt', 'Pookie logo on the skirt']

generator = torch.Generator(device="cuda").manual_seed(155)

encoded_images = []
for i in range(num_images_per_prompt):
  image = pipe(prompt=inpainting_prompts[i], generator=generator, image=source_image, mask_image=stable_diffusion_mask).images[0]
  encoded_images.append(image)

In [None]:
def create_image_grid(original_image, images, names, rows, columns):
    names = copy.copy(names)  # Create a copy of the names list to avoid modifying the external variable
    images = copy.copy(images)  # Create a copy of the images list to avoid modifying the external variable

    # Check if images is a tensor
    if torch.is_tensor(images):
        # Check if the number of tensor images and names is equal
        assert images.size(0) == len(names), "Number of images and names should be equal"

        # Check if there are enough images for the specified grid size
        assert images.size(0) >= (rows * columns) - 1 - 1, "Not enough images for the specified grid size"

        # Convert tensor images to PIL images and apply sigmoid normalization
        images = [to_pil_image(torch.sigmoid(img)) for img in images]
    else:
        # Check if the number of PIL images and names is equal
        assert len(images) == len(names), "Number of images and names should be equal"

    # Check if there are enough images for the specified grid size
    assert len(images) >= (rows * columns) - 1 - 1, "Not enough images for the specified grid size"

    # Add the original image to the beginning of the images list
    images.insert(0, original_image)

    # Add an empty name for the original image to the beginning of the names list
    names.insert(0, '')

    # Create a figure with specified rows and columns
    fig, axes = plt.subplots(rows, columns, figsize=(15, 15))

    # Iterate through the images and names
    for idx, (img, name) in enumerate(zip(images, names)):
        # Calculate the row and column index for the current image
        row, col = divmod(idx, columns)

        # Add the image to the grid
        axes[row, col].imshow(img, cmap='gray' if idx > 0 and torch.is_tensor(images) else None)

        # Set the title (name) for the subplot
        axes[row, col].set_title(name)

        # Turn off axes for the subplot
        axes[row, col].axis('off')

    # Iterate through unused grid cells
    for idx in range(len(images), rows * columns):
        # Calculate the row and column index for the current cell
        row, col = divmod(idx, columns)

        # Turn off axes for the unused grid cell
        axes[row, col].axis('off')

    # Adjust the subplot positions to eliminate overlaps
    plt.tight_layout()

    # Display the grid of images with their names
    plt.show()


In [None]:
create_image_grid(source_image, encoded_images, inpainting_prompts, 2, 2)

## SAM 2 with Stable Diffusion 2 InPainting model