In [None]:
!pip install git+https://github.com/facebookresearch/segment-anything.git
!pip install torch torchvision
!pip install opencv-python pycocotools matplotlib

Collecting git+https://github.com/facebookresearch/segment-anything.git
  Cloning https://github.com/facebookresearch/segment-anything.git to /tmp/pip-req-build-2txrop7t
  Running command git clone --filter=blob:none --quiet https://github.com/facebookresearch/segment-anything.git /tmp/pip-req-build-2txrop7t
  Resolved https://github.com/facebookresearch/segment-anything.git to commit dca509fe793f601edb92606367a655c15ac00fdf
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: segment_anything
  Building wheel for segment_anything (setup.py) ... [?25l[?25hdone
  Created wheel for segment_anything: filename=segment_anything-1.0-py3-none-any.whl size=36592 sha256=85a3a39e28e437aea2c0215838844e94deac836165262da0f3a7a00eae4c2837
  Stored in directory: /tmp/pip-ephem-wheel-cache-i34e182g/wheels/10/cf/59/9ccb2f0a1bcc81d4fbd0e501680b5d088d690c6cfbc02dc99d
Successfully built segment_anything
Installing collected packages: segment_anything
  Attempting 

In [None]:
!wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth

--2024-10-31 13:38:45--  https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth
Resolving dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)... 3.163.189.108, 3.163.189.96, 3.163.189.51, ...
Connecting to dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)|3.163.189.108|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1249524607 (1.2G) [binary/octet-stream]
Saving to: ‘sam_vit_l_0b3195.pth’


2024-10-31 13:38:58 (89.9 MB/s) - ‘sam_vit_l_0b3195.pth’ saved [1249524607/1249524607]



In [None]:
import torch
import torchvision
import cv2
import numpy as np
from segment_anything import sam_model_registry, SamPredictor
import matplotlib.pyplot as plt
from IPython.display import clear_output
import ipywidgets as widgets
from PIL import Image

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
sam = sam_model_registry["vit_l"](checkpoint="/Users/armandbryan/Documents/aivancity/PGE5/Deployment AI/Project/weights/sam_vit_h_4b8939.pth")
sam.to(device)
predictor = SamPredictor(sam)

In [19]:
class ImageAnnotator:
    def __init__(self):
        self.points = []
        self.labels = []
        self.current_class = 1
        self.masks = []
        self.class_names = ['void', 'puce']
        self.mask_colors = {
            0: [0, 0, 1],    # Blue for void
            1: [1, 0, 0]     # Red for puce
        }

    def setup_interface(self, image_path):
        # Load image
        self.image = cv2.imread(image_path)
        if self.image is None:
            raise Exception(f"Failed to load image from {image_path}")
        self.image = cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB)

        # Set image for predictor
        predictor.set_image(self.image)

        # Create figures for original and mask visualization
        plt.close('all')
        self.fig = plt.figure(figsize=(18, 8))

        # Create two subplots side by side
        self.ax1 = self.fig.add_subplot(121)  # Original image with points
        self.ax2 = self.fig.add_subplot(122)  # Mask visualization

        # Display initial images
        self.plot_images()

        # Create buttons
        self.class_button = widgets.Button(
            description=f'Current Class: {self.class_names[self.current_class]}',
            button_style='info',
            layout=widgets.Layout(width='150px')
        )
        self.class_button.on_click(self.toggle_class)

        self.clear_button = widgets.Button(
            description='Clear Points',
            button_style='warning',
            layout=widgets.Layout(width='150px')
        )
        self.clear_button.on_click(self.clear_points)

        self.save_button = widgets.Button(
            description='Save Masks',
            button_style='success',
            layout=widgets.Layout(width='150px')
        )
        self.save_button.on_click(self.save_masks)

        # Display buttons
        display(widgets.HBox([self.class_button, self.clear_button, self.save_button]))

        # Connect mouse event
        self.fig.canvas.mpl_connect('button_press_event', self.on_click)

    def toggle_class(self, _):
        self.current_class = 1 - self.current_class
        self.class_button.description = f'Current Class: {self.class_names[self.current_class]}'

    def clear_points(self, _):
        self.points = []
        self.labels = []
        self.masks = []
        self.plot_images()

    def on_click(self, event):
        if event.inaxes == self.ax1:  # Only respond to clicks on the left image
            self.points.append([event.xdata, event.ydata])
            self.labels.append(self.current_class)

            input_points = np.array(self.points)
            input_labels = np.array(self.labels)

            masks, scores, logits = predictor.predict(
                point_coords=input_points,
                point_labels=input_labels,
                multimask_output=False
            )

            self.masks.append((masks[0], self.current_class))
            self.plot_images()

    def plot_images(self):
        # Clear both axes
        self.ax1.clear()
        self.ax2.clear()

        # Plot original image with points
        self.ax1.imshow(self.image)
        for point, label in zip(self.points, self.labels):
            color = 'red' if label == 1 else 'blue'
            self.ax1.plot(point[0], point[1], 'o', color=color, markersize=10)
        self.ax1.set_title('Click to add points (Red: Puce, Blue: Void)')

        # Plot masks
        mask_overlay = np.zeros_like(self.image, dtype=float)
        for mask, label in self.masks:
            color = self.mask_colors[label]
            mask_rgb = np.zeros((*mask.shape, 3))
            mask_rgb[mask] = color
            mask_overlay += mask_rgb

        # Normalize overlay to handle overlapping masks
        mask_overlay = np.clip(mask_overlay, 0, 1)

        # Plot combined visualization
        combined_img = self.image.copy() / 255.0
        mask_pixels = mask_overlay.sum(axis=-1) > 0
        combined_img[mask_pixels] = combined_img[mask_pixels] * 0.5 + mask_overlay[mask_pixels] * 0.5

        self.ax2.imshow(combined_img)
        self.ax2.set_title('Segmentation Masks')

        plt.draw()

    def save_masks(self, _):
      if not self.masks:
          print("No masks to save!")
      return

    # Prepare directories
      base_dir = os.path.dirname(image_path)
      labels_dir = os.path.join(base_dir, 'labels')
      masks_dir = os.path.join(base_dir, 'masks')
      os.makedirs(labels_dir, exist_ok=True)
      os.makedirs(masks_dir, exist_ok=True)

      # Get image dimensions
      height, width = self.image.shape[:2]

     # Create separate binary masks for each class
      void_mask = np.zeros(self.image.shape[:2], dtype=bool)
      puce_mask = np.zeros(self.image.shape[:2], dtype=bool)

    # Base filename
      base_name = os.path.splitext(os.path.basename(image_path))[0]
      label_file_path = os.path.join(labels_dir, f'{base_name}.txt')

    # Open label file for writing YOLO format
      with open(label_file_path, 'w') as f:
          for mask, label in self.masks:
              if label == 0:  # void
                  void_mask |= mask
              else:  # puce
                  puce_mask |= mask

            # Find contours for segmentation mask
              contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

              for contour in contours:
                  # Flatten and normalize contour points
                  polygon = contour.squeeze()
                  normalized_polygon = polygon / [width, height]

                  # Flatten polygon to a list of coordinates
                  polygon_coords = normalized_polygon.flatten().tolist()

                  # Write YOLO format: class_index x1 y1 x2 y2 ...
                  coords_str = ' '.join(map(str, polygon_coords))
                  f.write(f"{label} {coords_str}\n")

    # Save individual masks (optional)
      cv2.imwrite(os.path.join(masks_dir, f'{base_name}_void.png'), void_mask.astype(np.uint8) * 255)
      cv2.imwrite(os.path.join(masks_dir, f'{base_name}_puce.png'), puce_mask.astype(np.uint8) * 255)

      print(f"YOLO-format annotations saved in {label_file_path}")
      print(f"Masks saved in {masks_dir}")

In [None]:
# Initialize the annotator and load your image
annotator = ImageAnnotator()
annotator.setup_interface('/content/drive/MyDrive/Void_detection_on_X_ray/valid/025_JPG.rf.b2cdc2d984adff593dc985f555b8d280.jpg')