In [None]:
#This program was runening on Google Colab, so the paths are set to /content/data/
#If you are running this on your local machine, change the paths accordingly
#set your cwd as home directory
import os
HOME = os.getcwd()
print("HOME:", HOME)

In [None]:
!nvidia-smi

In [None]:
# Install model if you did not do it yet
!git clone https://github.com/facebookresearch/segment-anything-2.git
!pip install -e . -q

In [None]:
!pip install -q supervision jupyter_bbox_widget

In [None]:
#can install the model weights but I included them in the repo
!mkdir -p {HOME}/checkpoints
!wget -q https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt -P {HOME}/checkpoints
!wget -q https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt -P {HOME}/checkpoints
!wget -q https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt -P {HOME}/checkpoints
!wget -q https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt -P {HOME}/checkpoints

In [None]:
#install roboflow and download the dataset
%cd {HOME}
!pip install roboflow

from roboflow import Roboflow
rf = Roboflow(api_key="your_api_key_here")
project = rf.workspace("roboflow_workspacename").project("projectname")
version = project.version(1)
dataset = version.download("sam2")

In [None]:
#rename the dataset folder to data
import os

os.rename("/content/wing_segment-1", "/content/data")

In [None]:
#change cwd to the segment-anything-2 directory
#import necessary libraries
%cd {HOME}/segment-anything-2
import cv2
import torch
import base64

import numpy as np
import supervision as sv

from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator

In [None]:
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()

if torch.cuda.get_device_properties(0).major >= 8:
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

In [None]:
#save the config file in /content/segment-anything-2/sam2
#build the fine tuned model
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
CHECKPOINT = f"{HOME}finetuned_weight path"
CONFIG = "sam2.1_hiera_b+ (1).yaml"

sam2_model = build_sam2(CONFIG, CHECKPOINT, device=DEVICE, apply_postprocessing=False)

In [None]:
#build the base model amd mask generator
checkpoint_base = f"{HOME}/checkpoints/sam2_hiera_large.pt"
model_cfg_base = "sam2_hiera_l.yaml"
sam2_base = build_sam2(model_cfg_base, checkpoint_base, device="cuda")
mask_generator_base = SAM2AutomaticMaskGenerator(sam2_base)

In [None]:
#build the fine tuned model mask generator
%cd {HOME}
mask_generator = SAM2AutomaticMaskGenerator(sam2_model)

In [None]:
# Load an image from the dataset
import cv2
IMAGE_PATH = f"{HOME}/YOUR_IMAGE_PATH_HERE.jpg"  # Replace with your image path

image_bgr = cv2.imread(IMAGE_PATH)
# Resize the image
image_bgr = cv2.resize(image_bgr, (image_bgr.shape[1] // 4, image_bgr.shape[0] // 4))  # Resize by half
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)

sam2_result = mask_generator.generate(image_rgb)

In [None]:
# generate masks using the fine tuned model
mask_annotator = sv.MaskAnnotator(color_lookup=sv.ColorLookup.INDEX)
detections = sv.Detections.from_sam(sam_result=sam2_result)

annotated_image = mask_annotator.annotate(scene=image_bgr.copy(), detections=detections)

sv.plot_images_grid(
    images=[image_bgr, annotated_image],
    grid_size=(1, 2),
    titles=['source image', 'segmented image']
)

In [None]:
#visualizing results of the base and fine tuned model
import os
import random
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

validation_set = os.listdir("Your path to the validation set")  # Replace with your validation set path

# Filter for image files (e.g., .jpg, .png)
image_files = [img for img in validation_set if img.endswith((".jpg", ".png"))]

for image_file in image_files:
    image_path = os.path.join("Your path to the validation set", image_file)
    opened_image = np.array(Image.open(image_path).convert("RGB"))

    result = mask_generator.generate(opened_image)
    detections = sv.Detections.from_sam(sam_result=result)

    mask_annotator = sv.MaskAnnotator(color_lookup=sv.ColorLookup.INDEX)
    annotated_image = opened_image.copy()
    annotated_image = mask_annotator.annotate(annotated_image, detections=detections)
    base_result = mask_generator_base.generate(opened_image)
    base_detections = sv.Detections.from_sam(sam_result=base_result)

    base_annotator = sv.MaskAnnotator(color_lookup=sv.ColorLookup.INDEX)
    base_annotated_image = opened_image.copy()
    base_annotated_image = base_annotator.annotate(base_annotated_image, detections=base_detections)

    # Plot the images
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))  # Adjust figsize as needed
    axes[0].imshow(annotated_image)
    axes[0].set_title("Fine-tuned SAM")
    axes[0].axis("off")
    axes[1].imshow(base_annotated_image)
    axes[1].set_title("Base SAM")
    axes[1].axis("off")
    plt.tight_layout()
    plt.show()

In [None]:
#generate binary masks for visualization
import os
import random
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import supervision as sv

validation_set = os.listdir("/content/data/valid")
image_files = [img for img in validation_set if img.endswith((".jpg", ".png"))]

# Make sure you are using the original mask_generator_2 object
for image_file in image_files:
    image_path = os.path.join("/content/data/valid", image_file)
    opened_image = np.array(Image.open(image_path).convert("RGB"))

    # Use mask_generator_2 to generate results, and store results in a new variable
    result = mask_generator_2.generate(opened_image)

    # Get binary masks
    masks = [
        mask['segmentation']
        for mask in sorted(result, key=lambda x: x['area'], reverse=True)
    ]

    # Plot the binary masks
    sv.plot_images_grid(
        images=masks[:16],  # Adjust the number of masks to display
        grid_size=(4, 4),  # Adjust grid size as needed
        size=(12, 12)     # Adjust figure size as needed
    )
    plt.show()  # Display the plot for each image

In [None]:
# extra parameters for the mask generator
mask_generator_2 = SAM2AutomaticMaskGenerator(
    model=sam2_model,
    points_per_side=64,
    points_per_batch=128,
    pred_iou_thresh=0.7,
    stability_score_thresh=0.92,
    stability_score_offset=0.7,
    crop_n_layers=1,
    box_nms_thresh=0.7,
)

In [None]:
mask_generator = mask_generator_2.generate(opened_image)

In [None]:
# Generate masks for the validation set and save them to JSON files
validation_set = os.listdir("/content/data/valid")
image_files = [img for img in validation_set if img.endswith((".jpg", ".png"))]
output_dir = "mask_jsons"  # Create a directory to store JSON files
os.makedirs(output_dir, exist_ok=True)

for image_file in image_files:
    image_path = os.path.join("/content/data/valid", image_file)
    opened_image = np.array(Image.open(image_path).convert("RGB"))

    result = mask_generator.generate(opened_image)
    # The result from mask_generator.generate is a list of dictionaries
    # Each dictionary contains the segmentation mask and other information

    # Extract masks directly from the result
    masks = [mask_data['segmentation'] for mask_data in result]

    # Save masks to JSON
    save_masks_to_json(masks, image_path, output_dir)