# Segment Anything PIV

## Installation & Current Directory

Import packages already installed.

In [None]:
## required packages
import os

import numpy as np
import matplotlib.pyplot as plt
import cv2
import tifffile # saving images to TIF
import torch
from PIL import Image, ImageEnhance


If using google colab, you can connect to your google drive folders and files with drive.mount:

In [None]:
## mounting google drive
from google.colab import drive
drive.mount('/content/drive')
%cd /content/drive/MyDrive/

#put path to colab folder here ^

Install segment anything. Additionlly, you need to manually install the "model checkpoints. From [Segment Anything Github Page](https://github.com/facebookresearch/segment-anything), scroll down until you get to the "Model Checkpoints" part. Then install any one of the models and **place them in the current directory specified above** (using %cd).

In [None]:
# installation of segment anything
!pip install git+https://github.com/facebookresearch/segment-anything.git;

Install packages required for segment anything.

In [None]:
# installing requirements for segment anything
!pip install opencv-python pycocotools matplotlib

Install packages that haven't been pre-installed.

In [None]:
# install not pre-installed packages
!pip install tifffile cv

## Manual Segmentation and Image Processing (No Classes/Functions)

###Image Processing

#### Video to Frames
If the images that you need to mask are in a video, use this to extract those images. Specify the path to the video, analyze all frames and then save each frame to an image file.

In [None]:

video_path = '/content/drive/MyDrive/<path_to_your_video.mp4>'
video = cv2.VideoCapture(video_path)

# Initialize frame count
frame_count = 0

# Read the video frames and save them as images
while True:
    # Read the next frame
    ret, frame = video.read()

    # Break the loop if no frame is captured
    if not ret:
        break

    # Save the frame as an image
    image_path = f'video_frame/frame_{frame_count:04d}.jpg'
    cv2.imwrite(image_path, frame)

    # Increment frame count
    frame_count += 1

# Release the video file and close the image window
video.release()
cv2.destroyAllWindows()

#### Pre-Processing

Define path to the image folder and filename for testing, and define the path to put the processed images, created by the next code block. You can chnage one parameter (brightness, sharpness, or contrast) and compare those images with each other, or create one image with all of these changes combined. The images are then plotted and compared with the original.

Test pre-processing before saving anything.

In [None]:

PATH_TO_IMG = ''

PATH_TO_PROCESSED_IMG = ''

filename = ''

img_path = os.path.join(PATH_TO_IMG,filename)
img = Image.open(img_path)

#for changing parameters individually, creating images for each, and then comparing individual images

contrast = ImageEnhance.Contrast(img)
img_processed_contrast = contrast.enhance(4) #the numbers defined in .enhance() control the strength of the effect. 1 is the default

brightness = ImageEnhance.Brightness(img)
img_processed_brightness = brightness.enhance(0.5)

sharpness = ImageEnhance.Sharpness(img)
img_processed_sharpness = sharpness.enhance(4)

#change any parameter and make single image with those changes

def process_img(image,contrast,brightness,sharpness):
  contraster = ImageEnhance.Contrast(image)
  img_contrast = contraster.enhance(contrast)

  brightnesser = ImageEnhance.Brightness(img_contrast)
  img_brightness_and_contrast = brightnesser.enhance(brightness)

  sharpnesser = ImageEnhance.Sharpness(img_brightness_and_contrast)
  img_processed_sharpness_brightness_contrast = sharpnesser.enhance(sharpness)

  return img_processed_sharpness_brightness_contrast

#define strength of effects
cont,br,sharp = 1,1,1

img_processed = process_img(img,cont,br,sharp)

#configure subplots to add differently processed images
plt.figure(figsize=(20,20))
plt.subplot(4,1,1)
plt.imshow(img,cmap='gray')
plt.subplot(4,1,2)
plt.imshow(img_processed,cmap='gray')


To pre-process all images in a folder afer testing. Path to the image folder and processed images folder specified above.

In [None]:
#set the contrast, brightness and sharpness values
cont, br, sharp = 1,1,1

# create output folder if it does not exist
if not os.path.exists(PATH_TO_PROCESSED_IMG):
    os.makedirs(PATH_TO_PROCESSED_IMG)

# loop through all the files in the folder
for file in os.listdir(PATH_TO_IMG):
  # open the image file
  img_path = os.path.join(PATH_TO_IMG,file)
  img = Image.open(img_path)
  #use function to edit images
  img_processed = process_img(img,cont,br,sharp)

  out_path = os.path.join(PATH_TO_PROCESSED_IMG,file)
  # save the edited image with the same filename
  img_processed.save(out_path)

#### Cropping of Images
Used in case the region of interest is smaller than complete image.

Specifiy image path, and path to save the cropped image.  The roi variable then crops the image using a rectangle with boundaries specified in an xyxy format (left side, top side, right side, and bottom side. To figure out where each point is, just plot the uncropped image instead and use plt.scatter to plot the point(s)).Then plot the cropped image.

Test cropping of images before saving anything.

In [None]:
# cropping of images, testing for one image without saving

PATH_TO_IMG  = ''
PATH_TO_CROP = ''
filename = '' # image used for testing

# set the region to be cropped in the format (left, upper, right, lower)
roi  = (270, 400, 1150, 900)

img_path = os.path.join(PATH_TO_IMG, filename)
img = Image.open(img_path)

# crop the image
img_crop = img.crop(roi)

plt.imshow(img_crop, cmap = 'gray')


To crop all images within a folder after testing. Path to the image folder and cropped images folder specified above.

In [None]:
# cropping of all images inside a folder

# create output folder if it does not exist
if not os.path.exists(PATH_TO_CROP):
    os.makedirs(PATH_TO_CROP)

# loop through all the files in the folder
for filename in os.listdir(PATH_TO_IMG):
    # open the image file
    img_path = os.path.join(PATH_TO_IMG, filename)
    img = Image.open(img_path)

    # crop the image
    img_crop = img.crop(roi)

    out_path = os.path.join(PATH_TO_CROP, filename)

    # save the cropped image with the same filename
    img_crop.save(out_path)

### Manual Segmentation (No Classes/Functions)

#### Parameters

Import segment anything, and specify the model and corresponding checkpoint being used (see above for instructions to download). Additionlly, specify the device being used for masking and specify the folders to the images and folder to save the masks to. Used for the mask generator and predictor code blocks that arent functions or classes.

In [None]:
# segment anyting
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry, SamPredictor
model_type = 'default' # 'default', "vit_l", "vit_b"

match model_type:

  # put the downloaded checkpoints into /content/drive/MyDrive/Colab Notebooks; whatever the path you set in 2nd code block
  case 'default':
    checkpoint = 'sam_vit_h_4b8939.pth'

  case 'vit_l':
    checkpoint = 'sam_vit_l_0b3195.pth'

  case 'vit_b':
    checkpoint = 'sam_vit_b_01ec64.pth'

device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # if available use GPU for speed-up

# image folder and folder for masking
PATH_TO_IMG  = ''
PATH_TO_MASK = ''


#### Mask Generator
* creation of masks for the whole image

Automatically creates mask for images without specification of any particular places/objects to mask. Recommended for simple images with one mask.

In [None]:
#set up sam masking
sam = sam_model_registry[model_type](checkpoint=checkpoint)
sam.to(device)
mask_generator = SamAutomaticMaskGenerator(sam)

# create output folder if it does not exist
if not os.path.exists(PATH_TO_MASK):
    os.makedirs(PATH_TO_MASK)

for filename in os.listdir(PATH_TO_IMG):
  # read image
  img_path = os.path.join(PATH_TO_IMG, filename)
  img = np.array(Image.open(img_path).convert('RGB'))# use of PIL to ensure compatible with PyTorch

  # masking

  mask = mask_generator.generate(img)

  # output image
  out = mask[0]['segmentation']
  out = np.logical_not(out)
 # out = cv2.bitwise_not(out.astype(np.uint8)).astype(bool) # use this in case masked and unmasked areas are flipped
  outname = 'mask_' + filename.replace('.jpg', '.tif')

  # save binary mask to output folder with same filename
  output_path = os.path.join(PATH_TO_MASK, outname)
  tifffile.imwrite(output_path, out)

#### Predictor
* creation of multiple masks from images and bonding boxes
* recommended for images that need manual masking

Preview a specified image for masking using the Predictor. Specify image path, plot the image as well as a specific point used for the Predictor masking.

In [None]:
put_img_file_here = 'img_00250.jpg'
img_file = os.path.join(PATH_TO_IMG, put_img_file_here)
img = Image.open(img_file)

x,y = 2000,1750   #specify location of point to observe
plt.imshow(img, cmap = 'gray')
plt.scatter(x,y)
print(img_file)

Mask using the Predictor. Use input_point to mask specific objects in the image, input_label to give these points a marker, and input_box to specify a box to mask.

For the point, you can specify numerous points on the same object via np.array([[x1,y1],[x2,y2],...,[xn,yn]]). If you decide to do this, you must label each point, for example, np.array([0,0,0,1, etc.]).

You can use a combination of box and point for further specification on what to mask.

You can also mask additional objects in the image.

In [None]:
#640, 150 and 225, 150
input_point = np.array([[400,300]]) #specify point to mask an object
input_point_2 = np.array([[225,150]])
input_label = np.array([1]) #0 or 1 for negative or positive input

input_box = np.array([0,290,925,460])   #box to mask, in xyxy format (left side, top side, right side, and bottom side of box)

#setup mask

PATH_TO_IMG = '/content/drive/MyDrive/2023UGSRP_FSI/Image_Seg./ct/test1/one_img'

sam = sam_model_registry[model_type](checkpoint=checkpoint)
sam.to(device)
predictor = SamPredictor(sam)

# create output folder if it does not exist
if not os.path.exists(PATH_TO_MASK):
    os.makedirs(PATH_TO_MASK)

for filename in os.listdir(PATH_TO_IMG):
  # read image
  img_path = os.path.join(PATH_TO_IMG, filename)
  img = np.array(Image.open(img_path).convert('RGB')) # use of PIL to ensure compatible with PyTorch

  # masking
  predictor.set_image(img)

  mask ,_, _ = predictor.predict(
      point_coords=input_point,
      point_labels=input_label,
      #box = input_box,           #optional box
      multimask_output=False,
  )
  #To mask a second object
  # mask_2 ,_, _ = predictor.predict(
  #    point_coords=input_point_2,
  #    point_labels=input_label_2,
  #    #box = input_box_2,
  #    multimask_output=False,
  #)

  out = mask # |mask_2
  out = np.logical_not(out)     #flips unmasked and masked portions of images

  outname = 'mask_' + filename.replace('.jpg', '.tif')

  # save binary mask to output folder with same filename
  output_path = os.path.join(PATH_TO_MASK, outname)
  tifffile.imwrite(output_path, out)
  print(img_path)

## Automatic Segmentation and Image Processing (Uses Classes/Functions)

Contains a Dataset class to specify and obtain path to images, an Image Processing class to process images and the Masking class to mask images.

### Dataset Class


Set image path and optional mask path and path to process images. Use the class functions to retrieve the image path, mask path and processed image paths.

In [None]:
class dataset():

  #Define all paths as strings. img_path used for the original set of images,
  #processed_img_path will hold the edited images, and mask_path is where to put the masked images
  #Can automatically make the mask and processed img folders

  def __init__(self, img_path,mask_path='',processed_img_path = '',):
    self.image_path = img_path

    self.processed_img_path = processed_img_path

    self.mask_path = mask_path

  def get_img_path(self):

    return self.image_path

  def get_processed_img_path(self):
    #make folder if it doesnt exist
    if not os.path.exists(self.processed_img_path):
        os.makedirs(self.processed_img_path)

    return self.processed_img_path

  def get_mask_path(self):
    #make folder if it doesnt exist
    if not os.path.exists(self.mask_path):
      os.makedirs(self.mask_path)

    return self.mask_path

### Image Processing Class

Edit images by cropping and/or changing the contrast, brightness and/or sharpness. For each function, define the input path for the images to edit, the editing parameters, and output path to store the edited images. The paths should be obtained via the dataset class.



In [None]:
#input and output paths defined with dataset class
class process_img():

#define input folder with original images, output folder to put edited images, and strength of effect (value > 0)
  def contrast(self,input_path,contrast_value,output_path):
    for file in os.listdir(input_path):
      img_path = os.path.join(input_path,file)
      img = Image.open(img_path)
      contraster = ImageEnhance.Contrast(img)
      img_contrast = contraster.enhance(contrast_value)

      out_path = os.path.join(output_path,file)
      img_contrast.save(out_path)

#define input folder with original images, output folder to put edited images, and strength of effect (value > 0)
  def brightness(self,input_path,brightness_value,output_path):
    for file in os.listdir(input_path):
      img_path = os.path.join(input_path,file)
      img = Image.open(img_path)
      brightnesser = ImageEnhance.Brightness(img)
      img_brightness = brightnesser.enhance(brightness_value)

      out_path = os.path.join(output_path,file)
      img_brightness.save(out_path)


#define input folder with original images, output folder to put edited images, and strength of effect (value > 0)
  def sharpness(self,input_path,sharpness_value,output_path):
      for file in os.listdir(input_path):
        img_path = os.path.join(input_path,file)
        img = Image.open(img_path)
        sharpnesser = ImageEnhance.Sharpness(img)
        img_sharpness = sharpnesser.enhance(sharpness_value)

        out_path = os.path.join(output_path,file)
        img_sharpness.save(out_path)

#define input folder with original images, output folder to put edited images, and strength of effect (value > 0)
  def crop(self,input_path,roi,output_path):
      for file in os.listdir(input_path):
        img_path = os.path.join(input_path,file)
        img = Image.open(img_path)
        img_crop = img.crop(roi)

        out_path = os.path.join(output_path,file)
        img_crop.save(out_path)


Example using Dataset class and Image Processing class.

In [None]:
from time import sleep
d = dataset('Flag_and_Foil/one_img',processed_img_path='Flag_and_Foil/one_img_test')
roi  = (270, 400, 1150, 900)

input = d.get_img_path()

output = d.get_processed_img_path()

p = process_img()
p.crop(input,roi,output)
sleep(8) #gives google drive enough time to update files
p.contrast(output,4,output)
sleep(8)
p.brightness(output,2,output)

### Masking Class

#### How to Use the Predictor Functions within the Masking Class:



For using **only** points to mask, any number of points can be used for a singular object and must be of the form:

In [None]:
input_point = [[[200,200],[200,200]]]
input_point[0]

[[200, 200], [200, 200]]

Here, we use 2 points for one object. Calling the first index of the list will give the above output, **such that len(input_point) is 1**. You can also specify multiple points for various objects:


In [1]:
input_point = [[[200,200],[200,200]],[[300,300],[300,300]]]
input_point[0],input_point[1],len(input_point)

([[200, 200], [200, 200]], [[300, 300], [300, 300]], 2)


Here, we use 2 points for each object. The first index of the input_point list labels points for the first object, and the second index for the second object, etc. **Calling len(input_point) should give the number of objects (2 in this case)**.

For each point, there must be a specified label assigned. 1 is for positive inputs, and a -1 is for negative inputs. See the github page for more info. For the points specified above, the input label should look like such:

In [None]:
input_label = [[1,1],[1,1]]

Where we assign a value of 1 for each point. **Again, len(input_label) should be the number of objects.**

The same idea applies for using **only** boxes:

In [2]:
input_box = [[1,2,3,4],[5,6,7,8]]
input_box[0], input_box[1],len(input_box)

([1, 2, 3, 4], [5, 6, 7, 8], 2)

Again, the first index of the input_box list is for the first object, the second index for the second object, and so on. **Using len(input_box) should once again give the number of objects (2 in this case).**

 However, if, for example, you have two objects and you want to use only points for the first object and only boxes for the second, you must place a [0] in the corresponding lists to indicate this:

In [None]:
input_point = [[[200,200]],[0]]
input_label = [[1],[0]]
input_box = [[0],[1,2,3,4]]
len(input_point) == len(input_box) == len(input_label) == 2

True

As shown, a [0] is placed in the second index in input_point and input_label as we dont want to use a point for the second object. Similarly, a [0] is placed in the first index for input_box as we dont want to use boxes to mask the first object.


In general, the [0] is placed in the n_th entry of the list where n is the number of the object where we don't want either a point or box masking for that object.


It is done this way to keep track of the different combinations of points/boxes that are being used for each object being masked, and such that len() of the point, label and box lists are all equal.



Code below, with an example of masking 2 objects, using 1 point each. Path to the image and mask folders must be specified. Additionally, model_type has to be specified and there is an option to invert the mask (change black to white and vice versa). A value of invert=1 inverts the mask, and invert=0 keeps it as is. The default is invert=0.

#### Preview Point on Image

Preview a specified image for masking using the Predictor. Specify image path, plot the image as well as a specific point used for the Predictor masking.

In [None]:
PATH_TO_IMG = ''
put_img_file_here = ''
img_file = os.path.join(PATH_TO_IMG, put_img_file_here)
img = Image.open(img_file)

x,y = 2000,1750   #specify location of point to observe
plt.imshow(img, cmap = 'gray')
plt.scatter(x,y)
print(img_file)

#### Mask Class

Mask images. Set model type. Then choose to mask using points, boxes or acombination. An optional parameter to invert the colors of the mask is available (just set invert=True).

You can either use the mask_generator (the first function) or the mask_predictor (the second, third and fourth functions).

In [None]:
class mask():

  def __init__(self,model_type):

    from segment_anything import SamAutomaticMaskGenerator, sam_model_registry, SamPredictor
    self.model_type = model_type # 'default', "vit_l", "vit_b"

    match self.model_type:

  # define path of downloaded checkpoint(s) from github page
      case 'default':
          checkpoint = 'sam_vit_h_4b8939.pth'

      case 'vit_l':
          checkpoint = 'sam_vit_l_0b3195.pth'

      case 'vit_b':
          checkpoint = 'sam_vit_b_01ec64.pth'

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # if available use GPU for speed-up

    sam = sam_model_registry[model_type](checkpoint=checkpoint)
    sam.to(device)
    predictor = SamPredictor(sam)
    generator = SamAutomaticMaskGenerator(sam)
    self.generator = generator
    self.predictor = predictor

#for automatically creating masks without specifying points/boxes
def mask_generator(self,PATH_TO_IMG,PATH_TO_MASK):

  # create output folder if it does not exist
  if not os.path.exists(PATH_TO_MASK):
    os.makedirs(PATH_TO_MASK)

  for filename in os.listdir(PATH_TO_IMG):
   # read image
    img_path = os.path.join(PATH_TO_IMG, filename)
    img = np.array(Image.open(img_path).convert('RGB'))# use of PIL to ensure compatible with PyTorch

   # masking
    mask = self.generator.generate(img)

    #output image
    out = mask[0]['segmentation']
    out = np.logical_not(out)

  # out = cv2.bitwise_not(out.astype(np.uint8)).astype(bool) # use this in case masked and unmasked areas are flipped
    outname = 'mask_' + filename.replace('.jpg', '.tif')

    # save binary mask to output folder with same filename
    output_path = os.path.join(PATH_TO_MASK, outname)
    tifffile.imwrite(output_path, out)

#for masking using only points, paths defined as strings; points and labels defined as lists
  def mask_point(self,path_to_img,path_to_mask,input_point,input_label,invert=False):
    for filename in os.listdir(path_to_img):
    # read image
        img_path = os.path.join(path_to_img, filename)
        img = np.array(Image.open(img_path).convert('RGB')) # use of PIL to ensure compatible with PyTorch

    # masking
        self.predictor.set_image(img)
        out = 0

        #for loops are for masking multiple different objects
        for i,j in zip(input_point,input_label):

         #masking object with no box
          mask ,_, _ = self.predictor.predict(
                    point_coords=np.array(i),
                    point_labels=np.array(j),
                    multimask_output=False
    )

          #convert masks to type bool and combine all masks
          mask = np.array(mask,dtype=bool)
          out |= mask

        #convert out to type bool
        out = np.array(out,dtype=bool)

        #invert mask colors
        if invert == True:
          out = np.logical_not(out)

        #save binary mask to output folder with same filename
        outname = 'mask_' + filename.replace('.jpg', '.tif')
        output_path = os.path.join(path_to_mask, outname)
        tifffile.imwrite(output_path, out)

#for masking using only boxes; paths defined as strings; box defined with list in xyxy format
  def mask_box(self,path_to_img,path_to_mask,input_box,invert=False):
    for filename in os.listdir(path_to_img):
    # read image
        img_path = os.path.join(path_to_img, filename)
        img = np.array(Image.open(img_path).convert('RGB')) # use of PIL to ensure compatible with PyTorch

    # masking
        self.predictor.set_image(img)
        out = 0

        #for loops are for masking multiple different objects

        for i in zip(input_box):

         #masking object with no box
          mask ,_, _ = self.predictor.predict(
                    box = np.array(i),
                    multimask_output=False
    )
          #convert masks to type bool and combine all masks
          mask = np.array(mask,dtype=bool)
          out |= mask

        #convert out to type bool
        out = np.array(out,dtype=bool)

        #invert mask colors
        if invert == True:
          out = np.logical_not(out)

        #save binary mask to output folder with same filename
        outname = 'mask_' + filename.replace('.jpg', '.tif')
        output_path = os.path.join(path_to_mask, outname)
        tifffile.imwrite(output_path, out)

#for masking using combination of points and boxes; paths defined as strings; points, labels and boxes defined as stated above
  def mask_box_point(self,path_to_img,path_to_mask,input_point,input_label,input_box,invert=False):

    for filename in os.listdir(path_to_img):
    # read image
        img_path = os.path.join(path_to_img, filename)
        img = np.array(Image.open(img_path).convert('RGB')) # use of PIL to ensure compatible with PyTorch

    # masking
        self.predictor.set_image(img)
        out = 0

    #for loops are for masking multiple different object
        for l,m,n in zip(input_point,input_label,input_box):

            if l == [0]: #masking object with no point
                mask ,_, _ = self.predictor.predict(
                    box = np.array(n),
                    multimask_output=False,
    )
            elif n == [0]: #masking object with no box
                mask ,_, _ = self.predictor.predict(
                    point_coords=np.array(l),
                    point_labels=np.array(m),
                    multimask_output=False,
        )
            elif (l !=[0]) & (n !=[0]): #masking object with both point and box
                mask ,_, _ = self.predictor.predict(
                    point_coords=np.array(l),
                    point_labels=np.array(m),
                    box = np.array(n),
                    multimask_output=False
    )

            #convert masks to type bool and combine all masks
            mask = np.array(mask,dtype=bool)
            out |= mask

        #convert out to type bool
        out = np.array(out,dtype=bool)
        #invert mask colors
        if invert == True:
          out = np.logical_not(out)

        #save binary mask to output folder with same filename
        outname = 'mask_' + filename.replace('.jpg', '.tif')
        output_path = os.path.join(path_to_mask, outname)
        tifffile.imwrite(output_path, out)



Example using Dataset class, Image Processing class, and Mask class:

In [None]:
data = dataset('/content/drive/MyDrive/4 3.8f',processed_img_path='Fish/crop',mask_path='Fish/mask_fish_box_point')
img = data.get_img_path()
mask_path = data.get_mask_path()
processed_img_path = data.get_processed_img_path()
box = [[405,340,1175,550]]
input_label = [[1,1]]
point_tail = [[[1020,415],[1020,445]]]

In [None]:
p = process_img()
p.crop(img,[0,0,1600,965],processed_img_path)

In [None]:
mask_fish = mask('default')

In [None]:
mask_fish.mask_box_point(processed_img_path,mask_path,point_tail,input_label,box)