In [5]:
import torch
import cv2
import numpy as np
import matplotlib.pyplot as plt
from segment_anything import sam_model_registry, SamPredictor
from segment_anything.utils.transforms import ResizeLongestSide

In [6]:
device = 'mps'

In [2]:
# Load SAM model
model_type = "vit_h"  # Model type can be "vit_b", "vit_l", or "vit_h"
sam = sam_model_registry[model_type](checkpoint="../../models/annotation/sam_vit_h_4b8939.pth")
sam.to('mps')  # Use 'cpu' if you don't have a GPU

# Create a predictor instance
predictor = SamPredictor(sam)

In [13]:
def generate_grid_points(image_size, points_per_side):
    height, width = image_size
    x_points = torch.linspace(0, width, points_per_side)
    y_points = torch.linspace(0, height, points_per_side)
    
    # Create grid of points
    grid_points = torch.cartesian_prod(x_points, y_points)
    
    return grid_points

In [14]:
# Load the image (you can replace this with your own image path)
image_path = r'../../data/split/train/images/F1_1_1_2.ts-frames_frame-1635.png'
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

points_per_side = 32  # You can modify this number to control the density of points


image_size = image.shape[:2]
grid_points = generate_grid_points(image_size, points_per_side)

# Prepare the input points for the predictor
input_points = grid_points.numpy()
input_labels = np.ones(input_points.shape[0])  # 1 for foreground points

predictor.set_image(image)

# Run the predictor with the grid points
masks, _, _ = predictor.predict(point_coords=input_points, point_labels=input_labels, multimask_output=False)

KeyboardInterrupt: 

In [ ]:
def visualize_grid_points(grid_points):
    # Convert grid points to x, y for plotting
    x_coords, y_coords = grid_points[:, 0].numpy(), grid_points[:, 1].numpy()
    
    # Plot the points on the image
    plt.scatter(x_coords, y_coords, s=10, c='red', marker='o')  # s is size, c is color, marker is shape
    
    plt.show()

# Call the function to visualize
visualize_grid_points(grid_points.cpu())

In [None]:
plt.imshow(masks[0].cpu(), cmap='gray')
plt.title('Generated Mask')
plt.show()

print(f"Generated {masks.shape[0]} masks")