In [None]:
import json
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Polygon
from pycocotools import mask as maskUtils
import cv2
import supervision as sv
from PIL import Image
import os
# !pip install astropy
import astropy
from astropy.io import fits
from scipy.interpolate import interp1d
from astropy.visualization import ZScaleInterval, ImageNormalize
import torch
import random
np.set_printoptions(precision=15)

# Ensure deterministic behavior
torch.backends.cudnn.deterministic = True
random.seed(hash("setting random seeds") % 2**32 - 1)
np.random.seed(hash("improves reproducibility") % 2**32 - 1)
torch.manual_seed(hash("by removing stochasticity") % 2**32 - 1)
torch.cuda.manual_seed_all(hash("so runs are repeatable") % 2**32 - 1)
!nvidia-smi

In [None]:
#import my files
from importlib import reload
import astronomy_utils, predictor_utils
reload(astronomy_utils)
reload(predictor_utils)
from predictor_utils import *
from astronomy_utils import *

In [None]:
OM_dir = '/workspace/raid/OM_DeepLearning/XMM_OM_code/scaled_raw/'
OM_dir

# Import Roboflow annotations

In [None]:
if 0==1:

    def display_masks(image_path, masks):
        image = Image.open(image_path)
        plt.imshow(image)
        ax = plt.gca()
    
        for mask in masks:
            if isinstance(mask[0], list):  # If the mask is a polygon
                polygon_points = np.array(mask[0]).reshape(-1, 2)
                polygon = Polygon(polygon_points, edgecolor='g', facecolor='none')
                ax.add_patch(polygon)
            else:  # If the mask is RLE
                binary_mask = maskUtils.decode(mask)
                ax.imshow(binary_mask, alpha=0.5, cmap='gray')
        plt.show()
    
    # Load the JSON file
    dir_train_path = '/workspace/raid/OM_DeepLearning/XMM_OM_code/OM_sky_images-6/train/'
    json_file_path = dir_train_path+'_annotations.coco.json'
    
    with open(json_file_path, 'r') as f:
        data = json.load(f) # dict_keys(['info', 'licenses', 'categories', 'images', 'annotations'])
    
    # Iterate through each image and its annotations
    for image_info in data['images']:
        image_id = image_info['id']
        image_path = dir_train_path + image_info['file_name']
        
        # Find annotations for the current image
        annotations = [anno for anno in data['annotations'] if anno['image_id'] == image_id]
        # annotations: dict_keys(['id', 'image_id', 'category_id', 'bbox', 'area', 'segmentation', 'iscrowd'])
        
        # Extract and display masks for the image
        masks = [anno['segmentation'] for anno in annotations]
        display_masks(image_path, masks)

# Mobile SAM *
>
> MobileSAM is approximately 5 times smaller and 7 times faster than the current FastSAM.

In [None]:
import torch
import sys
import os
import numpy as np
import PIL
from PIL import Image
sys.path.append('/workspace/raid/OM_DeepLearning/MobileSAM-master/')

import cv2
import mobile_sam
from mobile_sam import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

mobile_sam_checkpoint = "/workspace/raid/OM_DeepLearning/MobileSAM-master/weights/mobile_sam.pt"

device = "cuda:7" if torch.cuda.is_available() else "cpu"

mobile_sam_model = sam_model_registry["vit_t"](checkpoint=mobile_sam_checkpoint)
mobile_sam_model.to(device=device)
mobile_sam_model.eval();
device

In [None]:
import supervision as sv
import math

def MobileSAM_predict(image_path, show_annots = True, mask_on_negative = None):
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    with torch.no_grad():
        # using the pixel mean and std specific to each image instead of the standard one.
        image_T = np.transpose(image, (2, 1, 0))
        pixel_mean = torch.as_tensor([np.mean(image_T[0]), np.mean(image_T[1]),np.mean(image_T[2])], dtype=torch.float, device=device)
        pixel_std = torch.as_tensor([np.std(image_T[0]), np.std(image_T[1]),np.std(image_T[2])], dtype=torch.float, device=device)
        
        mobile_sam_model.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
        mobile_sam_model.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
                
        predictor = SamPredictor(mobile_sam_model)
        predictor.set_image(image)
        
        masks_np, iou_predictions_np, low_res_masks_np = predictor.predict()
        if show_annots:
            image_bgr = cv2.imread(image_path)
            
            mask_generator = SamAutomaticMaskGenerator(mobile_sam_model, points_per_side=None, point_grids=img_points)
            mobile_sam_result = mask_generator.generate(image)
            mask_annotator = sv.MaskAnnotator(color_lookup=sv.ColorLookup.INDEX)
            
            if mask_on_negative is not None:
                mobile_sam_result = remove_masks(sam_result=mobile_sam_result, mask_on_negative=mask_on_negative, threshold=50, remove_big_masks=True, img_shape = image.shape)
    
            detections = sv.Detections.from_sam(mobile_sam_result)
        
            annotated_image = mask_annotator.annotate(scene=image_bgr.copy(), detections=detections)
            image = Image.fromarray(annotated_image)
            output_path = "/workspace/raid/OM_DeepLearning/XMM_OM_dataset/mobile_sam/"+image_path.split('/')[-1].replace(".png", "_mobile_sam_nonnegative.png")
            image.save(output_path)
            
        '''
        predictor.predict() returns:
              (np.ndarray): The output masks in CxHxW format, where C is the
                number of masks, and (H, W) is the original image size.
              (np.ndarray): An array of length C containing the model's
                predictions for the quality of each mask.
              (np.ndarray): An array of shape CxHxW, where C is the number
                of masks and H=W=256. These low resolution logits can be passed to
                a subsequent iteration as mask input.
        '''
        
        return masks_np, iou_predictions_np, low_res_masks_np

In [None]:
# with open('extracted_sources_bboxes_points.json', 'r') as f:
#     extracted_bboxes_points = json.load(f)
extracted_bboxes_points['points']

In [None]:
import json
import time

image_path = '/workspace/raid/OM_DeepLearning/XMM_OM_code/scaled_raw/clahe_S0720251301_L.png'

with open('extracted_sources_bboxes_points.json', 'r') as f:
    extracted_bboxes_points = json.load(f)

img_points = np.array([point for filename, point in extracted_bboxes_points['points'].items() 
                       if filename in image_path.split("/")[-1]])
img_points = img_points/255.0

mask_on_negative = mask_and_plot_image(image_path.replace(".png", ".fits").replace('clahe_', ''))

start_time = time.time()
output_mobile_sam = MobileSAM_predict(image_path, mask_on_negative=mask_on_negative)
end_time = time.time()

print(f"Mobile SAM predict time/img: {end_time-start_time} s")
image = cv2.imread('/workspace/raid/OM_DeepLearning/XMM_OM_dataset/mobile_sam/'+image_path.split('/')[-1].replace(".png", "_mobile_sam_nonnegative.png"))
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
plt.imshow(image)
plt.title('Mobile SAM pred. on image input')
plt.imsave('mobile_sam_entire_image.png', image)
plt.show()
plt.close()

In [None]:
np.mean(image), np.std(image)

In [None]:
plt.imshow(cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB))
plt.scatter(img_points[0][:, 0]*255, img_points[0][:, 1]*255, s=10, c='red')
# plt.gca().invert_yaxis() 
plt.show()
plt.close()

In [None]:
stophere

# Install Segment Anything Model (SAM) and other dependencies

In [None]:
# %cd {HOME}

# import sys
# !{sys.executable} -m pip install 'git+https://github.com/facebookresearch/segment-anything.git'

In [None]:
# !pip install -q jupyter_bbox_widget roboflow dataclasses-json supervision

### Download SAM weights

In [None]:
# %cd {HOME}
# !mkdir {HOME}/weights
# %cd {HOME}/weights

# !wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

# 🚀 SAM inference

## Load Model

In [None]:
import os
HOME = os.getcwd()

CHECKPOINT_PATH = os.path.join(HOME, "weights", "sam_vit_h_4b8939.pth")
print(CHECKPOINT_PATH, "; exist:", os.path.isfile(CHECKPOINT_PATH))

In [None]:
import torch
from torch import cuda
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

available_gpus = [torch.cuda.device(i) for i in range(torch.cuda.device_count())] # find all available GPUs in the cluster
DEVICE = torch.device(f'cuda:7' if torch.cuda.is_available() else 'cpu') # take one available
MODEL_TYPE = "vit_h"

sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH).to(device=DEVICE)
sam.eval();

# 🚀 The predictor function (can remove masks on negative pixels)

In [None]:
from importlib import reload
import astronomy_utils, predictor_utils
reload(astronomy_utils)
reload(predictor_utils)
from predictor_utils import *
from astronomy_utils import *

In [None]:
with open('extracted_sources_bboxes_points.json', 'r') as f:
    extracted_bboxes_points = json.load(f)

In [None]:
image_path = "/workspace/raid/OM_DeepLearning/XMM_OM_code/scaled_raw/clahe_S0720251301_L.png"
mask_on_negative = mask_and_plot_image("/workspace/raid/OM_DeepLearning/XMM_OM_dataset/scaled_raw/S0720251301_L.fits")

img_points = np.array([point for filename, point in extracted_bboxes_points['points'].items() 
                       if filename in image_path.split("/")[-1]])
img_points = img_points/255

start_time = time.time()
_, _, annotated_image1 = SAM_predictor(SamAutomaticMaskGenerator, sam, image_path, mask_on_negative=mask_on_negative, img_grid_points=None)
print(f"original SAM predict time/img: {time.time()-start_time} s")

start_time2 = time.time()
_, _, annotated_image2 = SAM_predictor(SamAutomaticMaskGenerator, sam, image_path, mask_on_negative=mask_on_negative, img_grid_points=img_points)
print(f"original SAM predict time/img: {time.time()-start_time2} s")

In [None]:
# fig, axs= plt.subplots(1, 3, figsize=(10, 10)) 

# axs[0].imshow(annotated_image1)
# axs[0].set_title(f'Original SAM inference \ntime: {round(end_time-start_time, 3)}s, {1024} grid points', \
#                  fontfamily='monospace', fontsize=10)

# axs[1].imshow(annotated_image2)
# axs[1].set_title(f'Original SAM inference\ntime: {round(end_time2-start_time2, 3)}s, {img_points.shape[1]} grid points', \
#                 fontfamily='monospace', fontsize=10)

# axs[2].imshow(cv2.cvtColor(cv2.imread(image_path2), cv2.COLOR_BGR2RGB))
# axs[2].scatter(img_points[0][:, 0]*255, img_points[0][:, 1]*255, s=10, c='red')
# axs[2].set_title(f'Extracted sources grid points', \
#                  fontfamily='monospace', fontsize=10)
# axs[2].set_aspect('equal', 'box')  
# plt.tight_layout()
# plt.savefig('plots/sam_grid_points_comparison.png', dpi=1000)
# plt.show()
# plt.close()

In [None]:
from importlib import reload
import astronomy_utils, predictor_utils
reload(astronomy_utils)
reload(predictor_utils)
from predictor_utils import *
from astronomy_utils import *

generate_all_predictions = True

inference_times = []
if generate_all_predictions:
    dir_files = [f for f in os.listdir(OM_dir) if os.path.isfile(os.path.join(OM_dir, f))]
    for file_ in dir_files:
        try:
            if '.png' in file_ and 'clahe' in file_:
                
                mask_on_negative = mask_and_plot_image(OM_dir+file_.replace('.png', '.fits').replace('clahe_', ''))
                img_points = np.array([point for filename, point in extracted_bboxes_points['points'].items() 
                                       if filename in file_.split("/")[-1]])
                img_points = img_points/255.0
                
                start_time = time.time()
                _, _, annotated_image = SAM_predictor(SamAutomaticMaskGenerator, sam, OM_dir+file_, mask_on_negative=mask_on_negative, img_grid_points=None)
                end_time = time.time()
                
                print(f"original SAM predict time/img: {end_time-start_time} s")
                inference_times.append(end_time-start_time)
        except Exception as e:
            print(e)
            pass
    
    # with open('cell_SAM_predict_with_threshold.txt', 'w') as f:
    #     f.write(str(cap))

In [None]:
np.mean(inference_times)

In [None]:
from importlib import reload
import astronomy_utils, predictor_utils
reload(astronomy_utils)
reload(predictor_utils)
from predictor_utils import *
from astronomy_utils import *

generate_all_predictions = True

inference_times = []
if generate_all_predictions:
    dir_files = [f for f in os.listdir(OM_dir) if os.path.isfile(os.path.join(OM_dir, f))]
    for file_ in dir_files:
        try:
            if '.png' in file_ and 'clahe' in file_:
                
                mask_on_negative = mask_and_plot_image(OM_dir+file_.replace('.png', '.fits').replace('clahe_', ''))
                img_points = np.array([point for filename, point in extracted_bboxes_points['points'].items() 
                                       if filename in file_.split("/")[-1]])
                img_points = img_points/255.0
                
                start_time = time.time()

                if len(img_points)==0:
                    img_points = None
                    
                _, _, annotated_image = SAM_predictor(SamAutomaticMaskGenerator, sam, OM_dir+file_, mask_on_negative=mask_on_negative, 
                                                           img_grid_points=img_points)
                end_time = time.time()
                
                print(f"original SAM predict time/img: {end_time-start_time} s")
                inference_times.append(end_time-start_time)
        except Exception as e:
            print(e)
            pass

In [None]:
np.mean(inference_times)

In [None]:
stophere2

In [None]:
# # %%capture cap
# # print(f"Segmented images with threshold masks-on-negative.")
# generate_all_predictions = True

# if generate_all_predictions:
        
#     dir_files = [f for f in os.listdir(OM_dir) if os.path.isfile(os.path.join(OM_dir, f))]
#     # print(dir_files)
#     for file_ in dir_files:
#         try:
#             if '.png' in file_ and 'clahe' in file_:
#                 mask_on_negative = mask_and_plot_image(OM_dir+file_.replace('.png', '.fits').replace('clahe_', ''))
#                 SAM_predictor(OM_dir+file_, mask_on_negative)
#         except Exception as e:
#             print(e)
#             pass
    
#     # with open('cell_SAM_predict_with_threshold.txt', 'w') as f:
#     #     f.write(str(cap))


## 🚀 Generate annotation json file (COCO format)

In [None]:
# download Roboflow annotations, v5

# !pip install roboflow

# from roboflow import Roboflow
# rf = Roboflow(api_key="GGtO5x2eJ77Wa0rLpQSt")
# project = rf.workspace("orij").project("om_sky_images")
# dataset = project.version(5).download("coco-segmentation")

In [None]:
import os
import json
import numpy as np
from datetime import datetime
from PIL import Image

def show_box(box, ax):
    x0, y0 = box[0], box[1]
    print(box[0], box[1], box[2],  box[3]) 
    w, h = box[2], box[3] 
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))

def numpy_to_list(obj):
    if isinstance(obj, np.ndarray):
        return obj.tolist()
    elif isinstance(obj, dict):
        return {k: numpy_to_list(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [numpy_to_list(item) for item in obj]
    else:
        return obj

# Initialize a dictionary to store all images and annotations
categories = [{"id":0,"name":"Artifacts-Sky-objects"},\
            {"id":1,"name":"other"},\
            {"id":2,"name":"ray"},\
            {"id":3,"name":"read-out-streak"},\
            {"id":4,"name":"smoke-ring"},\
            {"id":5,"name":"star"},\
            {"id":6,"name":"star-ghost"},\
            {"id":7,"name":"star-loop"},\
            {"id":8,"name":"star-ring"}]

# coco_style_annotations = {'categories': categories, 'images': [], 'annotations': []}
coco_style_annotations = {'annotations': []}


def get_SAM_annotations(IMAGE_PATH, mask_on_negative = None, output_mode="binary_mask"):
    """
    This function calls SAM (Segment Anything) and gets annotations for a given image.
    Args:
        IMAGE_PATH (str): The path to the image file.
        remove_masks_on_negative (bool, optional): If True, masks on negative detections are removed.

    Returns:
        tuple: A tuple containing the SAM results, detections, and the annotated image.
    """
    image_bgr = cv2.imread(IMAGE_PATH)
    annotated_image, detections, sam_result = None, None, None

    # try:
    if True:
        image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
        mask_generator = SamAutomaticMaskGenerator(sam, output_mode=output_mode)
        
        sam_result = mask_generator.generate(image_rgb)
        if mask_on_negative is not None:
            sam_result = remove_masks(sam_result=sam_result,mask_on_negative=mask_on_negative, threshold=50)
            
        mask_annotator = sv.MaskAnnotator(color_lookup=sv.ColorLookup.INDEX)
        detections = sv.Detections.from_sam(sam_result=sam_result)
        annotated_image = mask_annotator.annotate(scene=image_rgb.copy(), detections=detections)
        # plt.imshow(annotated_image)
        # plt.show()
    # except Exception as e:
    #     print("Exception:\n", IMAGE_PATH, e)
    #     pass
        
    return sam_result, detections, annotated_image


# Run SAM for .png files in the directory and create the annotation json file in ~COCO format

def generate_json_file(input_dir, coco_style_annotations):
    for file_ in os.listdir(input_dir):
    # if True:
        # file_ = 'S0720251301_L.png'
        if "png" in file_:

            print(file_)
            mask_on_negative = mask_and_plot_image(input_dir+file_.replace('.png', '.fits'))
            sam_result_i, detections_i, annotated_image_i = get_SAM_annotations(input_dir+file_, mask_on_negative.astype(int))
            
            sam_result_i = numpy_to_list(sam_result_i)
            img = cv2.imread(os.path.join(input_dir, file_))
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    
            height, width, _ = img.shape
            # print(k)
            # plt.figure(figsize=(10,10))
            # print(type(annotated_image_i))
            # plt.imshow(annotated_image_i)
            # # show_box(annotation['bbox'], plt.gca())
            # # plt.axis('off')
            # plt.show()
            # coco_style_annotations['images'].append({
            #             'id': file_.split('.')[0],
            #             'license': 1,
            #             'file_name': input_dir+file_,
            #             'height':height, 
            #             'width': width,  
            #             'date_captured': datetime.now().isoformat(), 
            #         })
            k=0
            for annotation in sam_result_i:
                # xyhw = annotation['bbox']
                # if (xyhw[2] >2 or xyhw[3] >2) and (xyhw[2]*1.0/height < 0.7 and xyhw[3]*1.0/width < 0.7):
                if True:
                    coco_style_annotations['annotations'].append({
                            'id': f'{file_.split(".")[0]}_mask{k}',
                            'image_id': file_.split('.')[0], 
                            'category_id': 0,  
                            'segmentation': annotation['segmentation'],
                            'area': annotation['area'],
                            'bbox': annotation['bbox'],
                            'iscrowd': 0,
                        })
                    k+=1
    
    with open('SAM_annotations_coco_style_v2.json', 'w') as f:
    # with open('SAM_annotations_coco_style_img1.json', 'w') as f:
        json.dump(coco_style_annotations, f)


input_dir = '/workspace/raid/OM_DeepLearning/XMM_OM_dataset/scaled_raw/'
generate_json_file(input_dir, coco_style_annotations)

## 🚀 Generate annotation json file (VOC format)

In [None]:
import glob
from roboflow import Roboflow

# Initialize Roboflow client
rf = Roboflow(api_key="GGtO5x2eJ77Wa0rLpQSt")
upload_project = rf.workspace("orij").project("xmm_om_images_v4-contrast-512")

### Create and export annotations in VOC format to Roboflow

In [None]:
#import my files
from importlib import reload
import astronomy_utils, predictor_utils, voc_annotate_and_Roboflow_export
reload(astronomy_utils)
reload(predictor_utils)
reload(voc_annotate_and_Roboflow_export)

from predictor_utils import *
from astronomy_utils import *
from voc_annotate_and_Roboflow_export import * # moved the files there

input_dir = '/workspace/raid/OM_DeepLearning/XMM_OM_dataset/zscaled_512_rescaled_SAM_stats/'

k=0
if 1==1:
    input_fits_dir = '/workspace/raid/OM_DeepLearning/XMM_OM_dataset/scaled_raw_512/'
    for file_ in os.listdir(input_dir):
        # if '.png' in file_ and k<200:
        if True:
            mask_on_negative = mask_and_plot_image(input_fits_dir+file_.replace('.png', '.fits'), plot_=False)
    
            sam_result_i, detections_i, annotated_image_i = SAM_predictor(SamAutomaticMaskGenerator, sam, input_dir+file_, 
                                                                          mask_on_negative=mask_on_negative, img_grid_points=None)
            if sam_result_i is not None and detections_i is not None and annotated_image_i is not None:
                objects = []
                for annotation in sam_result_i: # a mask over an image is a binary array with shape (img_h, img_w)
                    polygon = binary_image_to_polygon(annotation['segmentation'])
                    # plot_polygon(polygon[0], annotated_image_i) # to see the masks polygons
                    objects.append({
                        'name': 'star',
                        'bbox': annotation['bbox'],
                        'segmentations': polygon[0]
                    })
                    
                create_annotation(filename=file_, width=512, height=512, depth=3, objects=objects) # generating xml file for VOC format
                image_path = input_dir+file_
                annotation_filename = "/workspace/raid/OM_DeepLearning/XMM_OM_code/"+file_.replace(".png", ".xml")
                
                upload_project.upload(image_path, annotation_filename, overwrite=False)
                os.remove(annotation_filename)

In [None]:
os.listdir(input_dir)[:1]

#### Export images and annotations in VOC format to Roboflow (2)

In [3]:
# !pip install roboflow

# from roboflow import Roboflow
# rf = Roboflow(api_key="GGtO5x2eJ77Wa0rLpQSt")
# project = rf.workspace("orij").project("xmm_om_images-contrast-512-v5")
# dataset = project.version(1).download("voc")

In [15]:
import requests
import os

import glob
from roboflow import Roboflow

rf = Roboflow(api_key="GGtO5x2eJ77Wa0rLpQSt")
upload_project = rf.workspace("orij").project("xmm_om_images-contrast-512-v5")

dataset_images_folder = '/workspace/raid/OM_DeepLearning/XMM_OM_dataset/zscaled_512_rescaled_SAM_stats/'
annotations_voc_dir = '/workspace/raid/OM_DeepLearning/XMM_OM_code_git/xmm_om_images-contrast-512-v5-1/train/'
annotations_ = []

for annot in os.listdir(annotations_voc_dir):
    if annot.endswith('.xml'):
        annotations_.append(annot)

for image_name in os.listdir(dataset_images_folder):
    image_path = os.path.join(dataset_images_folder, image_name)
    if os.path.isfile(image_path):
        annotations_voc_filename = [annotation for annotation in annotations_ if annotation.startswith(image_path.split('/')[-1].replace('.', '_'))]
        upload_project.upload(image_path, annotations_voc_dir+annotations_voc_filename[0], overwrite=True)

print("Image upload complete.")

loading Roboflow workspace...
loading Roboflow project...
Image upload complete.


## AutomaticMask Generation

To run automatic mask generation, provide a SAM model to the `SamAutomaticMaskGenerator` class.

In [None]:
# import os

# IMAGE_NAME = "/workspace/raid/OM_DeepLearning/XMM_OM_code/gaussian_.png"
# IMAGE_PATH = os.path.join(OM_dir,IMAGE_NAME)

### Generate masks with SAM

### !!!The predictor function

In [None]:
# # usually the pip install "opencv-python-headless<4.3" solves this problem:
# # AttributeError: partially initialized module 'cv2' has no attribute '_registerMatType' (most likely due to a circular import)
# # !pip install "opencv-python-headless<4.3"
# # !pip install jupyter-bbox-widget
# import cv2
# import supervision as sv
# from PIL import Image
# import os
# # !pip install astropy
# import astropy
# import numpy as np
# import matplotlib.pyplot as plt
# from astropy.io import fits
# from scipy.interpolate import interp1d
# from astropy.visualization import ZScaleInterval, ImageNormalize

# def SAM_predictor(IMAGE_PATH, remove_masks_on_negative = False):
    
#     image_bgr = cv2.imread(IMAGE_PATH)
#     # print(IMAGE_PATH)
#     annotated_image = None
#     try:
#         image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
#         # image_rgb = image_bgr
    
#         sam_result = mask_generator.generate(image_rgb)
#         output_file = IMAGE_PATH.replace("scaled_raw", "segmented_SAM").replace(".png", "_segmented.png")
        
#         if remove_masks_on_negative:
#             sam_to_keep = []
#             for segm_index in range(len(sam_result)):
#                 if (np.any((sam_result[segm_index]['segmentation']==1) & (mask_on_negative==1)))==0:
#                     sam_to_keep.append(sam_result[segm_index])
#             sam_result = sam_to_keep
#             output_file = IMAGE_PATH.replace("scaled_raw", "segmented_SAM").replace(".png", "_segmented_removed_negative.png")
            
#         mask_annotator = sv.MaskAnnotator(color_lookup=sv.ColorLookup.INDEX)
#         detections = sv.Detections.from_sam(sam_result=sam_result)
#         annotated_image = mask_annotator.annotate(scene=image_bgr.copy(), detections=detections)
#         image = Image.fromarray(annotated_image)
#         image.save(output_file)
    
#         sv.plot_images_grid(
#             images=[image_bgr, annotated_image],
#             grid_size=(1, 2),
#             titles=['source image', 'segmented image']
#         )
#     except Exception as e:
#         print("Exception:\n", e)
#         pass
        
#     return image_bgr, annotated_image

In [None]:
# SAM_predictor("/workspace/raid/OM_DeepLearning/colored_atrous_wavelet_decomposition.png")
# SAM_predictor(IMAGE_NAME)

# for file_ in os.listdir(OM_dir+"scaled_raw/"):
#     if "png" in file_:
#         SAM_predictor(OM_dir+"scaled_raw/"+file_)

In [None]:

# def save_4(image1, image2, image3, image4):
#     fig, axs = plt.subplots(2, 2, figsize=(8, 8))
    
#     # Plot each image in the corresponding subplot
#     axs[0, 0].imshow(image1, cmap="gray") 
#     axs[0, 0].axis('off')
#     axs[0, 0].set_title('source image w/o distribution')
    
#     axs[0, 1].imshow(image2, cmap="gray")
#     axs[0, 1].axis('off')
#     axs[0, 1].set_title('modified image w/o distribution')
    
#     axs[1, 0].imshow(image3, cmap="gray")
#     axs[1, 0].axis('off')
#     axs[1, 0].set_title('source image w/ distribution')
    
#     axs[1, 1].imshow(image4, cmap="gray")
#     axs[1, 1].axis('off')
#     axs[1, 1].set_title('modified image w/ distribution')
    
#     plt.tight_layout()
    
#     # Save the figure
#     plt.savefig('2x2_images_grid.png', bbox_inches='tight', pad_inches=0)
    
#     # plt.show()


In [None]:
# import matplotlib.pyplot as plt

# image_bgr, annotated_image = SAM_predictor("/workspace/raid/OM_DeepLearning/XMM_OM_dataset/scaled_raw/S0112231601_V.png")
# image_bgr_filled, annotated_image_filled = SAM_predictor("/workspace/raid/OM_DeepLearning/XMM_OM_code/S0112231601_V_gaussian.png")

# sv.plot_images_grid(
#     images=[image_bgr, annotated_image, image_bgr_filled, annotated_image_filled],
#     grid_size=(2, 2),
#     titles=['source image', 'segmented image', 'source image w/ distribution', 'segmented image w/ distribution']
# )

# save_4(image_bgr, annotated_image, image_bgr_filled, annotated_image_filled)

In [None]:
# # image_bgr, annotated_image = SAM_predictor("/workspace/raid/OM_DeepLearning/XMM_OM_dataset/scaled_raw/S0655570201_L.png")
# # image_bgr_filled, annotated_image_filled = SAM_predictor("/workspace/raid/OM_DeepLearning/XMM_OM_code/S0655570201_L_gaussian.png")

# sv.plot_images_grid(
#     images=[image_bgr, annotated_image, image_bgr_filled, annotated_image_filled],
#     grid_size=(2, 2),
#     titles=['source image', 'segmented image', 'source image w/ distribution', 'segmented image w/ distribution']
# )
# save_4(image_bgr, annotated_image, image_bgr_filled, annotated_image_filled)


### Output format

`SamAutomaticMaskGenerator` returns a `list` of masks, where each mask is a `dict` containing various information about the mask:

* `segmentation` - `[np.ndarray]` - the mask with `(W, H)` shape, and `bool` type
* `area` - `[int]` - the area of the mask in pixels
* `bbox` - `[List[int]]` - the boundary box of the mask in `xywh` format
* `predicted_iou` - `[float]` - the model's own prediction for the quality of the mask
* `point_coords` - `[List[List[float]]]` - the sampled input point that generated this mask
* `stability_score` - `[float]` - an additional measure of mask quality
* `crop_box` - `List[int]` - the crop of the image used to generate this mask in `xywh` format

### Interaction with segmentation results

In [None]:
# masks = [
#     mask['segmentation']
#     for mask
#     in sorted(sam_result, key=lambda x: x['area'], reverse=True)
# ]

# sv.plot_images_grid(
#     images=masks,
#     grid_size=(9, int(len(masks) / 8)),
#     size=(16, 16)
# )