In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from utils.grounded_sam_helpers import grounded_segmentation, plot_detections

In [None]:
# image_url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg"
labels = ["car", "shadow of car"]
threshold = 0.3

detector_id = "IDEA-Research/grounding-dino-tiny"
segmenter_id = "facebook/sam-vit-base"

In [None]:
image_array, detections = grounded_segmentation(
    image=image_url,
    labels=labels,
    threshold=threshold,
    polygon_refinement=True,
    detector_id=detector_id,
    segmenter_id=segmenter_id,
)
plot_detections(image_array, detections, "cute_cats.png")

In [None]:
from transformers import pipeline
generator = pipeline("mask-generation", model="facebook/sam-vit-huge", device=0)

In [None]:
from PIL import Image
import requests
import matplotlib.pyplot as plt

img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg"
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")

display(raw_image)

In [None]:
from utils.sam_helpers import show_masks_on_image
outputs = generator(raw_image, points_per_batch=64)
masks = outputs["masks"]
show_masks_on_image(raw_image, masks)

In [None]:
# Convert raw_image to bchw tensor
from torchvision.transforms.functional import to_tensor
image = to_tensor(raw_image).unsqueeze(0)
b, c, h, w = image.shape
# device = "cuda:0"
device = "cpu"
image = image.to(device)

In [None]:
import torch
grid = torch.stack(
    torch.meshgrid(
        torch.arange(h, device=device, dtype=torch.float32),
        torch.arange(w, device=device, dtype=torch.float32),
        indexing="ij",
    ),
    dim=-1,
)
grid = grid.unsqueeze(0).repeat(b, 1, 1, 1)

In [None]:
mask = torch.tensor(outputs["masks"][1], dtype=torch.float32)
plt.imshow(mask, cmap="gray", interpolation="lanczos")
plt.axis("off")

In [None]:
from einops import rearrange, repeat, reduce
mask = torch.tensor(outputs["masks"][1], dtype=torch.float32, device=device)
mass_center = reduce(rearrange(mask, "h w -> 1 h w 1") * grid, "b h w c -> b c", "sum") / reduce(mask, "h w -> 1", "sum")
ci, cj = mass_center[0].tolist()
ci, cj

In [None]:
from utils.transform_matrices import translation, rotation, scale, shear
grid_homo = torch.cat((grid, torch.ones(b, h, w, 1, device=device)), dim=3)
# Move mass center to 0, 0
transforms = (
    translation(ci, cj) 
    @ rotation(0.03) 
    @ scale(1.01, 1) 
    @ shear(0.00, 0.02) 
    @ translation(-ci + 20, -cj + 20)
)
grid_homo = grid_homo @ transforms.T

In [None]:
grid_out = grid_homo[..., :2] / grid_homo[..., 2:]
grid_out = grid_out / torch.tensor([h - 1, w - 1], device=device, dtype=torch.float32)
grid_out = grid_out * 2 - 1
# Flip the last dimension to match the grid_sample format
grid_out = grid_out.flip(-1)
out = torch.nn.functional.grid_sample(image, grid_out, align_corners=True)
out_np = rearrange(out.cpu().numpy(), "1 c h w -> h w c")
plt.imshow(out_np)

In [None]:
mask_out = torch.nn.functional.grid_sample(
    rearrange(mask, "h w -> 1 1 h w"), grid_out, align_corners=True
).clamp(0, 1)
out_composed = image * (1 - mask_out) + out * mask_out
out_composed_np = rearrange(out_composed.cpu().numpy(), "1 c h w -> h w c")
plt.imshow(out_composed_np)

In [None]:
def generate_mask_structure(masks):
    n = len(masks)
    areas = reduce(masks, "n h w -> n", "sum")
    order = areas.argsort(descending=True)
    masks = masks[order]
    # Determine the parent-child relationship
    parent = torch.full((len(masks),), -1, dtype=torch.long)
    for i in range(n - 1, -1, -1):
        th = masks[i].sum() * 0.9
        for j in range(i - 1, -1, -1):
            if (masks[i] & masks[j]).sum() >= th:
                parent[i] = j
                break
    return masks, parent

In [None]:
import numpy as np
masks = torch.tensor(np.array(outputs["masks"]), device=device)
masks, parent = generate_mask_structure(masks)
masks.shape

In [None]:
fig, axs = plt.subplots(10, 6, figsize=(12, 20))
for i, ax in enumerate(axs.flatten()):
    if i >= len(masks):
        break
    ax.imshow(masks[i], cmap="gray", interpolation="lanczos")
    ax.axis("off")
    ax.title.set_text(f"{i} ({parent[i].item()})")
plt.tight_layout()

In [None]:
# Save masks[2] as a PNG file
mask_pil = Image.fromarray((masks[2].cpu().numpy() * 255).astype(np.uint8))
mask_pil.save("mask.png")