# Spatial Transform 

This notebook upright-aligns and centers using core point alignment for flash and non-flash images

In [1]:
import numpy as np
import cv2
import matplotlib.pyplot as plt
import PIL.Image as Image

In [2]:
def detect_core_point(img_gray, img_gabor=None, block_size=16, coherence_thresh=0.3):
    """
    Detect core point (x, y) in a contactless fingerprint image.

    Inputs:
        img_gray : 2D grayscale fingerprint image (uint8)
        img_gabor: optional 2D binary ridge map (0=ridges, 255=background)
        block_size: size of blocks for orientation field (e.g., 16)
        coherence_thresh: minimal coherence for reliable orientation blocks

    Returns:
        (x_core, y_core): pixel coordinates of detected core point (float)
        or None if no reliable core is found.
    """

    h, w = img_gray.shape
    img_norm = img_gray  # already enhanced

    # --- Compute orientation and coherence ---
    num_blk_x = w // block_size
    num_blk_y = h // block_size
    orient = np.zeros((num_blk_y, num_blk_x))
    coherence = np.zeros((num_blk_y, num_blk_x))

    for by in range(num_blk_y):
        for bx in range(num_blk_x):
            block = img_norm[by*block_size:(by+1)*block_size,
                             bx*block_size:(bx+1)*block_size]
            gx = cv2.Sobel(block, cv2.CV_64F, 1, 0, ksize=3)
            gy = cv2.Sobel(block, cv2.CV_64F, 0, 1, ksize=3)

            Vx = 2 * np.sum(gx * gy)
            Vy = np.sum(gx*gx - gy*gy)
            theta = 0.5 * np.arctan2(Vx, Vy)
            orient[by, bx] = theta

            numerator = np.sqrt(Vx*Vx + Vy*Vy)
            denominator = np.sum(gx*gx + gy*gy) + 1e-8
            coherence[by, bx] = numerator / denominator

    reliable = coherence > coherence_thresh

    # --- Compute Poincaré index ---
    poincare = np.full((num_blk_y, num_blk_x), np.nan)
    for by in range(1, num_blk_y - 1):
        for bx in range(1, num_blk_x - 1):
            if not reliable[by, bx]:
                continue
            neigh = orient[by-1:by+2, bx-1:bx+2].flatten()
            diffs = []
            for i in range(len(neigh) - 1):
                d = neigh[i+1] - neigh[i]
                if d >= np.pi/2:
                    d -= np.pi
                elif d <= -np.pi/2:
                    d += np.pi
                diffs.append(d)
            d = neigh[0] - neigh[-1]
            if d >= np.pi/2:
                d -= np.pi
            elif d <= -np.pi/2:
                d += np.pi
            diffs.append(d)
            P = sum(diffs) / (2 * np.pi)
            poincare[by, bx] = P

    # --- Candidate mask for core regions ---
    candidate_mask = (poincare > 0.45) & reliable
    if not np.any(candidate_mask):
        return None

    # --- Region restriction: ignore top/bottom ---
    mask_y, mask_x = np.mgrid[0:num_blk_y, 0:num_blk_x]
    y_frac = mask_y / num_blk_y
    region_mask = (y_frac > 0.20) & (y_frac < 0.85)

    # --- Compute score for each block ---
    center_x, center_y = num_blk_x / 2, num_blk_y / 2
    dist_center = np.sqrt((mask_x - center_x)**2 + (mask_y - center_y)**2)
    score = coherence * candidate_mask * region_mask / (1 + 0.05 * dist_center)

    if not np.any(score):
        return None

    best_idx = np.argmax(score)
    by_best, bx_best = np.unravel_index(best_idx, coherence.shape)
    y_core = (by_best + 0.5) * block_size
    x_core = (bx_best + 0.5) * block_size

    # --- Optional refinement using binary Gabor map ---
    if img_gabor is not None:
        win = block_size * 2
        x0 = int(max(0, x_core - win))
        y0 = int(max(0, y_core - win))
        x1 = int(min(w, x_core + win))
        y1 = int(min(h, y_core + win))
        sub = img_gabor[y0:y1, x0:x1]

        sub_norm = sub.astype(np.float32) / 255.0
        ksize = max(3, block_size // 4)
        mean_map = cv2.blur(sub_norm, (ksize, ksize))
        min_loc = np.unravel_index(np.argmin(mean_map), mean_map.shape)
        y_core = y0 + min_loc[0]
        x_core = x0 + min_loc[1]

    return (x_core, y_core)

def auto_upright_finger_canny(img_gray, img_flash, img_non_flash, debug=False):
    """
    Auto-rotates a fingertip image upright based on its outer contour.
    Detects and fits lines to the nearly vertical sides, ignoring curved fingertip and base.
    """

    # --- 1. Edge Detection ---
    blur = cv2.GaussianBlur(img_gray, (19, 19), 0)

    mask = (img_gray < 240).astype(np.uint8) * 255  # pixels darker than near-white
    mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, np.ones((7, 7), np.uint8))
    edges = mask

    if debug:
        plt.imshow(blur, cmap='gray')
        plt.title("Blurred Image")
        plt.show()
        plt.imshow(edges, cmap='gray')
        plt.title("Canny Edges")
        plt.show()

    # --- 2. Largest Contour ---
    contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
    if not contours:
        return img_gray
    contour = max(contours, key=cv2.contourArea)
    contour = contour.reshape(-1, 2)

    if debug:
        vis = np.zeros_like(img_gray)
        cv2.drawContours(vis, [contour], -1, 255, 2)
        plt.imshow(vis, cmap='gray')
        plt.title("Largest Contour")
        plt.show()

    # --- 3. Restrict to vertical middle region ---
    y_min, y_max = np.min(contour[:, 1]), np.max(contour[:, 1])
    y1 = int(y_min + 0.4 * (y_max - y_min))
    y2 = int(y_min + 0.8 * (y_max - y_min))
    mid_mask = (contour[:, 1] > y1) & (contour[:, 1] < y2)
    contour_mid = contour[mid_mask]

    if contour_mid.size == 0:
        contour_mid = contour

    # --- 4. Split into left and right sides ---
    x_min, _, w, _ = cv2.boundingRect(contour)
    x_mid = x_min + w // 2
    left_pts = contour_mid[contour_mid[:, 0] < x_mid]
    right_pts = contour_mid[contour_mid[:, 0] >= x_mid]

    # --- 5. Fit lines to side segments ---
    def fit_line_points(pts):
        if len(pts) < 2:
            return None, None
        [vx, vy, x0, y0] = cv2.fitLine(pts, cv2.DIST_L2, 0, 0.01, 0.01)
        vx, vy, x0, y0 = float(vx), float(vy), float(x0), float(y0)
        angle = np.degrees(np.arctan2(vy, vx))
        return (vx, vy, x0, y0), angle

    line_left, angle_left = fit_line_points(left_pts)
    line_right, angle_right = fit_line_points(right_pts)

    valid_angles = [a for a in [angle_left, angle_right] if a is not None]
    if not valid_angles:
        return img_gray

    mean_angle = 0
    neg = False
    max_diff = 0
    num = 0

    for ang in valid_angles:
        if ang < 0:
            diff = 90 - np.abs(ang)
            if diff > max_diff:
                max_diff = diff
                neg = True
        else:
            diff = 90 - ang
            if diff > max_diff:
                max_diff = diff
                neg = False
        mean_angle += diff
    mean_angle /= len(valid_angles)

    if neg:
        mean_angle = -mean_angle

    # --- 6. Visualization ---
    if debug:
        vis = cv2.cvtColor(img_gray, cv2.COLOR_GRAY2BGR)
        cv2.drawContours(vis, [contour], -1, (0, 255, 0), 1)
        cv2.rectangle(vis, (0, y1), (img_gray.shape[1], y2), (255, 255, 0), 2)

        def draw_line(vis, line_params, color):
            if line_params is None:
                return
            vx, vy, x0, y0 = line_params
            y_top = 0
            y_bottom = vis.shape[0]
            x_top = int(x0 - (y0 - y_top) * (vx / vy))
            x_bottom = int(x0 + (y_bottom - y0) * (vx / vy))
            cv2.line(vis, (x_top, y_top), (x_bottom, y_bottom), color, 2)

        draw_line(vis, line_left, (255, 0, 0))   # Blue = left
        draw_line(vis, line_right, (0, 0, 255))  # Red = right

        if angle_left is None:
            angle_left = 0.0
        if angle_right is None:
            angle_right = 0.0


        plt.imshow(vis[..., ::-1])
        plt.title(f"Fitted Lines (mid region) | Left: {angle_left:.2f}°, Right: {angle_right:.2f}°")
        plt.show()

    # --- 7. Rotate image upright ---
    h, w = img_gray.shape
    center = (w // 2, h // 2)
    rot_mat = cv2.getRotationMatrix2D(center, -mean_angle, 1.0)
    rotated = cv2.warpAffine(img_gray, rot_mat, (w, h),
                             flags=cv2.INTER_LINEAR, borderValue=255)
    rotated_flash = cv2.warpAffine(img_flash, rot_mat, (w, h),
                             flags=cv2.INTER_LINEAR, borderValue=(255, 255, 255))
    rotated_non_flash = cv2.warpAffine(img_non_flash, rot_mat, (w, h),
                             flags=cv2.INTER_LINEAR, borderValue=(255, 255, 255))

    if debug:
        plt.imshow(rotated, cmap='gray')
        plt.title(f"Rotated by {mean_angle:.2f}° (Upright)")
        plt.show()

    return rotated, rotated_flash, rotated_non_flash

def center_on_point(img_gray, img_flash, img_non_flash, x, y):
    """
    Shifts the image so that point (x, y) becomes the new center.
    Any empty regions are filled with black.
    Returns (shifted_image, new_x, new_y).
    """
    if x is None or y is None:
        return img_gray, x, y  # no translation if no point

    h, w = img_gray.shape
    cx, cy = w // 2, h // 2  # desired center position

    # Compute translation offsets
    dx = int(cx - x)
    dy = int(cy - y)

    # Translation matrix
    M = np.float32([[1, 0, dx],
                    [0, 1, dy]])

    # Apply translation
    shifted = cv2.warpAffine(img_gray, M, (w, h), flags=cv2.INTER_LINEAR, borderValue=255)
    shifted_flash = cv2.warpAffine(img_flash, M, (w, h), flags=cv2.INTER_LINEAR, borderValue=(255, 255, 255))
    shifted_non_flash = cv2.warpAffine(img_non_flash, M, (w, h), flags=cv2.INTER_LINEAR, borderValue=(255, 255, 255))
    
    # New coordinates of the core point
    new_x = x + dx
    new_y = y + dy

    return shifted, shifted_flash, shifted_non_flash, new_x, new_y

# Example

In [None]:
img_flash_path = "path_to_flash_image"
img_non_flash_path = "path_to_non_flash_image"

In [None]:
img_flash = Image.open(img_flash_path)
img_non_flash = Image.open(img_non_flash_path)
img_gray = Image.open(img_flash_path).convert("L")
img_gray_rot, img_flash_rot, img_non_flash_rot = auto_upright_finger_canny(img_gray, img_flash, img_non_flash, debug=False)
result = detect_core_point(img_gray_rot, None, block_size=8, coherence_thresh=0.3)

if result is None:
    x, y = None, None
    image_centered_gray, image_centered_flash, image_centered_non_flash = img_gray_rot, img_flash_rot, img_non_flash_rot

else:
    x, y = result
    img_centered_gray, img_centered_flash, img_centered_non_flash, new_x, new_y = center_on_point(img_gray_rot, img_flash_rot, img_non_flash_rot, x, y)

plt.figure(figsize=(12, 4))  # wide figure for 10 columns
plt.subplot(1, 4, 1)
plt.imshow(img_flash)
plt.title("Flash Before", fontsize=8)
plt.subplot(1, 4, 2)
plt.imshow(img_non_flash, cmap='gray')
plt.title("Non-Flash Before", fontsize=8)
plt.subplot(1, 4, 3)
plt.imshow(img_centered_flash)
plt.title("Flash After", fontsize=8)
plt.subplot(1, 4, 4)
plt.imshow(img_centered_non_flash)
plt.title("Non-Flash After", fontsize=8)
plt.tight_layout()
plt.show()