Author - Muntashir Bin Solaiman
Last modified - 22-03-2025

ModelLoader Class - Responsible for loading and setting up the model.

In [None]:
import torch
from segment_anything import sam_model_registry

class ModelLoader:
    def __init__(self, model_type, checkpoint_path, device='cuda:0'):
        self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
        self.model_type = model_type
        self.checkpoint_path = checkpoint_path
        self.sam_model = None
        self._load_model()

    def _load_model(self):
        """Loads the model using the specified checkpoint."""
        self.sam_model = sam_model_registry[self.model_type](checkpoint=self.checkpoint_path)
        self.sam_model.to(device=self.device)

    def get_model(self):
        """Returns the loaded model."""
        return self.sam_model


MaskGenerator Class - Handles the generation of segmentation masks from images.

In [None]:
from segment_anything import SamAutomaticMaskGenerator

class MaskGenerator:
    def __init__(self, model):
        self.mask_generator = SamAutomaticMaskGenerator(model)

    def generate_masks(self, image):
        """Generates segmentation masks for the given image."""
        return self.mask_generator.generate(image)


ImageProcessor Class - Manages the processing of individual images, including loading, converting formats, and passing them through the mask generator.

In [None]:
import cv2

class ImageProcessor:
    def __init__(self, mask_generator):
        self.mask_generator = mask_generator

    def process_image(self, image_path):
        """Loads and processes the image."""
        image_bgr = cv2.imread(image_path)
        image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
        masks = self.mask_generator.generate_masks(image_rgb)
        return image_bgr, masks


BoundingBoxDrawer Class - Draws bounding boxes around detected regions in images.

In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches

# class BoundingBoxDrawer:
#     @staticmethod
#     def draw_bounding_boxes(image_bgr, anns):
#         """Draws bounding boxes around the detected regions."""
#         fig, ax = plt.subplots(1, figsize=(12, 12))
#         ax.imshow(cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB))  # Show original image

#         # Draw bounding boxes
#         for ann in anns.xyxy:
#             x1, y1, x2, y2 = ann
#             ax.add_patch(patches.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, color='red', linewidth=2))

#         plt.show()


In [None]:
class BoundingBoxDrawer:
    @staticmethod
    def draw_bounding_boxes(anns, image_bgr):
        if len(anns) == 0:  # Check if anns is a list and has items
            return

        # If anns is a list of bounding boxes, directly handle it
        areas = [(box[2] - box[0]) * (box[3] - box[1]) for box in anns]  # Calculate area for each box
        sorted_anns = sorted(zip(anns, areas), key=lambda x: x[1], reverse=True)

        fig, ax = plt.subplots(1, figsize=(12, 12))
        ax.imshow(cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB))

        # Draw bounding boxes
        for ann, area in sorted_anns:
            x1, y1, x2, y2 = ann
            ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, color='red', linewidth=2))

        plt.show()


AnnotationDisplay Class - Handles the display of annotations and visualizing results

In [None]:
# class AnnotationDisplay:
#     @staticmethod
#     def show_annotations(detections, image_bgr):
#         """Displays the annotations on the image."""
#         if len(detections.xyxy) == 0:
#             return
        
#         areas = (detections.xyxy[:, 2] - detections.xyxy[:, 0]) * (detections.xyxy[:, 3] - detections.xyxy[:, 1])
#         sorted_anns = sorted(zip(detections.xyxy, areas), key=lambda x: x[1], reverse=True)

#         BoundingBoxDrawer.draw_bounding_boxes(image_bgr, sorted_anns)


In [None]:
class AnnotationDisplay:
    def __init__(self):
        self.mask_annotator = sv.MaskAnnotator(color_lookup=sv.ColorLookup.CLASS)

    def show_annotations(self, masks, image_bgr):
        detections = sv.Detections.from_sam(masks)  # Error: detections is a list
        BoundingBoxDrawer.draw_bounding_boxes(detections, image_bgr)  # Error: detections needs to be handled


SegmentationPipeline Class - A high-level class that integrates the model loading, mask generation, image processing, and annotation display.

In [None]:
import supervision as sv

class SegmentationPipeline:
    def __init__(self, model_loader, mask_generator, image_processor, annotation_display):
        self.model_loader = model_loader
        self.mask_generator = mask_generator
        self.image_processor = image_processor
        self.annotation_display = annotation_display

    def run(self, image_paths):
        """Runs the entire segmentation pipeline."""
        for image_path in image_paths:
            image_bgr, masks = self.image_processor.process_image(image_path)
            detections = sv.Detections.from_sam(masks)  # Assuming sv.Detections.from_sam() is correctly implemented
            self.annotation_display.show_annotations(detections, image_bgr)
        print("Segmentation complete for all images.")


In [None]:
class SegmentationPipeline:
    def __init__(self, model_type, checkpoint_path, device, image_paths):
        self.model_loader = ModelLoader(model_type, checkpoint_path, device)
        self.mask_generator = MaskGenerator(self.model_loader.model)
        self.annotation_display = AnnotationDisplay()
        self.image_paths = image_paths

    def run_pipeline(self):
        for image_path in self.image_paths:
            image_processor = ImageProcessor(image_path)
            masks = self.mask_generator.generate_masks(image_processor.image_rgb)
            self.annotation_display.show_annotations(masks, image_processor.image_bgr)
        print("Segmentation complete for all images.")


DetectionResults Class - Stores and manages the results of detections, including segmentation masks and bounding boxes for each image

In [None]:
class DetectionResults:
    def __init__(self):
        self.results = {}

    def add_results(self, image_path, masks):
        """Stores segmentation results for an image."""
        self.results[image_path] = masks

    def get_results(self):
        """Returns all stored results."""
        return self.results


Example Usage

In [None]:
if __name__ == "__main__":
    # Initialize classes
    model_loader = ModelLoader(model_type="vit_h", checkpoint_path="/home/ad/22021468/sam_checkpoints/sam_vit_h_4b8939.pth")
    mask_generator = MaskGenerator(model_loader.get_model())
    image_processor = ImageProcessor(mask_generator)
    annotation_display = AnnotationDisplay()
    segmentation_pipeline = SegmentationPipeline(model_loader, mask_generator, image_processor, annotation_display)

    # List of image paths to process
    image_paths = [
        "/data/shared/CSIT_Placement_2025_3D_Reef/CBHE_BA2D_P1/images/frame_00001.JPG",
        "/data/shared/CSIT_Placement_2025_3D_Reef/CBHE_BA2D_P1/images/frame_00002.JPG",
    ]

    # Run the segmentation pipeline
    segmentation_pipeline.run(image_paths)
    
