## Counting Grains : Slice and Dice 

We use EfficientSAM in our workflow:

1. Image is sliced into smaller windows
2. Each slice is run through EfficientSAM to generate the annotations
3. Count of annotations from each image is added
4. ..finally the slices are stiched together

This script provides example for how to get segment everything visualization result from EfficientSAM using weight file.

The basic method is same as SAM, we generate a grid of point prompts over the image and get the masks. Currently we directly compute all the masks in one time so it requires a large memory. If you face OOM Issue, you can consider reduce the GRID_SIZE. We will update the efficient version by calculating the mask in local crops in the future.

the post processing part is from original SAM project to get a better visualization result, part of the visualization code are borrow from MobileSAM project, many thanks!

In [10]:
import matplotlib.pyplot as plt
import numpy as np
import torch
from torchvision.transforms import ToTensor
from PIL import Image
import os
import cv2
import tqdm 
GRID_SIZE = 32

In [None]:
!pip install git+https://github.com/facebookresearch/segment-anything.git

In [3]:
from segment_anything.utils.amg import (
    batched_mask_to_box,
    calculate_stability_score,
    mask_to_rle_pytorch,
    remove_small_regions,
    rle_to_mask,
)
from torchvision.ops.boxes import batched_nms, box_area
def process_small_region(rles):
        new_masks = []
        scores = []
        min_area = 100
        nms_thresh = 0.7
        for rle in rles:
            mask = rle_to_mask(rle[0])

            mask, changed = remove_small_regions(mask, min_area, mode="holes")
            unchanged = not changed
            mask, changed = remove_small_regions(mask, min_area, mode="islands")
            unchanged = unchanged and not changed

            new_masks.append(torch.as_tensor(mask).unsqueeze(0))
            # Give score=0 to changed masks and score=1 to unchanged masks
            # so NMS will prefer ones that didn't need postprocessing
            scores.append(float(unchanged))

        # Recalculate boxes and remove any new duplicates
        masks = torch.cat(new_masks, dim=0)
        boxes = batched_mask_to_box(masks)
        keep_by_nms = batched_nms(
            boxes.float(),
            torch.as_tensor(scores),
            torch.zeros_like(boxes[:, 0]),  # categories
            iou_threshold=nms_thresh,
        )

        # Only recalculate RLEs for masks that have changed
        for i_mask in keep_by_nms:
            if scores[i_mask] == 0.0:
                mask_torch = masks[i_mask].unsqueeze(0)
                rles[i_mask] = mask_to_rle_pytorch(mask_torch)
        masks = [rle_to_mask(rles[i][0]) for i in keep_by_nms]
        return masks

In [4]:
def get_predictions_given_embeddings_and_queries(img, points, point_labels, model):
    predicted_masks, predicted_iou = model(
        img[None, ...], points, point_labels
    )
    sorted_ids = torch.argsort(predicted_iou, dim=-1, descending=True)
    predicted_iou_scores = torch.take_along_dim(predicted_iou, sorted_ids, dim=2)
    predicted_masks = torch.take_along_dim(
        predicted_masks, sorted_ids[..., None, None], dim=2
    )
    predicted_masks = predicted_masks[0]
    iou = predicted_iou_scores[0, :, 0]
    index_iou = iou > 0.7
    iou_ = iou[index_iou]
    masks = predicted_masks[index_iou]
    score = calculate_stability_score(masks, 0.0, 1.0)
    score = score[:, 0]
    index = score > 0.9
    score_ = score[index]
    masks = masks[index]
    iou_ = iou_[index]
    masks = torch.ge(masks, 0.0)
    return masks, iou_

def run_everything_ours(img_path, model):
    model = model.cpu()
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    img_tensor = ToTensor()(image)
    _, original_image_h, original_image_w = img_tensor.shape
    xy = []
    for i in range(GRID_SIZE):
        curr_x = 0.5 + i / GRID_SIZE * original_image_w
        for j in range(GRID_SIZE):
            curr_y = 0.5 + j / GRID_SIZE * original_image_h
            xy.append([curr_x, curr_y])
    xy = torch.from_numpy(np.array(xy))
    points = xy
    num_pts = xy.shape[0]
    point_labels = torch.ones(num_pts, 1)
    with torch.no_grad():
      predicted_masks, predicted_iou = get_predictions_given_embeddings_and_queries(
              img_tensor.cpu(),
              points.reshape(1, num_pts, 1, 2).cpu(),
              point_labels.reshape(1, num_pts, 1).cpu(),
              model.cpu(),
          )
    rle = [mask_to_rle_pytorch(m[0:1]) for m in predicted_masks]
    predicted_masks = process_small_region(rle)
    return predicted_masks, predicted_iou

In [5]:
def show_anns_ours(mask, ax):
    ax.set_autoscale_on(False)
    img = np.ones((mask[0].shape[0], mask[0].shape[1], 4))
    img[:,:,3] = 0
    for ann in mask:
        m = ann
        color_mask = np.concatenate([np.random.random(3), [0.5]])
        img[m] = color_mask
    ax.imshow(img)

In [6]:
!git clone https://github.com/yformer/EfficientSAM.git
import os
os.chdir("EfficientSAM")

fatal: destination path 'EfficientSAM' already exists and is not an empty directory.


In [7]:
from efficient_sam.build_efficient_sam import build_efficient_sam_vits
import zipfile

with zipfile.ZipFile("EfficientSAM/weights/efficient_sam_vits.pt.zip", 'r') as zip_ref:
    zip_ref.extractall("weights")
efficient_sam_vits_model = build_efficient_sam_vits()

#### Slice the image into smaller units

```
$grains.size
(3024, 4032)
```

In [8]:
filepath = "E:\EfficientSAM\grains_of_sand.jpg"
grains = Image.open(filepath)

if not os.path.exists("grain_slices"):
    os.mkdir("grain_slices")

for r in range(0, grains.size[0],520):
    for c in range(0, grains.size[1],520):
        box = (r,  c, r+520, c+520)
        # print(box)
        grain_slice = grains.crop(box)
        grain_slice.save(f"grain_slices\\grains_{r}_{c}.jpg")

# get each file name
        
filenames = os.listdir("grain_slices")
filenames = [fp for fp in filenames if fp.endswith(".jpg")]

len(filenames)

48

In [None]:
if not os.path.exists("grain_masks"):
    os.mkdir("grain_masks")

mask_count = 0

for image_path in tqdm.tqdm(filenames):
    
    #Create output filename
    output_file = image_path.split(".")[0] + "_mask.jpg"
    output_file = os.path.join("grain_masks", output_file)

    # Run SAM
    image_path = os.path.join("grain_slices", image_path )
    mask_efficient_sam_vits, mask_iou = run_everything_ours(image_path, efficient_sam_vits_model)

    mask_count += len(mask_iou) -1 
    #Generate output mask file
    width_px = 520
    height_px = 520 
    dpi = 96

    fig, ax = plt.subplots(1, 1, figsize=(width_px/dpi, height_px/dpi), dpi=dpi)

    image = np.array(Image.open(image_path))
    ax.imshow(image)
    show_anns_ours(mask_efficient_sam_vits, ax)
    ax.axis('off')
    plt.savefig(output_file, bbox_inches="tight", pad_inches=0, dpi=96)

### Stitch the slices together into a single image

In [13]:
result_width, result_height = Image.open("E:\\EfficientSAM\\grains_of_sand.jpg").size
result_image = Image.new("RGB", (result_width, result_height))

In [14]:
width_px = 520
height_px = 520 
dpi = 96

filenames = os.listdir("grain_masks")
filenames = [fp for fp in filenames if fp.endswith(".jpg")]

print(f"Total masks files: {len(filenames)}")

Total masks files: 48


In [20]:

for file in filenames: 
    # Find the upper-left box point for each image slice
    file_box = file.replace("grains_","")
    file_box = file_box.replace("_mask.jpg", "")
    box = [int(fp) for fp in file_box.split("_")]

    # Read image file 
    img = Image.open(os.path.join("grain_masks",file))
    img = img.resize((520,520))
    # Stich together the images    
    result_image.paste(im=img, box=box)

In [21]:
result_image.save("grain_mask.jpg")

In [23]:
print(f"Total number of grains: {mask_count}")

Total number of grains: 39690
