# SEM Particle Analysis - Test Notebook

This notebook tests the refactored `sem_particle_analysis` package on a folder of SEM images.

## What this notebook does:
1. Loads the refactored Python package
2. Processes SEM images from `/Users/sanjaypradeep/Downloads/Trial Images`
3. Automatically detects scale bars
4. Segments particles using SAM
5. Analyzes particle sizes
6. Saves results to CSV

## Installation
If you haven't installed the package yet:
```bash
cd sem_particle_analysis
pip install -e .
```

In [None]:
# Import the refactored package
import sys
import os

# Add the package to path (if not installed)
package_path = os.path.join(os.getcwd(), 'sem_particle_analysis')
if package_path not in sys.path:
    sys.path.insert(0, package_path)

from sem_particle_analysis import (
    SAMModel,
    ScaleDetector,
    ParticleSegmenter,
    ParticleAnalyzer,
    ResultsManager
)
from sem_particle_analysis.utils import (
    load_image,
    find_images_in_folder,
    visualize_masks,
    visualize_comparison,
    plot_size_distribution,
    print_summary
)

import matplotlib.pyplot as plt
import numpy as np

print("✓ Packages imported successfully!")

## Configuration

In [None]:
# Configuration
SAM_CHECKPOINT = "/Users/sanjaypradeep/segment-anything/models/sam_vit_h_4b8939.pth"
MODEL_TYPE = "vit_h"  # Options: "vit_h" (best), "vit_l", "vit_b" (fastest)
IMAGE_FOLDER = "/Users/sanjaypradeep/Downloads/Trial Images"
OUTPUT_CSV = "batch_analysis_results.csv"

# Analysis parameters
MIN_PARTICLE_AREA = 50  # Minimum particle area in pixels
CROP_PERCENT = 7.0  # Percentage to crop from bottom (for scale bar removal)

print("Configuration:")
print(f"  SAM Model: {MODEL_TYPE}")
print(f"  Image folder: {IMAGE_FOLDER}")
print(f"  Output CSV: {OUTPUT_CSV}")

## Initialize SAM Model

This loads the Segment Anything Model once and reuses it for all images.

In [None]:
# Initialize SAM model
print("Loading SAM model...")
sam_model = SAMModel(SAM_CHECKPOINT, model_type=MODEL_TYPE)
print("✓ SAM model loaded successfully!")

## Find Images in Folder

In [None]:
# Find all images in the folder
image_paths = find_images_in_folder(IMAGE_FOLDER)
print(f"Found {len(image_paths)} images in '{IMAGE_FOLDER}'")

if len(image_paths) == 0:
    print("ERROR: No images found in the specified folder!")
else:
    print("\nFirst 5 images:")
    for i, path in enumerate(image_paths[:5], 1):
        print(f"  {i}. {os.path.basename(path)}")

## Initialize Components

In [None]:
# Initialize components
scale_detector = ScaleDetector(use_gpu=False)
segmenter = ParticleSegmenter(sam_model)
results_manager = ResultsManager(OUTPUT_CSV)

print("✓ All components initialized!")

## Process Single Image (Example)

Let's process the first image as an example to see the full workflow.

In [None]:
# Process first image as example
if len(image_paths) > 0:
    example_image_path = image_paths[0]
    print(f"Processing example image: {os.path.basename(example_image_path)}")
    
    # 1. Load image
    image = load_image(example_image_path)
    print(f"  Image size: {image.shape[1]} x {image.shape[0]} pixels")
    
    # Display original image
    plt.figure(figsize=(10, 8))
    plt.imshow(image)
    plt.title(f"Original Image: {os.path.basename(example_image_path)}")
    plt.axis('off')
    plt.tight_layout()
    plt.show()

In [None]:
# 2. Detect scale bar
print("\nDetecting scale bar...")
try:
    scale_info = scale_detector.detect_scale_bar(
        image,
        region_width=0.5,
        region_height=0.06,
        vertical_offset=0.0,
        threshold=250
    )
    
    print(f"  Scale bar detected:")
    print(f"    Physical length: {scale_info['scale_nm']:.1f} nm")
    print(f"    Pixel length: {scale_info['pixel_length']} pixels")
    print(f"    Conversion: {scale_info['conversion']:.3f} nm/pixel")
    print(f"    OCR text: '{scale_info['ocr_text']}'")
    
    conversion_factor = scale_info['conversion']
    
except Exception as e:
    print(f"  Warning: Scale bar detection failed: {e}")
    print("  Using manual scale entry...")
    conversion_factor = scale_detector.set_manual_scale(scale_nm=100.0, pixel_length=50)
    print(f"  Manual conversion set to: {conversion_factor:.3f} nm/pixel")

In [None]:
# 3. Crop scale bar from image
cropped_image = scale_detector.crop_scale_bar(image, crop_percent=CROP_PERCENT)
print(f"\nCropped image size: {cropped_image.shape[1]} x {cropped_image.shape[0]} pixels")

# Display cropped image
plt.figure(figsize=(10, 8))
plt.imshow(cropped_image)
plt.title("Cropped Image (Scale Bar Removed)")
plt.axis('off')
plt.tight_layout()
plt.show()

In [None]:
# 4. Segment particles with SAM
print("\nSegmenting particles with SAM...")
masks, scores = segmenter.segment_image(cropped_image, multimask_output=True)

# Visualize mask candidates
fig = visualize_masks(cropped_image, masks, scores, figsize=(18, 6))
plt.show()

print(f"\nGenerated {len(masks)} mask candidates")
print(f"Confidence scores: {scores}")

In [None]:
# 5. Select best mask
selected_mask = segmenter.select_mask()  # Auto-selects highest score
binary_mask = segmenter.get_binary_mask(invert=True)  # Particles = 1, background = 0

print(f"Selected mask {segmenter.selected_mask_index} with score {scores[segmenter.selected_mask_index]:.3f}")

In [None]:
# 6. Analyze particles
print("\nAnalyzing particles...")
analyzer = ParticleAnalyzer(conversion_factor=conversion_factor)
num_particles, regions = analyzer.analyze_mask(
    binary_mask,
    min_area=MIN_PARTICLE_AREA,
    min_size=30,
    remove_border=True,
    border_buffer=4
)

print(f"\nDetected {num_particles} particles")

In [None]:
# 7. Get measurements
measurements = analyzer.get_measurements(in_nm=True)

# Print summary
print_summary(measurements, title="Example Image Analysis")

In [None]:
# 8. Visualize results
fig = visualize_comparison(cropped_image, analyzer.labeled_mask, analyzer.regions, figsize=(16, 8))
plt.show()

In [None]:
# 9. Plot size distributions
fig = plot_size_distribution(measurements, bins=20, figsize=(12, 5))
if fig:
    plt.show()

## Batch Process All Images

Now let's process all images in the folder automatically.

In [None]:
# Batch processing function
def process_image_auto(image_path, sam_model, scale_detector, segmenter, 
                       crop_percent=7.0, min_area=50):
    """
    Process a single image automatically.
    Returns measurements dict or None if processing fails.
    """
    try:
        filename = os.path.basename(image_path)
        print(f"\nProcessing: {filename}")
        
        # Load image
        image = load_image(image_path)
        
        # Detect scale
        try:
            scale_info = scale_detector.detect_scale_bar(image)
            conversion = scale_info['conversion']
            print(f"  Scale: {scale_info['scale_nm']:.1f} nm = {scale_info['pixel_length']} px")
        except Exception as e:
            print(f"  Warning: Scale detection failed ({e}). Using manual scale.")
            conversion = scale_detector.set_manual_scale(scale_nm=100.0, pixel_length=50)
        
        # Crop image
        cropped_image = scale_detector.crop_scale_bar(image, crop_percent=crop_percent)
        
        # Segment with SAM
        masks, scores = segmenter.segment_image(cropped_image, multimask_output=True)
        selected_mask = segmenter.select_mask()
        binary_mask = segmenter.get_binary_mask(invert=True)
        
        # Analyze particles
        analyzer = ParticleAnalyzer(conversion_factor=conversion)
        num_particles, regions = analyzer.analyze_mask(
            binary_mask,
            min_area=min_area,
            min_size=30,
            remove_border=True,
            border_buffer=4
        )
        
        measurements = analyzer.get_measurements(in_nm=True)
        print(f"  Detected {num_particles} particles")
        
        return filename, measurements
        
    except Exception as e:
        print(f"  ERROR: Failed to process {filename}: {e}")
        return None, None

print("Batch processing function defined")

In [None]:
# Process all images
print(f"\n{'='*60}")
print(f"BATCH PROCESSING {len(image_paths)} IMAGES")
print(f"{'='*60}")

successful = 0
failed = 0

for i, image_path in enumerate(image_paths, 1):
    print(f"\n[{i}/{len(image_paths)}]", end=" ")
    
    filename, measurements = process_image_auto(
        image_path,
        sam_model,
        scale_detector,
        segmenter,
        crop_percent=CROP_PERCENT,
        min_area=MIN_PARTICLE_AREA
    )
    
    if measurements is not None:
        # Save results
        results_manager.add_result(filename, measurements)
        successful += 1
    else:
        failed += 1

print(f"\n\n{'='*60}")
print(f"BATCH PROCESSING COMPLETE")
print(f"{'='*60}")
print(f"Successfully processed: {successful}/{len(image_paths)}")
print(f"Failed: {failed}/{len(image_paths)}")

## View Results Summary

In [None]:
# Print summary of all results
results_manager.print_summary()

In [None]:
# View results dataframe
results_df = results_manager.get_results()
print("\nResults DataFrame:")
display(results_df)

## Export Results

In [None]:
# Results are automatically saved to the CSV file
print(f"Results saved to: {OUTPUT_CSV}")
print(f"\nYou can also export to a different file:")

# Example: Export to a custom filename
# results_manager.export_results("my_custom_results.csv")

## Analysis Complete!

The refactored package has successfully analyzed all SEM images.

### Next Steps:
1. Review the results in `batch_analysis_results.csv`
2. Use the stored measurements for further statistical analysis
3. Adjust parameters (min_area, crop_percent) if needed and rerun

### Key Features Demonstrated:
- ✓ Automated scale bar detection
- ✓ SAM-based particle segmentation  
- ✓ Particle size measurements (area, diameter)
- ✓ Batch processing of multiple images
- ✓ Results storage and management
- ✓ Visualization tools