In [None]:
# Copyright (c) Meta Platforms, Inc. and affiliates.

# BibTex Citation

# @article{kirillov2023segany,
#   title={Segment Anything},
#   author={Kirillov, Alexander and Mintun, Eric and Ravi, Nikhila and Mao, Hanzi and Rolland, Chloe and Gustafson, Laura and Xiao, Tete and Whitehead, Spencer and Berg, Alexander C. and Lo, Wan-Yen and Doll{\'a}r, Piotr and Girshick, Ross},
#   journal={arXiv:2304.02643},
#   year={2023}
# }

# Setup

Necessary imports and helper functions for displaying points, boxes, and masks.

In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2

In [None]:
def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)

def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))


## Sequential Renaming of Folder

Run following two blocks if you haven't before, renames each image in folder to a number sequentially

In [None]:
# Path to the directory containing the image sequence
directory_path = ''

In [None]:
# import os
# os.getcwd()
# for i, filename in enumerate(os.listdir(dir)):
#     if filename != str(i) + '.tif':
#     os.rename(dir + '//' + filename,dir + '//' + str(i) + '.tif')

Below are functions that are used to rename files within the folder sequentially, using a 4 digit numbering system (i.e., 0001, 0002, 0003, ...)

These functions may not be applicable to your file naming conventions, and need to be modified to your specific needs.

In [None]:
import os

# Word to remove from the filenames
word_to_remove = "Camera_15_46_55_"

# List all files in the directory
image_files = [f for f in os.listdir(directory_path) if os.path.isfile(os.path.join(directory_path, f))]

# Iterate through the image files and rename them
for image_file in image_files:
    # Split the filename and extension
    base_filename, file_extension = os.path.splitext(image_file)
    
    # Remove the word from the filename
    new_base_filename = base_filename.replace(word_to_remove, "")
    
    # Construct the new filename
    new_filename = new_base_filename + file_extension
    
    # Build the full path to the old and new file
    old_file_path = os.path.join(directory_path, image_file)
    new_file_path = os.path.join(directory_path, new_filename)
    
    # Rename the file
    os.rename(old_file_path, new_file_path)

In [None]:
# Sort the image files by their numerical order
image_files.sort(key=lambda x: int(os.path.splitext(x)[0]))

# Initialize a counter
counter = 0

# Iterate through the image files and rename them to a sequential numbered list
for image_file in image_files:
    # Get the file extension (e.g., .jpg, .png)
    file_extension = os.path.splitext(image_file)[-1]
    
    # Construct the new filename with a numbered list
    new_filename = f"{counter:04d}{file_extension}"  # Use 4 digits for the number
    
    # Build the full path to the old and new file
    old_file_path = os.path.join(directory_path, image_file)
    new_file_path = os.path.join(directory_path, new_filename)
    
    # Rename the file
    os.rename(old_file_path, new_file_path)
    
    # Increment the counter
    counter += 1

Turn renamed and numbered image sequence folder into a list

In [None]:
# Sort folder into an ordered list
import glob
import re
import os

images=[]
count = 0
natsort = lambda s: [int(t) if t.isdigit() else t.lower() for t in re.split('(\d+)', s)]

for image in sorted(glob.glob(directory_path + '\*.tif'),key = natsort):
    img = cv2.imread(image)
    images.append(img)

## Set SAM Checkpoint

First, make sure Segment Anything is installed. (pip install git+https://github.com/facebookresearch/segment-anything.git)

https://github.com/facebookresearch/segment-anything


Load the SAM model and predictor. Change the path below to point to the SAM checkpoint.

If not done already, a Model Checkpoint needs to be downloaded. This can be done on Meta's SAM github: https://github.com/facebookresearch/segment-anything#model-checkpoints

In [None]:
import sys
sys.path.append('') # Change path here to point to SAM Model Checkpoint
from segment_anything import sam_model_registry, SamPredictor

sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)

predictor = SamPredictor(sam)

Display of first image in folder. Note the coordinate system for future reference.

In [None]:
plt.imshow(images[0])
plt.axis('on')

## Segmentation

In the block below, choose a point on the image at which you want segmented.

For example: 

    input_point = np.array([[50,100]])
    input_label = np.array([1])

This would set one point on the image at (50,100) and highlight the predicted segment.

Also:

    input_point = np.array([[50,100],[100,200]])
    input_label = np.array([1,1])

This would choose two points and run the same prediction model.



For each image, the model will display three different segment predictions. You may choose the best one and save only that number into your output folder. Or, if desired, two or all images can be saved.


In [None]:
# Change this number to choose different images within your input folder
im_num = 0

predictor.set_image(images[im_num])
    
# Set coordinate and number of points for segmentation
input_point = np.array([[200,200]])
input_label = np.array([1])

masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=True,
)

for i, (mask, score) in enumerate(zip(masks, scores)):
    plt.figure(figsize=(10,10))
    plt.imshow(images[im_num])
    show_mask(mask, plt.gca())
    plt.axis('off')

Note which number provides the best segmentation results, it can be chosen in the next block. First image is 0, second image is 1, third image is 2.

Feel free to run multiple images to confirm your choice.

This next block iterates over the entire image sequence. Make sure to choose the correct number to save only the desired mask files.

This will run for a long time (about 30-45 seconds per image). 

In [None]:
# Declare which image you would like to move forward with (0,1, or 2)

best_img = 0

In [None]:
# Output folder Path
# Folder should already exist

output = ''

Set the input_point and input_label that was used above to yield desired results. 

These prompts will be used for every image in the image sequence. Make sure that the point(s) being used as your segmentation prompt will provide correct segmentation results throughout the entire image sequence, especially in for applications with large deformations. If necessary, modify the ranges of the for loop and run the code sectionally with different input prompts.

In [None]:
input_point = np.array([[200,200]])
input_label = np.array([1])

Modify the block below to the size and dpi of your images 

In [None]:
w = 400
l = 250
dpi = 96

In [None]:
count = 0
for image in images[0:]: # This number can be changed to set a range of images that you would like processed; "count" should also be changed for file naming purpose
    predictor.set_image(image)

    masks, scores, logits = predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    multimask_output=True,
    )

    for i, (mask, score) in enumerate(zip(masks, scores)):
        fig = plt.figure(figsize=(w/dpi,l/dpi))
        ax = fig.add_axes([0, 0, 1, 1])
        plt.imshow(image)
        show_mask(mask, plt.gca())
        ax.axis('off')
        ax.set(xlim=[-0.5, w - 0.5], ylim=[l - 0.5, -0.5], aspect=1)
        if i == best_img: # If saving multiple images, the if statement can be changed to i == 0 or i == 1 (for example), or, for saving all images, the if statement can be removed.
            plt.savefig(output + f'\{count:04d}.tif',dpi=dpi,transparent=True)
    
    count += 1

After the for loop is done running, you may notice a few bad frames. You can always go back and check the three different possible outputs for each individual image and change the image number (0,1,2) to reflect the best image. The range of the for loop can be changed to run select images again.