# Setup

In [None]:
import json
from PIL import Image
import matplotlib.pyplot as plt
from matplotlib.patches import Polygon
from pycocotools import mask as maskUtils
import torch
import random
import sys
import numpy as np
import supervision as sv
import time
import cv2
from typing import Tuple, Optional

np.set_printoptions(precision=15)

!nvidia-smi

In [None]:
# import utils
import data_preprocess, class_agnostic_sam_predictor
from data_preprocess import preprocess_utils as preprocess
from class_agnostic_sam_predictor import predictor_utils as predict

# from preprocess_utils import *

In [None]:
OM_dir = '/workspace/raid/OM_DeepLearning/XMM_OM_dataset/scaled_raw_512/'
OM_dir

# Import and visualize Roboflow annotations

This step assumes that you have downloaded the dataset locally, in COCO format. 

In [None]:
if 0==1:
    
    def display_masks(image_path, masks):
        image = Image.open(image_path)
        plt.imshow(image)
        ax = plt.gca()
    
        for mask in masks:
            if isinstance(mask[0], list):  # If the mask is a polygon
                polygon_points = np.array(mask[0]).reshape(-1, 2)
                polygon = Polygon(polygon_points, edgecolor='g', facecolor='none')
                ax.add_patch(polygon)
            else:  # If the mask is RLE
                binary_mask = maskUtils.decode(mask)
                ax.imshow(binary_mask, alpha=0.5, cmap='gray')
        plt.show()
    
    dir_train_path = './xmm_om_images_v4-contrast-512-5-7/train/'
    json_file_path = dir_train_path+'_annotations.coco.json'
    
    with open(json_file_path, 'r') as f:
        data = json.load(f) 
    
    for image_info in data['images']:
        image_id = image_info['id']
        image_path = dir_train_path + image_info['file_name']
        
        annotations = [anno for anno in data['annotations'] if anno['image_id'] == image_id]
        # annotations: dict_keys(['id', 'image_id', 'category_id', 'bbox', 'area', 'segmentation', 'iscrowd'])
        
        masks = [anno['segmentation'] for anno in annotations]
        display_masks(image_path, masks)

# Mobile SAM inference

Here, the SAM Auto Mask Generator is used to predict masks on an input image. The predictor generates $32\times32$ grid points which represent foreground point input prompts for the mask decoder and filteres the best masks for prediciton. However, some methods are have the `torch.no_grad()` decorator thus such predictor cannot be trained. One idea would be to add gradients <i>on the way</i>, but it is risky and is not targeted for this moment.

In [None]:
sys.path.append('/workspace/raid/OM_DeepLearning/MobileSAM-master/') # MobileSAM repo path
import mobile_sam
from mobile_sam import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

mobile_sam_checkpoint = "/workspace/raid/OM_DeepLearning/MobileSAM-master/weights/mobile_sam.pt"
device = "cuda:3" if torch.cuda.is_available() else "cpu"
mobile_sam_model = sam_model_registry["vit_t"](checkpoint=mobile_sam_checkpoint)
mobile_sam_model.to(device=device)
mobile_sam_model.eval();
device

In [None]:
image_path = './roboflow_datasets/xmm_om_images_v4-contrast-512-5-7/train/S0037980401_L_png.rf.17cd9454f2c96e8a3e06676a49f2640b.jpg'
image_fits_path = '../XMM_OM_dataset/scaled_raw_512/S0037980401_L.fits'

mask_on_negative = predict.mask_and_plot_image(image_fits_path)
output_path = image_path.split('/')[-1].replace(".png", "_mobile_sam_nonnegative.png")
start_time = time.time()
image, mobile_sam_result = predict.MobileSAM_predict(
    image_path, 
    model=mobile_sam_model,
    predictor=SamPredictor,
    generator=SamAutomaticMaskGenerator,
    device=device, 
    output_path=output_path, 
    mask_on_negative=mask_on_negative)
end_time = time.time()
print(f"Mobile SAM predict time/img: {end_time-start_time} s")

plt.imshow(image)
plt.title('Mobile SAM pred. on image input')
plt.show()
plt.close()

# 🚀 Segment Anything Model (SAM)

## Install SAM model and dependencies

In [None]:
# %cd {HOME}

# import sys
# !{sys.executable} -m pip install 'git+https://github.com/facebookresearch/segment-anything.git'
# !pip install -q jupyter_bbox_widget roboflow dataclasses-json supervision

## Download SAM weights

In [None]:
# %cd {HOME}
# !mkdir {HOME}/weights
# %cd {HOME}/weights

# !wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

## Load Model

In [None]:
import os
HOME = os.getcwd()

CHECKPOINT_PATH = os.path.join(HOME, "weights", "sam_vit_h_4b8939.pth")
print(CHECKPOINT_PATH, "; exist:", os.path.isfile(CHECKPOINT_PATH))

In [None]:
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

DEVICE = torch.device(f'cuda:1' if torch.cuda.is_available() else 'cpu') 
sam = sam_model_registry["vit_h"](checkpoint=CHECKPOINT_PATH).to(device=DEVICE)
sam.eval();

## 🚀 The predictor function (can remove masks on negative pixels)

In [None]:
mask_on_negative = predict.mask_and_plot_image("../XMM_OM_dataset/scaled_raw_512/S0012850201_L.fits", plot_=True)

## AMG with extracted and deblended sources centroids

Generating 1024 points on an astronomical image may be redundant, as the usual Field of View contains way fewer noticeable sources. In this way, we can reduce inference time by using deblended sources centroids resulted from a source detection algorithm. This approach yields faster result but is sometimes less accurate.

In [None]:
obs_id = 'S0801800201_L'
image_path = f"../XMM_OM_dataset/zscaled_512_stretched/{obs_id}.png"
mask_on_negative = predict.mask_and_plot_image(f"../XMM_OM_dataset/scaled_raw_512/{obs_id}.fits")
start_time = time.time()

with open('extracted_sources_bboxes_points.json', 'r') as f:
    extracted_bboxes_points = json.load(f)

img_points = np.array([point for filename, point in extracted_bboxes_points['points'].items() 
                       if obs_id in filename])/255.0

_, _, annotated_image1 = predict.SAM_predictor(
    SamAutomaticMaskGenerator, 
    sam, image_path, 
    mask_on_negative=mask_on_negative, 
    img_grid_points=None)
sam_grid_time = time.time()-start_time

img_points = np.array([point for filename, point in extracted_bboxes_points['points'].items() 
                       if filename in image_path.split("/")[-1]])/255
if len(img_points)>0:
    start_time = time.time()
    _, _, annotated_image2 = predict.SAM_predictor(
        SamAutomaticMaskGenerator, 
        sam, 
        image_path, 
        mask_on_negative=None, 
        img_grid_points=img_points)
    sam_with_detected_sources_time = time.time()-start_time
    
fig, axs= plt.subplots(1, 3, figsize=(10, 10)) 
axs[0].imshow(annotated_image1)
axs[0].set_title(f'Original SAM inference \ntime: {round(sam_grid_time, 3)}s, {1024} grid points', \
                 fontfamily='monospace', fontsize=10)

axs[1].imshow(annotated_image2)
axs[1].set_title(f'SAM inference \nwith deblended sources centroids \ntime: '+\
                 f'{round(sam_with_detected_sources_time, 3)}s, {img_points.shape[1]} grid points', \
                fontfamily='monospace', fontsize=10)

axs[2].imshow(cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB))
axs[2].scatter(img_points[0][:, 0]*255, img_points[0][:, 1]*255, s=10, c='red')
axs[2].set_title(f'Extracted sources grid points', \
                 fontfamily='monospace', fontsize=10)
axs[2].set_aspect('equal', 'box')  
plt.tight_layout()
# plt.savefig('plots/sam_grid_points_comparison.png', dpi=300)
plt.show()
plt.close()

In [None]:
generate_all_predictions = True

inference_times = []
if generate_all_predictions:
    dir_files = [f for f in os.listdir(OM_dir) if os.path.isfile(os.path.join(OM_dir, f))]
    for file_ in dir_files:
        try:
            if '.png' in file_:
                image_path = OM_dir+'/'+file_
                mask_on_negative = predict.mask_and_plot_image(image_path.replace('.png', '.fits'))
                start_time = time.time()
                
                with open('extracted_sources_bboxes_points.json', 'r') as f:
                    extracted_bboxes_points = json.load(f)
                
                img_points = np.array([point for filename, point in extracted_bboxes_points['points'].items() 
                                       if filename in image_path])/255.0
                
                _, _, annotated_image1 = predict.SAM_predictor(
                    SamAutomaticMaskGenerator, 
                    sam, 
                    image_path, 
                    mask_on_negative=mask_on_negative, 
                    img_grid_points=None)
                sam_grid_time = time.time()-start_time
                
                img_points = np.array([point for filename, point in extracted_bboxes_points['points'].items() 
                                       if filename in image_path.split("/")[-1]])/255
                if len(img_points)>0:
                    start_time = time.time()
                    _, _, annotated_image2 = predict.SAM_predictor(
                        SamAutomaticMaskGenerator, 
                        sam, 
                        image_path, 
                        mask_on_negative=None, 
                        img_grid_points=img_points)
                    sam_with_detected_sources_time = time.time()-start_time
                    
                fig, axs= plt.subplots(1, 3, figsize=(10, 10)) 
                axs[0].imshow(annotated_image1)
                axs[0].set_title(f'Original SAM inference \ntime: {round(sam_grid_time, 3)}s, {1024} grid points', \
                                 fontfamily='monospace', fontsize=10)
                
                axs[1].imshow(annotated_image2)
                axs[1].set_title(f'SAM inference \nwith deblended sources centroids \ntime: '+\
                     f'{round(sam_with_detected_sources_time, 3)}s, {img_points.shape[1]} grid points', \
                fontfamily='monospace', fontsize=10)
                
                axs[2].imshow(cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB))
                axs[2].scatter(img_points[0][:, 0]*255, img_points[0][:, 1]*255, s=10, c='red')
                axs[2].set_title(f'Extracted sources grid points', \
                                 fontfamily='monospace', fontsize=10)
                axs[2].set_aspect('equal', 'box')  
                plt.tight_layout()
                # plt.savefig('plots/sam_grid_points_comparison.png', dpi=300)
                plt.show()
                plt.close()

        except Exception as e:
            print(e)
            pass
    
    # with open('cell_SAM_predict_with_threshold.txt', 'w') as f:
    #     f.write(str(cap))

## 🚀 Generate annotation json file (COCO format)

In [None]:
import os
import json
import numpy as np
from datetime import datetime
from PIL import Image

def show_box(box, ax):
    x0, y0 = box[0], box[1]
    print(box[0], box[1], box[2],  box[3]) 
    w, h = box[2], box[3] 
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))

def numpy_to_list(obj):
    if isinstance(obj, np.ndarray):
        return obj.tolist()
    elif isinstance(obj, dict):
        return {k: numpy_to_list(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [numpy_to_list(item) for item in obj]
    else:
        return obj

categories = []

# coco_style_annotations = {'categories': categories, 'images': [], 'annotations': []}
coco_style_annotations = {'annotations': []}

def get_SAM_annotations(IMAGE_PATH, mask_on_negative = None, output_mode="binary_mask"):
    """
    This function calls SAM (Segment Anything) and gets annotations for a given image.
    Args:
        IMAGE_PATH (str): The path to the image file.
        remove_masks_on_negative (bool, optional): If True, masks on negative detections are removed.

    Returns:
        tuple: A tuple containing the SAM results, detections, and the annotated image.
    """
    image_bgr = cv2.imread(IMAGE_PATH)
    annotated_image, detections, sam_result = None, None, None

    image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
    mask_generator = SamAutomaticMaskGenerator(sam, output_mode=output_mode)
    
    sam_result = mask_generator.generate(image_rgb)
    if mask_on_negative is not None:
        sam_result = remove_masks(sam_result=sam_result,mask_on_negative=mask_on_negative, threshold=50)
        
    mask_annotator = sv.MaskAnnotator(color_lookup=sv.ColorLookup.INDEX)
    detections = sv.Detections.from_sam(sam_result=sam_result)
    annotated_image = mask_annotator.annotate(scene=image_rgb.copy(), detections=detections)
    
    return sam_result, detections, annotated_image

# Run SAM for image files in the directory and create the annotation json file in COCO format
def generate_json_file(input_dir, coco_style_annotations):
    for file_ in os.listdir(input_dir):
        if "png" in file_:
            mask_on_negative = mask_and_plot_image(input_dir+file_.replace('.png', '.fits'))
            sam_result_i, detections_i, annotated_image_i = get_SAM_annotations(input_dir+file_, mask_on_negative.astype(int))
            sam_result_i = numpy_to_list(sam_result_i)
            img = cv2.imread(os.path.join(input_dir, file_))
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            height, width, _ = img.shape
            
            for annotation in sam_result_i:
                xyhw = annotation['bbox']
                
                # filter out very small/big objects
                if (xyhw[2] >2 or xyhw[3] >2) and (xyhw[2]*1.0/height < 0.7 and xyhw[3]*1.0/width < 0.7):
                    coco_style_annotations['annotations'].append({
                            'id': f'{file_.split(".")[0]}_mask{k}',
                            'image_id': file_.split('.')[0], 
                            'category_id': 0,  
                            'segmentation': annotation['segmentation'],
                            'area': annotation['area'],
                            'bbox': annotation['bbox'],
                            'iscrowd': 0,
                        })
                    k+=1
    
    with open('SAM_annotations_coco_style_v2.json', 'w') as f:
        json.dump(coco_style_annotations, f)

input_dir = '/workspace/raid/OM_DeepLearning/XMM_OM_dataset/scaled_raw/'
generate_json_file(input_dir, coco_style_annotations)

## 🚀 Generate annotation json file (VOC format)

In [None]:
import glob
from roboflow import Roboflow

# Initialize Roboflow client
rf = Roboflow(api_key="my_apy_key")
upload_project = rf.workspace("my_username").project("xmm_om_images_v4-contrast-512-5") # error if the project doesn't exist

### Create and export SAM annotations in VOC format to Roboflow

In [None]:
import dataset, class_agnostic_sam_predictor
from class_agnostic_sam_predictor import predictor_utils as predict
from dataset import voc_annotate_and_Roboflow_export as voc

input_dir = './temp_images/'

k=0
if 1==1:
    input_fits_dir = '/workspace/raid/OM_DeepLearning/XMM_OM_dataset/scaled_raw_512/'
    for file_ in os.listdir(input_dir):
        if k>1:
            break
        else:
            k+=1
            mask_on_negative = predict.mask_and_plot_image(input_fits_dir+file_.replace('.png', '.fits'), plot_=False)
            sam_result_i, detections_i, annotated_image_i = predict.SAM_predictor(SamAutomaticMaskGenerator, sam, input_dir+file_, 
                                                                          mask_on_negative=mask_on_negative, img_grid_points=None)
            if sam_result_i is not None and detections_i is not None and annotated_image_i is not None:
                objects = []
                for annotation in sam_result_i: # a mask over an image is a binary array with shape (img_h, img_w)
                    polygon = voc.binary_image_to_polygon(annotation['segmentation'])
                    # plot_polygon(polygon[0], annotated_image_i) # to see the masks polygons
                    objects.append({
                        'name': 'star',
                        'bbox': annotation['bbox'],
                        'segmentations': polygon[0]
                    })

                print(file_)
                voc.create_annotation_SAM(filename=file_, width=512, height=512, depth=3, objects=objects) # generating xml file for VOC format
                image_path = input_dir+file_
                annotation_filename = file_.replace(".png", ".xml")
                
                upload_project.upload(image_path, annotation_filename, overwrite=False)
                os.remove(annotation_filename)

### Convert annotations from **COCO** format to VOC and mount to Roboflow

In [None]:
from importlib import reload
import dataset_utils
reload(dataset_utils)
from dataset_utils import *

import json 
import dataset
from dataset import voc_annotate_and_Roboflow_export as voc

input_dir = "/workspace/raid/OM_DeepLearning/XMM_OM_code_git/roboflow_datasets/xmm_om_images_512_SG_SR_CR_only-10-COCO/train/"
json_file_path = input_dir+'_annotations.coco.json'

output_d = "/workspace/raid/OM_DeepLearning/VOC_xmm_om_images-contrast-512-v5-3/"

with open(json_file_path) as f:
    data_in = json.load(f)

k=0
if 1==1:
    for im in data_in['images']: 
            k+=1  
            if k>1:
                 break  
            objects = []
            file_ = im['file_name']
            extension = "."+file_.split(".")[-1]
            masks = [data_in['annotations'][a] for a in range(len(data_in['annotations'])) if data_in['annotations'][a]['image_id'] == im['id']]
            classes = [data_in['annotations'][a]['category_id'] for a in range(len(data_in['annotations'])) if data_in['annotations'][a]['image_id'] == im['id']]
            class_categories = {data_in['categories'][a]['id']:data_in['categories'][a]['name'] for a in range(len(data_in['categories']))}
            temp_img = cv2.imread(input_dir+im["file_name"])
            temp_img = cv2.cvtColor(temp_img, cv2.COLOR_BGR2RGB)
            cv2.imwrite(f"./{file_.replace('_png', '.png')}", temp_img)

            for i in range(len(masks)):
                segmentation = masks[i]['segmentation']
                if isinstance(segmentation, list):
                    if len(segmentation) > 0 and isinstance(segmentation[0], list):
                        points = segmentation[0]
                        h_img, w_img = temp_img.shape[:2]
                binary_m = create_mask(points, (h_img, w_img)) # COCO segmentations are polygon points, and must be converted to masks
                objects.append({
                    'name': class_categories[classes[i]],
                    'bbox': mask_to_bbox(binary_m),
                    'segmentations': segmentation[0]
                })

            voc.create_annotation(filename=file_.replace('_png', '.png'), width=512, height=512, depth=3, objects=objects) # generating xml file for VOC format 
            image_path = file_.replace('_png', '.png')
            annotation_filename = file_.replace('_png', '.png').replace(extension, ".xml")
            new_lines = ['<annotation>\n','	<folder></folder>\n']
            
            with open(annotation_filename, 'r') as file:
                lines = file.readlines()
            
            del lines[:3]
            
            modified_lines = new_lines + lines
            
            with open(annotation_filename, 'w') as file:
                file.writelines(modified_lines)

            upload_project.upload(image_path, annotation_filename, overwrite=False)
            del temp_img
            os.remove(annotation_filename)
            os.remove(image_path)

### Export images and VOC annotations in VOC format to Roboflow (2)

In [None]:
import requests
import os

import glob
from roboflow import Roboflow

rf = Roboflow(api_key="my_api_key")
upload_project = rf.workspace("my_username").project("xmm_om_images_v4-contrast-512-3")

dataset_images_folder = '/workspace/raid/OM_DeepLearning/XMM_OM_dataset/zscaled_512_stretched/'
annotations_voc_dir = '/workspace/raid/OM_DeepLearning/XMM_OM_code_git/xmm_om_images_v4-contrast-512-1/train/'
annotations_ = []

for annot in os.listdir(annotations_voc_dir):
    if annot.endswith('.xml'):
        annotations_.append(annot)

for image_name in os.listdir(dataset_images_folder):
    image_path = os.path.join(dataset_images_folder, image_name)
    if os.path.isfile(image_path):
        print(image_path.split('/')[-1].replace('.', '_'))
        annotations_voc_filename = [annotation for annotation in annotations_ if annotation.startswith(image_path.split('/')[-1].replace('.', '_'))]
        if len(annotations_voc_filename):
            upload_project.upload(image_path, annotations_voc_dir+annotations_voc_filename[0], overwrite=True)

print("Image upload complete.")