In [15]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor

# use bfloat16 for the entire notebook
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()

if torch.cuda.get_device_properties(0).major >= 8:
    # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

# helper function to show masks
def show_mask(mask, ax, random_color=False, borders = True):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask = mask.astype(np.uint8)
    mask_image =  mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    if borders:
        import cv2
        contours, _ = cv2.findContours(mask,cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) 
        # Try to smooth contours
        contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
        mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2) 
    ax.imshow(mask_image)

def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   

def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))    

def show_masks(image, masks, scores, point_coords=None, box_coords=None, input_labels=None, borders=True):
    for i, (mask, score) in enumerate(zip(masks, scores)):
        plt.figure(figsize=(10, 10))
        plt.imshow(image)
        show_mask(mask, plt.gca(), borders=borders)
        if point_coords is not None:
            assert input_labels is not None
            show_points(point_coords, input_labels, plt.gca())
        if box_coords is not None:
            # boxes
            show_box(box_coords, plt.gca())
        if len(scores) > 1:
            plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
        plt.axis('off')
        plt.show()


# building predictor
sam2_checkpoint = "segment-anything-2/checkpoints/sam2_hiera_large.pt"
model_cfg = "sam2_hiera_l.yaml"

sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda")

predictor = SAM2ImagePredictor(sam2_model)


In [3]:
from ultralytics import YOLO

def yolo_keypoints(image_path):
    # Load a model
    model = YOLO("yolov8n.pt")  # pretrained YOLOv8n model

    results = model(image_path,
                    conf = 0.5)  # return a list of Results objects
    # print("results_list_len: =", len(results))
    if len(results) == 1:
        # Process results list - single image infrence
        xyxy = results[0].boxes.xyxy  # Boxes object for bounding box outputs
        xyxy = xyxy.cpu().detach().numpy()
    elif len(results) > 1:
        xyxy = []
        # Process results list - multi-image infrence
        for result in results:
            xyxy.append(result.boxes.xyxy)  # Boxes object for bounding box outputs

    return xyxy

In [37]:
# loading the image
image_path = "test1.webp"
image = Image.open(image_path)
image = np.array(image.convert("RGB"))


predictor.set_image(image)


# define input coordinates in xyxy format - get coordinates from yolov2
# will be coming from yolov8
input_box = yolo_keypoints(image_path)

# take top 5 keypoints
top_k = 5
input_box = input_box[:top_k]
print(input_box.shape)

masks, scores, logits = predictor.predict(
    point_coords=None,
    point_labels=None,
    box=input_box[None,:],
    multimask_output=False,
)

save_masks(masks, dir='./masks')



image 1/1 /home/ubuntu/arnab/sd_inpainting_datagen/test1.webp: 384x640 8 persons, 1 dog, 1 handbag, 1 chair, 16.2ms
Speed: 1.7ms preprocess, 16.2ms inference, 21.5ms postprocess per image at shape (1, 3, 640, 640)


(5, 4)


In [36]:
import os

def save_masks(masks, prefix="mask", dir="masks"):
    """
    Save the masks as PNG images in the specified directory.
    
    Args:
    - masks (numpy.ndarray): The masks array with shape [N, 1, h, w].
    - prefix (str): The prefix for the saved image filenames.
    - dir (str): The directory where the masks will be saved.
    """
    # Ensure the directory exists
    os.makedirs(dir, exist_ok=True)

    # Check the shape of the masks array
    N, _, h, w = masks.shape

    for i in range(N):
        mask = masks[i, 0]  # Get the mask for the ith item, shape [h, w]
        
        # Convert the numpy array to a PIL Image
        pil_image = Image.fromarray((mask * 255).astype(np.uint8))
        
        # Optionally, convert to 'L' mode for grayscale if needed
        pil_image = pil_image.convert("L")
        
        # Construct the file path
        file_path = os.path.join(dir, f"{prefix}_{i}.png")
        
        # Save the image
        pil_image.save(file_path)

In [31]:
print(np.amin(masks[3]))

0.0


In [43]:
def get_prompts():
    #TODO prompts from gpt
    prompt = "a lady with long hair"
    return prompt

In [38]:
# https://github.com/runwayml/stable-diffusion#inpainting-with-stable-diffusion

from diffusers import StableDiffusionInpaintPipeline
import torch

pipe = StableDiffusionInpaintPipeline.from_pretrained(
    "stabilityai/stable-diffusion-2-inpainting",
    torch_dtype=torch.float16,
)

pipe.to("cuda")

Loading pipeline components...: 100%|██████████| 6/6 [00:39<00:00,  6.50s/it]
100%|██████████| 50/50 [00:06<00:00,  7.90it/s]


In [44]:
prompt = get_prompts()

#image and mask_image should be PIL images.
#The mask structure is white for inpainting and black for keeping as is
input_image = Image.open('test1.webp')
print(input_image.size)
mask_image = Image.open('masks/mask_2.png')
print(mask_image.size)

image = pipe(prompt=prompt, image=input_image, mask_image=mask_image).images[0]
image.save("./test1.png")

(951, 535)
(951, 535)


100%|██████████| 50/50 [00:07<00:00,  6.56it/s]
