# <b>Deep Learning:</b>
# Improving existing segmentators performance with zero-shot segmentators

This Notebook implements the code used in our paper "Improving existing segmentators performance with zero-shot segmentators".

In our study, we used the predicted segmentation masks from state-of-the-art methods **DeepLabV3+** https://github.com/VainF/DeepLabV3Plus-Pytorch and **PVTv2** https://github.com/whai362/PVT.

From these masks, we produce some checkpoints to feed **SAM** (**Segment Anything**, https://github.com/facebookresearch/segment-anything) for *Post-Processing Segmentation Enhancement* or SEEM (**Segment Everything Everywhere All at Once**, https://github.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once) models.

The **Segment Anything Model (SAM)** produces high quality object masks from input prompts such as points or boxes, and it can be used to generate masks for all objects in an image. It has been trained on a dataset of 11 million images and 1.1 billion masks, and has strong zero-shot performance on a variety of segmentation tasks.

Similarly to SAM, **Segment Everything Everywhere All at Once (SEEM)** allows users to easily segment an image using prompts of different types including visual prompts (points, marks, boxes, scribbles and image segments) and language prompts (text and audio), etc. It can also work with any combinations of prompts or generalize to custom prompts.


We devised 4 different methods for producing checkpoints:
 - A: the pixel whose coordinates are, along each dimension, the average of value of all the mask's pixels coordinate
 - B: the center of mass of the mask
 - C: one (or more) pixels drawn (uniformly or not...) randomly inside the mask area
 - D: pixels drawn from the intersection of a uniform grid of fixed step size and the mask. "b" stands for the intersection between the grid and the eroded mask, where the mask is shrinked of 10 pixels.
---

### In order to run the script, you need to:
 - set the path to your data folder    (in "Parameters of the script" cell)
 - prepare the dataset in this folder structure:

~~~
    ├── dataset
        ├── name of dataset
        │   ├── imgs                    # rgb source images
        │   ├── gt                      # gt segmentation images
        │   └── segmentator_deeplab     # example of segmentator source: deeplab
                                        # in this folder you must put
                                        #   - .png images (segmentator's output logits)
                                        #   - .bmp images (segmentator's output binary mask)
        └── ...
~~~

---



### Necessary imports

In [None]:
import torch                    ## pip install torch
import matplotlib.pyplot as plt ## pip install matplotlib
import cv2                      ## pip install opencv-python
from skimage import measure     ## pip install scikit-image
from scipy import ndimage       ## pip install scipy
import scipy.io
from tqdm import tqdm           ## pip install tqdm

import numpy as np              ## pip install numpy
import pickle                   ## pip install pickle

import os,glob
import sys
sys.path.append("..")
np.set_printoptions(threshold=sys.maxsize)

import time
import math

###
## to download SAM:
## git clone git@github.com:facebookresearch/segment-anything.git
## cd segment-anything; pip install -e .
###

from segment_anything import sam_model_registry, SamPredictor

### For reproducibility, seed of random generators

In [None]:
torch.manual_seed(0)
np.random.seed(0)

### Helper function for reading images and masks

In [None]:
def read_img(path:str) -> np.ndarray:
    return cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB)

def read_bmask(path:str) -> np.ndarray:
    return cv2.imread(path, cv2.IMREAD_GRAYSCALE) / 255.0

def read_rmask(path:str) -> np.ndarray:
    return cv2.imread(path, cv2.IMREAD_UNCHANGED) 

def get_data(paths:list) -> list:
    img_path, gt_mask_path, dplabv3_bmask_path, dplabv3_rmask_path = paths
    img           = read_img(img_path)
    gt_mask       = read_bmask(gt_mask_path)
    dplabv3_bmask = read_bmask(dplabv3_bmask_path)
    dplabv3_rmask = read_rmask(dplabv3_rmask_path)
    assert img.shape[:2] == gt_mask.shape == \
           dplabv3_bmask.shape == dplabv3_rmask.shape[:2], \
            f"Error: shape mismatch, {img.shape[:2]} {gt_mask.shape} {dplabv3_bmask.shape} {dplabv3_rmask.shape[:2]}"
    return img, gt_mask, dplabv3_bmask, dplabv3_rmask

### Helper functions for displaying points, boxes, and masks.

In [None]:
def show_mask(mask: np.ndarray, ax, random_color:bool=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)
    
def show_points(coords:np.ndarray, labels:np.ndarray, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) # this is if you want the star
    #ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=10, edgecolor='green', linewidth=1.25) # this is if you want the dot
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   
    
def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))    

def draw_img(img:np.ndarray, input_point:np.ndarray=None, input_label:np.ndarray=None, \
             mask: np.ndarray=None, title:plt.title = None, plt_show:bool=True):
    plt.clf()
    plt.figure(figsize=(3,3))
    plt.imshow(img)
    if mask is not None:
        show_mask(mask, plt.gca())
    if input_point is not None and input_label is not None:
        show_points(input_point, input_label, plt.gca())
    if title is not None:
        plt.title(title, fontsize=18)
    plt.axis('off')
    if plt_show:
        plt.show()
        
def draw_results(masks:list, scores:list, gt_mask:np.ndarray, \
                 input_point:np.ndarray, input_label:np.ndarray):
    for i, (mask, score) in enumerate(zip(masks, scores)):
        iou = get_iou(mask, gt_mask)
        title = f"Mask {i+1}, Score: {score:.3f}, IoU: {iou:.3f}"
        draw_img(img, input_point, input_label, mask, title)

### Helper function for computing how many blobs a mask contains

In [None]:
def get_mask_of_blobs(mask: np.ndarray) -> np.ndarray:
    # compute how many blobs a mask contains
    # @return: a mask in which each pixel is assigned a blob ID
    mask_of_blobs = measure.label(mask)
    return mask_of_blobs

### Helper function for logging

In [None]:
def log(paths: tuple):
    img_path, gt_mask_path, dplabv3_bmask_path, dplabv3_rmask_path = paths
    print("img_path          :", img_path)
    print("gt_mask_path      :", gt_mask_path)
    print("dplabv3_bmask_path:", dplabv3_bmask_path)
    print("dplabv3_rmask_path:", dplabv3_rmask_path)

### Metrics
(between a predicted mask and a Ground Truth mask)

In [None]:
from sklearn.metrics import jaccard_score

class Metrics():
    eps=np.finfo(np.double).eps
    
    def reset(self):
        self.ious, self.maes, self.dices, self.wfms, self.emes = [], [], [], [], []
        self.tps, self.fps, self.tns, self.fns = 0, 0, 0, 0
    
    def step_common(self, pred, GT):
        iou       = self.get_iou(pred, GT)
        dice      = self.get_dice(pred, GT)
        mae       = self.compute_mae(pred, GT)
        fscore    = self.FbetaMeasure(pred.astype(bool), GT.astype(bool))
        e_measure = self.EMeasure(pred.astype(bool), GT.astype(bool))
        self.ious.append(iou)
        self.dices.append(dice)
        self.maes.append(mae)
        self.wfms.append(fscore)
        self.emes.append(e_measure)
    
    def step_skin(self, pred, gt):
        y_pred_bool = pred.astype(bool)
        y_true_bool = gt.astype(bool)
        self.tps += np.logical_and(y_true_bool, y_pred_bool).sum()
        self.tns += np.logical_and(~y_true_bool, ~y_pred_bool).sum()
        self.fps += np.logical_and(~y_true_bool, y_pred_bool).sum()
        self.fns += np.logical_and(y_true_bool, ~y_pred_bool).sum()
        
        
    def step_locuste(self, pred, GT):
        iou       = self.get_iou_locuste(pred, GT)
        dice      = self.get_dice_locuste(pred.astype(bool), GT.astype(bool))
        mae       = self.compute_mae(pred, GT)
        e_measure = self.EMeasure(pred.astype(bool), GT.astype(bool))
        fscore    = self.FbetaMeasure(pred.astype(bool), GT.astype(bool))
        self.ious.append(iou)
        self.dices.append(dice)
        self.maes.append(mae)
        self.wfms.append(fscore)
        self.emes.append(e_measure)
        
    
    def get_iou_common(self, pred, gt, beta=1):
        y_pred_bool = pred.astype(bool)
        y_true_bool = gt.astype(bool)
        tp = np.logical_and(y_true_bool, y_pred_bool).sum()
        tn = np.logical_and(~y_true_bool, ~y_pred_bool).sum()
        fp = np.logical_and(~y_true_bool, y_pred_bool).sum()
        fn = np.logical_and(y_true_bool, ~y_pred_bool).sum()
        if tp+fn+fp==0:
            if tp==0:
                iou=1.0
            else:
                iou=0.0
        else:
            iou = tp / (tp + fn + fp)

        return iou

    def get_iou_locuste(self, pred, target):
        return jaccard_score(target.reshape(-1).astype(bool), pred.reshape(-1).astype(bool))

    def get_dice_common(self, pred, gt):
        y_pred_bool = pred.astype(bool)
        y_true_bool = gt.astype(bool)

        tp = np.logical_and( y_true_bool,  y_pred_bool).sum()
        tn = np.logical_and(~y_true_bool, ~y_pred_bool).sum()
        fp = np.logical_and(~y_true_bool,  y_pred_bool).sum()
        fn = np.logical_and( y_true_bool, ~y_pred_bool).sum()
        
        if tp+fn+fp==0:
            dice=1.0 if tp==0 else 0.0
        else:
            dice = 2*tp / (2*tp + fn + fp)

        return dice

    def _calConfusion(self, pred, GT):
        TP=np.sum(pred[GT]==1)
        FP=np.sum(pred[~GT]==1)
        TN=np.sum(pred[~GT]==0)
        FN=np.sum(pred[GT]==0)
        return TP,FP,TN,FN

    def get_dice_locuste(self, y_pred, y_true):
        # True Positive (TP): we predict a label of 1 (positive), and the true label is 1.
        y_true_bool = y_true.astype(bool)
        y_pred_bool = y_pred.astype(bool)

        tp = np.logical_and(y_true_bool, y_pred_bool).sum()
        tn = np.logical_and(~y_true_bool, ~y_pred_bool).sum()
        fp = np.logical_and(~y_true_bool, y_pred_bool).sum()
        fn = np.logical_and(y_true_bool, ~y_pred_bool).sum()
        return (2.0 * tp) / (2.0 * tp + fp + fn + 1e-7)#, tp, tn, fp, fn

    def compute_mae(self, pred: np.ndarray, gt: np.ndarray) -> np.ndarray:
        mae = np.mean(np.abs(pred - gt))
        return mae
    
    ## F-Measure
    def FbetaMeasure(self, pred, GT, beta= math.sqrt(0.3)):
        TP,FP,TN,FN=self._calConfusion(pred, GT)
        if TP+FN+FN==0:
            if TP==0:
                Fbeta=1.0
            else:
                Fbeta=0.0
        else:
            P=TP/(TP+FP+1e-8) #precision
            R=TP/(TP+FN+1e-8) #recall
            Fbeta=(beta**2+1)*P*R/((beta**2)*P+R+1e-8)
        return Fbeta
    
    ## E-Measure
    def _EnhancedAlignmnetTerm(self, align_Matrix):
        enhanced=((align_Matrix+1)**2)/4
        return enhanced

    def _AlignmentTerm(self, dGT, dpred):
        mean_dpred=np.mean(dpred)
        mean_dGT=np.mean(dGT)
        align_dpred=dpred-mean_dpred
        align_dGT=dGT-mean_dGT
        align_matrix=2*(align_dGT*align_dpred)/(align_dGT**2+align_dpred**2+self.eps)
        return align_matrix

    def EMeasure(self, pred, GT):
        dGT,dpred=GT.astype(np.float64),pred.astype(np.float64)
        if np.sum(GT)==0:#completely black
            enhanced_matrix=1-dpred
        elif np.sum(~GT)==0:
            enhanced_matrix=dpred
        else:
            align_matrix=self._AlignmentTerm(dGT,dpred)
            enhanced_matrix=self._EnhancedAlignmnetTerm(align_matrix)
        rows,cols= GT.shape
        
        # score=np.sum(enhanced_matrix)/(rows*cols-1+self.eps)
        score=np.sum(enhanced_matrix)/(rows*cols+self.eps)
        return score
    
    def get_results_common(self) -> (float, float, float, float, float):
        return np.array(self.ious).mean(), np.array(self.dices).mean(), np.array(self.maes).mean(), np.array(self.wfms).mean(), np.array(self.emes).mean()

    def get_results_skin(self):
        iou = self.tps / (self.tps + self.fns + self.fps)
        dice = (2.0 * self.tps) / (2.0 * self.tps + self.fps + self.fns + 1e-7)
        # for skin dataset, we didn't need the other metrics. TODO: implement
        return iou, dice, None, None, None
    
    def set_mode_locuste(self):
        self.step = self.step_locuste
        self.get_iou = self.get_iou_locuste
        self.get_dice = self.get_dice_locuste
    
    def set_mode_skin(self):
        self.step = self.step_skin
        self.get_results = self.get_results_skin
        
    def __init__(self, dataset=None):
        self.reset()
        self.step        = self.step_common
        self.get_iou     = self.get_iou_common
        self.get_dice    = self.get_dice_common
        self.get_results = self.get_results_common
        if "SKIN" in dataset:
            self.set_mode_skin()
        elif "Locuste" in dataset:
            self.set_mode_locuste()
            
    

### Functions for the sampling of the checkpoints.

In [None]:

class Sampler:
    verbose = True
    sampling_step = None
    min_blob_count = None
    def __init__(self, verbose, sampling_step, min_blob_count):
        self.verbose        = verbose
        self.sampling_step  = sampling_step
        self.min_blob_count = min_blob_count
    
    def sample_pixels(self, mask_of_blobs: np.ndarray, mask: np.ndarray) -> (np.ndarray, np.ndarray):
        # draw a pix for each blob
        input_point, input_label = [], []
        blob_labels, blob_sample = np.unique(mask_of_blobs, return_index=True)
        gt_fl = mask.flatten()
        for bl, bs in zip(blob_labels, blob_sample):
            mask_bool = (mask_of_blobs==bl)
            count = mask_bool.sum()
            if gt_fl[bs]>=1.0 and count>self.min_blob_count: ## it's not a background blob or a false blob
                x_center, y_center = np.argwhere(mask_bool).sum(0)/count
                x_center, y_center = int(x_center) % mask.shape[0], int(y_center) % mask.shape[1]
                input_point.append([y_center, x_center])
                input_label.append(1)
                print(f"blob #{bl} drawn point: {[x_center, y_center]}") if self.verbose else None

        # no mask? pick the center pixel of image
        if len(input_point) == 0:
            input_point, input_label = [[mask.shape[1]//2, mask.shape[0]//2]], [1]

        return np.array(input_point), np.array(input_label)
    
    def sample_pixels_center_of_mass(self, mask_of_blobs: np.ndarray, mask: np.ndarray) -> (np.ndarray, np.ndarray):
        # draw a pix for each blob
        input_point, input_label = [], []
        blob_labels, blob_sample = np.unique(mask_of_blobs, return_index=True)
        gt_fl = mask.flatten()
        for bl, bs in zip(blob_labels, blob_sample):
            mask_bool = (mask_of_blobs==bl)
            count = mask_bool.sum()
            if gt_fl[bs]>=1.0 and count>self.min_blob_count: ## it's not a background blob or a false blob
                x_center, y_center = ndimage.center_of_mass(mask_bool)
                input_point.append([y_center, x_center])
                input_label.append(1)
                print(f"blob #{bl} drawn point: {[x_center, y_center]}") if self.verbose else None

        # no mask? pick the center pixel of image
        if len(input_point) == 0:
            input_point, input_label = [[mask.shape[1]//2, mask.shape[0]//2]], [1]
            print(f"empty blob -> {input_point}") if self.verbose else None

        return np.array(input_point), np.array(input_label)

    def sample_pixels_random(self, mask_of_blobs: np.ndarray, mask: np.ndarray) -> (np.ndarray, np.ndarray):
        # draw a pix for each blob
        input_point, input_label = [], []
        blob_labels, blob_sample = np.unique(mask_of_blobs, return_index=True)
        gt_fl = mask.flatten()
        for bl, bs in zip(blob_labels, blob_sample):
            mask_bool = (mask_of_blobs==bl)
            count = mask_bool.sum()
            if gt_fl[bs]>=1.0 and count>self.min_blob_count: ## it's not a background blob or a false blob
                indices = np.argwhere(mask_bool)
                random_index = np.random.choice(indices.shape[0])
                x_center, y_center = indices[random_index]
                input_point.append([y_center, x_center])
                input_label.append(1)
                print(f"blob #{bl} drawn point: {[x_center, y_center]}") if self.verbose else None

        # no mask? sample a random point
        if len(input_point) == 0:
            input_point, input_label = [                 [np.random.randint(0, mask.shape[1]),                    np.random.randint(0, mask.shape[0])]],             [1]

        return np.array(input_point), np.array(input_label)

    def get_grid(self, mask, offset_px_x, offset_px_y):
        row = np.zeros(mask.shape, dtype=int)
        col = np.zeros(mask.shape, dtype=int)

        for i in range(offset_px_y, row.shape[0], self.sampling_step):
            row[i, :] = 1
        for i in range(offset_px_x, col.shape[1], self.sampling_step):
            col[:, i] = 1
        res = row & col
        return res
    
    def sample_pixels_grid(self, mask_of_blobs: np.ndarray, mask: np.ndarray) -> (np.ndarray, np.ndarray):

        res = self.get_grid(mask, 0, 0)

        input_point = np.argwhere(res & mask.astype(np.int64))
        input_point[:, (0, 1)] = input_point[:, (1, 0)]

        input_label = [1 for _ in input_point]
        
        return np.array(input_point), np.array(input_label)
    
    def sample_pixels_eroded_grid(self, mask_of_blobs: np.ndarray, mask: np.ndarray) -> (np.ndarray, np.ndarray):
        
        input_point = []
        offset_px_x = 0
        offset_px_y = 0
        
        while len(input_point)==0 and offset_px_y < self.sampling_step:
            res = self.get_grid(mask, offset_px_x, offset_px_y)
            erode_size = 10
        
            while True:
                # Erode the mask
                kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (erode_size, erode_size))
                eroded_mask = cv2.erode(mask.astype('uint8'), kernel)

                input_point = np.argwhere(res & eroded_mask.astype(np.int64))
                input_point[:, (0, 1)] = input_point[:, (1, 0)]

                blob_labels, blob_sample = np.unique(mask_of_blobs, return_index=True)
                gt_fl = mask.flatten()

                blobs = np.zeros(blob_labels.shape[0], dtype=np.float32)
                for i, (bl, bs) in enumerate(zip(blob_labels, blob_sample)):
                    mask_bool = (mask_of_blobs==bl)
                    count = mask_bool.sum()
                    if not (gt_fl[bs]>=1.0 and count>self.min_blob_count): ## it's not a background blob or a false blob
                        blobs[i] = -1

                for i in range(input_point.shape[0]):
                    fl_ip = input_point[i][0]*mask.shape[1] + input_point[i][1]
                    idx = mask_of_blobs[input_point[i][1], input_point[i][0]]
                    blobs[idx]=1.0
                
                if not np.any(blobs==0.0) or erode_size==1:
                    break
                erode_size -= 1
            
            offset_px_x += 1
            if offset_px_x>self.sampling_step:
                offset_px_x = 0
                offset_px_y += 1
            
        
        input_label = [1 for _ in input_point]

        # still no mask? sample a random point
        if len(input_point) == 0:
            return self.sample_pixels_grid(mask_of_blobs, mask)

        return np.array(input_point), np.array(input_label)
    
    def sample(self, mode, border_mode, mask_of_blobs: np.ndarray, mask: np.ndarray):
        if mode=="A":
            return self.sample_pixels(mask_of_blobs, mask)
        elif mode=="B":
            return self.sample_pixels_center_of_mass(mask_of_blobs, mask)
        elif mode=="C":
            return self.sample_pixels_random(mask_of_blobs, mask)
        elif mode=="D":
            if border_mode=="on":
                return self.sample_pixels_eroded_grid(mask_of_blobs, mask)
            else:
                return self.sample_pixels_grid(mask_of_blobs, mask)



## Parameters of the script

In [None]:
verbose           = False

## Data settings
dataset_path     =  ## Put your PATH here!
base_output_path =  ## Put your PATH here!

def get_complete_output_path(bop, dataset_name, src_msk, model, create=False):
    results_dir = os.path.join(bop, dataset_name, src_msk, model)
    if create:
        os.makedirs(results_dir, exist_ok=True)
    return results_dir

def get_min_blob_number_based_on_dataset(dataset):
    return 20 if dataset=="portrait" else 10

## Check datasets health

In [None]:
datasets = ["CAMO", "Portrait", "Locuste", "Ribs", \
            "SKIN/SKIN_COMPAQ", "SKIN/SKIN_ECU", "SKIN/SKIN_HANDGESTURE", "SKIN/SKIN_MCG", \
            "SKIN/SKIN_Pratheepan", "SKIN/SKIN_Schmugge", "SKIN/SKIN_SFA", \
            "SKIN/SKIN_uchile", "SKIN/SKIN_VMD", "SKIN/SKIN_VT-AAST", \
            "Butterfly/FoldDA1_1", "Butterfly/FoldDA1_2", "Butterfly/FoldDA1_3", "Butterfly/FoldDA1_4", \
            "COCO_val2017"]

max_length = len(max(datasets, key=len))

for dataset in datasets:
    path = os.path.join(dataset_path, dataset)
    
    # 1. check "imgs" exists and count files in it
    # if not, print a big warning (imgs does not exists)
    # 2. check   "gt" exists and count files in it
    # if not, print a big warning (gt does not exists)
    # get all directories starting with "segmentator_"
    #   for each, count files in it
    # if empty: print warning (no segmentator -> only oracle is available)
    
    imgs_num = gt_num = 0
    segm_num = {}
    
    imgs_exists = os.path.isdir(os.path.join(path, "imgs"))
    gt_exists   = os.path.isdir(os.path.join(path, "gt"))
    segmentators = glob.glob(os.path.join(path, "segmentator_*"))
    
    if imgs_exists:
        imgs_num = len(glob.glob(os.path.join(path, "imgs", "*")))
    
    if gt_exists:
        gt_num = len(glob.glob(os.path.join(path, "gt", "*")))
    
    print(f"{dataset.ljust(max_length)} | {imgs_num:8d} | {gt_num:8d}", end="", flush=True)
    
    for idx, segmentator in enumerate(segmentators):
        num = len(glob.glob(os.path.join(segmentator, "*.bmp")))
        print(f" - {os.path.basename(segmentator)[12:]} ({num})", end="", flush=True)
    print()


In [None]:
def loadpaths(dataset_path, dataset_name, segmentator_name):
    
    if not os.path.isdir(os.path.join(dataset_path, dataset_name)):
        print("ERROR. provided dataset does not exist!")
        return None
    
    orig_images_folder = os.path.join(dataset_path, dataset_name, "imgs")
    gt_folder          = os.path.join(dataset_path, dataset_name, "gt")
    segmentator_folder = os.path.join(dataset_path, dataset_name, "segmentator_" + segmentator_name)
    
    ## Load input images ##
    test_imgs = glob.glob(os.path.join(orig_images_folder, '*'))
    bn = [os.path.basename(path[:-4]).zfill(6) for path in test_imgs]
    test_imgs = [test_imgs[i] for i in sorted(range(len(bn)), key=lambda k: bn[k])]
    
    ## Load GT masks ##
    gt_masks = glob.glob(os.path.join(gt_folder, '*'))
    bn = [os.path.basename(path[:-4]).zfill(6) for path in gt_masks]
    gt_masks = [gt_masks[i] for i in sorted(range(len(bn)), key=lambda k: bn[k])]
    
    print(os.path.join(segmentator_folder, '*.bmp'))

    ## Load DeepLabV3+ produced binary masks ##
    segmentator_bmasks = glob.glob(os.path.join(segmentator_folder, '*.bmp'))
    bn = [os.path.basename(path[:-4]).zfill(6) for path in segmentator_bmasks]
    segmentator_bmasks = [segmentator_bmasks[i] for i in sorted(range(len(bn)), key=lambda k: bn[k])]

    # ## Load DeepLabV3+ produced 3D masks ##
    segmentator_rmasks = glob.glob(os.path.join(segmentator_folder, '*.png'))
    bn = [os.path.basename(path[:-4]).zfill(6) for path in segmentator_rmasks]
    segmentator_rmasks = [segmentator_rmasks[i] for i in sorted(range(len(bn)), key=lambda k: bn[k])]
        
    return test_imgs, gt_masks, segmentator_bmasks, segmentator_rmasks


## Run SAM

Predict with `SamPredictor.predict`. The model returns
 - masks  (`masks.shape  # (number_of_masks) x H x W) ` )
 - quality predictions for those masks
 - low resolution mask logits that can be passed to the next iteration of prediction.

The `predict()` function accepts three parameters (among many):

 - `point_coords`: an np.ndarray of 2D pixels that will provide SAM the checkpoints/seeds of the object to segment
 - `point_labels`: is the corresponding pixel a pixel belonging to the object (1) or not (0) ?

 - With `multimask_output=True` (the default setting), SAM outputs 3 masks, where `scores` gives the model's own estimation of the quality of these masks. This setting is intended for ambiguous input prompts, and helps the model disambiguate different objects consistent with the prompt. When `False`, it will return a single mask. For ambiguous prompts such as a single point, it is recommended to use `multimask_output=True` even if only a single mask is desired; the best single mask can be chosen by picking the one with the highest score returned in `scores`. This will often result in a better mask.

In [None]:
def perform(predictor, model_type, source_mask, dataset, points_sampling_mode, sampling_step, border_mode):
    global verbose
    
    print(model_type, source_mask, dataset, points_sampling_mode, sampling_step, border_mode) 
    
    min_blob_count = get_min_blob_number_based_on_dataset(dataset)
    
    test_imgs, gt_masks, src_bmasks, src_rmasks = loadpaths(dataset_path, dataset, source_mask)
    
    if source_mask=="oracle":
        assert len(test_imgs) == len(gt_masks),\
                  f"unbalanced datasets! {len(test_imgs)} {len(gt_masks)}"
        src_bmasks = src_rmasks = gt_masks
    else:
        assert len(test_imgs) == len(gt_masks) == len(src_bmasks) == len(src_rmasks),\
                  f"unbalanced datasets! {len(test_imgs)} {len(gt_masks)} {len(src_rmasks)}"
    
    toiterate = zip(test_imgs, gt_masks, src_bmasks, src_rmasks)

    print("len of files:", len(test_imgs)) if verbose else None
    
    possampler = Sampler(verbose, sampling_step, min_blob_count)
    
    results_dir = get_complete_output_path(base_output_path, dataset, source_mask, model_type, create=True)
    
    print("results will be saved at for results_dir", results_dir)
    
    metrics, source_mask_metrics, metrics_fusion = Metrics(dataset), Metrics(dataset), Metrics(dataset)
    
    for idx, paths in tqdm(enumerate(toiterate), total=len(test_imgs)):
        print(f" - img idx {str(idx+1).zfill(6)}/{len(test_imgs)}:") if verbose else None
        
        # Get paths
        img_path, gt_mask_path, src_bmask_path, src_rmask_path = paths
        log(paths) if verbose else None
        
        # Load images from disk using paths
        img           = read_img(img_path)
        gt_mask       = read_bmask(gt_mask_path)
        src_bmask = read_bmask(src_bmask_path)
        src_rmask = read_rmask(src_rmask_path)
        
        if source_mask=="oracle":
            assert img.shape[:2] == gt_mask.shape, f"Error: shape mismatch {img.shape[:2]} {gt_mask.shape}"
        else:
            assert img.shape[:2] == gt_mask.shape == \
                src_bmask.shape == src_rmask.shape[:2], f"Error: shape mismatch {img.shape[:2]} {gt_mask.shape} {src_bmask.shape} {src_rmask.shape[:2]}"
        
        if source_mask=="oracle":
            mask_to_sample = gt_mask
        else:
            mask_to_sample = src_bmask

        # Count the number of distinct labels (it corresponds to the number of blobs)
        mask_of_blobs = get_mask_of_blobs(mask_to_sample)
        
        # Sample the checkpoints (at least one for blob)
        unique_blobs = np.unique(mask_of_blobs)
        num_blobs = unique_blobs.shape[0]
        if num_blobs==1 and 0 in unique_blobs:
            print("unique blob")
            input_point = input_label= np.array([])
            binary_mask = np.zeros_like(mask_to_sample)
            masks=[binary_mask]
            best_score_idx=0
        else:
            
            input_point, input_label = possampler.sample(points_sampling_mode, border_mode, mask_of_blobs, mask_to_sample.astype(np.int64))
            
            print("setting image on predictor", end="... ")  if verbose else None
            predictor.set_image(img)

            print("predicting", end="... ")  if verbose else None
            masks, scores, logits = predictor.predict(
                point_coords=input_point,
                point_labels=input_label,
                multimask_output=False,
                return_logits=True
            )
            torch.cuda.empty_cache()
            
            best_score_idx = np.argmax(scores)
            binary_mask = masks[best_score_idx] > 0
            
            metrics.step(binary_mask, gt_mask)
        
        source_mask_metrics.step(src_bmask, gt_mask)
        
        # instead of saving and then loading images, we can use this script to simulate the storage of .jpg images
        success, encoded_image = cv2.imencode(".jpg",  masks[best_score_idx]*255)
        if success:
            jpeg_data = np.array(encoded_image).tobytes()
        else:
            print("Failed to encode the image as JPEG.")
        decoded_image = cv2.imdecode(encoded_image, cv2.IMREAD_COLOR)[:, :, 0]
        
        ## this is how we saved the produced masks
        # basename = os.path.basename(img_path)
        # out_mask_path = results_dir + "/" + basename[:basename.rfind(".")+1]+"jpg"
        # cv2.imwrite(out_mask_path, masks[best_score_idx]*255)
        
        ## this is how we loaded .jpg images
        # basename = os.path.basename(img_path)
        # in_mask_path = results_dir + "/" + basename[:basename.rfind(".")+1]+"jpg"
        # decoded_image = cv2.imread(in_mask_path)[:, :, 0]
        
        src_rmask = read_bmask(src_rmask_path).astype(np.float64) * 255
        fused_mask = abs(255 - decoded_image.astype('uint8'))
        fused_mask = ((fused_mask + 2*src_rmask)/3).astype(np.uint8) < 128
        metrics_fusion.step(fused_mask, gt_mask)
        
        ## TO VISUALIZE computed mask, show the superposition with the original image
        ## and SAM's predicted score and IoU wrt Ground Truth mask
        # plt.clf()
        # fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(30, 30))
        # axes[0].imshow(img)
        # if input_point.shape[0]>0:
        #     axes[0].scatter([input_point[:, 0]], [input_point[:, 1]], color='red', marker='*', s=250, edgecolor='white', linewidth=1.25) # this is if you want the star
        # axes[1].imshow(gt_mask)
        # axes[2].imshow(binary_mask)
        # # axes[3].imshow(fused_mask)
        # fig.tight_layout()
        # plt.show()
        # break
        
    
    iou_avg, dice_avg, mae_avg, wfm_avg, eme_avg  = metrics.get_results()
    
    print("SAM alone metrics:")
    print(f"average iou      : {iou_avg*100:5.2f}")
    print(f"average dice     : {dice_avg*100:5.2f}")
    print()
    
    iou_avg, dice_avg, mae_avg, wfm_avg, eme_avg  = source_mask_metrics.get_results()
    print(f"segmentator_{source_mask} metrics:")
    print(f"average iou      : {iou_avg*100:5.2f}")
    print(f"average dice     : {dice_avg*100:5.2f}")
    print()
    
    iou_avg, dice_avg, mae_avg, wfm_avg, eme_avg  = metrics_fusion.get_results()
    print("SAM-fusion performance:")
    print(f"average iou      : {iou_avg*100:5.2f}")
    print(f"average dice     : {dice_avg*100:5.2f}")


In [None]:
def perform_all(dataset_name, predictor, model_type, source_mask):
    # perform(predictor, model_type, source_mask, dataset_name, "A", None, None)
    # perform(predictor, model_type, source_mask, dataset_name, "B", None, None)
    # perform(predictor, model_type, source_mask, dataset_name, "C", None, None)
    # perform(predictor, model_type, source_mask, dataset_name, "D", 10, "on")
    # perform(predictor, model_type, source_mask, dataset_name, "D", 30, "on")
    # perform(predictor, model_type, source_mask, dataset_name, "D", 50, "on")
    # perform(predictor, model_type, source_mask, dataset_name, "D", 100, "on")
    # perform(predictor, model_type, source_mask, dataset_name, "D", 10, "off")
    # perform(predictor, model_type, source_mask, dataset_name, "D", 30, "off")
    perform(predictor, model_type, source_mask, dataset_name, "D", 50, "off")
    # perform(predictor, model_type, source_mask, dataset_name, "D", 100, "off")

## USE VIT-L

In [None]:
# device = "cuda" # ["cuda", "cpu"]
# sam_checkpoint = os.path.join("/home/fusaro/segment-anything/", "pretrained_models", "sam_vit_l_0b3195.pth")
# model_type = "vit_l"
# source_mask = "deeplab" # ["oracle", "deeplab", "pvtv2", "sota"] ### it depends, choose as needed!
# verbose=False

# print(f"creating sam {model_type} and moving it to device")
# sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
# sam.to(device=device)
# print("creating predictor")
# predictor = SamPredictor(sam)  


# # CHOOSE ONE (OR MANY)
# datasets = []
# # datasets.append("CAMO")
# # datasets.append("Portrait")
# datasets.append("Locuste")
# # datasets.append("Ribs")
# # datasets.append("SKIN/SKIN_COMPAQ")
# # datasets.append("SKIN/SKIN_ECU")
# # datasets.append("SKIN/SKIN_HANDGESTURE")
# # datasets.append("SKIN/SKIN_MCG")
# # datasets.append("SKIN/SKIN_Pratheepan")
# # datasets.append("SKIN/SKIN_Schmugge")
# # datasets.append("SKIN/SKIN_SFA")
# # datasets.append("SKIN/SKIN_uchile")
# # datasets.append("SKIN/SKIN_VMD")
# # datasets.append("SKIN/SKIN_VT-AAST)
# # datasets.append("Butterfly/FoldDA1_1")
# # datasets.append("Butterfly/FoldDA1_2")
# # datasets.append("Butterfly/FoldDA1_3")
# # datasets.append("Butterfly/FoldDA1_4")
# # datasets.append("COCO_val2017")

# for dataset in datasets:
#     perform_all(dataset, predictor, model_type, source_mask)


## USE VIT-H

In [None]:
device = "cuda" # ["cuda", "cpu"]
verbose = False
sam_checkpoint = os.path.join("/home/fusaro/segment-anything/", "pretrained_models", "sam_vit_h_4b8939.pth")
model_type = "default"
source_mask = "deeplab" # ["oracle", "deeplab", "pvtv2", "sota"] ### it depends, choose as needed!

# print(f"creating sam {model_type} and moving it to device")
# sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
# sam.to(device=device)
# print("creating predictor")
# predictor = SamPredictor(sam)  

# CHOOSE ONE (OR MANY)
datasets = []
# datasets.append("CAMO")
# datasets.append("Portrait")
# datasets.append("Locuste")
# datasets.append("Ribs")
# datasets.append("SKIN/SKIN_COMPAQ")
# datasets.append("SKIN/SKIN_ECU")
# datasets.append("SKIN/SKIN_HANDGESTURE")
# datasets.append("SKIN/SKIN_MCG")
# datasets.append("SKIN/SKIN_Pratheepan")
# datasets.append("SKIN/SKIN_Schmugge")
# datasets.append("SKIN/SKIN_SFA")
# datasets.append("SKIN/SKIN_uchile")
# datasets.append("SKIN/SKIN_VMD")
# datasets.append("SKIN/SKIN_VT-AAST")
# datasets.append("Butterfly/FoldDA1_1")
# datasets.append("Butterfly/FoldDA1_2")
# datasets.append("Butterfly/FoldDA1_3")
datasets.append("Butterfly/FoldDA1_4")
# datasets.append("COCO_val2017")

for dataset in datasets:
    perform_all(dataset, predictor, model_type, source_mask)
