In [None]:
PYTORCH_NO_CUDA_MEMORY_CACHING=1

from pathlib import Path
import matplotlib.pyplot as plt
import cv2
import torch
from torch import cuda
import os
import numpy as np
import random
from PIL import Image
import matplotlib.colors as mcolors
import numpy.ma as ma

np.set_printoptions(precision=15)

# Ensure deterministic behavior (cannot control everything though)
torch.backends.cudnn.deterministic = True
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

!nvidia-smi

In [None]:
import json 
import cv2
import numpy as np
# import plotly.express as px
from matplotlib.path import Path
import math

def merge_dicts(dict1, dict2): # used to merge train & valid sets into one
    for key, value in dict2.items():
        if isinstance(dict2[key], dict):
            merge_dicts(dict1[key], value)
        else:
            dict1[key] = value
    return dict1

input_dir = '/workspace/raid/OM_DeepLearning/XAMI/dog-2/train/'
json_file_path = '/workspace/raid/OM_DeepLearning/XAMI/dog-2/train/_annotations.coco.json'

valid_path = '/workspace/raid/OM_DeepLearning/XAMI/dog-2/valid/'
valid_json_path = '/workspace/raid/OM_DeepLearning/XAMI/dog-2/valid/_annotations.coco.json'

test_path = '/workspace/raid/OM_DeepLearning/XAMI/dog-2/test/'
test_json_path = '/workspace/raid/OM_DeepLearning/XAMI/dog-2/test/_annotations.coco.json'

with open(json_file_path, 'r') as f:
    training_data = json.load(f)

with open(test_json_path, 'r') as f:
    test_data = json.load(f)

with open(valid_json_path, 'r') as f:
    valid_data = json.load(f)

In [None]:
from importlib import reload
import dataset_utils
reload(dataset_utils)
from dataset_utils import *

import predictor_utils
reload(predictor_utils)
from predictor_utils import *

In [None]:
image_ids = {} 

ground_truth_masks, bbox_coords, classes, class_categories = get_coords_and_masks_from_json(input_dir, training_data) # type: ignore

# valid_gt_masks, valid_bboxes = get_coords_and_masks_from_json(input_dir, valid_data)
# test_gt_masks, test_bboxes = get_coords_and_masks_from_json(input_dir, test_data)

image_keys = []
for key in ground_truth_masks.keys():
    file_name_key = "_".join(key.split("_")[:-1])
    if file_name_key not in image_keys:
        image_keys.append(file_name_key)
        # print(file_name_key)

image_paths = [input_dir+img_ for img_ in image_keys]
image_paths[0]

In [None]:
def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

def show_masks(masks, ax, random_color=False):
    for mask in masks:
        if random_color:
            color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
        else:
            color = np.array([30/255, 144/255, 255/255, 0.6])
        h, w = mask.shape[-2:]
        mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
        ax.imshow(mask_image)

def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='orange', facecolor=(0,0,0,0), lw=2))    

In [None]:
import traceback
import cv2
import supervision as sv
import numpy as np
from torchvision.transforms.functional import resize

def any_sam_model_predictor(any_sam_model, AMG, data_set_gt_masks, model_name,  IMAGE_PATH, use_negative=None, mask_on_negative=False, show_plot=False):
    
    image_name = IMAGE_PATH.split("/")[-1]
    predicted_masks = []
    gt_image_masks = np.array([mask for key, mask in data_set_gt_masks.items() if key.startswith(image_name)])
 
    with torch.no_grad():
        image_bgr = cv2.imread(IMAGE_PATH)
        annotated_image = None

        # here also set the negative masked pixels to 0 after pre-processing
        image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
        # image_rgb = 255 - image_rgb

        negative_mask = np.where(image_rgb>0, True, False)
        negative_mask = torch.from_numpy(negative_mask)  
        negative_mask = negative_mask.permute(2, 0, 1)
        negative_mask = resize(negative_mask, [1024, 1024], antialias=True) 
        negative_mask = negative_mask.unsqueeze(0)

        if mask_on_negative:
            mask_generator = AMG(any_sam_model, negative_mask=negative_mask)
        else:
            mask_generator = AMG(any_sam_model)
            
        sam_result = mask_generator.generate(image_rgb)
        output_file = './plots/'+image_name+'_'+model_name+'.png'

        if mask_on_negative:
            sam_result = remove_masks(sam_result=sam_result, mask_on_negative=negative_mask.detach().cpu().numpy(), threshold=50)
            output_file = './plots/'+image_name+'_'+model_name+'_segmented_removed_negative.png'

        # !!! takes the predicted masks, and removes the ones that are covering more than 50% of the image
        predicted_masks = np.array([out_pred['segmentation'] for out_pred in sam_result if np.sum(out_pred['segmentation']) <image_rgb.shape[0]**2/2]) 
        mask_annotator = sv.MaskAnnotator(color_lookup=sv.ColorLookup.INDEX)
        detections = sv.Detections.from_sam(sam_result=sam_result)
        if detections is not None and detections.mask is not None:
            # !!! takes the detection masks and removes the ones that are covering more than 50% of the image
            detections.mask = np.array([detmask for detmask in detections.mask if np.sum(detmask) <image_rgb.shape[0]**2/2])

        annotated_image = mask_annotator.annotate(scene=image_rgb.copy(), detections=detections)
        print((annotated_image * (image_rgb>0).astype(int)).shape)
        annotated_image = annotated_image * (image_rgb>0).astype(float)
        if annotated_image.max() <= 1.0:
            annotated_image *= 255
        
        # Now, convert the type to 'uint8' (unsigned 8-bit integer)
        annotated_image = annotated_image.astype(np.uint8)
        
        # After the conversion, create the PIL image
        image = Image.fromarray(annotated_image)
        # image.save(output_file)

        image_rgb = (image_rgb)*(image_rgb>0).astype(float)
        if image_rgb.max() <= 1.0:
            image_rgb *= 255
        
        # Now, convert the type to 'uint8' (unsigned 8-bit integer)
        image_rgb = image_rgb.astype(np.uint8)
        
        # After the conversion, create the PIL image
        # image_rgb = Image.fromarray(image_rgb)
        # image.save(output_file)
        
        iou_assoc_loss = compute_loss(gt_image_masks, predicted_masks)
        if show_plot:
            sv.plot_images_grid(
                images=[image_rgb, annotated_image],
                grid_size=(1, 2),
                titles=[f'source image\n{image_name.split(".")[0]}', \
                        f'segmented image with {model_name}'])
    return annotated_image, iou_assoc_loss

In [None]:
from sklearn.metrics import jaccard_score
from scipy.special import expit
import numpy as np

def dice_loss_numpy(pred, target, area=None, smooth = 1): 
    pred_flat = pred.flatten()
    target_flat = target.flatten()
    
    intersection = np.sum(pred_flat * target_flat)
    union = np.sum(pred_flat) + np.sum(target_flat)
    
    dice = (2. * intersection + smooth) / (union + smooth)
    dice_loss = 1 - dice
    
    return dice_loss
        
def compute_loss(gt_masks, pred_masks):
    if gt_masks.size == 0:
        return 0 
    if pred_masks.size == 0:
        return 1
    losses = []
    for gt_mask in gt_masks:
        mask_losses = []
        max_iou = 0.0
        mask_loss = 0.5
        for pred_mask in pred_masks:
            iou = np.sum(pred_mask.flatten() * gt_mask.flatten())
            if iou > max_iou:
                mask_loss = dice_loss_numpy(pred_mask, gt_mask)
        losses.append(mask_loss)
    return np.mean(losses) 

**orig SAM checkpoint**

In [None]:
# !wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth 

In [None]:
import os
from segment_anything import sam_model_registry as orig_sam_model_registry, \
                            SamAutomaticMaskGenerator as orig_SamAutomaticMaskGenerator, \
                            SamPredictor as orig_SamPredictor

HOME = os.getcwd()

origSAM_CHECKPOINT_PATH = os.path.join(HOME, "weights", "sam_vit_h_4b8939.pth")
device = "cuda:7" if torch.cuda.is_available() else "cpu"

sam = orig_sam_model_registry["vit_h"](checkpoint=origSAM_CHECKPOINT_PATH).to(device=device)
sam.eval();

In [None]:
imgs = ['S0802200201_M.png', 'S0673730101_U.png', 'S0673350201_U.png', 'S0304201401_U.png', 'S0100240801_U.png', 'S0302884001_M.png']
OM_dir = '/workspace/raid/OM_DeepLearning/XMM_OM_dataset/scaled_raw_512/'

In [None]:
for immg in imgs:
    annotated_image_orig_sam, annotated_image_orig_sam_loss = any_sam_model_predictor(sam, orig_SamAutomaticMaskGenerator, ground_truth_masks, 
                                                                                  'orig_SAM', OM_dir+immg, mask_on_negative=None, show_plot=True)

In [None]:
gaussian_solved = '/workspace/raid/OM_DeepLearning/XAMI/gaussian_distrib/'
for immg in imgs:
    annotated_image_orig_sam, annotated_image_orig_sam_loss = any_sam_model_predictor(sam, orig_SamAutomaticMaskGenerator, ground_truth_masks, 
                                                                                  'orig_SAM', gaussian_solved+immg, mask_on_negative=None, show_plot=True)

**original MobileSAM checkpoint**

In [None]:
import sys
sys.path.append('/workspace/raid/OM_DeepLearning/MobileSAM-master/')
import mobile_sam
from mobile_sam import sam_model_registry as orig_mobile_sam_registry, \
SamAutomaticMaskGenerator as orig_mobile_SamAutomaticMaskGenerator, \
SamPredictor as orig_mobile_SamPredictor

orig_mobile_sam_checkpoint = "/workspace/raid/OM_DeepLearning/MobileSAM-master/weights/mobile_sam.pt"
print("device:", device)

mobile_sam_model_orig = orig_mobile_sam_registry["vit_t" ](checkpoint=orig_mobile_sam_checkpoint)
mobile_sam_model_orig.to(device);
mobile_sam_model_orig.eval();
annotated_image_ft_mobile_sam, annotated_image_ft_mobile_sam_loss = any_sam_model_predictor(mobile_sam_model_orig, orig_mobile_SamAutomaticMaskGenerator, \
                                                       ground_truth_masks, 'ft_MobileSAM', image_paths[1], mask_on_negative=False, show_plot=True)

**fine-tuned checkpoint**

In [None]:
import sys
sys.path.append('/workspace/raid/OM_DeepLearnin/XAMI/mobile_sam/')
import ft_mobile_sam
from ft_mobile_sam import sam_model_registry as ft_mobile_sam_registry, \
SamAutomaticMaskGenerator as ft_SamAutomaticMaskGenerator, \
SamPredictor as ft_SamPredictor

ft_mobile_sam_checkpoint = "/workspace/raid/OM_DeepLearning/XAMI/mobile_sam_model_checkpoint.pth"
print("device:", device)

ft_mobile_sam_model = ft_mobile_sam_registry["vit_t" ](checkpoint=ft_mobile_sam_checkpoint)
ft_mobile_sam_model.to(device);
ft_mobile_sam_model.eval();

In [None]:
img_paths = ['raid/OM_DeepLearning/XAMI/dog-2/train']

In [None]:
annotated_image_ft_mobile_sam, annotated_image_ft_mobile_sam_loss = amg_predict(ft_mobile_sam_model, ft_SamAutomaticMaskGenerator, \
                                                       ground_truth_masks, 'MobileSAM', image_paths[1], mask_on_negative=False, show_plot=True)

---------------------

**fine-tune MobileSAM AutoMaskGenerator**

In [None]:
def remove_masks(sam_result, mask_on_negative, threshold, remove_big_masks=False, big_masks_threshold=None, img_shape=None):
    '''
    Given a segmentation result, this function removes the masks 
    if the intersection with the negative pixels gives a number is > than a threshold
    '''
    big_masks_threshold = img_shape[0]**2/5 if big_masks_threshold is None else big_masks_threshold
    bad_indices = np.array([],  dtype=int) 
    print(sam_result[0]['segmentation'].shape, mask_on_negative.shape)
    for segm_index in range(len(sam_result)):
        count = np.sum((sam_result[segm_index]['segmentation'] == 1) & (mask_on_negative == 1))            
        # remove masks on negative pixels given threshold
        if count > threshold:
            bad_indices = np.append(bad_indices, segm_index)
        
        # remove very big (>70) masks
        if remove_big_masks and img_shape is not None and np.sum(sam_result[segm_index]['segmentation']) > big_masks_threshold:
            print(f"Removing mask {segm_index} with area {np.sum(sam_result[segm_index]['segmentation'])}")
            bad_indices = np.append(bad_indices, segm_index)   
    sam_result = np.delete(sam_result, bad_indices)
    return sam_result

def amg_predict(any_sam_model, AMG, data_set_gt_masks, model_name,  IMAGE_PATH, use_negative=None, mask_on_negative=False, show_plot=False):
    
    image_name = IMAGE_PATH.split("/")[-1]
    predicted_masks = []
    gt_image_masks = np.array([mask for key, mask in data_set_gt_masks.items() if key.startswith(image_name)])
 
    # with torch.no_grad():
    if True:
        image_bgr = cv2.imread(IMAGE_PATH)
        annotated_image = None

        # here also set the negative masked pixels to 0 after pre-processing
        image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
        negative_mask = np.where(image_rgb>0, True, False)
        negative_mask = torch.from_numpy(negative_mask)  
        negative_mask = negative_mask.permute(2, 0, 1)
        negative_mask = resize(negative_mask, [1024, 1024], antialias=True) 
        negative_mask = negative_mask.unsqueeze(0)

        if mask_on_negative:
            mask_generator = AMG(any_sam_model, negative_mask=negative_mask)
        else:
            mask_generator = AMG(any_sam_model)
            
        sam_result = mask_generator.generate(image_rgb)
        output_file = './plots/'+image_name+'_'+model_name+'.png'

        if mask_on_negative:
            print(image_rgb.shape)
            img_negative_mask = np.where(image_rgb>0, 1, 0) 
            img_negative_mask = np.mean(img_negative_mask, axis=2) # to make it 2D
            sam_result = remove_masks(sam_result=sam_result, mask_on_negative=img_negative_mask, threshold=50, img_shape=image_rgb.shape)
            output_file = './plots/'+image_name+'_'+model_name+'_segmented_removed_negative.png'

        # !!! takes the predicted masks, and removes the ones that are covering more than 50% of the image
        predicted_masks = np.array([out_pred['segmentation'] for out_pred in sam_result if np.sum(out_pred['segmentation']) <image_rgb.shape[0]**2/2]) 
        mask_annotator = sv.MaskAnnotator(color_lookup=sv.ColorLookup.INDEX)
        detections = sv.Detections.from_sam(sam_result=sam_result)
        if detections is not None and detections.mask is not None:
            # !!! takes the detection masks and removes the ones that are covering more than 50% of the image (for detections)
            detections.mask = np.array([detmask for detmask in detections.mask if np.sum(detmask) <image_rgb.shape[0]**2/2])

        annotated_image = mask_annotator.annotate(scene=image_rgb.copy(), detections=detections)
        image = Image.fromarray(annotated_image)
        # image.save(output_file)

        iou_assoc_loss = compute_loss(gt_image_masks, predicted_masks)
        
        if show_plot:
            sv.plot_images_grid(
                images=[image_bgr, annotated_image],
                grid_size=(1, 2),
                titles=[f'source image\n{image_name.split(".")[0]}', \
                        f'segmented image with {model_name}'])
    return annotated_image, iou_assoc_loss

In [None]:
annotated_image_ft_mobile_sam, annotated_image_ft_mobile_sam_loss = any_sam_model_predictor(ft_mobile_sam_model, ft_SamAutomaticMaskGenerator, \
                                                       ground_truth_masks, 'MobileSAM', image_paths[0], mask_on_negative=True, show_plot=True)