#### Script Description: Vanilla SAM applied on confocal images acquired by Sanchari. 
#### Input data: Input_Data/Sanchari_data/Sanchari Confocal Images_RGB
#### Output: results/sanchari_masks_pkl_SAM

In [None]:
import numpy as np
from numpy import array
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw, ImageStat
from skimage import io, util, color
from scipy.optimize import linear_sum_assignment
import sys
import cv2
import pickle

In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

In [None]:
!pip install --upgrade imagecodecs

In [None]:
# image = io.imread("/content/WER 2-PBS_Ch1-T2_ORG.tif")
# image = io.imread("/content/drive/MyDrive/consolidated_images/6.png")

image = io.imread("/content/drive/MyDrive/Sanchari Confocal Images/SCR 1 day water objective stack_z07_Ch1-T2_ORG.tif")
# if image.shape[2] > 3:
#   image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# image1 = io.imread("/content/020422_At1_7d_long_meristem_CW_jpegfile_rgb.jpg")
# image1 = image[0]
# io.imshow(image1[7])
# io.imshow(image1)
io.imshow(image)

In [None]:
# with open(f'/content/drive/MyDrive/mask_ps_pickle/masks1.pkl', 'rb') as file:
#   masks2 = pickle.load(file)

In [None]:
# from skimage import filters
# import cv2
# sobel_x = cv2.Sobel(image, cv2.CV_32F, 1, 0, ksize=3)
# sobel_y = cv2.Sobel(image, cv2.CV_32F, 0, 1, ksize=3)
# sobel = cv2.magnitude(sobel_x, sobel_y)
# sobel = cv2.convertScaleAbs(sobel)
# io.imshow(sobel)

In [None]:
import os

HOME = os.getcwd()
print("HOME:", HOME)

In [None]:
%cd {HOME}

import sys
!{sys.executable} -m pip install 'git+https://github.com/facebookresearch/segment-anything.git'

In [None]:
!pip install -q jupyter_bbox_widget roboflow dataclasses-json supervision

In [None]:
%cd {HOME}
!mkdir {HOME}/weights
%cd {HOME}/weights

!wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

In [None]:
import os

CHECKPOINT_PATH = os.path.join(HOME, "weights", "sam_vit_h_4b8939.pth")
print(CHECKPOINT_PATH, "; exist:", os.path.isfile(CHECKPOINT_PATH))

In [None]:
import torch

DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
MODEL_TYPE = "vit_h"

In [None]:
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH).to(device=DEVICE)

In [None]:
mask_generator = SamAutomaticMaskGenerator(sam)

In [None]:
import supervision as sv
sam_result = mask_generator.generate(image)

In [None]:
mask_annotator = sv.MaskAnnotator(color_lookup = sv.ColorLookup.INDEX)

detections = sv.Detections.from_sam(sam_result=sam_result)

annotated_image = mask_annotator.annotate(scene=image.copy(), detections=detections)

sv.plot_images_grid(
    images=[image, annotated_image],
    grid_size=(1, 2),
    titles=['source image', 'segmented image'])

In [None]:
io.imsave('/content/weights/annotated_image.jpg', annotated_image)

In [None]:
masks = [
    mask['segmentation']
    for mask
    in sorted(sam_result, key=lambda x: x['area'], reverse=True)
]

In [None]:
import ipywidgets as widgets
from IPython.display import display, clear_output
import matplotlib.pyplot as plt


# Initialize variables
current_image_idx = 0
selected_images = []

# Create an Output widget for displaying images
output = widgets.Output()

# Function to display the current image
# def display_image(image_data):
#     with output:
#         clear_output(wait=True)
#         plt.imshow(image_data)
#         plt.axis('off')
#         plt.show()
def display_image(image_data):
    with output:
        clear_output(wait=True)
        plt.figure(figsize=(10, 5))  # Adjust the figure size as needed
        plt.subplot(121)  # Subplot for the changing image
        plt.imshow(image_data)
        plt.axis('off')
        plt.subplot(122)  # Subplot for the reference image
        plt.imshow(image, cmap='gray')  # Replace 'gray' with the appropriate colormap
        plt.axis('off')
        plt.show()
# Create buttons for selecting, discarding, and quitting
select_button = widgets.Button(description="Select")
discard_button = widgets.Button(description="Discard")
quit_button = widgets.Button(description="Quit")
next_button = widgets.Button(description="Next")

# Define actions for buttons
def select_image(_):
    global current_image_idx
    selected_images.append(masks[current_image_idx])
    current_image_idx += 1
    if current_image_idx < len(masks):
        display_image(masks[current_image_idx])
    else:
        clear_output()
        print("Image selection is complete.")

def discard_image(_):
    global current_image_idx
    current_image_idx += 1
    if current_image_idx < len(masks):
        display_image(masks[current_image_idx])
    else:
        clear_output()
        print("Image selection is complete.")

def quit_app(_):
    clear_output()
    print("Image selection is complete.")

def next_image(_):
    global current_image_idx
    current_image_idx += 1
    if current_image_idx < len(masks):
        display_image(masks[current_image_idx])
    else:
        clear_output()
        print("Image selection is complete.")

# Bind button actions
select_button.on_click(select_image)
discard_button.on_click(discard_image)
quit_button.on_click(quit_app)
next_button.on_click(next_image)

# Initial display
display_image(masks[current_image_idx])

# Display the buttons
button_box = widgets.HBox([select_button, discard_button, next_button, quit_button])
display(widgets.VBox([output, button_box]))


In [None]:
import numpy as np

# Assuming you have a list of object masks (masks) and an image size (image_height, image_width)

# Initialize an empty mask with the same dimensions as your image
empty_mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)

# Iterate through the 10 individual object masks and combine them
for object_mask in selected_images:
    empty_mask = np.logical_or(empty_mask, object_mask)

plt.imshow(empty_mask)

In [None]:
mask_predictor = SamPredictor(sam)

In [None]:
# helper function that loads an image before adding it to the widget

import base64

def encode_image(filepath):
    with open(filepath, 'rb') as f:
        image_bytes = f.read()
    encoded = str(base64.b64encode(image_bytes), 'utf-8')
    return "data:image/jpg;base64,"+encoded

In [None]:
from ipywidgets import widgets
IS_COLAB = True

if IS_COLAB:
    from google.colab import output
    output.enable_custom_widget_manager()

from jupyter_bbox_widget import BBoxWidget # Adjust the figure size as needed

# plt.imshow(annotated_image, cmap='gray')  # Replace 'gray' with the appropriate colormap

widget = BBoxWidget()
widget.image = encode_image("/content/weights/annotated_image.jpg")
# widget.image = encode_image("/content/drive/MyDrive/consolidated_images/6.png")
widget

In [None]:
prompt_masks = []

In [None]:
# default_box is going to be used if you will not draw any box on image above
default_box = {'x': 68, 'y': 247, 'width': 555, 'height': 678, 'label': ''}

for i in range(len(widget.bboxes)):
  box = widget.bboxes[i] if widget.bboxes else default_box
  box = np.array([
      box['x'],
      box['y'],
      box['x'] + box['width'],
      box['y'] + box['height']
  ])
  mask_predictor.set_image(annotated_image)

  mask_prompt, scores, logits = mask_predictor.predict(
      box=box,
      multimask_output=False
  )

  prompt_masks.append(mask_prompt[0])
  box_annotator = sv.BoxAnnotator(color=sv.Color.red())
  mask_annotator = sv.MaskAnnotator(color_lookup = sv.ColorLookup.INDEX)

  detections = sv.Detections(
      xyxy=sv.mask_to_xyxy(masks=mask_prompt),
      mask=mask_prompt
  )
  # detections = detections[detections.area == np.max(detections.area)]

  source_image = box_annotator.annotate(scene=annotated_image.copy(), detections=detections, skip_label=True)
  segmented_image = mask_annotator.annotate(scene=annotated_image.copy(), detections=detections)

  sv.plot_images_grid(
      images=[source_image, segmented_image],
      grid_size=(1, 2),
      titles=['source image', 'segmented image']
  )

In [None]:
# from skimage.transform import resize
# for i, mask in enumerate(selected_images):
#   selected_images[i] = resize(mask, (600, 600))

In [None]:
test1 = selected_images + prompt_masks

In [None]:
test2 = test1 + prompt_masks

In [None]:
test3 = test2 + prompt_masks

In [None]:
test4 = test3 + prompt_masks

In [None]:
selected_images[0].shape

In [None]:
with open(f"/content/drive/MyDrive/masks5.pkl", 'wb') as file:
    pickle.dump(test1, file)

In [None]:
import numpy as np

# Assuming you have a list of object masks (masks) and an image size (image_height, image_width)

# Initialize an empty mask with the same dimensions as your image
empty_mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)

# Iterate through the 10 individual object masks and combine them
i = 0
for object_mask in test1:

    empty_mask = np.logical_or(empty_mask, object_mask)

plt.imshow(empty_mask)

In [None]:
object_mask.shape

In [None]:
plt.imsave("/content/drive/MyDrive/Goundtruth_masks/mask_23.jpg", empty_mask)

In [None]:
import supervision as sv
sv.plot_images_grid(
    images=selected_images,
    # grid_size=(16, int(len(masks) / 16)),
    grid_size= (16, 16),
    size=(30, 30)
)

In [None]:
# detections = detections[detections.area == np.max(detections.area)]

# source_image = box_annotator.annotate(scene=annotated_image.copy(), detections=detections, skip_label=True)
# segmented_image = mask_annotator.annotate(scene=annotated_image.copy(), detections=detections)

# sv.plot_images_grid(
#     images=[source_image, segmented_image],
#     grid_size=(1, 2),
#     titles=['source image', 'segmented image']
# )

In [None]:
# import supervision as v

# sv.plot_images_grid(
#     images=masks,
#     grid_size=(16, 16),
#     size=(16, 16)
# )