In [4]:
import os
import cv2
import torch
import numpy as np
from segment_anything import sam_model_registry, SamPredictor

class PetSegmentationPointPrompt:
    def __init__(self, model_type="vit_h", checkpoint_path="../model/sam_checkpoints/sam_vit_h_4b8939.pth"):
        # Device configuration
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        print(f"Using device: {self.device}")
        
        # Load SAM model
        self.sam = sam_model_registry[model_type](checkpoint=checkpoint_path)
        self.sam.to(device=self.device)
        
        # Initialize predictor
        self.predictor = SamPredictor(self.sam)

    def extract_pet_mask(self, image_path):
        """
        Extract pet mask using center point prompt
        """
        # Read image
        image = cv2.imread(image_path)
        image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # Set image in predictor
        self.predictor.set_image(image_rgb)
        
        # Get image center point
        height, width = image.shape[:2]
        center_point = np.array([[width//2, height//2]])
        
        # Generate mask from point prompt
        point_label = np.array([1])  # 1 indicates foreground point
        
        masks, scores, logits = self.predictor.predict(
            point_coords=center_point,
            point_labels=point_label,
            multimask_output=True
        )
        
        areas = np.sum(masks, axis=(1, 2))
        # Combine areas and scores for ranking
        # Normalize areas to 0-1 range
        normalized_areas = (areas - np.min(areas)) / (np.max(areas) - np.min(areas))
        # Combine normalized areas and scores
        area_wight = 0.6
        combined_scores = area_wight * normalized_areas + (1 - area_wight) * scores
        # Select best mask based on score
        best_mask_idx = np.argmax(combined_scores)
        best_mask = masks[best_mask_idx]
        
        # Create visualization
        mask_visualization = self.create_mask_visualization(image_rgb, best_mask)
        
        return best_mask, mask_visualization
    
    def create_mask_visualization(self, image, mask):
        visualization = image.copy()
        color = np.array([30, 144, 255])  # Dodger Blue
        
        mask_overlay = visualization.copy()
        mask_overlay[mask] = color
        
        alpha = 0.5
        visualization = cv2.addWeighted(visualization, 1 - alpha, mask_overlay, alpha, 0)
        return visualization

    def process_demo_dataset(self, input_dir, output_dir, num_images=10):
        # Create output directories
        os.makedirs(output_dir, exist_ok=True)
        os.makedirs(os.path.join(output_dir, 'masks'), exist_ok=True)
        os.makedirs(os.path.join(output_dir, 'visualizations'), exist_ok=True)
        
        # Get image files
        image_files = [f for f in os.listdir(input_dir) 
                      if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
        demo_files = image_files[:num_images]
        
        results = []
        
        for filename in demo_files:
            input_path = os.path.join(input_dir, filename)
            
            try:
                # Extract pet mask using point prompt
                pet_mask, mask_visualization = self.extract_pet_mask(input_path)
                
                if pet_mask is not None:
                    # Original image
                    image = cv2.imread(input_path)
                    
                    # Create masked image
                    masked_image = image.copy()
                    masked_image[~pet_mask] = [255, 255, 255]  # White background
                    
                    # Save outputs
                    mask_output_path = os.path.join(output_dir, 'masks', filename)
                    cv2.imwrite(mask_output_path, (pet_mask * 255).astype(np.uint8))
                    
                    masked_image_path = os.path.join(output_dir, filename)
                    cv2.imwrite(masked_image_path, masked_image)
                    
                    viz_path = os.path.join(output_dir, 'visualizations', filename)
                    cv2.imwrite(viz_path, cv2.cvtColor(mask_visualization, cv2.COLOR_RGB2BGR))
                    
                    results.append({
                        'filename': filename,
                        'mask_path': mask_output_path,
                        'masked_image_path': masked_image_path,
                        'visualization_path': viz_path
                    })
                    
                    print(f"Processed {filename}")
                    
            except Exception as e:
                print(f"Error processing {filename}: {e}")
                import traceback
                traceback.print_exc()
                
        return results

def main():
    # Initialize segmentor
    segmentor = PetSegmentationPointPrompt()
    
    # Process demo dataset
    demo_results = segmentor.process_demo_dataset(
        input_dir='../dataset/trainset',
        output_dir='../dataset/preprocessed_trainset_point_prompt_demo',
        num_images=15
    )

if __name__ == "__main__":
    main()

Using device: cuda


  state_dict = torch.load(f)


Processed A*5wfcQaqfSZtvQ0hKmwjDrwAAAQAAAQ.jpg
Processed A*RUTcRqGoadQAAAAAAAAAAAAAAQAAAQ.jpg
Processed A*yhhMRIVoNgIAAAAAAAAAAAAAAQAAAQ.jpg
Processed A*7DHLR7A9-PcAAAAAAAAAAAAAAQAAAQ.jpg
Processed A*B0UFQY5-NWAAAAAAAAAAAAAAAQAAAQ.jpg
Processed A*Ilm_R4vQBwAAAAAAAAAAAAAAAQAAAQ.jpg
Processed A*AcLoQ62XAiEAAAAAAAAAAAAAAQAAAQ.jpg
Processed A*7LqiQ6pid-MAAAAAAAAAAAAAAQAAAQ.jpg
Processed A*zO1xSLQ2HekAAAAAAAAAAAAAAQAAAQ.jpg
Processed A*mn91SptgaVkAAAAAAAAAAAAAAQAAAQ.jpg
Processed A*Vj18TaEJwkYAAAAAAAAAAAAAAQAAAQ.jpg
Processed A*XAoET6NgEiEAAAAAAAAAAAAAAQAAAQ.jpg
Processed A*wsZqTbPRGS8AAAAAAAAAAAAAAQAAAQ.jpg
Processed A*6rpyT7SHRbEAAAAAAAAAAAAAAQAAAQ.jpg
Processed A*qpEoQKj_Er0AAAAAAAAAAAAAAQAAAQ.jpg
