In [2]:
# Import & Settings
import sys
import scipy.io
import json
import numpy as np
import pathlib
import os
import platform
import matplotlib.pyplot as plt
import cv2
import torch
import torchvision
from PIL import Image

from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
SAM_CHECKPOINT = 'sam_vit_h_4b8939.pth'
MODEL_TYPE = 'vit_h'
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
TITLE = "FISH APP"
PATH = "151-200_Hong/MAX_KO2_w1-359 DAPI_s058.tif"
SAM = sam_model_registry[MODEL_TYPE](checkpoint=SAM_CHECKPOINT)
SAM.to(device=DEVICE)

Sam(
  (image_encoder): ImageEncoderViT(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 1280, kernel_size=(16, 16), stride=(16, 16))
    )
    (blocks): ModuleList(
      (0-31): 32 x Block(
        (norm1): LayerNorm((1280,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=1280, out_features=3840, bias=True)
          (proj): Linear(in_features=1280, out_features=1280, bias=True)
        )
        (norm2): LayerNorm((1280,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (lin1): Linear(in_features=1280, out_features=5120, bias=True)
          (lin2): Linear(in_features=5120, out_features=1280, bias=True)
          (act): GELU(approximate='none')
        )
      )
    )
    (neck): Sequential(
      (0): Conv2d(1280, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): LayerNorm2d()
      (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (3): LayerNorm2d

In [4]:
def import_image(PATH):
    img = Image.open(pathlib.Path(PATH))
    rgb_img = img.convert("RGB")
    image_array = np.array(rgb_img)

    return image_array

def get_scores(masks):
    mask_iou_scores = [single_cell['predicted_iou'] for single_cell in masks]
    return mask_iou_scores

In [5]:
mask_generator1 = SamAutomaticMaskGenerator(SAM)
image_array = import_image(PATH)
masks1 = mask_generator1.generate(image_array)

In [7]:
mask_generator2 = SamAutomaticMaskGenerator(SAM, points_per_side=8)
masks2 = mask_generator2.generate(image_array)

In [8]:
print(get_scores(masks1))
print(get_scores(masks2))

[1.044431209564209, 0.9926577210426331, 0.9921486377716064, 0.9889293909072876, 0.988497257232666, 0.9873670339584351, 0.9854803085327148, 0.9820512533187866, 0.9661861658096313, 0.9603437781333923, 0.9578291177749634, 0.9513841867446899, 0.947485089302063, 0.9443943500518799, 0.9177680015563965, 0.9059348106384277, 0.8953773975372314, 0.8809551000595093]
[1.0386055707931519, 0.9709571599960327, 0.9478321671485901, 0.9415301084518433]
