# Import frameworks

In [8]:
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import cv2
import sys

sys.path.append("..")
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

# Reading image and load SAM model

In [9]:
image_original = cv2.imread('img1.jpg')
image_original = cv2.cvtColor(image_original, cv2.COLOR_BGR2RGB)

image = cv2.GaussianBlur(image_original, (5, 5), 0)
clahe = cv2.createCLAHE(clipLimit=4, tileGridSize=(16, 16))
image = np.stack([clahe.apply(image[:, :, i]) for i in range(3)], axis=2)
image_resized = cv2.resize(image, (1024, 768))

sam_checkpoint = 'sam_vit_h_4b8939.pth'
model_type = "vit_h"
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

Using device: cpu


Sam(
  (image_encoder): ImageEncoderViT(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 1280, kernel_size=(16, 16), stride=(16, 16))
    )
    (blocks): ModuleList(
      (0-31): 32 x Block(
        (norm1): LayerNorm((1280,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=1280, out_features=3840, bias=True)
          (proj): Linear(in_features=1280, out_features=1280, bias=True)
        )
        (norm2): LayerNorm((1280,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (lin1): Linear(in_features=1280, out_features=5120, bias=True)
          (lin2): Linear(in_features=5120, out_features=1280, bias=True)
          (act): GELU(approximate='none')
        )
      )
    )
    (neck): Sequential(
      (0): Conv2d(1280, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): LayerNorm2d()
      (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (3): LayerNorm2d

# 

# This function "filter_sinter_stones" filters a list of masks (generated by SAM) to keep only those that meet specific criteria, ensuring that the resulting masks represent meaningful agglomerate stones rather than noise, small fragments, or redundant overlapping regions

In [10]:
def filter_sinter_stones(masks, area_range=(100, 5000), min_aspect_ratio=0.5, min_iou_with_largest=0.3):
    filtered_masks = []
    if not masks:
        return filtered_masks
    
    sorted_masks = sorted(masks, key=lambda x: x['area'], reverse=True)
    
    for i, mask in enumerate(sorted_masks):
        area = mask['area']
        bbox = mask['bbox']
        width = bbox[2]
        height = bbox[3]
        aspect_ratio = min(width / height, height / width)
        
        is_redundant = False
        for larger_mask in sorted_masks[:i]:
            larger_seg = larger_mask['segmentation']
            current_seg = mask['segmentation']
            intersection = np.logical_and(larger_seg, current_seg).sum()
            union = np.logical_or(larger_seg, current_seg).sum()
            iou = intersection / union if union > 0 else 0
            if iou > min_iou_with_largest:
                is_redundant = True
                break
        
        if (area_range[0] <= area <= area_range[1] and 
            aspect_ratio >= min_aspect_ratio and 
            not is_redundant):
            filtered_masks.append(mask)
    
    return filtered_masks

# function "merge_overlapping_masks" takes a list of masks and merges masks that overlap significantly into a single mask

In [11]:
def merge_overlapping_masks(masks, iou_threshold):
    if not masks:
        return masks
    
    merged_masks = []
    remaining_masks = masks.copy()
    
    while remaining_masks:
        current_mask = remaining_masks.pop(0)
        current_seg = current_mask['segmentation']
        current_bbox = current_mask['bbox']
        
        overlapping = []
        for i, other_mask in enumerate(remaining_masks):
            other_seg = other_mask['segmentation']
            intersection = np.logical_and(current_seg, other_seg).sum()
            union = np.logical_or(current_seg, other_seg).sum()
            iou = intersection / union if union > 0 else 0
            
            if iou > iou_threshold:
                overlapping.append(i)
        
        for idx in sorted(overlapping, reverse=True):
            other_mask = remaining_masks.pop(idx)
            current_seg = np.logical_or(current_seg, other_mask['segmentation'])
            other_bbox = other_mask['bbox']
            current_bbox = [
                min(current_bbox[0], other_bbox[0]),
                min(current_bbox[1], other_bbox[1]),
                max(current_bbox[0] + current_bbox[2], other_bbox[0] + other_bbox[2]) - min(current_bbox[0], other_bbox[0]),
                max(current_bbox[1] + current_bbox[3], other_bbox[1] + other_bbox[3]) - min(current_bbox[1], other_bbox[1])
            ]
        
        current_mask['segmentation'] = current_seg
        current_mask['bbox'] = current_bbox
        merged_masks.append(current_mask)
    
    return merged_masks

# function "exclude_text_regions" its for excluding a region with text at image

In [12]:
def exclude_text_regions(masks, image_shape, text_regions):
    height, width = image_shape[:2]
    filtered_masks = []
    
    for mask in masks:
        m = mask['segmentation']
        mask_height, mask_width = m.shape
        overlaps_text = False
        
        for x, y, w, h in text_regions:
            mask_x, mask_y = int(x * mask_width / width), int(y * mask_height / height)
            mask_w, mask_h = int(w * mask_width / width), int(h * mask_height / height)
            
            if np.any(m[max(0, mask_y):min(mask_height, mask_y + mask_h),
                        max(0, mask_x):min(mask_width, mask_x + mask_w)]):
                overlaps_text = True
                break
        
        if not overlaps_text:
            filtered_masks.append(mask)
    
    return filtered_masks

In [13]:
def color_segmentation(masks, base_image):
    segmented_image = base_image.copy()
    colors = [[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 1, 0], [0, 1, 1], [1, 0, 1]]
    
    for i, mask in enumerate(masks):
        m = mask['segmentation']
        m_resized = cv2.resize(m.astype(np.uint8), (segmented_image.shape[1], segmented_image.shape[0]), interpolation=cv2.INTER_NEAREST)
        color = colors[i % len(colors)]
        color_mask = np.zeros_like(segmented_image, dtype=np.float32)
        
        for c in range(3):
            color_mask[:, :, c] = m_resized * color[c]
        color_mask = np.uint8(color_mask * 0.3 * 255)
        segmented_image = cv2.addWeighted(segmented_image, 1, color_mask, 0.7, 0)
    
    return segmented_image

def draw_bounding_boxes(image, masks, color=(0, 255, 0), thickness=1):
    result_image = image.copy()
    processed_bboxes = set()
    
    for mask in masks:
        bbox = mask['bbox']
        bbox_tuple = (int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3]))
        if bbox_tuple not in processed_bboxes:
            x_min, y_min, w, h = bbox_tuple
            cv2.rectangle(result_image, (x_min, y_min), (x_min + w, y_min + h), color, thickness)
            processed_bboxes.add(bbox_tuple)
    
    return result_image

# Mask generation with SAM and parameters

In [16]:
mask_generator1_ = SamAutomaticMaskGenerator(
    model=sam,
    points_per_side=64,
    pred_iou_thresh=0.7,
    stability_score_thresh=0.7,
    crop_n_layers=1,
    crop_n_points_downscale_factor=2,
    min_mask_region_area=50,  # Increased from 1 to reduce small masks
)

masks_original = mask_generator1_.generate(image_resized)
print(f"Number of masks generated: {len(masks_original)}")

Number of masks generated: 662


In [18]:
text_regions = [(0, 0, 300, 50), (image_resized.shape[1] - 300, image_resized.shape[0] - 70, 300, 50)]

filtered_masks = exclude_text_regions(masks_original, image_resized.shape, text_regions)
sinter_stone_masks = filter_sinter_stones(filtered_masks, area_range=(100, 5000), min_aspect_ratio=0.5, min_iou_with_largest=0.3)
sinter_stone_masks = merge_overlapping_masks(sinter_stone_masks, iou_threshold=0.7)

# Generate segmented image and add bounding boxes
segmented_image = color_segmentation(sinter_stone_masks, image_resized)
segmented_image_with_boxes = draw_bounding_boxes(segmented_image, sinter_stone_masks, color=(0, 255, 0), thickness=2)

# Display and save results
plt.figure(figsize=(50, 25))
plt.imshow(segmented_image_with_boxes)
plt.axis('off')
plt.savefig('segmented_image_with_one_box_per_stone6.png')
plt.close()