# Interactive SEM Particle Analysis

This notebook demonstrates the **complete interactive workflow** from the original notebook, now using the refactored package.

## Features:
- ‚úÖ Interactive image selection with progress tracking
- ‚úÖ Interactive scale bar detection with sliders
- ‚úÖ SAM-based particle segmentation
- ‚úÖ **Interactive particle refinement** with:
  - Click-to-delete particles
  - Click-to-add particles with SAM
  - Merge mode for combining particles
  - Live SAM refinement with point prompts (+/- clicks)
  - Edge clearing with buffer
  - Dual-view visualization (original + mask)
- ‚úÖ Results management and CSV export

## Setup

Make sure you have the matplotlib widget backend enabled:
```python
%matplotlib widget
```

In [None]:
# Enable interactive matplotlib
%matplotlib widget

In [None]:
# Import the refactored package
import sys
import os

# Add the package to path (if not installed)
package_path = os.path.join(os.getcwd(), 'sem_particle_analysis')
if package_path not in sys.path:
    sys.path.insert(0, package_path)

from sem_particle_analysis import (
    SAMModel,
    ScaleDetector,
    ParticleSegmenter,
    ParticleAnalyzer,
    ResultsManager,
    InteractiveRefiner  # <-- New interactive refinement class
)
from sem_particle_analysis.utils import (
    load_image,
    find_images_in_folder,
    visualize_masks,
    print_summary
)

import matplotlib.pyplot as plt
import numpy as np
import ipywidgets as widgets
from IPython.display import display, clear_output
import hashlib

print("‚úì Packages imported successfully!")

## Configuration

In [None]:
# Configuration
SAM_CHECKPOINT = "/Users/sanjaypradeep/segment-anything/models/sam_vit_h_4b8939.pth"
MODEL_TYPE = "vit_h"
IMAGE_FOLDER = "/Users/sanjaypradeep/Downloads/Trial Images"
OUTPUT_CSV = "interactive_analysis_results.csv"

print("Configuration:")
print(f"  SAM Model: {MODEL_TYPE}")
print(f"  Image folder: {IMAGE_FOLDER}")
print(f"  Output CSV: {OUTPUT_CSV}")

## Step 1: Initialize SAM Model

In [None]:
# Initialize SAM model (do this once)
print("Loading SAM model...")
sam_model = SAMModel(SAM_CHECKPOINT, model_type=MODEL_TYPE)
print("‚úì SAM model loaded successfully!")

## Step 2: Initialize Components

In [None]:
# Initialize components
scale_detector = ScaleDetector(use_gpu=False)
segmenter = ParticleSegmenter(sam_model)
results_manager = ResultsManager(OUTPUT_CSV)

print("‚úì All components initialized!")

## Step 3: Interactive Image Selection

Browse through images with:
- **Proceed**: Continue with current image
- **Skip**: Move to next image
- **Jump**: Go to specific image number

In [None]:
# Find all images
image_paths = find_images_in_folder(IMAGE_FOLDER)
print(f"Found {len(image_paths)} images")

# Image selection state
current_idx = -1
seen_hashes = set()
total_images = len(image_paths)
current_image = None
current_image_path = None

# Create UI widgets
progress_bar = widgets.IntProgress(value=0, min=0, max=total_images, description='Progress:')
progress_label = widgets.Label(value=f"0/{total_images} (0.0%)")
image_output = widgets.Output()

btn_proceed = widgets.Button(description="Proceed", button_style='success')
btn_skip = widgets.Button(description="Skip", button_style='warning')
jump_input = widgets.IntText(value=1, min=1, max=total_images, description='Jump to #:')
btn_jump = widgets.Button(description="Jump", button_style='info')

def update_progress():
    done = min(current_idx + 1, total_images)
    progress_bar.value = done
    pct = (done / total_images) * 100 if total_images else 0
    progress_label.value = f"{done}/{total_images} ({pct:.1f}%)"

def load_image_by_index(target_idx):
    global current_idx, current_image, current_image_path
    
    if target_idx < 0 or target_idx >= total_images:
        with image_output:
            clear_output()
            print(f"Invalid index. Choose between 1 and {total_images}")
        return False
    
    current_idx = target_idx
    update_progress()
    
    current_image_path = image_paths[current_idx]
    current_image = load_image(current_image_path)
    img_hash = hashlib.md5(current_image.tobytes()).hexdigest()
    
    with image_output:
        clear_output(wait=True)
        fig, ax = plt.subplots(figsize=(8, 6))
        ax.imshow(current_image)
        ax.set_title(f"{os.path.basename(current_image_path)} (#{current_idx + 1}/{total_images})")
        ax.axis('off')
        plt.tight_layout()
        plt.show()
        
        if img_hash in seen_hashes:
            print("‚ö†Ô∏è  Duplicate image (seen before)")
        else:
            print("‚úì New image ready for analysis")
        print(f"Dimensions: {current_image.shape[1]} x {current_image.shape[0]} pixels")
    
    return True

def load_next():
    global current_idx
    while True:
        current_idx += 1
        if current_idx >= total_images:
            with image_output:
                clear_output()
                print("‚úì All images processed!")
            btn_proceed.disabled = True
            btn_skip.disabled = True
            btn_jump.disabled = True
            return
        if load_image_by_index(current_idx):
            break

def on_proceed(b):
    img_hash = hashlib.md5(current_image.tobytes()).hexdigest()
    seen_hashes.add(img_hash)
    with image_output:
        print(f"\n‚úì PROCEEDING with: {os.path.basename(current_image_path)}")
        print("Continue to the next cell for scale bar detection.")
    btn_proceed.disabled = True
    btn_skip.disabled = True
    btn_jump.disabled = True

def on_skip(b):
    img_hash = hashlib.md5(current_image.tobytes()).hexdigest()
    seen_hashes.add(img_hash)
    load_next()

def on_jump(b):
    target = jump_input.value - 1
    if load_image_by_index(target):
        jump_input.value = current_idx + 1

# Connect events
btn_proceed.on_click(on_proceed)
btn_skip.on_click(on_skip)
btn_jump.on_click(on_jump)

# Display UI
display(widgets.VBox([
    progress_bar,
    progress_label,
    widgets.HBox([jump_input, btn_jump]),
    image_output,
    widgets.HBox([btn_proceed, btn_skip])
]))

# Load first image
load_next()

## Step 4: Interactive Scale Bar Detection

Adjust sliders to position the detection region, then click **Detect Scale Bar**.

In [None]:
# Scale bar detection with interactive controls
if current_image is None:
    print("ERROR: No image loaded. Please run the previous cell first.")
else:
    print(f"Processing scale bar for: {os.path.basename(current_image_path)}")
    
    # Create interactive sliders
    width_slider = widgets.FloatSlider(value=0.5, min=0.1, max=1.0, step=0.05, description='Region Width:')
    height_slider = widgets.FloatSlider(value=0.06, min=0.02, max=0.2, step=0.01, description='Region Height:')
    offset_slider = widgets.FloatSlider(value=0.0, min=0.0, max=0.3, step=0.02, description='Vertical Offset:')
    threshold_slider = widgets.IntSlider(value=250, min=100, max=255, step=10, description='Threshold:')
    crop_slider = widgets.FloatSlider(value=7.0, min=0.0, max=20.0, step=1.0, description='Bottom Crop %:')
    
    btn_detect = widgets.Button(description="Detect Scale Bar", button_style='primary')
    btn_accept = widgets.Button(description="Accept & Continue", button_style='success', disabled=True)
    btn_manual = widgets.Button(description="Manual Entry", button_style='warning')
    
    detection_output = widgets.Output()
    status_output = widgets.Output()
    
    scale_info = None
    cropped_image = None
    
    def detect_scale():
        global scale_info
        try:
            with status_output:
                clear_output()
                print("Detecting scale bar...")
            
            scale_info = scale_detector.detect_scale_bar(
                current_image,
                region_width=width_slider.value,
                region_height=height_slider.value,
                vertical_offset=offset_slider.value,
                threshold=int(threshold_slider.value)
            )
            
            with detection_output:
                clear_output(wait=True)
                fig, axes = plt.subplots(1, 3, figsize=(15, 5))
                
                # Full image with region
                axes[0].imshow(current_image)
                x0, y0, box_w, box_h = scale_info['region']
                rect = plt.Rectangle((x0, y0), box_w, box_h, fill=False, edgecolor='red', linewidth=2)
                axes[0].add_patch(rect)
                axes[0].set_title("Full Image with Scale Bar Region")
                axes[0].axis('off')
                
                # Detection region
                axes[1].imshow(scale_info['binary_image'], cmap='gray', vmin=0, vmax=255)
                left, right, y_line = scale_info['line_coords']
                axes[1].plot([left, right], [y_line, y_line], 'r-', linewidth=3)
                axes[1].text(right + 5, y_line, f"{scale_info['pixel_length']} px", color='red', fontsize=12, fontweight='bold', va='center')
                axes[1].text(5, 15, f"OCR: {scale_info['ocr_text']}", color='yellow', fontsize=10, bbox=dict(facecolor='black', alpha=0.7, pad=4))
                axes[1].set_title(f"Scale Bar Detection (Threshold: {scale_info['threshold']})")
                axes[1].axis('off')
                
                # Cropped preview
                crop_height = int(current_image.shape[0] * (1 - crop_slider.value / 100))
                preview = current_image[:crop_height, :]
                axes[2].imshow(preview)
                axes[2].set_title(f"Analysis Image (Bottom {crop_slider.value:.1f}% Removed)")
                axes[2].axis('off')
                
                plt.tight_layout()
                plt.show()
                
                print("SCALE BAR DETECTION RESULTS:")
                print(f"  Pixel length: {scale_info['pixel_length']} pixels")
                print(f"  Physical length: {scale_info['scale_nm']:.1f} nm")
                print(f"  Conversion factor: {scale_info['conversion']:.3f} nm/pixel")
            
            with status_output:
                clear_output()
                print("‚úì Scale bar detected successfully!")
            
            btn_accept.disabled = False
            
        except Exception as e:
            with status_output:
                clear_output()
                print(f"‚ùå ERROR: {str(e)}")
                print("Try adjusting parameters or use Manual Entry.")
            btn_accept.disabled = True
    
    def accept_scale():
        global cropped_image, conversion_factor
        conversion_factor = scale_info['conversion']
        cropped_image = scale_detector.crop_scale_bar(current_image, crop_percent=crop_slider.value)
        
        with status_output:
            clear_output()
            print("‚úì SCALE BAR ACCEPTED!")
            print(f"Conversion: {conversion_factor:.3f} nm/pixel")
            print("Ready to proceed to segmentation!")
        
        btn_detect.disabled = True
        btn_accept.disabled = True
        btn_manual.disabled = True
    
    def manual_entry():
        global cropped_image, conversion_factor
        
        manual_scale = widgets.FloatText(description='Scale (nm):', value=100.0)
        manual_pixels = widgets.IntText(description='Length (pixels):', value=50)
        btn_set = widgets.Button(description='Set Manual Scale', button_style='success')
        
        def set_manual(b):
            global conversion_factor, cropped_image
            conversion_factor = scale_detector.set_manual_scale(manual_scale.value, manual_pixels.value)
            cropped_image = scale_detector.crop_scale_bar(current_image, crop_percent=crop_slider.value)
            with status_output:
                clear_output()
                print("‚úì MANUAL SCALE SET!")
                print(f"Conversion: {conversion_factor:.3f} nm/pixel")
        
        btn_set.on_click(set_manual)
        
        with detection_output:
            clear_output()
            display(widgets.VBox([
                widgets.HTML("<h4>Manual Scale Bar Entry</h4>"),
                manual_scale,
                manual_pixels,
                btn_set
            ]))
    
    btn_detect.on_click(lambda b: detect_scale())
    btn_accept.on_click(lambda b: accept_scale())
    btn_manual.on_click(lambda b: manual_entry())
    
    # Display interface
    display(widgets.VBox([
        widgets.HTML("<h3>Scale Bar Region Controls</h3>"),
        width_slider,
        height_slider,
        offset_slider,
        threshold_slider,
        crop_slider,
        widgets.HBox([btn_detect, btn_accept, btn_manual]),
        status_output,
        detection_output
    ]))
    
    # Run initial detection
    detect_scale()

## Step 5: SAM Particle Segmentation

Generate multiple mask candidates and select the best one.

In [None]:
# Segment particles with SAM
if cropped_image is None:
    print("ERROR: No cropped image available. Please run scale detection first.")
else:
    print(f"Segmenting particles in cropped image: {cropped_image.shape}")
    
    masks, scores = segmenter.segment_image(cropped_image, multimask_output=True)
    
    # Visualize masks
    fig = visualize_masks(cropped_image, masks, scores, figsize=(18, 6))
    plt.show()
    
    # Create mask selector
    mask_selector = widgets.Dropdown(
        options=[(f"Mask {i} (score: {scores[i]:.3f})", i) for i in range(len(masks))],
        value=int(np.argmax(scores)),
        description='Select Mask:'
    )
    
    confirm_button = widgets.Button(description='Confirm Selection', button_style='success')
    output = widgets.Output()
    
    selected_mask = None
    
    def on_confirm(b):
        global selected_mask
        selected_mask = segmenter.select_mask(mask_selector.value)
        with output:
            clear_output()
            print(f"‚úì Selected mask {segmenter.selected_mask_index} with score {scores[segmenter.selected_mask_index]:.3f}")
            print("Proceed to the next cell for particle analysis and refinement.")
    
    confirm_button.on_click(on_confirm)
    
    display(widgets.VBox([mask_selector, confirm_button, output]))
    
    # Auto-select best mask
    selected_mask = segmenter.select_mask()

## Step 6: Initial Particle Analysis

In [None]:
# Analyze particles
if selected_mask is None:
    print("ERROR: No mask selected. Please run the previous cell.")
else:
    # Get binary mask (inverted so particles = 1)
    binary_mask = segmenter.get_binary_mask(invert=True)
    
    # Create analyzer
    analyzer = ParticleAnalyzer(conversion_factor=conversion_factor)
    
    # Analyze mask
    num_particles, regions = analyzer.analyze_mask(
        binary_mask,
        min_area=50,
        min_size=30,
        remove_border=True,
        border_buffer=4
    )
    
    print(f"\n‚úì Detected {num_particles} particles")
    
    # Get measurements
    measurements = analyzer.get_measurements(in_nm=True)
    print_summary(measurements, title="Initial Analysis")

## Step 7: Interactive Particle Refinement

### üéØ THIS IS THE KEY INTERACTIVE CELL! üéØ

## How to use:

### Select/Delete Mode (default):
- **Left-click** on a particle to queue it for deletion (turns yellow)
- **Right-click** to queue adding a particle at that location
- Click **"Update"** to apply queued operations
- Click **"Clear queued ops"** to cancel

### Merge Mode:
- Toggle **"Merge mode"** ON
- **Left-click** multiple particles to select for merging (turn cyan)
- Click **"Merge selected"** to combine them into one particle

### SAM Refinement Mode:
- Switch to **"Refine with SAM"** mode
- **Left-click** to add positive points (green +)
- **Right-click** to add negative points (red √ó)
- SAM will refine the segmentation in real-time
- Click **"Apply SAM to mask"** when satisfied

### Other Controls:
- **"Clear Edges"** with buffer: Remove particles near borders
- **"Finish"**: Save current results to CSV

### Views:
- **Left**: Original image with particle outlines and labels
- **Right**: Binary mask (white = particles, black = background)

In [None]:
# Create interactive refiner
if analyzer is None or analyzer.mask is None:
    print("ERROR: No analyzer available. Please run the previous cell.")
else:
    # Define callback to save results
    def save_results(measurements):
        filename = os.path.basename(current_image_path) if current_image_path else "unknown"
        results_manager.add_result(filename, measurements)
        print(f"\n‚úì Results saved for '{filename}'")
    
    # Create interactive refiner
    refiner = InteractiveRefiner(
        image=cropped_image,
        analyzer=analyzer,
        segmenter=segmenter,
        results_callback=save_results
    )
    
    # Display the interactive interface
    print("\n" + "="*80)
    print("INTERACTIVE PARTICLE REFINEMENT")
    print("="*80)
    print("\nInstructions:")
    print("  ‚Ä¢ Select/Delete mode: Left-click=delete, Right-click=add")
    print("  ‚Ä¢ Merge mode: Left-click particles to merge, then 'Merge selected'")
    print("  ‚Ä¢ SAM mode: Left-click=positive (+), Right-click=negative (‚àí)")
    print("  ‚Ä¢ Click 'Finish' when done to save results\n")
    
    refiner.display()

## Step 8: View Final Results

In [None]:
# Get final measurements after refinement
if 'refiner' in locals():
    final_measurements = refiner.get_measurements(in_nm=True)
    print_summary(final_measurements, title="Final Refined Analysis")
    
    # Get final mask
    final_mask = refiner.get_final_mask()
    print(f"\nFinal mask shape: {final_mask.shape}")
    print(f"Total pixels segmented: {final_mask.sum()}")

## Step 9: View All Saved Results

In [None]:
# Print summary of all results
results_manager.print_summary()

# View results dataframe
results_df = results_manager.get_results()
print("\nResults DataFrame:")
display(results_df)

## Process Next Image

To process the next image, **restart from Step 3** (Image Selection cell) and work through the workflow again.

All results are automatically saved to the CSV file!

## Summary

This notebook provides the **complete interactive workflow** from the original notebook, now fully refactored and modular:

‚úÖ **All original features preserved:**
- Interactive image browsing
- Adjustable scale bar detection
- Multiple SAM mask candidates
- Click-to-delete particles
- Click-to-add particles
- Merge particles
- Live SAM refinement with +/‚àí points
- Edge clearing
- Results management

‚úÖ **Benefits of refactored code:**
- Clean, reusable API
- Modular components
- Easy to extend
- Properly documented
- Can be used in scripts or notebooks