In [None]:
import os
from pathlib import Path

import cv2
import matplotlib.pyplot as plt
import numpy as np
from sklearn.cluster import DBSCAN


class AutoSubplotSplitter:
    def __init__(
        self,
        min_subplot_area=50000,
        padding=10,
        debug=False,
        num_subplots=0,
        expand_factor=0.3,
    ):
        self.min_subplot_area = min_subplot_area
        self.padding = padding
        self.debug = debug
        self.num_subplots = num_subplots
        self.expand_factor = (
            expand_factor  # How much to expand grid cells (0.3 = 30%)
        )

    def detect_subplots(self, image_path):
        """Automatically detect and extract subplots from an image"""
        # Expand path and check if file exists
        image_path = os.path.expanduser(image_path)
        image_path = str(Path(image_path).resolve())

        if not os.path.exists(image_path):
            raise FileNotFoundError(f"Image file not found: {image_path}")

        # Load image
        img = cv2.imread(image_path)
        if img is None:
            raise ValueError(f"Could not load image from: {image_path}")

        gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

        if self.debug:
            print(f"Image shape: {img.shape}")
            if self.num_subplots > 0:
                print(f"Target number of subplots: {self.num_subplots}")
            cv2.imwrite("debug_original.png", img)

        # Run all detection methods and collect results
        methods = [
            ("whitespace", self._detect_by_whitespace_improved),
            ("lines", self._detect_by_lines_improved),
            ("contours", self._detect_by_contours_improved),
        ]

        all_candidates = []

        for method_name, method_func in methods:
            if self.debug:
                print(f"\nTrying method: {method_name}")

            subplots = method_func(gray, img)

            if self.debug:
                print(f"Method {method_name} found {len(subplots)} subplots")

            if subplots:
                all_candidates.append((method_name, subplots))

        # If no specific number requested, use original logic
        if self.num_subplots <= 0:
            for method_name, subplots in all_candidates:
                if 2 <= len(subplots) <= 12:
                    if self.debug:
                        print(f"Using method: {method_name}")
                    return subplots

            # Fallback
            return all_candidates[0][1] if all_candidates else []

        # Specific number requested - find best selection
        best_selection = self._select_best_subplot_arrangement(
            all_candidates, img
        )

        if self.debug:
            print(f"Final selection: {len(best_selection)} subplots")

        return best_selection

    def _detect_by_whitespace_improved(self, gray, img):
        """Improved whitespace detection for subplot boundaries"""
        h, w = gray.shape

        # Create projection profiles
        h_profile = np.mean(gray, axis=1)  # Row averages
        v_profile = np.mean(gray, axis=0)  # Column averages

        if self.debug:
            plt.figure(figsize=(12, 4))
            plt.subplot(1, 2, 1)
            plt.plot(h_profile)
            plt.title("Horizontal Profile")
            plt.subplot(1, 2, 2)
            plt.plot(v_profile)
            plt.title("Vertical Profile")
            plt.savefig("debug_profiles.png")
            plt.close()

        # Find significant gaps (whitespace between subplots)
        h_gaps = self._find_significant_gaps(h_profile, axis="horizontal")
        v_gaps = self._find_significant_gaps(v_profile, axis="vertical")

        if self.debug:
            print(f"Horizontal gaps: {h_gaps}")
            print(f"Vertical gaps: {v_gaps}")

        # Create subplot grid
        h_boundaries = [0, *h_gaps, h]
        v_boundaries = [0, *v_gaps, w]

        # Remove boundaries that are too close
        h_boundaries = self._filter_close_boundaries(
            h_boundaries, min_distance=h // 10
        )
        v_boundaries = self._filter_close_boundaries(
            v_boundaries, min_distance=w // 10
        )

        subplots = []
        for i in range(len(h_boundaries) - 1):
            for j in range(len(v_boundaries) - 1):
                y1, y2 = int(h_boundaries[i]), int(h_boundaries[i + 1])
                x1, x2 = int(v_boundaries[j]), int(v_boundaries[j + 1])

                # Check minimum size
                area = (y2 - y1) * (x2 - x1)
                if area < self.min_subplot_area:
                    continue

                # Check aspect ratio (plots shouldn't be too thin)
                aspect_ratio = (x2 - x1) / (y2 - y1)
                if aspect_ratio < 0.3 or aspect_ratio > 3.0:
                    continue

                subplot = img[y1:y2, x1:x2]
                subplots.append(
                    {
                        "image": subplot,
                        "bbox": (x1, y1, x2, y2),
                        "method": "whitespace_improved",
                    }
                )

        return subplots

    def _find_significant_gaps(
        self, profile, axis="horizontal", min_gap_size=None
    ):
        """Find significant whitespace gaps in intensity profile"""
        if min_gap_size is None:
            min_gap_size = len(profile) // 20  # At least 5% of image dimension

        # Normalize profile
        profile_norm = (profile - np.min(profile)) / (
            np.max(profile) - np.min(profile)
        )

        # Find peaks (whitespace)
        # Use adaptive threshold based on profile statistics
        threshold = np.mean(profile_norm) + 0.5 * np.std(profile_norm)

        gaps = []
        in_gap = False
        gap_start = 0

        for i, value in enumerate(profile_norm):
            if value > threshold and not in_gap:
                in_gap = True
                gap_start = i
            elif value <= threshold and in_gap:
                gap_end = i
                gap_size = gap_end - gap_start
                if gap_size > min_gap_size:
                    # Use center of gap as boundary
                    gaps.append((gap_start + gap_end) // 2)
                in_gap = False

        return gaps

    def _filter_close_boundaries(self, boundaries, min_distance):
        """Remove boundaries that are too close together"""
        if len(boundaries) <= 2:
            return boundaries

        filtered = [boundaries[0]]
        for boundary in boundaries[1:]:
            if boundary - filtered[-1] >= min_distance:
                filtered.append(boundary)

        return filtered

    def _detect_by_lines_improved(self, gray, img):
        """Improved line detection for grid-based subplots"""
        # Edge detection with better parameters
        edges = cv2.Canny(gray, 30, 100, apertureSize=3)

        if self.debug:
            cv2.imwrite("debug_edges.png", edges)

        # Detect lines with more restrictive parameters
        lines = cv2.HoughLinesP(
            edges,
            1,
            np.pi / 180,
            threshold=max(100, min(gray.shape) // 4),
            minLineLength=min(gray.shape) // 3,
            maxLineGap=20,
        )

        if lines is None:
            return []

        # Filter and classify lines
        h_lines = []
        v_lines = []

        for line in lines:
            x1, y1, x2, y2 = line[0]
            length = np.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)

            # Only consider long lines
            if length < min(gray.shape) // 4:
                continue

            angle = np.arctan2(y2 - y1, x2 - x1) * 180 / np.pi

            if abs(angle) < 15 or abs(angle) > 165:  # Horizontal
                y_pos = (y1 + y2) / 2
                h_lines.append(y_pos)
            elif 75 < abs(angle) < 105:  # Vertical
                x_pos = (x1 + x2) / 2
                v_lines.append(x_pos)

        # Cluster similar lines
        h_positions = self._cluster_lines(h_lines, eps=gray.shape[0] // 20)
        v_positions = self._cluster_lines(v_lines, eps=gray.shape[1] // 20)

        if self.debug:
            print(
                f"Line-based boundaries - H: {h_positions}, V: {v_positions}"
            )

        # Create subplots
        return self._create_subplots_from_boundaries(
            h_positions, v_positions, img, "lines_improved"
        )

    def _detect_by_contours_improved(self, gray, img):
        """Much more selective contour detection"""
        # Use adaptive threshold
        binary = cv2.adaptiveThreshold(
            gray,
            255,
            cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
            cv2.THRESH_BINARY_INV,
            11,
            2,
        )

        # Heavy morphological operations to merge small elements
        kernel = np.ones((15, 15), np.uint8)
        binary = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel)
        binary = cv2.morphologyEx(binary, cv2.MORPH_OPEN, kernel)

        if self.debug:
            cv2.imwrite("debug_binary.png", binary)

        contours, _ = cv2.findContours(
            binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
        )

        valid_contours = []
        img_area = gray.shape[0] * gray.shape[1]

        for contour in contours:
            area = cv2.contourArea(contour)

            # Much stricter area filtering
            if area < self.min_subplot_area or area > img_area * 0.8:
                continue

            x, y, w, h = cv2.boundingRect(contour)

            # Strict aspect ratio and size requirements
            aspect_ratio = w / h
            if not (0.4 < aspect_ratio < 2.5):
                continue

            # Must be reasonably sized relative to image
            if w < gray.shape[1] * 0.15 or h < gray.shape[0] * 0.15:
                continue

            valid_contours.append((x, y, w, h, area))

        # Sort by area and take largest reasonable number
        valid_contours.sort(key=lambda x: x[4], reverse=True)
        valid_contours = valid_contours[:9]  # Max 9 subplots

        subplots = []
        for x, y, w, h, _ in valid_contours:
            x1 = max(0, x - self.padding)
            y1 = max(0, y - self.padding)
            x2 = min(img.shape[1], x + w + self.padding)
            y2 = min(img.shape[0], y + h + self.padding)

            subplot = img[y1:y2, x1:x2]
            subplots.append(
                {
                    "image": subplot,
                    "bbox": (x1, y1, x2, y2),
                    "method": "contours_improved",
                }
            )

        return subplots

    def _create_subplots_from_boundaries(
        self, h_positions, v_positions, img, method_name
    ):
        """Create subplots from boundary positions"""
        h, w = img.shape[:2]

        # Ensure positions are lists and not None
        if h_positions is None:
            h_positions = []
        if v_positions is None:
            v_positions = []

        h_boundaries = [0, *sorted(h_positions), h]
        v_boundaries = [0, *sorted(v_positions), w]

        # Filter boundaries that are too close
        h_boundaries = self._filter_close_boundaries(h_boundaries, h // 15)
        v_boundaries = self._filter_close_boundaries(v_boundaries, w // 15)

        if self.debug:
            print(f"Final boundaries - H: {h_boundaries}, V: {v_boundaries}")

        subplots = []
        for i in range(len(h_boundaries) - 1):
            for j in range(len(v_boundaries) - 1):
                y1, y2 = int(h_boundaries[i]), int(h_boundaries[i + 1])
                x1, x2 = int(v_boundaries[j]), int(v_boundaries[j + 1])

                area = (y2 - y1) * (x2 - x1)
                if area < self.min_subplot_area:
                    continue

                aspect_ratio = (x2 - x1) / (y2 - y1)
                if not (0.3 < aspect_ratio < 3.0):
                    continue

                subplot = img[y1:y2, x1:x2]
                subplots.append(
                    {
                        "image": subplot,
                        "bbox": (x1, y1, x2, y2),
                        "method": method_name,
                    }
                )

        return subplots

    def _cluster_lines(self, positions, eps=20):
        """Cluster line positions to find distinct boundaries"""
        if not positions:
            return []

        positions = np.array(positions).reshape(-1, 1)

        if len(positions) < 2:
            return []

        # Use DBSCAN to cluster nearby lines
        clustering = DBSCAN(eps=eps, min_samples=1).fit(positions)

        # Get centroid of each cluster
        unique_labels = set(clustering.labels_)
        centroids = []

        for label in unique_labels:
            if label != -1:  # Ignore noise
                cluster_points = positions[clustering.labels_ == label]
                centroids.append(np.mean(cluster_points))

    def _select_best_subplot_arrangement(self, all_candidates, img):
        """Select the best subplot arrangement from all detection methods"""
        if not all_candidates:
            return []

        best_selection = []
        best_score = float("inf")

        # Try each method's results
        for method_name, subplots in all_candidates:
            if not subplots:
                continue

            # If this method gives exactly what we want, great!
            if len(subplots) == self.num_subplots:
                if self.debug:
                    print(
                        f"Method {method_name} gives exact count: {len(subplots)}"
                    )
                return subplots

            # If too many subplots, select the best ones
            if len(subplots) > self.num_subplots:
                selected = self._select_best_subplots(
                    subplots, self.num_subplots
                )
                score = self._score_subplot_arrangement(selected, img)

                if self.debug:
                    print(
                        f"Method {method_name}: selected {len(selected)} from {len(subplots)}, score: {score:.2f}"
                    )

                if score < best_score:
                    best_score = score
                    best_selection = selected

            # If too few subplots, see if we can intelligently split
            elif len(subplots) < self.num_subplots:
                if (
                    len(subplots) >= self.num_subplots // 2
                ):  # Only try splitting if reasonably close
                    expanded = self._smart_expand_subplots(
                        subplots, img, self.num_subplots
                    )
                    if len(expanded) == self.num_subplots:
                        score = self._score_subplot_arrangement(expanded, img)

                        if self.debug:
                            print(
                                f"Method {method_name}: expanded to {len(expanded)}, score: {score:.2f}"
                            )

                        if score < best_score:
                            best_score = score
                            best_selection = expanded

        # If we still don't have a good selection, try the largest set and select from it
        if not best_selection:
            largest_set = max(all_candidates, key=lambda x: len(x[1]))
            method_name, subplots = largest_set

            if len(subplots) >= self.num_subplots:
                best_selection = self._select_best_subplots(
                    subplots, self.num_subplots
                )
                if self.debug:
                    print(
                        f"Using largest set from {method_name}: {len(best_selection)} subplots"
                    )

            # Last resort: create a simple grid if detection completely failed
            if not best_selection and self.num_subplots > 0:
                if self.debug:
                    print("All methods failed, falling back to expanded grid")
                best_selection = self._create_simple_grid(
                    img, self.num_subplots, self.expand_factor
                )

        return best_selection

    def _score_subplot_arrangement(self, subplots, img):
        """Score a subplot arrangement based on quality metrics"""
        if not subplots:
            return float("inf")

        h, w = img.shape[:2]
        total_score = 0

        for subplot in subplots:
            x1, y1, x2, y2 = subplot["bbox"]

            # Size score (prefer reasonable sizes)
            area = (x2 - x1) * (y2 - y1)
            size_score = abs(area - (h * w / len(subplots))) / (h * w)

            # Aspect ratio score (prefer reasonable aspect ratios)
            aspect_ratio = (x2 - x1) / (y2 - y1)
            aspect_score = (
                abs(aspect_ratio - 1.0) if 0.5 <= aspect_ratio <= 2.0 else 1.0
            )

            # Position score (slight preference for grid-like arrangements)

            position_score = 0  # Could add grid regularity scoring here

            subplot_score = (
                size_score + aspect_score * 0.3 + position_score * 0.1
            )
            total_score += subplot_score

        # Coverage score (prefer arrangements that cover the image well)
        coverage_score = self._calculate_coverage_score(subplots, img)

        # Overlap penalty
        overlap_penalty = self._calculate_overlap_penalty(subplots)

        final_score = (
            total_score + coverage_score * 0.5 + overlap_penalty * 2.0
        )

        return final_score

    def _calculate_coverage_score(self, subplots, img):
        """Calculate how well the subplots cover the image"""
        h, w = img.shape[:2]
        total_image_area = h * w

        total_subplot_area = sum(
            (x2 - x1) * (y2 - y1)
            for x1, y1, x2, y2 in [s["bbox"] for s in subplots]
        )

        # Prefer arrangements that cover a reasonable portion of the image
        coverage_ratio = total_subplot_area / total_image_area
        ideal_coverage = 0.7  # 70% coverage is reasonable

        return abs(coverage_ratio - ideal_coverage)

    def _calculate_overlap_penalty(self, subplots):
        """Penalize overlapping subplots"""
        penalty = 0

        for i, subplot1 in enumerate(subplots):
            for j, subplot2 in enumerate(subplots[i + 1 :], i + 1):
                x1a, y1a, x2a, y2a = subplot1["bbox"]
                x1b, y1b, x2b, y2b = subplot2["bbox"]

                # Check for overlap
                overlap_x = max(0, min(x2a, x2b) - max(x1a, x1b))
                overlap_y = max(0, min(y2a, y2b) - max(y1a, y1b))
                overlap_area = overlap_x * overlap_y

                if overlap_area > 0:
                    area1 = (x2a - x1a) * (y2a - y1a)
                    area2 = (x2b - x1b) * (y2b - y1b)
                    min_area = min(area1, area2)
                    overlap_ratio = overlap_area / min_area
                    penalty += overlap_ratio

        return penalty

    def _smart_expand_subplots(self, subplots, img, target_count):
        """Intelligently expand subplot count by splitting the most suitable ones"""
        if len(subplots) >= target_count:
            return subplots

        current_subplots = subplots.copy()
        splits_needed = target_count - len(subplots)

        if self.debug:
            print(
                f"Need to split {splits_needed} subplots to reach {target_count}"
            )

        # Sort subplots by area (largest first)
        subplot_areas = [
            (i, (s["bbox"][2] - s["bbox"][0]) * (s["bbox"][3] - s["bbox"][1]))
            for i, s in enumerate(current_subplots)
        ]
        subplot_areas.sort(key=lambda x: x[1], reverse=True)

        splits_made = 0
        for subplot_idx, area in subplot_areas:
            if splits_made >= splits_needed:
                break

            subplot_to_split = current_subplots[subplot_idx]
            x1, y1, x2, y2 = subplot_to_split["bbox"]
            width = x2 - x1
            height = y2 - y1

            # Only split if subplot is large enough
            if width > 100 and height > 100:
                # Create 2 new subplots
                if width > height:
                    # Split vertically
                    mid_x = x1 + width // 2

                    left_subplot = {
                        "image": img[y1:y2, x1:mid_x],
                        "bbox": (x1, y1, mid_x, y2),
                        "method": "smart_split_vertical",
                    }
                    right_subplot = {
                        "image": img[y1:y2, mid_x:x2],
                        "bbox": (mid_x, y1, x2, y2),
                        "method": "smart_split_vertical",
                    }

                    current_subplots.extend([left_subplot, right_subplot])
                else:
                    # Split horizontally
                    mid_y = y1 + height // 2

                    top_subplot = {
                        "image": img[y1:mid_y, x1:x2],
                        "bbox": (x1, y1, x2, mid_y),
                        "method": "smart_split_horizontal",
                    }
                    bottom_subplot = {
                        "image": img[mid_y:y2, x1:x2],
                        "bbox": (x1, mid_y, x2, y2),
                        "method": "smart_split_horizontal",
                    }

                    current_subplots.extend([top_subplot, bottom_subplot])

                splits_made += 1

                # Update indices after modification
                subplot_areas = [
                    (i if i < subplot_idx else i + 1, area)
                    for i, area in subplot_areas
                    if i != subplot_idx
                ]

        return current_subplots

    def _select_best_subplots(self, subplots, target_count):
        """Select the best subplots based on multiple quality criteria"""
        if len(subplots) <= target_count:
            return subplots

        if self.debug:
            print(
                f"Selecting {target_count} best subplots from {len(subplots)} candidates"
            )

        # Score each subplot
        scored_subplots = []

        for i, subplot in enumerate(subplots):
            x1, y1, x2, y2 = subplot["bbox"]
            area = (x2 - x1) * (y2 - y1)
            aspect_ratio = (x2 - x1) / (y2 - y1)

            # Area score (prefer medium-sized subplots)
            area_score = area

            # Aspect ratio score (prefer reasonable ratios)
            if 0.5 <= aspect_ratio <= 2.0:
                aspect_score = 100  # Good aspect ratio
            else:
                aspect_score = 50 / (
                    1 + abs(aspect_ratio - 1.0)
                )  # Penalize extreme ratios

            # Could add more sophisticated positioning logic here
            position_score = 50  # Neutral for now

            total_score = area_score + aspect_score + position_score
            scored_subplots.append((total_score, i, subplot))

        # Sort by score (highest first) and take the best ones
        scored_subplots.sort(key=lambda x: x[0], reverse=True)

        selected = [
            subplot for _, _, subplot in scored_subplots[:target_count]
        ]

        if self.debug:
            print(
                f"Selected subplots with scores: {[score for score, _, _ in scored_subplots[:target_count]]}"
            )

        return selected

    def _create_simple_grid(self, img, target_count, expand_factor=0.3):
        """Create a simple grid with expanded boundaries to capture legends/labels"""
        h, w = img.shape[:2]

        # Find best grid dimensions
        best_rows, best_cols = 1, target_count
        min_aspect_diff = float("inf")

        for rows in range(1, target_count + 1):
            if target_count % rows == 0:
                cols = target_count // rows
                aspect_diff = abs(
                    (w / cols) / (h / rows) - 1.0
                )  # How far from square
                if aspect_diff < min_aspect_diff:
                    min_aspect_diff = aspect_diff
                    best_rows, best_cols = rows, cols

        # If no perfect division, try close approximations
        if best_rows == 1 and target_count > 4:
            # Try common layouts
            if target_count <= 6:
                best_rows, best_cols = 2, 3
            elif target_count <= 9:
                best_rows, best_cols = 3, 3
            else:
                best_rows = int(target_count**0.5)
                best_cols = (target_count + best_rows - 1) // best_rows

        if self.debug:
            print(
                f"Creating {best_rows}x{best_cols} grid for {target_count} subplots with {expand_factor * 100}% expansion"
            )

        subplots = []
        cell_h = (
            h / best_rows
        )  # Use float division for more precise calculations
        cell_w = w / best_cols

        count = 0
        for i in range(best_rows):
            for j in range(best_cols):
                if count >= target_count:
                    break

                # Calculate core cell boundaries
                core_y1 = i * cell_h
                core_y2 = (i + 1) * cell_h
                core_x1 = j * cell_w
                core_x2 = (j + 1) * cell_w

                # Calculate expansion amounts (30% of cell size on each side by default)
                expand_h = cell_h * expand_factor
                expand_w = cell_w * expand_factor

                # Expand boundaries by expansion factor
                y1 = max(0, int(core_y1 - expand_h))
                y2 = min(h, int(core_y2 + expand_h))
                x1 = max(0, int(core_x1 - expand_w))
                x2 = min(w, int(core_x2 + expand_w))

                if self.debug:
                    print(
                        f"Subplot {count + 1}: core=({int(core_x1)},{int(core_y1)},{int(core_x2)},{int(core_y2)}) "
                        f"expanded=({x1},{y1},{x2},{y2})"
                    )

                subplot = img[y1:y2, x1:x2]
                subplots.append(
                    {
                        "image": subplot,
                        "bbox": (x1, y1, x2, y2),
                        "method": f"expanded_grid_{best_rows}x{best_cols}_exp{int(expand_factor * 100)}%",
                    }
                )

                count += 1

            if count >= target_count:
                break

        return subplots

    def visualize_detection(self, image_path, subplots):
        """Visualize detected subplots for debugging"""
        img = cv2.imread(image_path)
        img_display = img.copy()

        # Draw bounding boxes
        colors = [
            (255, 0, 0),
            (0, 255, 0),
            (0, 0, 255),
            (255, 255, 0),
            (255, 0, 255),
            (0, 255, 255),
            (128, 0, 128),
            (255, 165, 0),
        ]

        for i, subplot_data in enumerate(subplots):
            x1, y1, x2, y2 = subplot_data["bbox"]
            color = colors[i % len(colors)]
            cv2.rectangle(img_display, (x1, y1), (x2, y2), color, 3)
            cv2.putText(
                img_display,
                f"{i + 1}",
                (x1 + 10, y1 + 30),
                cv2.FONT_HERSHEY_SIMPLEX,
                1,
                color,
                2,
            )

        cv2.imwrite("debug_detection.png", img_display)
        print("Saved detection visualization as 'debug_detection.png'")

    def save_subplots(self, subplots, output_prefix="subplot"):
        """Save detected subplots to files"""
        saved_files = []

        for i, subplot_data in enumerate(subplots):
            filename = f"{output_prefix}_{i:02d}.png"
            cv2.imwrite(filename, subplot_data["image"])
            saved_files.append(filename)

            print(
                f"Saved {filename} - Method: {subplot_data['method']}, "
                f"BBox: {subplot_data['bbox']}"
            )

        return saved_files


# Usage example
def auto_split_subplots(
    image_path,
    output_prefix="subplot",
    debug=True,
    num_subplots=0,
    expand_factor=0.3,
):
    """Main function to automatically split subplots with debugging"""
    # Expand and validate path
    image_path = os.path.expanduser(image_path)
    image_path = str(Path(image_path).resolve())

    # Use larger minimum area and enable debugging
    splitter = AutoSubplotSplitter(
        min_subplot_area=20000,
        padding=15,
        debug=debug,
        num_subplots=num_subplots,
        expand_factor=expand_factor,
    )

    print(f"Processing: {image_path}")
    if num_subplots > 0:
        print(
            f"Target: {num_subplots} subplots with {int(expand_factor * 100)}% expansion"
        )

    try:
        subplots = splitter.detect_subplots(image_path)
    except (FileNotFoundError, ValueError) as e:
        print(f"Error: {e}")
        return []

    if not subplots:
        print("No subplots detected!")
        return []

    print(f"Detected {len(subplots)} subplots")

    # Show detection visualization if debugging
    if debug and subplots:
        splitter.visualize_detection(image_path, subplots)

    saved_files = splitter.save_subplots(subplots, output_prefix)

    return saved_files

In [None]:
# Tight crop: 10% expansion (minimal context)
files = auto_split_subplots(
    "~/Downloads/nh3_cracking_1.png",
    "test",
    debug=True,
    expand_factor=0.1,
    num_subplots=4,
)