## Imports


In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from collections import Counter
from functools import partial

import matplotlib.pyplot as plt
import numpy as np
from loguru import logger as lg
from rich import print as rprint
from rich.console import Console
from rich import get_console

# Rich Jupyter fix
console: Console = get_console()
console.is_jupyter = False

In [None]:
from snap_fit.config.aruco.aruco_board_config import ArucoBoardConfig
from snap_fit.config.aruco.aruco_detector_config import ArucoDetectorConfig
from snap_fit.params.snap_fit_params import get_snap_fit_paths
from snap_fit.puzzle.sheet_aruco import SheetAruco
from snap_fit.puzzle.sheet_manager import SheetManager

## Helper: Load Puzzle Data


In [None]:
def load_puzzle_sheets(puzzle_name: str) -> SheetManager:
    """Load puzzle sheets for a given puzzle name (e.g., 'sample_puzzle_v1')."""
    # Configure ArUco detection
    board_config = ArucoBoardConfig(
        markers_x=7,
        markers_y=5,
        marker_length=100,
        marker_separation=100,
    )
    detector_config = ArucoDetectorConfig(board=board_config)

    # Initialize sheet loader
    sheet_aruco = SheetAruco(detector_config)
    aruco_loader = partial(sheet_aruco.load_sheet, min_area=5_000)

    # Load puzzle sheets
    paths = get_snap_fit_paths()
    data_dir = paths.data_fol / puzzle_name / "sheets"
    lg.info(f"Loading data from {data_dir}")

    manager = SheetManager()
    manager.add_sheets(
        folder_path=data_dir,
        pattern="*.png",
        loader_func=aruco_loader,
    )

    lg.info(
        f"Loaded {len(manager.get_sheets_ls())} sheets with "
        f"{len(manager.get_pieces_ls())} pieces"
    )
    return manager

## Load v1 and v2 Puzzle Data


In [None]:
manager_v1 = load_puzzle_sheets("sample_puzzle_v1")

In [None]:
manager_v2 = load_puzzle_sheets("sample_puzzle_v2")

In [None]:
# Compute expected segment counts for a 6x8 puzzle
ROWS, COLS = 6, 8

total_pieces = ROWS * COLS  # 48
total_segments = total_pieces * 4  # 192

# EDGE segments are on the puzzle boundary (perimeter)
perimeter_segments = 2 * (ROWS + COLS)  # 28

# Internal segments should be IN or OUT (each internal edge has one of each)
internal_horizontal_edges = ROWS * (COLS - 1)  # 42
internal_vertical_edges = (ROWS - 1) * COLS  # 40
total_internal_edges = internal_horizontal_edges + internal_vertical_edges  # 82

# Each internal edge contributes one IN and one OUT segment
expected_in = total_internal_edges  # 82
expected_out = total_internal_edges  # 82
expected_edge = perimeter_segments  # 28
expected_weird = 0  # ideally none!

rprint("[bold]Expected Segment Distribution for 6x8 Puzzle[/bold]")
rprint(f"  Total segments: {total_segments}")
rprint(
    f"  EDGE (boundary): {expected_edge} ({100 * expected_edge / total_segments:.1f}%)"
)
rprint(f"  IN:   {expected_in} ({100 * expected_in / total_segments:.1f}%)")
rprint(f"  OUT:  {expected_out} ({100 * expected_out / total_segments:.1f}%)")
rprint(f"  WEIRD (target): {expected_weird} (0.0%)")

## Analyze Shape Distribution


In [None]:
def analyze_shapes(manager: SheetManager, label: str) -> Counter:
    """Count segment shapes across all pieces."""
    shapes = []
    for piece in manager.get_pieces_ls():
        for seg in piece.segments.values():
            shapes.append(seg.shape)

    counts = Counter(shapes)
    total = sum(counts.values())

    rprint(f"\n[bold]{label}[/bold]")
    rprint(f"  Total segments: {total}")
    for shape, count in counts.most_common():
        pct = 100 * count / total
        rprint(f"  {shape.name:>6}: {count:>3} ({pct:5.1f}%)")

    return counts

In [None]:
counts_v1 = analyze_shapes(manager_v1, "sample_puzzle_v1")
counts_v2 = analyze_shapes(manager_v2, "sample_puzzle_v2")

## Visualize WEIRD Segments

Let's look at segments classified as WEIRD to understand the failure mode.


In [None]:
from snap_fit.image.segment import SegmentShape


def get_weird_segments(manager: SheetManager, max_count: int = 10):
    """Get a sample of WEIRD segments for visualization."""
    weird_segments = []
    for piece in manager.get_pieces_ls():
        for seg_id, seg in piece.segments.items():
            if seg.shape == SegmentShape.WEIRD:
                weird_segments.append((piece.piece_id, seg_id, seg))
                if len(weird_segments) >= max_count:
                    return weird_segments
    return weird_segments


weird_v1 = get_weird_segments(manager_v1, max_count=6)
rprint(f"Found {len(weird_v1)} WEIRD segments to analyze")

In [None]:
def plot_segment_points(seg, ax, title: str):
    """Plot segment points showing the classification problem."""
    # seg.points has shape (N, 1, 2) - OpenCV contour format
    pts = seg.points.squeeze()  # Now (N, 2)

    # Plot raw points
    ax.plot(pts[:, 0], pts[:, 1], "b-", linewidth=2, label="Segment")
    ax.scatter(pts[0, 0], pts[0, 1], c="green", s=100, zorder=5, label="Start")
    ax.scatter(pts[-1, 0], pts[-1, 1], c="red", s=100, zorder=5, label="End")

    ax.set_title(f"{title}\nShape: {seg.shape.name}")
    ax.set_aspect("equal")
    ax.legend(fontsize=8)
    ax.grid(True, alpha=0.3)

In [None]:
# Visualize WEIRD segments
if weird_v1:
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    axes = axes.flatten()

    for i, (piece_id, seg_id, seg) in enumerate(weird_v1[:6]):
        plot_segment_points(seg, axes[i], f"Piece {piece_id}, Seg {seg_id}")

    plt.tight_layout()
    plt.show()

## Analyze Point Distribution (Understanding the Classification)

The current algorithm counts points beyond `flat_th=20` on each side of the center line.
Let's see the actual distributions.


In [None]:
def analyze_segment_distribution(seg):
    """Analyze segment point distribution relative to center line.

    Returns dict with statistics about point distribution.
    """
    # seg.points has shape (N, 1, 2) - OpenCV contour format
    pts = seg.points.squeeze()  # Now (N, 2)

    # Transform to align start-end with x-axis (mimicking _compute_shape)
    start, end = pts[0], pts[-1]
    direction = end - start
    length = np.linalg.norm(direction)

    if length < 1e-6:
        return None

    # Rotation to align with x-axis
    angle = np.arctan2(direction[1], direction[0])
    cos_a, sin_a = np.cos(-angle), np.sin(-angle)

    # Transform points
    centered = pts - start
    rotated = np.column_stack(
        [
            centered[:, 0] * cos_a - centered[:, 1] * sin_a,
            centered[:, 0] * sin_a + centered[:, 1] * cos_a,
        ]
    )

    # Y values are perpendicular distances from center line
    y_vals = rotated[:, 1]

    # Current thresholds
    flat_th = 20
    count_th = 5

    out_count = (y_vals < -flat_th).sum()  # Points below line
    in_count = (y_vals > flat_th).sum()  # Points above line

    return {
        "shape": seg.shape,
        "y_vals": y_vals,
        "mean_y": np.mean(y_vals),
        "std_y": np.std(y_vals),
        "min_y": np.min(y_vals),
        "max_y": np.max(y_vals),
        "out_count": out_count,
        "in_count": in_count,
        "is_weird": out_count > count_th and in_count > count_th,
        "signed_area": np.trapz(y_vals),
        "length": length,
    }

In [None]:
# Analyze all segments
all_stats_v1 = []
for piece in manager_v1.get_pieces_ls():
    for seg_id, seg in piece.segments.items():
        stats = analyze_segment_distribution(seg)
        if stats:
            stats["piece_id"] = piece.piece_id
            stats["seg_id"] = seg_id
            all_stats_v1.append(stats)

rprint(f"Analyzed {len(all_stats_v1)} segments")

In [None]:
# Compare distributions by shape type
from snap_fit.image.segment import SegmentShape

fig, axes = plt.subplots(2, 2, figsize=(12, 10))

shape_types = [SegmentShape.IN, SegmentShape.OUT, SegmentShape.EDGE, SegmentShape.WEIRD]
colors = ["blue", "green", "orange", "red"]

for ax, shape, color in zip(axes.flatten(), shape_types, colors):
    shape_stats = [s for s in all_stats_v1 if s["shape"] == shape]

    if shape_stats:
        mean_ys = [s["mean_y"] for s in shape_stats]
        signed_areas = [s["signed_area"] for s in shape_stats]

        ax.scatter(mean_ys, signed_areas, c=color, alpha=0.6, s=50)
        ax.axhline(y=0, color="gray", linestyle="--", alpha=0.5)
        ax.axvline(x=0, color="gray", linestyle="--", alpha=0.5)

    ax.set_title(f"{shape.name} (n={len(shape_stats)})")
    ax.set_xlabel("Mean Y (perpendicular displacement)")
    ax.set_ylabel("Signed Area")
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## Compare Classification Approaches

Let's test alternative classification methods on the existing data.


In [None]:
def classify_by_mean(stats, threshold=10):
    """Option B: Classify by mean displacement."""
    mean_y = stats["mean_y"]
    if mean_y < -threshold:
        return "OUT"
    elif mean_y > threshold:
        return "IN"
    else:
        return "EDGE"


def classify_by_area(stats, threshold=500):
    """Option B: Classify by signed area."""
    area = stats["signed_area"]
    if area < -threshold:
        return "OUT"
    elif area > threshold:
        return "IN"
    else:
        return "EDGE"


def classify_adaptive(stats):
    """Option A: Adaptive thresholds based on segment stats."""
    y_vals = stats["y_vals"]
    flat_th = max(10, np.std(y_vals) * 1.5)
    count_th = max(3, len(y_vals) * 0.05)

    out_count = (y_vals < -flat_th).sum()
    in_count = (y_vals > flat_th).sum()

    # Convert numpy bools to Python bools for pattern matching
    is_out = bool(out_count > count_th)
    is_in = bool(in_count > count_th)

    match (is_out, is_in):
        case (True, False):
            return "OUT"
        case (False, True):
            return "IN"
        case (False, False):
            return "EDGE"
        case (True, True):
            return "WEIRD"

In [None]:
# Compare classification methods
methods = {
    "Current": lambda s: s["shape"].name,
    "Mean (th=10)": lambda s: classify_by_mean(s, threshold=10),
    "Mean (th=15)": lambda s: classify_by_mean(s, threshold=15),
    "Area (th=50)": lambda s: classify_by_area(s, threshold=50),
    "Area (th=500)": lambda s: classify_by_area(s, threshold=500),
    "Area (th=1000)": lambda s: classify_by_area(s, threshold=1000),
    "Adaptive": classify_adaptive,
}

rprint("\n[bold]Classification Method Comparison (v1)[/bold]")
if not all_stats_v1:
    rprint("[red]No stats available - rerun previous cells[/red]")
else:
    for method_name, classify_fn in methods.items():
        classifications = [classify_fn(s) for s in all_stats_v1]
        counts = Counter(classifications)
        total = len(classifications)
        weird_pct = 100 * counts.get("WEIRD", 0) / total if total > 0 else 0
        rprint(
            f"  {method_name:15s}: WEIRD={counts.get('WEIRD', 0):>3} ({weird_pct:5.1f}%), "
            f"IN={counts.get('IN', 0):>3}, OUT={counts.get('OUT', 0):>3}, EDGE={counts.get('EDGE', 0):>3}"
        )

## Summary & Next Steps

Based on this exploration:

- Current WEIRD count: ???
- Best alternative method: ???

**Decision needed**: Which approach to implement?
