## **Mask exporting from SAM**

This notebook contains scripts necessary for using Segment Anything (SAM) to perform non-semantic segmentation on scan photos, exporting several types of segmented images:

(for each image)
1. Overall segmented image with mask overlays in random colors
2. Separate binary masks for each segmented region
3. Separate RGB images of the original photos with outlines of each segmented region

The scripts provided here perform some rudimentary culling of overlapping segmented regions and very small regions.

Requirements
- A machine with a CUDA-enabled GPU is better

## Setup

In [None]:
import sys
import subprocess

# cuda_version = "cu128" # if you have a cuda-enabled card, else use cpu version
# index_url = f"https://download.pytorch.org/whl/{cuda_version}"

subprocess.check_call([sys.executable, "-m", "pip", "install", "-r", "requirements.txt"])
# subprocess.check_call([sys.executable, "-m", "pip", "install", "opencv-python"])
# subprocess.check_call([sys.executable, "-m", "pip", "install", "git+https://github.com/facebookresearch/segment-anything.git"])
# subprocess.check_call([sys.executable, "-m", "pip", "install", "supervision"])
# subprocess.check_call([sys.executable, "-m", "pip", "install", "matplotlib"])
# subprocess.check_call([sys.executable, "-m", "pip", "install", "pillow"])
# subprocess.check_call([sys.executable, "-m", "pip", "install", "torch", "--index-url", index_url])
# subprocess.check_call([sys.executable, "-m", "pip", "install", "torchvision", "--index-url", index_url])
# subprocess.check_call([sys.executable, "-m", "pip", "install", "torchaudio", "--index-url", index_url])



In [None]:
# Import necessary libraries
import torch
import cv2
import numpy as np
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
import matplotlib.pyplot as plt
from PIL import Image
import supervision as sv
import math
import os
import requests

In [None]:
def download_file(url, target_directory, filename=None):
    """
    Downloads a file from a given URL to a specified directory.

    Args:
        url (str): The URL of the file to download.
        target_directory (str): The path to the directory where the file will be saved.
        filename (str, optional): The name to save the file as. If None,
                                  the filename will be extracted from the URL.
    """
    if not os.path.exists(target_directory):
        os.makedirs(target_directory)  # Create the directory if it doesn't exist

    if filename is None:
        filename = url.split('/')[-1]  # Extract filename from URL

    save_path = os.path.join(target_directory, filename)

    try:
        response = requests.get(url, stream=True)
        response.raise_for_status()  # Raise an exception for bad status codes (4xx or 5xx)

        with open(save_path, 'wb') as f:
            for chunk in response.iter_content(chunk_size=8192):
                f.write(chunk)
        print(f"File downloaded successfully to: {save_path}")

    except requests.exceptions.RequestException as e:
        print(f"Error downloading file: {e}")

# Example usage:
sam_model_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"  
download_folder = f"C:/Repos/GL-Material-Classification/Notebooks"  
output_filename = sam_model_url.split('/')[-1]
sam_model_path = f"{download_folder}/{output_filename}"
if not os.path.exists(sam_model_path):
    download_file(sam_model_url, download_folder, output_filename)

In [None]:
# Load SAM model
device = torch.device('cuda')# torch.device('cuda' if torch.cuda.is_available() else 'cpu')
sam = sam_model_registry["default"](checkpoint=sam_model_path)  # Use appropriate checkpoint

print(torch.__version__)

sam.to(device=device)
print(str(next(sam.parameters()).device))

mask_generator = SamAutomaticMaskGenerator(sam)

# Define Variables

### File locations

In [None]:
batch_input_dir = 'C:/tmp_img_process/batch-in' #for batch processing, the directory that contains all photos to segment
batch_output_dir = 'C:/tmp_img_process/batch-out'#for batch processing, the sub-directory to put segmentations into
input_image_path = 'C:/tmp_img_process/test.jpg' #for a single image test, the input image
overlay_images_output_dir = 'C:/tmp_img_process/overlays' #for a single image test, output directory for overlays
output_image_path = 'C:/tmp_img_process/composite_image.jpg' #for a signle image test, output path for the composite segmented image
test_output_masks_dir = 'C:/tmp_img_process/masks' # for a single image test, the output directory for masks
overlay_image_path = 'C:/tmp_img_process/overlay_image.jpg' #output for single image overlay 


### Settings

In [None]:

line_thickness = 3 #thickness of the boundaries around masks
erosion_iterations = int(line_thickness / 2) #used for offsetting the mask boundaries inwards so they are legible and reduces overlapping lines
min_mask_area = 10000 #used for culling out very small masks

# Definitions

In [None]:
def cull_small_masks(masks):
  final_masks = []
  for i, mask in enumerate(masks):
    if mask["area"] < min_mask_area:
      continue
    final_masks.append(mask)
  return final_masks
def recalculate_mask_areas(masks):
  for mask in masks:
    mask["area"] = np.sum(get_binary_mask(mask))

def extract_masks_only(masks):
  final_masks = []
  for mask in masks:
    final_masks.append(mask['segmentation'])
  return final_masks
def order_masks_by_size_descending(masks):
  masks.sort(reverse = True, key = get_mask_area)

def get_mask_area(mask):
  return mask['area']

def update_mask_area(mask):
  mask['area'] = np.sum(get_binary_mask(mask))

def get_binary_mask(mask):
  return mask['segmentation']

def set_binary_mask(mask, binary_mask):
  mask['segmentation'] = binary_mask

#note: this does not produced desired result and needs to be refined!
def handle_mask_intersections_by_subtraction(masks):
    final_masks = []
    #masks is a list of dictionaries
    masks_list = list(masks)  # Ensure masks is a list
    #print(type(masks[0]['segmentation']))
    #print(masks[0]['segmentation'])
    while masks_list:
        base_mask = masks_list.pop(0)
        base_binary_mask = get_binary_mask(base_mask)

        for i, mask in enumerate(masks_list):
            #print(f"mask type pre subtraction {type(get_binary_mask(base_mask))}")
            current_binary_mask = get_binary_mask(mask)
            intersection = np.logical_and(base_binary_mask, current_binary_mask)
            if np.any(intersection):  # Check if there is any intersection
                subtracted = np.array(np.logical_and(base_binary_mask, np.logical_not(current_binary_mask)), dtype=bool)
                set_binary_mask(base_mask, subtracted)
                base_binary_mask = get_binary_mask(base_mask)

        update_mask_area(base_mask)
        final_masks.append(base_mask)
    return final_masks

def segment_image(input_image_path):
  # Read input image
  image = cv2.imread(input_image_path)
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  height, width, _ = image.shape
  # print(height)
  # print(width)

  # Generate masks
  masks = mask_generator.generate(image)
  #print(masks[0])
  #print(f"number of masks: {len(masks)}")
  return masks, height, width

# Function to generate a composite image with segment masks
def generate_composite_image(masks, height, width, output_image_path):
  # Create a blank image for the composite output
  composite_image = np.zeros((height, width, 3), dtype=np.uint8)

  # Generate unique colors for each mask
  colors = np.random.randint(0, 255, (len(masks), 3), dtype=np.uint8)

  for i, mask in enumerate(masks):
    color = colors[i]
    # Apply color to the mask area in the composite image
    composite_image[mask['segmentation']] = color

  # Save the composite image
  cv2.imwrite(output_image_path, cv2.cvtColor(composite_image, cv2.COLOR_RGB2BGR))

  return composite_image


def copy_input_img(input_image_path):
  # Read input image
  image = cv2.imread(input_image_path)
  image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  height, width, _ = image_rgb.shape
  return image_rgb.copy(), height, width

def draw_single_boundary(boundaries_image, width, height, binary_mask, color, line_thickness, erosion_iterations):
  mask = binary_mask
  # Ensure the mask has the same dimensions as the original image
  if mask.shape != (height, width):
      mask = cv2.resize(mask.astype(np.uint8), (width, height), interpolation=cv2.INTER_NEAREST)

  #apply erosion to mask to offset the boundary inside so that neighboring boundaries clip each other less
  eroded_mask = cv2.erode(mask.astype(np.uint8), np.ones((3, 3), np.uint8), iterations=erosion_iterations)

  # Find contours of the mask
  contours, _ = cv2.findContours(eroded_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

  # Draw contours on the boundaries image
  cv2.drawContours(boundaries_image, contours, -1, color, line_thickness)

def write_image(output_image_path, image):
  cv2.imwrite(output_image_path, cv2.cvtColor(image, cv2.COLOR_RGB2BGR))

# Function to overlay mask boundaries on the original image
def overlay_boundaries_composite(masks, input_image_path, output_image_path, line_thickness=2, erosion_iterations=1):
  boundaries_image, height, width = copy_input_img(input_image_path)

  # Generate unique colors for each boundary
  colors = np.random.randint(0, 255, (len(masks), 3), dtype=np.uint8)

  for i, mask_data in enumerate(masks):
      color = colors[i].tolist()
      draw_single_boundary(boundaries_image, width, height, get_binary_mask(mask_data), color, line_thickness, erosion_iterations)

  # Overlay boundaries on the original image
  overlay_image = boundaries_image 

  # Save the composite image
  write_image(output_image_path, overlay_image)

  return overlay_image

def overlay_boundaries_single(masks, input_image_path, output_images_dir, line_thickness=2, erosion_iterations=1):
  color = [255,0,0]
  for i, mask_data in enumerate(masks):
    boundaries_image, height, width = copy_input_img(input_image_path)
    draw_single_boundary(boundaries_image, width, height, get_binary_mask(mask_data), color, line_thickness, erosion_iterations)
    output_path = f"{output_images_dir}/{str(i).zfill(3)}.jpg"
    write_image(output_path, boundaries_image)

def save_masks_single(masks, output_images_dir):
  for i, mask_data in enumerate(masks):
    # Convert the 'true'/'false' mask to a numerical mask (0 for 'false', 1 for 'true')
    numerical_mask = np.where(get_binary_mask(mask_data), 1, 0)
    # Convert the numerical mask to an 8-bit image (0-255)
    img = (numerical_mask * 255).astype(np.uint8)
    # Create an image from the NumPy array
    image = Image.fromarray(img, mode='L')  # 'L' mode is for grayscale
    image.save(f"{output_images_dir}/{str(i).zfill(3)}.png")

def single_image_process(input_image_path):
  head, tail = os.path.split(input_image_path)
  name_parts = tail.split('.')
  name = name_parts[0]
  masks, height, width = segment_image(input_image_path)
  # print(f"number of initial masks: {len(masks)}")
  # print(masks[0]['area'])
  # print(masks[len(masks) - 1]['area'])
  order_masks_by_size_descending(masks)
  # print(masks[0]['area'])
  # print(masks[len(masks) - 1]['area'])
  masks = cull_small_masks(masks)
  #print(f"number of masks after size cull 1: {len(masks)}")
  masks = handle_mask_intersections_by_subtraction(masks)
  #print(f"number of masks after subtraction: {len(masks)}")
  recalculate_mask_areas(masks)
  masks = cull_small_masks(masks)
  #print(f"number of final masks: {len(masks)}")
  composite_image_final = generate_composite_image(masks, height, width, f"{batch_output_dir}/{name}_segment.jpg")
  overlay_image_final = overlay_boundaries_composite(masks, input_image_path, f"{batch_output_dir}/{name}_overlay.jpg", line_thickness, erosion_iterations)
  boundaries_dir = f"{batch_output_dir}/{name}/outlines"
  masks_dir = f"{batch_output_dir}/{name}/masks"
  if not os.path.exists(boundaries_dir):
    os.makedirs(boundaries_dir)
  if not os.path.exists(masks_dir):
    os.makedirs(masks_dir)
  overlay_boundaries_single(masks, input_image_path, boundaries_dir, line_thickness, erosion_iterations)
  save_masks_single(masks, masks_dir)

def plot_masks(masks, column_ct):
  binary_masks = []
  #print(masks[0])
  for mask in masks:
    binary_masks.append(get_binary_mask(mask))
  #print(binary_masks[0])
  row_ct = math.ceil(len(masks)/column_ct)
  sv.plot_images_grid(images = binary_masks, grid_size = (row_ct, column_ct), size=(16, 16))


# Single Image Test

In [None]:
# Generate composite image
masks, height, width = segment_image(input_image_path)
print(f"number of initial masks: {len(masks)}")
print(masks[0]['area'])
print(masks[len(masks) - 1]['area'])
order_masks_by_size_descending(masks)
print(masks[0]['area'])
print(masks[len(masks) - 1]['area'])
masks = cull_small_masks(masks)
print(f"number of masks after size cull 1: {len(masks)}")
#masks = extract_masks_only(masks)
composite_image_prelim = generate_composite_image(masks, height, width, output_image_path)
overlay_image_prelim = overlay_boundaries_composite(masks, input_image_path, overlay_image_path, line_thickness, erosion_iterations)
print(masks[0])
print(f"number of big masks: {len(masks)}")
masks = handle_mask_intersections_by_subtraction(masks)
print(f"number of masks after subtraction: {len(masks)}")
recalculate_mask_areas(masks)
masks = cull_small_masks(masks)
print(f"number of masks after size cull 2: {len(masks)}")
print(f"number of final masks: {len(masks)}")
composite_image_final = generate_composite_image(masks, height, width, output_image_path)
overlay_image_final = overlay_boundaries_composite(masks, input_image_path, overlay_image_path, line_thickness, erosion_iterations)
overlay_boundaries_single(masks, input_image_path, overlay_images_output_dir, line_thickness, erosion_iterations)
if not os.path.exists(test_output_masks_dir):
    os.makedirs(test_output_masks_dir)
save_masks_single(masks, test_output_masks_dir)

In [None]:


# Display the composite image
plt.imshow(overlay_image_prelim)
plt.axis('off')
plt.show()

In [None]:
# Display the composite image
plt.imshow(composite_image_prelim)
plt.axis('off')
plt.show()


In [None]:
# Display the composite image
plt.imshow(overlay_image_final)
plt.axis('off')
plt.show()

In [None]:
# Display the composite image
plt.imshow(composite_image_final)
plt.axis('off')
plt.show()


In [None]:
plot_masks(masks, 4)

# Batch Process

In [None]:
files = []
for file in os.listdir(batch_input_dir):
  if file.endswith(".jpg"):
      files.append(os.path.join(batch_input_dir, file))
i = 1
c = len(files)
for file in files:
  single_image_process(file)
  print(f"completed image {i} of {c}")
  i = i + 1