In [None]:
"""
danielsinkin97@gmail.com

This module computes and visualizes signed Histogram of Oriented Gradients (HOG) energy
for individual 8x8 cells in a grayscale image. It smooths the image, computes Sobel
gradients, converts them to magnitudes and angles, bins the signed orientations over
[0°, 360°) into a fixed number of bins, and renders a four-panel visualization:
(1) the full image with the selected cell highlighted, (2) the 8x8 cell patch,
(3) the magnitude-weighted orientation histogram, and (4) a normalized polar
"rose" plot showing directionality. An interactive UI with sliders lets you choose
which cell to inspect.
"""

from pathlib import Path

import ipywidgets as widgets
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import display
from matplotlib.patches import Rectangle

from computer_vision.util.images import (
    load_image_as_array,
    plot_grayscale,
    rgb_to_grayscale,
)
from computer_vision.src.filter import apply_filter, get_filter


image_fp = Path("data").joinpath("fennec.png")

sobel_x = get_filter("sobel_x")
sobel_y = get_filter("sobel_y")


def compute_signed_bins_for_cell(
    magnitude: np.ndarray,
    angles_deg: np.ndarray,
    patch_y: int,
    patch_x: int,
    cell_size: int = 8,
    num_bins: int = 9,
) -> tuple[np.ndarray, tuple[float, ...], int, int, slice, slice]:
    """
    Aggregate gradient magnitudes into signed orientation bins for a single cell.

    Parameters
    ----------
    magnitude : np.ndarray
        Gradient magnitude image of shape (H, W).
    angles_deg : np.ndarray
        Gradient angles in degrees, same shape as `magnitude`. Values may be any real
        numbers; they will be wrapped into [0, 360).
    patch_y : int
        Cell row index (top to bottom) in cell coordinates.
    patch_x : int
        Cell column index (left to right) in cell coordinates.
    cell_size : int, optional
        Cell side length in pixels.
    num_bins : int, optional
        Number of orientation bins spanning [0°, 360°).

    Returns
    -------
    bin_magnitudes : np.ndarray
        Length-`num_bins` array with magnitude sums per orientation bin.
    bin_edges : tuple[float, ...]
        The `num_bins + 1` bin edges over [0°, 360°].
    py : int
        Clamped cell row index actually used.
    px : int
        Clamped cell column index actually used.
    sy : slice
        Pixel slice for rows of the selected cell.
    sx : slice
        Pixel slice for columns of the selected cell.
    """
    H, W = magnitude.shape
    max_py = H // cell_size - 1
    max_px = W // cell_size - 1
    py = int(np.clip(patch_y, 0, max_py))
    px = int(np.clip(patch_x, 0, max_px))

    sy = slice(py * cell_size, (py + 1) * cell_size)
    sx = slice(px * cell_size, (px + 1) * cell_size)

    mag_patch = magnitude[sy, sx]
    ang_patch = np.mod(angles_deg[sy, sx], 360.0)

    bin_edges = np.linspace(0.0, 360.0, num_bins + 1, endpoint=True)
    bin_width = 360.0 / num_bins

    bins = np.floor(ang_patch / bin_width).astype(int) % num_bins
    bin_magnitudes = np.bincount(
        bins.ravel(), weights=mag_patch.ravel(), minlength=num_bins
    ).astype(float)

    return bin_magnitudes, tuple(bin_edges), py, px, sy, sx


def plot_hog_bin_arrows(
    bin_magnitudes,
    bin_edges,
    ax: plt.Axes | None = None,
    title: str = "HoG orientation energy (0..360°, normalized)",
    zero_tol: float = 0.0,
    guide_alpha: float = 0.25,
    guide_linestyle: str = "--",
    guide_linewidth: float = 1.0,
):
    """
    Render a normalized polar "rose" with arrows at bin centers proportional to energy.

    Parameters
    ----------
    bin_magnitudes : array-like
        Magnitude per orientation bin.
    bin_edges : array-like
        Bin edges over [0°, 360°] with length `len(bin_magnitudes) + 1`.
    ax : plt.Axes | None, optional
        Axes to draw on. If None, a new figure and axes are created.
    title : str, optional
        Title for the plot.
    zero_tol : float, optional
        Values ≤ `zero_tol` are considered zero and not drawn.
    guide_alpha : float, optional
        Transparency for radial guide lines.
    guide_linestyle : str, optional
        Line style for radial guides.
    guide_linewidth : float, optional
        Line width for radial guides.

    Returns
    -------
    None
    """
    mags = np.asarray(bin_magnitudes, dtype=float)
    edges = np.asarray(bin_edges, dtype=float)
    assert edges.ndim == 1 and mags.ndim == 1
    assert len(edges) == len(mags) + 1

    centers_deg = 0.5 * (edges[:-1] + edges[1:])
    centers_rad = np.deg2rad(centers_deg)

    draw_mask = mags > zero_tol
    if np.any(draw_mask):
        scale = mags[draw_mask].max()
        mags_n = np.where(draw_mask, mags / scale, 0.0)
    else:
        mags_n = np.zeros_like(mags)

    xs = mags_n * np.cos(centers_rad)
    ys = mags_n * np.sin(centers_rad)

    created_fig = False
    if ax is None:
        _, ax = plt.subplots(figsize=(5, 5))
        created_fig = True

    ax.set_aspect("equal", adjustable="box")
    unit_circle = plt.Circle((0.0, 0.0), 1.0, fill=False)
    ax.add_artist(unit_circle)

    ax.set_xlim(-1.1, 1.1)
    ax.set_ylim(-1.1, 1.1)
    fixed_ticks = [-1.0, -0.5, 0.0, 0.5, 1.0]
    ax.set_xticks(fixed_ticks)
    ax.set_yticks(fixed_ticks)

    for th in centers_rad:
        ax.plot(
            [0.0, np.cos(th)],
            [0.0, np.sin(th)],
            linestyle=guide_linestyle,
            linewidth=guide_linewidth,
            alpha=guide_alpha,
            color="black",
        )

    head_w = 0.03
    head_l = 0.06
    for x, y, m in zip(xs, ys, mags_n):
        if m <= zero_tol:
            continue
        ax.arrow(
            0.0,
            0.0,
            x,
            y,
            length_includes_head=True,
            head_width=head_w,
            head_length=head_l,
        )

    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.set_title(title)

    if created_fig:
        plt.show()


def visualize_hog_cell_row(
    image: np.ndarray,
    magnitude: np.ndarray,
    angles_deg: np.ndarray,
    patch_y: int,
    patch_x: int,
    cell_size: int = 8,
    num_bins: int = 9,
    zero_tol: float = 0.0,
):
    """
    Draw a four-panel visualization for a selected cell.

    Parameters
    ----------
    image : np.ndarray
        Grayscale image (H, W), normalized to [0, 1] recommended.
    magnitude : np.ndarray
        Gradient magnitude, same shape as `image`.
    angles_deg : np.ndarray
        Gradient angles in degrees, same shape as `image`.
    patch_y : int
        Cell row index in cell coordinates.
    patch_x : int
        Cell column index in cell coordinates.
    cell_size : int, optional
        Cell side length in pixels.
    num_bins : int, optional
        Number of orientation bins across [0°, 360°).
    zero_tol : float, optional
        Threshold for suppressing near-zero arrows in the rose plot.

    Returns
    -------
    None
    """
    assert image.ndim == 2, "Expected a grayscale image (H, W)."

    bin_mags, bin_edges, py, px, sy, sx = compute_signed_bins_for_cell(
        magnitude, angles_deg, patch_y, patch_x, cell_size=cell_size, num_bins=num_bins
    )

    fig, axes = plt.subplots(1, 4, figsize=(18, 4))
    ax_img, ax_patch, ax_hist, ax_rose = axes

    ax_img.imshow(image, cmap="gray")
    rect = Rectangle(
        (sx.start, sy.start),
        width=sx.stop - sx.start,
        height=sy.stop - sy.start,
        fill=False,
        edgecolor=(1.0, 0.0, 0.0),
        linewidth=2.0,
    )
    ax_img.add_patch(rect)
    ax_img.set_title(f"Full image (Patch(y={py}, x={px}))")
    ax_img.axis("off")

    ax_patch.imshow(image[sy, sx], cmap="gray")
    ax_patch.set_title("Cell patch (8x8)")
    ax_patch.axis("off")

    centers = 0.5 * (np.asarray(bin_edges[:-1]) + np.asarray(bin_edges[1:]))
    bin_width = bin_edges[1] - bin_edges[0]
    ax_hist.bar(centers, bin_mags, width=0.9 * bin_width, align="center")
    ax_hist.set_xlim(0, 360)
    ax_hist.set_xticks(np.arange(0, 361, bin_width))
    ax_hist.set_xlabel("Orientation (deg)")
    ax_hist.set_ylabel("Magnitude sum")
    ax_hist.set_title(f"Histogram ({num_bins} bins)")

    plot_hog_bin_arrows(
        bin_magnitudes=bin_mags,
        bin_edges=bin_edges,
        ax=ax_rose,
        title="Gradient Contributions",
        zero_tol=zero_tol,
    )

    plt.tight_layout()
    plt.show()


def interactive_hog_four_panel(
    image: np.ndarray,
    magnitude: np.ndarray,
    angles_deg: np.ndarray,
    cell_size: int = 8,
    num_bins: int = 9,
    zero_tol: float = 0.0,
):
    """
    Create an interactive two-slider UI to select a cell and render the four panels.

    Parameters
    ----------
    image : np.ndarray
        Grayscale image (H, W).
    magnitude : np.ndarray
        Gradient magnitude, same shape as `image`.
    angles_deg : np.ndarray
        Gradient angles in degrees, same shape as `image`.
    cell_size : int, optional
        Cell side length in pixels.
    num_bins : int, optional
        Number of orientation bins across [0°, 360°).
    zero_tol : float, optional
        Threshold for suppressing near-zero arrows in the rose plot.

    Returns
    -------
    None
    """
    H, W = image.shape
    max_py = H // cell_size - 1
    max_px = W // cell_size - 1

    py_slider = widgets.IntSlider(
        description="patch_y", min=0, max=max_py, step=1, value=0
    )
    px_slider = widgets.IntSlider(
        description="patch_x", min=0, max=max_px, step=1, value=0
    )

    ui = widgets.HBox([py_slider, px_slider])

    out = widgets.interactive_output(
        lambda patch_y, patch_x: visualize_hog_cell_row(
            image=image,
            magnitude=magnitude,
            angles_deg=angles_deg,
            patch_y=patch_y,
            patch_x=patch_x,
            cell_size=cell_size,
            num_bins=num_bins,
            zero_tol=zero_tol,
        ),
        {"patch_y": py_slider, "patch_x": px_slider},
    )

    display(ui, out)


def main() -> None:
    """
    Load the image, smooth it, compute Sobel gradients, derive magnitudes and angles,
    and launch the interactive four-panel HOG visualization for 8x8 cells.

    Returns
    -------
    None
    """
    image = rgb_to_grayscale(load_image_as_array(image_fp))
    image = apply_filter(image, get_filter("gaussian_15x15"), pad_same_size=True)

    image /= 255.0

    x, y = image.shape
    assert y >= x
    delta: int = (y - x) // 2
    image_square = image[:x, delta : y - delta]
    x_square, y_square = image_square.shape
    assert x_square == y_square
    image = np.sqrt(image_square)

    grad_x = apply_filter(image, sobel_x, pad_same_size=True)
    grad_y = apply_filter(image, sobel_y, pad_same_size=True)
    magnitude = np.hypot(grad_x, grad_y)
    angles_rad = np.atan2(grad_y, grad_x)
    angles_deg = np.degrees(angles_rad)

    interactive_hog_four_panel(
        image=image,
        magnitude=magnitude,
        angles_deg=angles_deg,
        cell_size=8,
        num_bins=9,
        zero_tol=0.0,
    )


if __name__ == "__main__":
    main()