## Setup & Imports

In [49]:
!pip install ipywidgets
!pip install torch torchvision pandas scikit-learn scikit-image


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.2.1[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.2.1[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [None]:
import os
import sys
from pathlib import Path
import numpy as np
import cv2
import matplotlib.pyplot as plt
from PIL import Image
import ipywidgets as widgets
from IPython.display import display, clear_output
from skimage.filters import threshold_niblack
import warnings
warnings.filterwarnings('ignore')

# Get the current notebook directory and project root
notebook_dir = Path(os.getcwd())
project_root = notebook_dir  # All files are in the same directory

# Add pipeline modules to path
pipeline_dir = notebook_dir  # All modules are in this directory
sys.path.insert(0, str(pipeline_dir))

# Import pipeline modules
import importlib.util

# Load modules
spec1 = importlib.util.spec_from_file_location("cnn_inference", pipeline_dir / "1.cnn_inference.py")
cnn_inference = importlib.util.module_from_spec(spec1)
spec1.loader.exec_module(cnn_inference)

spec2 = importlib.util.spec_from_file_location("segmentation", pipeline_dir / "2.segmentation.py")
segmentation = importlib.util.module_from_spec(spec2)
spec2.loader.exec_module(segmentation)

spec3 = importlib.util.spec_from_file_location("dan_segmentation", pipeline_dir / "2.2.dansegmentation.py")
dan_segmentation = importlib.util.module_from_spec(spec3)
spec3.loader.exec_module(dan_segmentation)

# Import Color Wheel
spec4 = importlib.util.spec_from_file_location("colorwheel", pipeline_dir / "3.colorwheel.py")
colorwheel = importlib.util.module_from_spec(spec4)
spec4.loader.exec_module(colorwheel)

# Note: voronoi_v7 module is not available in current directory structure
# If you need Voronoi analysis, please add voronoi_v7.py to the AFM_Web folder
voronoi_v7 = None
try:
    spec_voronoi = importlib.util.spec_from_file_location("voronoi_v7", pipeline_dir / "voronoi_v7.py")
    if spec_voronoi and spec_voronoi.loader:
        voronoi_v7 = importlib.util.module_from_spec(spec_voronoi)
        spec_voronoi.loader.exec_module(voronoi_v7)
        print("Voronoi module loaded successfully!")
except Exception as e:
    print(f"Voronoi module not available: {e}")
    print("Voronoi analysis will be disabled.")

print("All modules loaded successfully!")

All modules loaded successfully!


## Pipeline Functions

In [None]:
# Modular pipeline functions
from skimage.measure import label, regionprops
import skimage
from skimage import color

def run_cnn_classification_step(image_path):
    """Step 1: Classify image using CNN (on binarized image) - SILENT"""
    img_path = Path(image_path)
    if not img_path.exists():
        raise FileNotFoundError(f"Image not found: {img_path}")
    
    # Create temporary binarized image for classification
    img = cv2.imread(str(img_path), cv2.IMREAD_GRAYSCALE)
    _, binary_img = cv2.threshold(img, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    
    temp_binary_path = str(img_path).replace('.png', '_temp_binary.png')
    cv2.imwrite(temp_binary_path, binary_img)
    
    # Load CNN model and classify - model is in the same directory
    model_path = pipeline_dir / "cnn_classifier.pth"
    cnn_model = cnn_inference.load_model(str(model_path))
    result = cnn_inference.predict_image(cnn_model, str(temp_binary_path))
    
    # Clean up temporary file
    if Path(temp_binary_path).exists():
        Path(temp_binary_path).unlink()
    
    return result

def run_segmentation_step(image_path, threshold=0.5, denoise=0, sharpen=0, invert=False):
    """Step 2: U-Net segmentation - SILENT"""
    mask_path = segmentation.segment_image(
        image_path=image_path,
        model_path=str(pipeline_dir / 'best_quality_unet.pt'),
        output_dir='segmentation_output',
        threshold=threshold,
        denoise=denoise,
        sharpen=sharpen,
        invert=False
    )
    
    # Optionally invert mask
    if invert:
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        mask = 255 - mask
        mask_path_inv = mask_path.replace('_mask.png', '_mask_inverted.png')
        cv2.imwrite(mask_path_inv, mask)
        mask_path = mask_path_inv
    
    return mask_path

def extract_dots_step(mask_path, min_circularity=0.6, max_aspect_ratio=1.8, min_area=15, max_area=400):
    """Step 3: Extract dots only - SILENT"""
    img = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
    _, binary = cv2.threshold(img, 127, 255, cv2.THRESH_BINARY)
    
    # Morphological opening
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
    opened = cv2.morphologyEx(binary, cv2.MORPH_OPEN, kernel, iterations=1)
    
    # Find connected components
    labels_img = label(opened > 0)
    props = regionprops(labels_img)
    
    # Filter by shape and create 5x5 dots
    output = np.zeros_like(binary)
    stats = {'total': len(props), 'kept': 0}
    
    for prop in props:
        area = prop.area
        perimeter = prop.perimeter
        circularity = 4 * np.pi * area / (perimeter ** 2) if perimeter > 0 else 0
        aspect = prop.major_axis_length / prop.minor_axis_length if prop.minor_axis_length > 0 else float('inf')
        
        if (min_area <= area <= max_area and circularity >= min_circularity and aspect <= max_aspect_ratio):
            cy, cx = map(int, prop.centroid)
            # Create consistent 5x5 dots
            y1, y2 = max(cy-2, 0), min(cy+3, output.shape[0])
            x1, x2 = max(cx-2, 0), min(cx+3, output.shape[1])
            output[y1:y2, x1:x2] = 255
            stats['kept'] += 1
    
    output_path = mask_path.replace('.png', '_DOTS_ONLY.png')
    cv2.imwrite(output_path, output)
    
    return output_path, stats

def run_voronoi_step(mask_path, image_size=1.0, threshold_edge=0.025, max_size=1024, auto_detect_features=True):
    """Step 4: Voronoi analysis - SILENT
    
    Args:
        mask_path: Path to the mask image
        image_size: Physical size of image in micrometers
        threshold_edge: Edge detection threshold
        max_size: Maximum image dimension
        auto_detect_features: If True, automatically detect which phase is features (minority)
    """
    if voronoi_v7 is None:
        print("Voronoi module not available. Please add voronoi_v7.py to the AFM_Web folder.")
        return {'error': 'Voronoi module not available'}
    
    # Load and process image
    im = Image.open(mask_path)
    im = im.convert("L")  # Convert to grayscale
    img_array = np.array(im)
    
    # Determine if we need to invert based on which phase is minority
    if auto_detect_features:
        # Count black vs white pixels
        black_pixels = np.sum(img_array < 128)
        white_pixels = np.sum(img_array >= 128)
        
        # If white pixels are minority, they are likely the features
        # Voronoi needs black features (0) on white background (1)
        if white_pixels < black_pixels:
            # White features on black background → invert to get black features on white
            img_array = 255 - img_array
            print(f"  Voronoi: Auto-detected WHITE features on BLACK background → Inverting")
        else:
            # Black features on white background → keep as is
            print(f"  Voronoi: Auto-detected BLACK features on WHITE background → No inversion")
    else:
        # Manual inversion (old behavior)
        img_array = 255 - img_array
        print(f"  Voronoi: Manual inversion applied")
    
    data = img_array.astype(float) / 255.0
    
    # Downsample if needed for consistent sizing
    if data.shape[0] > max_size or data.shape[1] > max_size:
        from skimage.transform import resize
        scale = max_size / max(data.shape)
        new_shape = (int(data.shape[0] * scale), int(data.shape[1] * scale))
        data = resize(data, new_shape, anti_aliasing=True, preserve_range=True)
    
    image_name = Path(mask_path).stem
    output_dir = 'voronoi_outputs'
    
    # Create output directory structure
    os.makedirs(output_dir, exist_ok=True)
    results_folder = os.path.join(output_dir, image_name)
    os.makedirs(results_folder, exist_ok=True)
    
    # Run Voronoi analysis
    results = voronoi_v7.analyze_image(
        image_data=data,
        image_name=image_name,
        image_size=image_size,
        save_image=True,
        show_image=False,
        save_location=output_dir,
        threshold_edge=threshold_edge
    )
    
    return results

def run_dan_binarization_step(image_path, method='adaptive', adaptive_method='gaussian',
                               block_size=11, C=2, niblack_window=25, k=0.1,
                               blur='none', equalize=False, clahe=False):
    """Step 5a: Dan's binarization - SILENT"""
    binary_path = dan_segmentation.binarize_image(
        image_path=image_path,
        method=method,
        adaptive_method=adaptive_method,
        block_size=block_size,
        C=C,
        niblack_window=niblack_window,
        k=k,
        blur=blur,
        equalize=equalize,
        clahe=clahe,
        output_dir='dan_binarized'
    )
    
    return binary_path

def run_dan_spacing_step(binary_path, image_size_um=2.0, invert=False):
    """Step 5b: Dan's spacing analysis - SILENT"""
    results = dan_segmentation.analyze_spacing(
        image_path=binary_path,
        image_size_um=image_size_um,
        output_dir='dan_spacing_output',
        invert=invert,
        save_viz=True
    )
    
    return results

def run_colorwheel_step(mask_path, num_clusters=8):
    """Step 6: Color wheel analysis for line orientations - Uses MASK not original - SILENT"""
    image_name = Path(mask_path).stem
    output_dir = f'colorwheel_output/{image_name}'
    
    # Run color wheel analysis on the MASK
    results = colorwheel.analyze_image(
        image_path=mask_path,
        output_dir=output_dir,
        num_clusters=num_clusters
    )
    
    return results

print("Pipeline functions ready")

Pipeline functions ready


## Automated Pipeline with Controls

In [None]:
# ===== INTERACTIVE STEP-BY-STEP PIPELINE =====

# Global state
state = {"step": 0}

# ===== STEP 1: IMAGE INPUT & CLASSIFICATION =====
step1_container = widgets.Output()

image_path_widget = widgets.Text(
    value=str(pipeline_dir / "Cnn_classifier_test" / "dots.png"),  # Default to test images folder
    placeholder="Enter image path...",
    description="Image Path:",
    style={"description_width": "100px"},
    layout=widgets.Layout(width="600px"),
)

classify_button = widgets.Button(
    description="Classify Image",
    button_style="primary",
    layout=widgets.Layout(width="200px", height="40px"),
)

step1_output = widgets.Output()


def on_classify(b):
    with step1_output:
        clear_output(wait=True)
        try:
            # Run CNN classification (silent)
            cnn_result = run_cnn_classification_step(image_path_widget.value)
            state["cnn_result"] = cnn_result
            state["input_image"] = image_path_widget.value
            state["step"] = 1

            predicted_class = cnn_result["predicted_class"]
            confidence = cnn_result["confidence"] * 100

            # Show classification result with image
            fig, ax = plt.subplots(1, 1, figsize=(6, 6))
            img = Image.open(image_path_widget.value)
            ax.imshow(img, cmap="gray")
            ax.set_title(
                f"Classification: {predicted_class.upper()} ({confidence:.1f}%)",
                fontsize=14,
                fontweight="bold",
            )
            ax.axis("off")
            plt.tight_layout()
            plt.show()

            # Enable next step
            if predicted_class in ["dots", "lines", "mixed"]:
                step2_container.layout.display = "block"
                step3_container.layout.display = "none"
            else:  # irregular
                step2_container.layout.display = "none"
                step3_container.layout.display = "block"

        except Exception as e:
            print(f"Error: {e}")


classify_button.on_click(on_classify)

# ===== STEP 2: DOTS/LINES/MIXED PATH =====
step2_container = widgets.Output(layout=widgets.Layout(display="none"))

# Preprocessing controls
seg_denoise = widgets.IntSlider(value=0, min=0, max=30, step=5, description="Denoise:")
seg_sharpen = widgets.IntSlider(value=0, min=0, max=10, step=1, description="Sharpen:")

# Postprocessing controls
seg_threshold = widgets.FloatSlider(
    value=0.5, min=0.1, max=0.9, step=0.05, description="Threshold:"
)
seg_remove_noise = widgets.IntSlider(
    value=0, min=0, max=5, step=1, description="Remove Noise:"
)
seg_invert = widgets.Checkbox(value=False, description="Invert Mask")

segment_button = widgets.Button(
    description="Segment Image",
    button_style="info",
    layout=widgets.Layout(width="200px", height="40px"),
)

step2_output = widgets.Output()


def on_segment(b):
    with step2_output:
        clear_output(wait=True)
        try:
            # Run segmentation (silent)
            mask_path = run_segmentation_step(
                state["input_image"],
                threshold=seg_threshold.value,
                denoise=seg_denoise.value,
                sharpen=seg_sharpen.value,
                invert=seg_invert.value,
            )

            # Apply noise removal if needed
            if seg_remove_noise.value > 0:
                mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
                kernel = cv2.getStructuringElement(
                    cv2.MORPH_ELLIPSE, (seg_remove_noise.value, seg_remove_noise.value)
                )
                mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
                cv2.imwrite(mask_path, mask)

            state["segmentation_mask"] = mask_path
            state["step"] = 2

            # Show 2-panel visualization
            fig, axes = plt.subplots(1, 2, figsize=(12, 6))

            # Original
            orig = Image.open(state["input_image"])
            axes[0].imshow(orig, cmap="gray")
            axes[0].set_title("Original Image", fontsize=12, fontweight="bold")
            axes[0].axis("off")

            # Segmented
            seg_img = Image.open(mask_path)
            axes[1].imshow(seg_img, cmap="gray")
            axes[1].set_title("U-Net Segmentation", fontsize=12, fontweight="bold")
            axes[1].axis("off")

            plt.tight_layout()
            plt.show()

            # Route to next step based on classification
            predicted_class = state["cnn_result"]["predicted_class"]

            if predicted_class == "dots":
                # Show extraction and voronoi only
                dot_extract_container.layout.display = "block"
                voronoi_container.layout.display = "none"  # Will show after extraction
                colorwheel_container.layout.display = "none"
            elif predicted_class == "lines":
                # Show color wheel only
                dot_extract_container.layout.display = "none"
                voronoi_container.layout.display = "none"
                colorwheel_container.layout.display = "block"
            elif predicted_class == "mixed":
                # Show extraction and both analysis options
                dot_extract_container.layout.display = "block"
                voronoi_container.layout.display = "none"  # Will show after extraction
                colorwheel_container.layout.display = "block"

        except Exception as e:
            print(f"Error: {e}")


segment_button.on_click(on_segment)

# Dot extraction controls
dot_extract_container = widgets.VBox(layout=widgets.Layout(display="none"))

dot_min_area = widgets.IntSlider(
    value=4, min=1, max=50, step=1, description="Min Area:"
)
dot_max_area = widgets.IntSlider(
    value=200, min=50, max=500, step=10, description="Max Area:"
)
dot_circularity = widgets.FloatSlider(
    value=0.7, min=0.1, max=1.0, step=0.05, description="Min Circular:"
)
dot_aspect_ratio = widgets.FloatSlider(
    value=2.5, min=1.0, max=5.0, step=0.1, description="Max Aspect:"
)

extract_dots_button = widgets.Button(
    description="Extract Dots",
    button_style="info",
    layout=widgets.Layout(width="200px", height="40px"),
)

step2_dot_output = widgets.Output()


# def on_extract_dots(b):
#     with step2_dot_output:
#         clear_output(wait=True)
#         try:
#             # Extract dots (silent)
#             dot_stats = extract_dots_step(
#                 state["segmentation_mask"],
#                 min_area=dot_min_area.value,
#                 max_area=dot_max_area.value,
#                 min_circularity=dot_circularity.value,
#                 max_aspect_ratio=dot_aspect_ratio.value,
#             )
#             state["dot_stats"] = dot_stats

#             # Display results
#             print(f"Extracted {dot_stats['kept']} dots")
#             print(f"(Rejected {dot_stats['rejected']} regions)")

#             # Enable voronoi analysis (always available after dot extraction)
#             voronoi_container.layout.display = "block"

#         except Exception as e:
#             print(f"Error: {e}")


def on_extract_dots(b):
    with step2_dot_output:
        clear_output(wait=True)
        try:
            # Extract dots (silent)
            dots_path, dot_stats = extract_dots_step(
                state["segmentation_mask"],
                min_circularity=dot_circularity.value,
                max_aspect_ratio=dot_aspect_ratio.value,
                min_area=dot_min_area.value,
                max_area=dot_max_area.value,
            )
            state["dots_mask"] = dots_path
            state["dot_stats"] = dot_stats
            state["step"] = 3

            # Show before/after
            fig, axes = plt.subplots(1, 2, figsize=(12, 6))

            mask_before = cv2.imread(state["segmentation_mask"], cv2.IMREAD_GRAYSCALE)
            axes[0].imshow(mask_before, cmap="gray")
            axes[0].set_title(f"Before Extraction", fontsize=12, fontweight="bold")
            axes[0].axis("off")

            mask_after = cv2.imread(dots_path, cv2.IMREAD_GRAYSCALE)
            axes[1].imshow(mask_after, cmap="gray")
            axes[1].set_title(
                f'Extracted Dots ({dot_stats["kept"]} features)',
                fontsize=12,
                fontweight="bold",
            )
            axes[1].axis("off")

            plt.tight_layout()
            plt.show()

            if dot_stats["kept"] >= 4:
                # Enable Voronoi controls
                voronoi_container.layout.display = "block"
            else:
                print(f"Few dots {dot_stats['kept']} dots found")

        except Exception as e:
            print(f"Error: {e}")



extract_dots_button.on_click(on_extract_dots)

dot_extract_container.children = [
    widgets.HTML("<h4>Dot Extraction Parameters</h4>"),
    dot_min_area,
    dot_max_area,
    dot_circularity,
    dot_aspect_ratio,
    extract_dots_button,
    step2_dot_output,
]

# Voronoi controls
voronoi_container = widgets.VBox(layout=widgets.Layout(display="none"))

voronoi_image_size = widgets.FloatText(value=1.0, description="Image Size (μm):")
voronoi_threshold_edge = widgets.FloatSlider(
    value=0.025,
    min=0.01,
    max=0.1,
    step=0.005,
    description="Edge Threshold:",
    readout_format=".3f",
)
voronoi_max_size = widgets.IntSlider(
    value=1024, min=512, max=2048, step=256, description="Max Size (px):"
)
voronoi_auto_detect = widgets.Checkbox(
    value=True, 
    description="Auto-detect features",
    tooltip="Automatically detect which phase is features (minority pixels)"
)

run_voronoi_button = widgets.Button(
    description="Run Voronoi Analysis",
    button_style="success",
    layout=widgets.Layout(width="200px", height="40px"),
)

step2_voronoi_output = widgets.Output()


def on_run_voronoi(b):
    with step2_voronoi_output:
        clear_output(wait=True)
        try:
            # Run voronoi analysis (silent) - uses extracted dots mask
            voronoi_results = run_voronoi_step(
                state["dots_mask"],
                image_size=voronoi_image_size.value,
                threshold_edge=voronoi_threshold_edge.value,
                max_size=voronoi_max_size.value,
                auto_detect_features=voronoi_auto_detect.value,
            )
            state["voronoi_results"] = voronoi_results

            # Show Voronoi visualization
            image_name = Path(state["dots_mask"]).stem
            overlay_path = (
                Path("voronoi_outputs")
                / image_name
                / f"{image_name}_voronoi_overlay.png"
            )
            # Display overlay image if it exists
            # output_dir = Path(f"voronoi_output/{Path(state['segmentation_mask']).stem}")
            # overlay_path = output_dir / "overlay_voronoi.png"

            if overlay_path.exists():
                fig, ax = plt.subplots(1, 1, figsize=(10, 10))
                img = Image.open(overlay_path)
                ax.imshow(img)
                ax.set_title("Voronoi Tessellation", fontsize=14, fontweight="bold")
                ax.axis("off")
                plt.tight_layout()
                plt.show()

            # Display key results
            print("=" * 50)
            print("VORONOI ANALYSIS RESULTS")
            print("=" * 50)
            print(f"Periodicity: {voronoi_results.get('periodicity', 'N/A'):.2f} nm")
            print(f"Block Ratio: {voronoi_results.get('block_ratio', 'N/A'):.4f}")
            print(f"Mean Morphology: {voronoi_results.get('mean_morph', 'N/A'):.2f}")
            print(f"Number of Dots: {state['dot_stats']['kept']}")
            print("=" * 50)

        except Exception as e:
            print(f"Error: {e}")


run_voronoi_button.on_click(on_run_voronoi)

voronoi_container.children = [
    widgets.HTML("<h4>Voronoi Analysis Parameters</h4>"),
    voronoi_image_size,
    voronoi_threshold_edge,
    voronoi_max_size,
    voronoi_auto_detect,
    run_voronoi_button,
    step2_voronoi_output,
]

# Color Wheel controls (for lines/mixed features)
colorwheel_container = widgets.VBox(layout=widgets.Layout(display="none"))

colorwheel_num_clusters = widgets.IntSlider(
    value=8, min=2, max=12, step=1, description="Num Clusters:"
)
colorwheel_invert = widgets.Checkbox(value=True, description="Invert Image")

run_colorwheel_button = widgets.Button(
    description="Run Color Wheel",
    button_style="success",
    layout=widgets.Layout(width="200px", height="40px"),
)

step2_colorwheel_output = widgets.Output()


def on_run_colorwheel(b):
    with step2_colorwheel_output:
        clear_output(wait=True)
        try:
            # Color wheel works on SEGMENTATION MASK (after U-Net)
            if "segmentation_mask" not in state:
                print("No segmentation mask available. Please run segmentation first.")
                return
            
            # Read segmentation mask
            mask_img = cv2.imread(state["segmentation_mask"], cv2.IMREAD_GRAYSCALE)
            
            # Apply invert if requested
            if colorwheel_invert.value:
                mask_img = 255 - mask_img
            
            # Show the mask being sent to color wheel
            # Hide it for now to reduce output clutter
            # fig, ax = plt.subplots(1, 1, figsize=(6, 6))
            # ax.imshow(mask_img, cmap='gray')
            # ax.set_title(f"Mask Input to Color Wheel {'(Inverted)' if colorwheel_invert.value else ''}", 
            #             fontsize=12, fontweight="bold")
            # ax.axis('off')
            # plt.tight_layout()
            # plt.show()
            
            # Save temporary mask for color wheel processing
            temp_dir = Path('colorwheel_temp')
            temp_dir.mkdir(exist_ok=True)
            temp_img_path = temp_dir / f"{Path(state['segmentation_mask']).stem}_colorwheel_input.png"
            cv2.imwrite(str(temp_img_path), mask_img)
            
            input_mask = str(temp_img_path)
            
            # Run color wheel analysis (silent)
            colorwheel_results = run_colorwheel_step(
                input_mask,
                num_clusters=colorwheel_num_clusters.value
            )
            state["colorwheel_results"] = colorwheel_results

            # Collect all grain mask images
            grain_masks = colorwheel_results.get("grain_masks", [])
            output_dir = Path(colorwheel_results["output_directory"])
            
            # Find all Mask_*.tiff files
            mask_files = sorted(output_dir.glob("Mask_*.tiff"))
            
            if not mask_files:
                print("No grain masks generated")
                return
            
            # Display results in 2x2 grid (or adjust based on number of masks)
            num_masks = len(mask_files)
            if num_masks == 0:
                print("No valid masks found")
                return
            
            # Create grid layout (2x2 for 4 masks, adjust as needed)
            ncols = 2
            nrows = (num_masks + 1) // 2
            
            fig, axes = plt.subplots(nrows, ncols, figsize=(12, 6 * nrows))
            
            # Flatten axes for easy iteration
            if nrows == 1 and ncols == 1:
                axes = [axes]
            elif nrows == 1 or ncols == 1:
                axes = axes.flatten()
            else:
                axes = axes.flatten()
            
            for idx, mask_file in enumerate(mask_files):
                if idx < len(axes):
                    img = Image.open(mask_file)
                    axes[idx].imshow(img)
                    axes[idx].set_title(f"Orientation Mask {idx}", fontsize=12, fontweight="bold")
                    axes[idx].axis("off")
            
            # Hide unused subplots
            for idx in range(len(mask_files), len(axes)):
                axes[idx].axis("off")
            
            plt.tight_layout()
            plt.show()

            # Display key results
            print("=" * 50)
            print("COLOR WHEEL ANALYSIS RESULTS")
            print("=" * 50)
            print(f"Number of Clusters: {colorwheel_results.get('num_clusters', 'N/A')}")
            print(f"Output Directory: {colorwheel_results.get('output_directory', 'N/A')}")
            print(f"Number of Grain Masks: {len(mask_files)}")
            print(f"GPU Accelerated: {colorwheel_results.get('gpu_accelerated', False)}")
            print("=" * 50)

        except Exception as e:
            import traceback
            print(f"Error: {e}")
            print(traceback.format_exc())


run_colorwheel_button.on_click(on_run_colorwheel)

colorwheel_container.children = [
    widgets.HTML("<h4>Color Wheel Analysis (Lines/Mixed)</h4>"),
    colorwheel_num_clusters,
    run_colorwheel_button,
    step2_colorwheel_output,
]

# ===== STEP 3: IRREGULAR PATH =====
step3_container = widgets.Output(layout=widgets.Layout(display="none"))

# Dan's binarization method widgets
common_layout = widgets.Layout(width="320px")

use_adaptive = widgets.Checkbox(value=True, description="Use Adaptive")
use_niblack = widgets.Checkbox(value=True, description="Use Niblack")

combine_method = widgets.Dropdown(
    options=["adaptive", "niblack", "AND", "OR", "weighted"],
    value="OR",
    description="Combine"
)
alpha_weight = widgets.FloatSlider(
    value=0.5, min=0.0, max=1.0, step=0.05, description="Alpha (weighted only)"
)

niblack_k = widgets.FloatSlider(
    value=0.1, min=-1.0, max=1.0, step=0.05, description="k"
)
niblack_window = widgets.IntSlider(
    value=25, min=3, max=111, step=2, description="Block size"
)

adaptive_method = widgets.ToggleButtons(
    options={"MEAN": 0, "GAUSSIAN": 1},
    description="Adaptive Method"
)
adaptive_block_size = widgets.IntSlider(
    value=11, min=3, max=101, step=2, description="Block size"
)
adaptive_C = widgets.IntSlider(
    value=2, min=-50, max=30, step=1, description="C"
)

blur_method = widgets.Dropdown(
    options=["none", "gaussian", "median", "bilateral"],
    value="none",
    description="Preprocess Blur"
)
blur_ksize = widgets.IntSlider(
    value=5, min=3, max=21, step=2, description="Blur kernel size"
)
dan_equalize = widgets.Checkbox(value=False, description="Equalize")
dan_clahe = widgets.Checkbox(value=False, description="CLAHE")

# === Conditional Display Logic ===
def toggle_niblack_widgets(change):
    enabled = change['new']
    niblack_k.disabled = not enabled
    niblack_window.disabled = not enabled

def toggle_adaptive_widgets(change):
    enabled = change['new']
    adaptive_method.disabled = not enabled
    adaptive_block_size.disabled = not enabled
    adaptive_C.disabled = not enabled

def toggle_alpha_slider(change):
    alpha_weight.disabled = (change['new'] != "weighted")

use_niblack.observe(toggle_niblack_widgets, names='value')
use_adaptive.observe(toggle_adaptive_widgets, names='value')
combine_method.observe(toggle_alpha_slider, names='value')

# Placeholders for layout spacing
niblack_placeholder = widgets.HTML(value="", layout=widgets.Layout(height="0px"))
adaptive_placeholder = widgets.HTML(value="", layout=widgets.Layout(height="0px"))
preprocess_placeholder = widgets.HTML(value="", layout=widgets.Layout(height="0px"))

# Initialize display states
toggle_niblack_widgets({'new': use_niblack.value})
toggle_adaptive_widgets({'new': use_adaptive.value})
toggle_alpha_slider({'new': combine_method.value})

# Apply common layout to all widgets
for w in [
    use_adaptive, use_niblack, combine_method, alpha_weight,
    niblack_k, niblack_window,
    adaptive_method, adaptive_block_size, adaptive_C,
    blur_method, blur_ksize,
    dan_equalize, dan_clahe
]:
    w.layout = common_layout
    if hasattr(w, 'style'):
        w.style.description_width = 'initial'

# Interactive preview function (updates automatically as sliders change)
def interactive_preview_binary(
    use_adapt,
    use_nib,
    comb_method,
    alpha_w,
    nib_k,
    nib_win,
    adapt_method,
    adapt_block,
    adapt_c,
    blur_m,
    blur_k,
    equalize,
    clahe,
):
    """Shows live preview of binarization with current parameters."""
    if "input_image" not in state:
        print("No image loaded. Please classify an image first.")
        return

    img = cv2.imread(state["input_image"], cv2.IMREAD_GRAYSCALE)

    # Preprocess
    if blur_m != "none":
        if blur_m == "gaussian":
            img = cv2.GaussianBlur(img, (blur_k, blur_k), 0)
        elif blur_m == "median":
            img = cv2.medianBlur(img, blur_k)
        elif blur_m == "bilateral":
            img = cv2.bilateralFilter(img, blur_k, 75, 75)

    if equalize:
        img = cv2.equalizeHist(img)

    if clahe:
        clahe_obj = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
        img = clahe_obj.apply(img)

    # Binarize - match Dan's implementation
    thresh_adaptive, thresh_niblack = None, None
    
    if use_adapt:
        thresh_adaptive = cv2.adaptiveThreshold(
            img, 255, adapt_method, cv2.THRESH_BINARY_INV, adapt_block, adapt_c
        )
    
    if use_nib:
        t_niblack = threshold_niblack(img, window_size=nib_win, k=nib_k)
        thresh_niblack = (img < t_niblack).astype(np.uint8) * 255
    
    # Combine or select
    final = None
    if use_adapt and not use_nib:
        final = thresh_adaptive
    elif use_nib and not use_adapt:
        final = thresh_niblack
    elif use_adapt and use_nib:
        if comb_method == "AND":
            final = cv2.bitwise_and(thresh_adaptive, thresh_niblack)
        elif comb_method == "OR":
            final = cv2.bitwise_or(thresh_adaptive, thresh_niblack)
        elif comb_method == "adaptive":
            final = thresh_adaptive
        elif comb_method == "niblack":
            final = thresh_niblack
        elif comb_method == "weighted":
            blend = (
                alpha_w * thresh_adaptive.astype(np.float32) +
                (1 - alpha_w) * thresh_niblack.astype(np.float32)
            )
            final = (blend > 127).astype(np.uint8) * 255
    
    if final is None:
        print("Must enable at least one method!")
        return

    # Store for saving
    state["preview_binary"] = final

    # Display 3-panel visualization
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))

    # Create overlay with magenta/pink color
    mask = final == 255
    overlay = np.zeros((*final.shape, 4), dtype=np.float32)
    overlay[mask] = [0.988, 0.102, 0.973, 0.75]  # Magenta with transparency

    axes[0].imshow(img, cmap="gray")
    axes[0].set_title("Original Gray", fontsize=12, fontweight="bold")
    axes[0].axis("off")

    axes[1].imshow(cv2.imread(state["input_image"], cv2.IMREAD_GRAYSCALE), cmap="gray")
    axes[1].imshow(overlay)
    axes[1].set_title("Processed + Overlay", fontsize=12, fontweight="bold")
    axes[1].axis("off")

    axes[2].imshow(cv2.imread(state["input_image"], cv2.IMREAD_GRAYSCALE), cmap="gray")
    axes[2].imshow(overlay)
    axes[2].set_title("Original + Overlay", fontsize=12, fontweight="bold")
    axes[2].axis("off")

    plt.tight_layout()
    plt.show()


# Create interactive widget
interactive_preview = widgets.interactive(
    interactive_preview_binary,
    use_adapt=use_adaptive,
    use_nib=use_niblack,
    comb_method=combine_method,
    alpha_w=alpha_weight,
    nib_k=niblack_k,
    nib_win=niblack_window,
    adapt_method=adaptive_method,
    adapt_block=adaptive_block_size,
    adapt_c=adaptive_C,
    blur_m=blur_method,
    blur_k=blur_ksize,
    equalize=dan_equalize,
    clahe=dan_clahe,
)

step3_output = widgets.Output()

# Final binarization button
def on_binarize(b):
    with step3_output:
        clear_output(wait=True)
        try:
            # Check if preview exists
            if "preview_binary" not in state:
                print("Please preview binarization first!")
                return

            # Save the binarized image
            output_dir = "dan_binarized"
            os.makedirs(output_dir, exist_ok=True)

            filename = Path(state["input_image"]).stem
            binary_path = os.path.join(output_dir, f"{filename}_binary.png")
            cv2.imwrite(binary_path, state["preview_binary"])

            state["dan_binary"] = binary_path

            print(f"Binarized image saved: {binary_path}")

            # Enable spacing analysis
            spacing_container.layout.display = "block"

        except Exception as e:
            print(f"Error: {e}")



# Final binarization button
binarize_button = widgets.Button(
    description="Save Binarization",
    button_style="success",
    layout=widgets.Layout(width="200px", height="40px"),
)

binarize_button.on_click(on_binarize)

# Spacing analysis controls
spacing_container = widgets.VBox(layout=widgets.Layout(display="none"))

spacing_image_size = widgets.FloatText(value=2.0, description="Image Size (μm):")
spacing_invert = widgets.Checkbox(value=False, description="Invert Features")

run_spacing_button = widgets.Button(
    description="Analyze Spacing",
    button_style="warning",
    layout=widgets.Layout(width="200px", height="40px"),
)

step3_spacing_output = widgets.Output()


def on_run_spacing(b):
    with step3_spacing_output:
        clear_output(wait=True)
        try:
            # Determine which binary to use
            if "dan_binary" in state:
                binary_path = state["dan_binary"]
            elif "segmentation_mask" in state:
                binary_path = state["segmentation_mask"]
            else:
                print("No binary mask available!")
                return

            # Run spacing analysis (silent)
            spacing_results = run_dan_spacing_step(
                binary_path,
                image_size_um=spacing_image_size.value,
                invert=spacing_invert.value,
            )
            state["spacing_results"] = spacing_results

            # Display skeleton and overlay images (match actual output directory)
            filename_stem = Path(binary_path).stem
            
            # Check for histogram and overlay files
            histogram_path = Path(f"dan_spacing_output/{filename_stem}_spacing_histogram.png")
            overlay_path = Path(f"dan_spacing_output/{filename_stem}_voronoi_overlay.png")

            if histogram_path.exists() and overlay_path.exists():
                fig, axes = plt.subplots(1, 2, figsize=(16, 8))

                hist_img = Image.open(histogram_path)
                axes[0].imshow(hist_img)
                axes[0].set_title("Spacing Histogram", fontsize=12, fontweight="bold")
                axes[0].axis("off")

                overlay_img = Image.open(overlay_path)
                axes[1].imshow(overlay_img)
                axes[1].set_title(
                    "Voronoi Skeleton Overlay",
                    fontsize=12,
                    fontweight="bold",
                )
                axes[1].axis("off")

                plt.tight_layout()
                plt.show()

            # Display results
            print("=" * 50)
            print("SPACING ANALYSIS RESULTS")
            print("=" * 50)
            print(f"Mean Spacing: {spacing_results['mean_spacing_nm']:.2f} nm")
            print(f"Median Spacing: {spacing_results['median_spacing_nm']:.2f} nm")
            print(f"Std Deviation: {spacing_results['std_spacing_nm']:.2f} nm")
            print("=" * 50)

        except Exception as e:
            print(f"Error: {e}")


run_spacing_button.on_click(on_run_spacing)

spacing_container.children = [
    widgets.HTML("<h4>Spacing Analysis Parameters</h4>"),
    spacing_image_size,
    spacing_invert,
    run_spacing_button,
    step3_spacing_output,
]

# ===== MAIN UI LAYOUT =====

# Build step 2 content (will be hidden initially)
step2_content = widgets.VBox(
    [
        widgets.HTML("<h3>Step 2: Segmentation & Analysis</h3>"),
        widgets.HTML("<h4>U-Net Segmentation Parameters</h4>"),
        seg_denoise,
        seg_sharpen,
        seg_threshold,
        seg_remove_noise,
        seg_invert,
        segment_button,
        step2_output,
        dot_extract_container,
        voronoi_container,
        colorwheel_container,
    ]
)

# Build step 3 content (will be hidden initially) - 3-column layout like Dan's
ui_left_dan = widgets.VBox([
    use_adaptive,
    use_niblack,
    combine_method,
    alpha_weight,
    niblack_placeholder,
    niblack_k,
    niblack_window
], layout=widgets.Layout(padding="10px"))

ui_middle_dan = widgets.VBox([
    adaptive_placeholder,
    adaptive_method,
    adaptive_block_size,
    adaptive_C
], layout=widgets.Layout(padding="10px"))

ui_right_dan = widgets.VBox([
    preprocess_placeholder,
    blur_method,
    blur_ksize,
    dan_equalize,
    dan_clahe
], layout=widgets.Layout(padding="10px"))

step3_content = widgets.VBox(
    [
        widgets.HTML("<h3>Step 3: Dan's Binarization & Spacing Analysis</h3>"),
        widgets.HTML("<h4>Binarization Parameters (Interactive Preview)</h4>"),
        widgets.HBox([ui_left_dan, ui_middle_dan, ui_right_dan], 
                     layout=widgets.Layout(justify_content='space-between')),
        interactive_preview.children[-1],  # Only show the output, not the widget controls
        binarize_button,
        step3_output,
        spacing_container,
    ]
)

# Main layout with 3 columns
ui_left = widgets.VBox(
    [
        widgets.HTML("<h3>Step 1: Image Input & Classification</h3>"),
        image_path_widget,
        classify_button,
        step1_output,
    ]
)

ui_middle = widgets.VBox([step2_container])
ui_right = widgets.VBox([step3_container])

# Display step2/step3 content inside their containers
with step2_container:
    display(step2_content)

with step3_container:
    display(step3_content)

# Main UI
ui = widgets.VBox([ui_left, ui_middle, ui_right])
display(ui)

VBox(children=(VBox(children=(HTML(value='<h3>Step 1: Image Input & Classification</h3>'), Text(value='/home/n…