In [1]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import sys
import glob
import os
import cv2
import time

from PIL import Image

In [2]:
# select the device for computation
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"using device: {device}")

using device: cuda


In [9]:
# segment-anything-2 functions
def show_mask(mask, ax, random_color=False, borders = True):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask = mask.astype(np.uint8)
    mask_image =  mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    if borders:
        import cv2
        contours, _ = cv2.findContours(mask,cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) 
        # Try to smooth contours
        contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
        mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2) 
    ax.imshow(mask_image)

def show_points(coords, labels, ax, marker_size=100):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='.', s=marker_size, edgecolor='black', linewidth=0.5)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='.', s=marker_size, edgecolor='black', linewidth=0.5)   

def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))    

def show_masks(image, masks, scores, point_coords=None, box_coords=None, input_labels=None, borders=True):
    for i, (mask, score) in enumerate(zip(masks, scores)):
        plt.figure(figsize=(10, 10))
        plt.imshow(image)
        show_mask(mask, plt.gca(), borders=borders)
        if point_coords is not None:
            assert input_labels is not None
            show_points(point_coords, input_labels, plt.gca())
        if box_coords is not None:
            # boxes
            show_box(box_coords, plt.gca())
        if len(scores) > 1:
            plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
        plt.axis('off')
        plt.show()

In [4]:
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor

sam2_checkpoint = "/home/wsl/bin/segment-anything-2/checkpoints/sam2_hiera_large.pt"
model_cfg = "sam2_hiera_l.yaml"

sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device)

predictor = SAM2ImagePredictor(sam2_model)

  OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings()


In [2]:
segments = {
    "marginal cell" : [[530, 30]],
    "1st submarginal cell": [[450, 45]],
    "2nd submarginal cell": [[500, 75]],
    "3rd submarginal cell": [[575, 75]],
    "2nd medial cell": [[500, 125]],
    "Forewing lobe": [[550, 200], [650, 100]]
}

all_points = [coord for coords in segments.values() for coord in coords]

input_dir = "/mnt/c/Projects/Master/Data/WingScansUniform/"
output_dir = "/mnt/c/Projects/Master/Data/WingScansSegmented/"

# Ensure the input directory exists
if not os.path.exists(input_dir):
    raise FileNotFoundError(f"Input directory '{input_dir}' was not found.")

# Create the output directory
os.makedirs(output_dir, exist_ok=True)

# List all directories in the specified directory
all_directories = [entry for entry in os.listdir(input_dir)]

for dirname in all_directories:
    if not "Hive" in dirname:
        print(f"Skipping directory: {dirname}")
        continue
        
    print(f"Processing directory: {dirname}")
    input_subdir = input_dir + "/" + dirname + "/"
    output_subdir = output_dir + "/" + dirname + "/"
    
    # Create the output directory
    os.makedirs(output_subdir, exist_ok=True)

    # Find jpg files
    jpg_files = [file for file in os.listdir(input_subdir) if file.endswith('.jpg')]
    for jpg_file in jpg_files:
        input_file = input_subdir + jpg_file
        output_file = output_subdir + jpg_file
        wing = Image.open(input_file)
        wing = np.array(wing.convert("RGB"))
        
        index_offset = 0
        for key, coords in segments.items():
            input_points = np.array(all_points)
    
            input_labels = np.zeros(len(all_points), dtype=int)
            for i in range(len(coords)):
                input_labels[index_offset + i] = 1
            index_offset += len(coords)
            
            predictor.set_image(wing)

            masks, scores, _ = predictor.predict(
            point_coords=input_points,
            point_labels=input_labels,
            multimask_output=False,
            )
            
            show_masks(wing, masks, scores, point_coords=input_points, input_labels=input_labels, )

NameError: name 'os' is not defined