In [19]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

In [24]:
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
from skimage.transform import resize  # Pour redimensionner les masques
import os
from datetime import datetime


def initialize_sam(checkpoint_path, device='cpu'):
    """Initialize SAM model with optimized parameters."""
    sam = sam_model_registry["vit_h"](checkpoint=checkpoint_path)
    sam.to(device=device)
    
    # Optimized parameters for small anomaly detection
    mask_generator = SamAutomaticMaskGenerator(
        model=sam,
        points_per_side=51,
        pred_iou_thresh=0.88,  # Increased for better precision
        stability_score_thresh=0.90,  # Increased for more stable masks
        crop_n_layers=1,
        crop_n_points_downscale_factor=2,  # Increased for speed
        min_mask_region_area=16,  # Reduced to catch smaller anomalies
        output_mode="binary_mask"
    )
    return mask_generator


def process_image(image_path, target_size=(1024, 1024)):
    """Load and preprocess image with efficient resizing."""
    with Image.open(image_path).convert('RGB') as img:
        img_resized = img.resize(target_size, Image.Resampling.LANCZOS)
        return np.array(img_resized).astype(np.float32), img


def analyze_regions(masks, threshold=300):
    """Analyze regions with optimized filtering."""
    regions = []
    for idx, mask in enumerate(masks):
        if mask['area'] < threshold:  # Only process small regions
            y_indices, x_indices = np.where(mask['segmentation'])
            if len(x_indices) > 0 and len(y_indices) > 0:
                centroid = (np.mean(x_indices), np.mean(y_indices))
                regions.append({
                    'id': idx,
                    'area_pixels': mask['area'],
                    'centroid': centroid,
                    'mask': mask['segmentation']
                })
    return regions


def visualize_results(original_image, regions, save_path=None):
    """Create and save scientific visualization with resized masks."""
    original_size = original_image.size[::-1]  # (height, width)
    fig, axes = plt.subplots(1, 3, figsize=(15, 5), dpi=150)
    
    # Original image
    axes[0].imshow(original_image)
    axes[0].set_title(f'Original\n({len(regions)} anomalies detected)')
    axes[0].axis('off')
    
    # Create mask visualization
    mask_overlay = np.zeros((*original_size, 4))  # Adjusted for the original size
    red_color = np.array([1, 0, 0, 0.7])  # Red with 0.7 opacity
    
    for region in regions:
        # Resize the mask to the original image size
        resized_mask = resize(region['mask'], original_size, mode='constant', preserve_range=True).astype(bool)
        mask_overlay[resized_mask] = red_color
    
    # Mask only
    axes[1].imshow(mask_overlay)
    axes[1].set_title('Detected Anomalies')
    axes[1].axis('off')
    
    # Overlay
    axes[2].imshow(original_image)
    axes[2].imshow(mask_overlay)
    axes[2].set_title('Overlay')
    axes[2].axis('off')
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, bbox_inches='tight', dpi=300)
        plt.close()
    else:
        plt.show()


def main():
    # Paths
    sam_checkpoint = "/Users/armandbryan/Documents/challenges/Computer Vision Projects Expo 2024/models/sam_vit_h_4b8939.pth"
    image_path = '/Users/armandbryan/Documents/challenges/Computer Vision Projects Expo 2024/datasets/aptos2019-blindness-detection/test_images/0a2b5e1a0be8.png'
    
    output_dir = 'anomaly_detection_results'
    os.makedirs(output_dir, exist_ok=True)
    mask_generator = initialize_sam(sam_checkpoint)
    
   
    img_array, original_image = process_image(image_path)
    masks = mask_generator.generate(img_array)
    
    # Analyze regions
    regions = analyze_regions(masks)
    
    # Save results
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    save_path = os.path.join(output_dir, f'anomaly_detection_{timestamp}.png')
    
    # Visualize and save results
    visualize_results(original_image, regions, save_path)
    
    # Print analysis
    print(f"\nDetected {len(regions)} anomalies:")
    for i, region in enumerate(regions):
        print(f"Anomaly {i+1}: Area = {region['area_pixels']:.1f} pixels, "
              f"Center = ({region['centroid'][0]:.1f}, {region['centroid'][1]:.1f})")


if __name__ == "__main__":
    main()


Detected 42 anomalies:
Anomaly 1: Area = 286.0 pixels, Center = (506.5, 570.7)
Anomaly 2: Area = 137.0 pixels, Center = (484.2, 795.3)
Anomaly 3: Area = 256.0 pixels, Center = (378.4, 546.0)
Anomaly 4: Area = 254.0 pixels, Center = (542.5, 645.9)
Anomaly 5: Area = 219.0 pixels, Center = (407.5, 564.8)
Anomaly 6: Area = 217.0 pixels, Center = (370.0, 410.4)
Anomaly 7: Area = 104.0 pixels, Center = (425.6, 404.7)
Anomaly 8: Area = 143.0 pixels, Center = (560.5, 674.6)
Anomaly 9: Area = 72.0 pixels, Center = (397.5, 677.7)
Anomaly 10: Area = 152.0 pixels, Center = (396.8, 399.7)
Anomaly 11: Area = 163.0 pixels, Center = (372.9, 663.6)
Anomaly 12: Area = 142.0 pixels, Center = (477.0, 710.3)
Anomaly 13: Area = 137.0 pixels, Center = (282.9, 383.4)
Anomaly 14: Area = 51.0 pixels, Center = (453.4, 439.0)
Anomaly 15: Area = 222.0 pixels, Center = (339.4, 672.4)
Anomaly 16: Area = 282.0 pixels, Center = (471.8, 590.8)
Anomaly 17: Area = 90.0 pixels, Center = (475.4, 762.9)
Anomaly 18: Area = 