In [1]:
import torch

print(f"CUDA is available: {torch.cuda.is_available()}")
print(f"CUDA device count: {torch.cuda.device_count()}")
if torch.cuda.is_available():
    print(f"Current CUDA device: {torch.cuda.current_device()}")
    print(f"Device name: {torch.cuda.get_device_name(0)}")

CUDA is available: True
CUDA device count: 8
Current CUDA device: 0
Device name: NVIDIA RTX 6000 Ada Generation


In [1]:
import os  
import cv2  
import torch  
import numpy as np  
import supervision as sv  
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator  

class PetSegmentationPreprocessor:  
    def __init__(self, model_type="vit_h", checkpoint_path="../model/sam_checkpoints/sam_vit_h_4b8939.pth"):  
        # Device configuration  
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'  
        print(f"Using device: {self.device}")  
        
        # Load SAM model  
        self.sam = sam_model_registry[model_type](checkpoint=checkpoint_path)  
        self.sam.to(device=self.device)  
        
        # Initialize mask generator with optimized parameters  
        self.mask_generator = SamAutomaticMaskGenerator(  
            model=self.sam,  
            points_per_side=32,  # Increased for more detailed segmentation  
            pred_iou_thresh=0.86,  # Higher threshold for quality masks  
            stability_score_thresh=0.92,  # Stricter stability criterion  
            crop_n_layers=1,  
            crop_n_points_downscale_factor=2,  
            min_mask_region_area=100,  # Minimum area to consider as a valid mask  
        )  

    def extract_pet_mask(self, image_path):  
        """  
        Extract pet mask from the image  
        Returns:   
        - Best pet mask   
        - Visualization of masks  
        """  
        # Read image  
        image = cv2.imread(image_path)  
        image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  

        # Generate masks  
        masks = self.mask_generator.generate(image_rgb)  
        
        # Sort masks by area (descending) and score (descending)  
        sorted_masks = sorted(  
            masks,   
            key=lambda x: (x['area'], x['predicted_iou']),   
            reverse=True  
        )  

        # Custom mask visualization  
        def create_mask_visualization(image, masks):  
            # Create a copy of the image to annotate  
            visualization = image.copy()  
            
            # Draw masks with random colors  
            for mask in masks:  
                # Generate a random color  
                color = np.random.randint(0, 255, 3).tolist()  
                
                # Create a mask overlay  
                mask_overlay = visualization.copy()  
                mask_overlay[mask['segmentation']] = color  
                
                # Blend the overlay  
                alpha = 0.5  # Transparency factor  
                visualization = cv2.addWeighted(visualization, 1 - alpha, mask_overlay, alpha, 0)  
            
            return visualization  

        # Visualize masks  
        mask_visualization = create_mask_visualization(image_rgb, sorted_masks)  

        # Select the most promising mask (largest area with high confidence)  
        if sorted_masks:  
            best_mask = sorted_masks[0]['segmentation']  
            return best_mask, mask_visualization  
        
        return None, image_rgb  

    def process_demo_dataset(self, input_dir, output_dir, num_images=10):  
        # Create output directories  
        os.makedirs(output_dir, exist_ok=True)  
        os.makedirs(os.path.join(output_dir, 'masks'), exist_ok=True)  
        os.makedirs(os.path.join(output_dir, 'visualizations'), exist_ok=True)  

        # Get image files  
        image_files = [f for f in os.listdir(input_dir)   
                       if f.lower().endswith(('.jpg', '.jpeg', '.png'))]  
        
        # Limit to first n images  
        demo_files = image_files[:num_images]  

        # Results tracking  
        results = []  

        for filename in demo_files:  
            input_path = os.path.join(input_dir, filename)  
            
            try:  
                # Extract pet mask  
                pet_mask, mask_visualization = self.extract_pet_mask(input_path)  
                
                if pet_mask is not None:  
                    # Original image  
                    image = cv2.imread(input_path)  
                    
                    # Create masked image  
                    masked_image = image.copy()  
                    masked_image[~pet_mask] = [255, 255, 255]  # White background  
                    
                    # Save outputs  
                    mask_output_path = os.path.join(output_dir, 'masks', filename)  
                    cv2.imwrite(mask_output_path, (pet_mask * 255).astype(np.uint8))  
                    
                    masked_image_path = os.path.join(output_dir, filename)  
                    cv2.imwrite(masked_image_path, masked_image)  
                    
                    # Save mask visualization  
                    viz_path = os.path.join(output_dir, 'visualizations', filename)  
                    cv2.imwrite(viz_path, mask_visualization)  
                    
                    # Track results  
                    results.append({  
                        'filename': filename,  
                        'mask_path': mask_output_path,  
                        'masked_image_path': masked_image_path,  
                        'visualization_path': viz_path  
                    })  
                    
                    print(f"Processed {filename}")  
                else:  
                    print(f"No mask found for {filename}")  

            except Exception as e:  
                print(f"Error processing {filename}: {e}")  
                # Optional: log the full traceback  
                import traceback  
                traceback.print_exc()  

        return results  

def main():  
    # Initialize segmentor  
    segmentor = PetSegmentationPreprocessor()  
    
    # Process demo dataset  
    demo_results = segmentor.process_demo_dataset(  
        input_dir='../dataset/trainset',  
        output_dir='../dataset/preprocessed_trainset_demo',  
        num_images=15  
    )  

if __name__ == "__main__":  
    main()

Using device: cuda




Processed A*5wfcQaqfSZtvQ0hKmwjDrwAAAQAAAQ.jpg
Processed A*RUTcRqGoadQAAAAAAAAAAAAAAQAAAQ.jpg
Processed A*yhhMRIVoNgIAAAAAAAAAAAAAAQAAAQ.jpg
Processed A*7DHLR7A9-PcAAAAAAAAAAAAAAQAAAQ.jpg
Processed A*B0UFQY5-NWAAAAAAAAAAAAAAAQAAAQ.jpg
Processed A*Ilm_R4vQBwAAAAAAAAAAAAAAAQAAAQ.jpg
Processed A*AcLoQ62XAiEAAAAAAAAAAAAAAQAAAQ.jpg
Processed A*7LqiQ6pid-MAAAAAAAAAAAAAAQAAAQ.jpg
Processed A*zO1xSLQ2HekAAAAAAAAAAAAAAQAAAQ.jpg
Processed A*mn91SptgaVkAAAAAAAAAAAAAAQAAAQ.jpg
Processed A*Vj18TaEJwkYAAAAAAAAAAAAAAQAAAQ.jpg
Processed A*XAoET6NgEiEAAAAAAAAAAAAAAQAAAQ.jpg
Processed A*wsZqTbPRGS8AAAAAAAAAAAAAAQAAAQ.jpg
Processed A*6rpyT7SHRbEAAAAAAAAAAAAAAQAAAQ.jpg
Processed A*qpEoQKj_Er0AAAAAAAAAAAAAAQAAAQ.jpg
