In [1]:
# Import & Settings
import sys
import scipy.io
import json
import numpy as np
import pathlib
import os
import platform
import matplotlib.pyplot as plt
import cv2
import torch
import torchvision
from PIL import Image

import fishLoader

from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def import_image_RGB(PATH):
    img = Image.open(pathlib.Path(PATH))
    rgb_img = img.convert("RGB")
    image_array = np.array(rgb_img)

    return image_array

def get_scores(masks):
    mask_iou_scores = [single_cell['predicted_iou'] for single_cell in masks]
    return mask_iou_scores

def mask_to_matrix(masks) -> np.ndarray:
        img = np.ones((2048, 2048, 4), dtype=np.float32)
        img[:, :, 3] = 0
        
        # Calculate the size of each mask
        mask_sizes = [np.sum(single_cell['segmentation']) for single_cell in masks]
        
        # Identify the largest mask
        largest_mask_index = np.argmax(mask_sizes)
        
        # Apply colors and opacity
        for index, single_cell in enumerate(masks):
            m = single_cell['segmentation']
            
            if index == largest_mask_index:
                color_mask = np.array([0, 0, 0, 0], dtype=np.float32)  # Set color and make invisible
            else:
                color_mask = np.concatenate([np.random.random(3).astype(np.float32), [1]]).astype(np.float32)  # Random color and fully opaque
            
            img[m] = color_mask
        
        return img

In [3]:
SAM_CHECKPOINT = 'sam_vit_h_4b8939.pth'
MODEL_TYPE = 'vit_h'
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
TITLE = "FISH APP"
PATH = "151-200_Hong/MAX_KO2_w1-359 DAPI_s058.tif"
SAM = sam_model_registry[MODEL_TYPE](checkpoint=SAM_CHECKPOINT)
SAM.to(device=DEVICE)
IMG = import_image_RGB(PATH)
IMG_RAW = Image.open(pathlib.Path(PATH))
CELLS_IN_IMG = 4

In [11]:
pps_tryout = {}
for i in range(5,32):
    mg = SamAutomaticMaskGenerator(SAM, points_per_side=i)
    mask = mg.generate(IMG)