In [None]:
# Local setup - no need for Google Drive mounting in VS Code
from google.colab import drive
drive.mount('/content/drive')
# Set base directory

Mounted at /content/drive


# STEP 1: Image Preprocessing

In [None]:
import cv2
import numpy as np
import os
import matplotlib.pyplot as plt

# === Your Exact: Needs Rotation Detection (Unchanged) ===

def needs_rotation(gray):
    edges = cv2.Canny(gray, 50, 150)
    lines = cv2.HoughLines(edges, 1, np.pi/180, 150)
    if lines is None:
        return False, 0
    angles = []
    for line in lines:
        for rho, theta in line:
            angle = (theta*180)/np.pi
            if 20 < angle < 160:
                angles.append(angle)
    if angles:
        main_angle = np.median(angles)
        rot_angle = main_angle - 90 if main_angle > 45 else main_angle
        return abs(rot_angle) > 2, rot_angle
    return False, 0

# === Rotate Function (Unchanged) ===

def rotate(gray, angle, reshape=False, mode='nearest'):
    h, w = gray.shape
    center = (w // 2, h // 2)
    M = cv2.getRotationMatrix2D(center, angle, 1.0)
    flags = cv2.INTER_NEAREST if mode == 'nearest' else cv2.INTER_LINEAR
    rotated = cv2.warpAffine(gray, M, (w, h), flags=flags, borderMode=cv2.BORDER_REPLICATE)
    print(f"  -> Rotated by {angle:.2f}° (reshape={reshape}, mode={mode}; size preserved).")
    return rotated

# === 1. Detection: Washed Out / Faded Areas (Unchanged) ===

def check_for_washout(img_gray, ink_threshold=70, contrast_threshold=40):
    darkest_pixels_val = np.percentile(img_gray, 5)
    std_dev = np.std(img_gray)
    print(f"  -> Contrast Analysis: Darkest Ink Level={darkest_pixels_val:.2f}, Contrast(StdDev)={std_dev:.2f}")
    if darkest_pixels_val > ink_threshold:
        print(f"  -> Detection: Image is WASHED OUT (Ink is too light: {darkest_pixels_val:.2f} > {ink_threshold})")
        return True
    if std_dev < contrast_threshold:
        print(f"  -> Detection: Image is LOW CONTRAST (StdDev: {std_dev:.2f} < {contrast_threshold})")
        return True
    print("  -> Detection: Contrast looks okay.")
    return False

# === 2. Noise Measurement Helper (Unchanged) ===

def get_noise_ratio(img_gray, noise_threshold=10):
    denoised_ref = cv2.medianBlur(img_gray, 3)
    diff = cv2.absdiff(img_gray, denoised_ref)
    noise_mask = diff > noise_threshold
    noise_pixel_count = np.sum(noise_mask)
    return noise_pixel_count / img_gray.size

# === Laplacian Variance (Unchanged) ===
def get_laplacian_variance(img_gray, blur_size=3):
    laplacian = cv2.Laplacian(img_gray, cv2.CV_64F, ksize=blur_size)
    variance = laplacian.var()
    print(f"  -> Laplacian Variance (noise sharpness): {variance:.2f}")
    return variance

# === Apply Gaussian Blur (Unchanged) ===
def apply_gaussian_blur_if_needed(img_gray, variance_threshold=100):
    variance = get_laplacian_variance(img_gray)
    if variance > variance_threshold:
        blurred = cv2.GaussianBlur(img_gray, (3, 3), sigmaX=0.5)
        print(f"  -> High noise detected (var={variance:.2f} > {variance_threshold}). Applied Gaussian blur.")
        return blurred, True
    else:
        print(f"  -> Noise levels okay (var={variance:.2f} <= {variance_threshold}). Skipping blur.")
        return img_gray, False

# === Sobel Variance (Unchanged) ===
def get_sobel_variance(img_gray, ksize=3):
    sobelx = cv2.Sobel(img_gray, cv2.CV_64F, 1, 0, ksize=ksize)
    sobely = cv2.Sobel(img_gray, cv2.CV_64F, 0, 1, ksize=ksize)
    sobel_combined = np.sqrt(sobelx**2 + sobely**2)
    variance = sobel_combined.var()
    print(f"  -> Sobel Variance (edge sharpness): {variance:.2f}")
    return variance

# === Apply Unsharp Sharpen (Unchanged) ===
def apply_unsharp_sharpen_if_needed(img_gray, variance_threshold=50, amount=1.0):
    variance = get_sobel_variance(img_gray)
    if variance < variance_threshold:
        blurred = cv2.GaussianBlur(img_gray, (5, 5), sigmaX=1.0)
        sharpened = cv2.addWeighted(img_gray, 1.5, blurred, -0.5, 0)
        print(f"  -> Soft edges detected (var={variance:.2f} < {variance_threshold}). Applied unsharp sharpening (amount={amount}).")
        return sharpened, True
    else:
        print(f"  -> Edges sharp enough (var={variance:.2f} >= {variance_threshold}). Skipping sharpening.")
        return img_gray, False

# === Binary Thresholding (Unchanged) ===
def apply_binary_thresholding_if_needed(img_gray, entropy_threshold=5.0, block_size=11, C=2):
    hist = cv2.calcHist([img_gray], [0], None, [256], [0, 256])
    entropy = -np.sum((hist / hist.sum()) * np.log2(hist / hist.sum() + 1e-10))
    print(f"  -> Image Entropy (B&W check): {entropy:.2f}")

    if entropy > entropy_threshold:
        binary = cv2.adaptiveThreshold(img_gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
                                       cv2.THRESH_BINARY_INV, block_size, C)
        dark_pixels = np.sum(binary == 0)
        total_pixels = binary.size
        density = dark_pixels / total_pixels
        if density < 0.05 or density > 0.5:
            _, binary = cv2.threshold(img_gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
        print(f"  -> Applied adaptive binary thresholding (initial density: {density:.4f}).")
        return binary, True
    else:
        print(f"  -> Image already B&W-like (entropy={entropy:.2f} <= {entropy_threshold}). Skipping binary.")
        return img_gray, False

# === Optimized: Repair Lines (Unchanged from Prior) ===

def repair_sudoku_lines_if_needed(binary_img, fragment_threshold=8, min_length=30):
    print("  -> Checking for interrupted Sudoku lines (optimized repair)...")

    contours, _ = cv2.findContours(binary_img, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
    line_mask = np.zeros_like(binary_img)
    large_contours_count = 0
    for cnt in contours:
        area = cv2.contourArea(cnt)
        if area > 300:
            cv2.drawContours(line_mask, [cnt], -1, 255, -1)
            large_contours_count += 1
    binary_lines = cv2.bitwise_and(binary_img, line_mask)
    print(f"  -> Filtered contours (area>300): {large_contours_count}/{len(contours)} total.")

    dil_kernel = np.ones((3, 3), np.uint8)
    binary_lines_dil = cv2.dilate(binary_lines, dil_kernel, iterations=1)

    edges = cv2.Canny(binary_lines_dil, 50, 150, apertureSize=3)
    lines = cv2.HoughLinesP(edges, rho=1, theta=np.pi/180, threshold=25,
                            minLineLength=min_length, maxLineGap=30)

    if lines is None:
        lines = np.array([])
        print("  -> No lines detected by Hough.")

    print(f"  -> Detected {len(lines)} line segments.")

    fragments = 0
    img_h, img_w = binary_img.shape
    expected_len = max(img_h, img_w) / 9
    short_threshold = expected_len / 3
    h_fragments = 0
    v_fragments = 0
    for line in lines:
        x1, y1, x2, y2 = line[0]
        length = np.sqrt((x2 - x1)**2 + (y2 - y1)**2)
        if abs(x2 - x1) >= abs(y2 - y1) and abs(x2 - x1) > 15:
            if length < short_threshold:
                h_fragments += 1
        elif abs(y2 - y1) >= abs(x2 - x1) and abs(y2 - y1) > 15:
            if length < short_threshold:
                v_fragments += 1
    fragments = h_fragments + v_fragments
    print(f"  -> H fragments: {h_fragments}, V fragments: {v_fragments} (short thr: {short_threshold:.1f}).")

    if fragments > fragment_threshold:
        h_kernel_size = max(1, img_w // 10)
        v_kernel_size = max(1, img_h // 10)
        h_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, h_kernel_size))
        repaired_h = cv2.morphologyEx(binary_lines, cv2.MORPH_CLOSE, h_kernel, iterations=3)

        v_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (v_kernel_size, 1))
        repaired_v = cv2.morphologyEx(repaired_h, cv2.MORPH_CLOSE, v_kernel, iterations=3)

        open_kernel = np.ones((3, 3), np.uint8)
        repaired = cv2.morphologyEx(repaired_v, cv2.MORPH_OPEN, open_kernel, iterations=2)
        repaired = cv2.dilate(repaired, np.ones((3, 3), np.uint8), iterations=2)

        if np.sum(repaired > 200) / repaired.size > 0.015:
            repaired = cv2.medianBlur(repaired, 3)
            print("  -> Applied median fallback for noise.")

        repaired = cv2.bitwise_or(repaired, binary_img)

        print(f"  -> Triggered repair ({fragments} > {fragment_threshold}). Used kernels H:1x{h_kernel_size}, V:{v_kernel_size}x1.")
        return repaired, True
    else:
        print(f"  -> Low fragments ({fragments} <= {fragment_threshold}). Skipping repair.")
        return binary_img, False

# === New: Extract Sudoku Board (Adapted from Reference) ===

def extract_sudoku_board(img_color, img_gray, img_binary=None, epsilon=0.015):
    """
    Adapted: Detects largest quad contour on binary/gray; orders corners; warps to top-down rectangle.
    Uses reference logic: medianBlur + adaptive thresh; max dist dims; clockwise order.
    Returns warped color/gray/binary (if provided); skips if no 4 points.
    """
    print("  -> Extracting Sudoku board (perspective warp)...")

    # Prep: Blur + adaptive on gray (as reference)
    blur = cv2.medianBlur(img_gray, 3)
    thresh = cv2.adaptiveThreshold(blur, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
                                   cv2.THRESH_BINARY_INV, 11, 3)

    # Contours: Largest external
    cnts = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    cnts = cnts[0] if len(cnts) == 2 else cnts[1]
    cnts = sorted(cnts, key=cv2.contourArea, reverse=True)

    if not cnts:
        print("  -> No contours found. Skipping board extraction.")
        return img_color, img_gray, img_binary or img_gray, False

    # Largest contour approx
    c = cnts[0]
    peri = cv2.arcLength(c, True)
    approx = cv2.approxPolyDP(c, epsilon * peri, True)

    if len(approx) != 4:
        print(f"  -> Largest contour approx has {len(approx)} points (need 4). Skipping board extraction.")
        return img_color, img_gray, img_binary or img_gray, False

    # Order corners (reference: clockwise TL, TR, BR, BL via sums/diffs)
    corners = approx.reshape(4, 2)
    s = corners.sum(axis=1)
    diff = np.diff(corners, axis=1)
    top_l = corners[np.argmin(s)]
    top_r = corners[np.argmin(diff)]
    bottom_r = corners[np.argmax(s)]
    bottom_l = corners[np.argmax(diff)]
    ordered_corners = np.array([top_l, top_r, bottom_r, bottom_l], dtype="float32")

    # Dimensions (max width/height from distances)
    width_a = np.sqrt(((bottom_r[0] - bottom_l[0]) ** 2) + ((bottom_r[1] - bottom_l[1]) ** 2))
    width_b = np.sqrt(((top_r[0] - top_l[0]) ** 2) + ((top_r[1] - top_l[1]) ** 2))
    width = max(int(width_a), int(width_b))

    height_a = np.sqrt(((top_r[0] - bottom_r[0]) ** 2) + ((top_r[1] - bottom_r[1]) ** 2))
    height_b = np.sqrt(((top_l[0] - bottom_l[0]) ** 2) + ((top_l[1] - bottom_l[1]) ** 2))
    height = max(int(height_a), int(height_b))

    # Dst points: TL, TR, BR, BL
    dimensions = np.array([[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]], dtype="float32")

    # Transform matrix
    matrix = cv2.getPerspectiveTransform(ordered_corners, dimensions)

    # Warp color, gray, binary (if avail)
    warped_color = cv2.warpPerspective(img_color, matrix, (width, height))
    warped_gray = cv2.warpPerspective(img_gray, matrix, (width, height))
    if img_binary is not None:
        warped_binary = cv2.warpPerspective(img_binary, matrix, (width, height))
    else:
        warped_binary = warped_gray

    print(f"  -> Board extracted: {warped_color.shape[:2]} (w:{width}, h:{height}).")
    return warped_color, warped_gray, warped_binary, True

# === 3. Iterative Denoising (Unchanged) ===

def iterative_denoise(img_gray, max_kernel=19, target_ratio=0.01):
    kernel = 3
    best_img = img_gray.copy()
    print("  -> Starting Iterative Denoising...")

    while kernel <= max_kernel:
        temp_img = cv2.medianBlur(img_gray, kernel)
        current_ratio = get_noise_ratio(temp_img)
        print(f"    -> Testing Kernel {kernel}: Noise Ratio = {current_ratio:.4f}")

        if current_ratio < target_ratio:
            print(f"    -> Success! Noise below {target_ratio} using Kernel {kernel}.")
            return temp_img, kernel

        best_img = temp_img
        kernel += 2

    print(f"    -> Warning: Reached Max Kernel ({max_kernel}) without fully cleaning. Using result anyway.")
    return best_img, max_kernel

# === 4. Lighting Fix (Unchanged) ===

def check_if_lighting_fix_needed(img_gray, dark_threshold=110, shadow_variance=60):
    avg_brightness = np.mean(img_gray)
    if avg_brightness < dark_threshold:
        print(f"  -> Detection: Image is too dark (Avg: {avg_brightness:.2f})")
        return True

    thumbnail = cv2.resize(img_gray, (20, 20))
    min_val = np.min(thumbnail)
    max_val = np.max(thumbnail)

    if (max_val - min_val) > (255 - dark_threshold) and min_val < shadow_variance:
        print("  -> Detection: Uneven shadows detected.")
        return True
    return False

def fix_lighting_and_shadows(img_gray):
    print("  -> Applying Contrast Stretching & Background Normalization...")
    dilated = cv2.dilate(img_gray, np.ones((7, 7), np.uint8))
    bg_img = cv2.medianBlur(dilated, 21)
    diff_img = 255 - cv2.absdiff(img_gray, bg_img)
    norm_img = diff_img.copy()
    cv2.normalize(diff_img, norm_img, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8UC1)
    clahe = cv2.createCLAHE(clipLimit=3.8, tileGridSize=(10,10))
    return clahe.apply(norm_img)

# === 5. Updated Pipeline (With Board Extraction as Final Step) ===

def process_image_pipeline(image_path, output_folder):
    img_color = cv2.imread(image_path)
    if img_color is None:
        print(f"Error: Could not read {image_path}")
        return

    current_gray = cv2.cvtColor(img_color, cv2.COLOR_BGR2GRAY)
    actions_taken = []

    # --- STEP 0: ROTATION ---
    print("  -> Checking for rotation/skew...")
    rotated, angle = needs_rotation(current_gray)
    if rotated:
        current_gray = rotate(current_gray, angle, reshape=False, mode='nearest')
        img_color = cv2.cvtColor(current_gray, cv2.COLOR_GRAY2BGR)  # Update color too
        actions_taken.append("rotated")

    # --- STEP A: NOISE & BLUR ---
    initial_noise = get_noise_ratio(current_gray)
    if initial_noise > 0.01:
        current_gray, final_k = iterative_denoise(current_gray, max_kernel=19, target_ratio=0.005)
        img_color = cv2.cvtColor(current_gray, cv2.COLOR_GRAY2BGR)
        actions_taken.append(f"denoisedK{final_k}")

    current_gray, blurred_flag = apply_gaussian_blur_if_needed(current_gray, variance_threshold=100)
    if blurred_flag:
        img_color = cv2.cvtColor(current_gray, cv2.COLOR_GRAY2BGR)
        actions_taken.append("gaussian_blur")

    # --- STEP B: LIGHTING & WASHOUT ---
    is_dark_or_shadowed = check_if_lighting_fix_needed(current_gray)
    is_washed_out = check_for_washout(current_gray)
    if is_dark_or_shadowed or is_washed_out:
        current_gray = fix_lighting_and_shadows(current_gray)
        img_color = cv2.cvtColor(current_gray, cv2.COLOR_GRAY2BGR)
        actions_taken.append("fixed_contrast")

    # --- STEP D: SHARPENING ---
    current_gray, sharpen_flag = apply_unsharp_sharpen_if_needed(current_gray, variance_threshold=50)
    if sharpen_flag:
        img_color = cv2.cvtColor(current_gray, cv2.COLOR_GRAY2BGR)
        actions_taken.append("sharpened")

    # --- STEP E: BINARY & REPAIR ---
    binary_img, binary_flag = apply_binary_thresholding_if_needed(current_gray, entropy_threshold=5.0)
    if binary_flag:
        actions_taken.append("binary")
        repaired_img, repair_flag = repair_sudoku_lines_if_needed(binary_img, fragment_threshold=8)
        if repair_flag:
            binary_img = repaired_img
            actions_taken.append("lines_repaired")

    # --- STEP G: SAVE (Enhanced + Binary + Board) ---

    if not os.path.exists(output_folder):
        os.makedirs(output_folder)

    filename = os.path.basename(image_path)
    enhanced_path = os.path.join(output_folder, f"{filename}")
    cv2.imwrite(enhanced_path, current_gray)

    print(f"  -> Saved enhanced to: {enhanced_path}\n")

def process_path(path, output_folder,
                 exts=(".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff")):
    """
    If `path` is a file -> process that file.
    If `path` is a directory -> process all images inside (non-recursive).
    """
    if os.path.isfile(path):
        print(f"Processing single file: {os.path.basename(path)}")
        process_image_pipeline(path, output_folder)
        return

    if not os.path.isdir(path):
        print(f"Error: {path} is neither a file nor a directory")
        return

    # Directory case: collect all images
    all_files = os.listdir(path)  # or glob(os.path.join(path, "*"))[web:2][web:10]
    image_files = [
        f for f in all_files
        if f.lower().endswith(exts)
    ]

    if not image_files:
        print(f"No image files found in: {path}")
        return

    print(f"Found {len(image_files)} images in: {path}")
    for fname in sorted(image_files):
        file_path = os.path.join(path, fname)
        print(f"Processing: {fname}")
        process_image_pipeline(file_path, output_folder)


# --- RUNNER (Updated for 16.jpg) ---


input_dir = "/content/drive/MyDrive/Computer Vision Final/input"
output_dir = "/content/drive/MyDrive/Computer Vision Final/processed_3/"

process_path(input_dir, output_dir)


# Updated Helper (Added lines_repaired variants)
def show_processed_images(folder, file_list, actions=["original_", "denoised_", "gaussian_blur_", "fixed_contrast_", "denoisedK5_fixed_contrast_gaussian_blur_", "sharpened_", "binary_", "lines_repaired_"]):
    plt.figure(figsize=(15, 5 * len(file_list)))
    img_id = 1
    for fname in file_list:
        for action in actions:
            full_path = os.path.join(folder, f"{action}{os.path.basename(fname)}")
            if os.path.exists(full_path):
                img = cv2.imread(full_path, cv2.IMREAD_GRAYSCALE)
                plt.subplot(len(file_list), len(actions), img_id)
                plt.imshow(img, cmap='gray')
                plt.title(f"{os.path.basename(fname)}\n[{action[:-1]}]" if action != "original_" else f"{os.path.basename(fname)}\n[Original]")
                plt.axis('off')
                img_id += 1

plt.show()

Found 17 images in: /content/drive/MyDrive/Computer Vision Final/input
Processing: 01.jpg
  -> Checking for rotation/skew...
  -> Rotated by -2.00° (reshape=False, mode=nearest; size preserved).
  -> Laplacian Variance (noise sharpness): 1639.90
  -> High noise detected (var=1639.90 > 100). Applied Gaussian blur.
  -> Contrast Analysis: Darkest Ink Level=88.00, Contrast(StdDev)=29.83
  -> Detection: Image is WASHED OUT (Ink is too light: 88.00 > 70)
  -> Applying Contrast Stretching & Background Normalization...
  -> Sobel Variance (edge sharpness): 10422.21
  -> Edges sharp enough (var=10422.21 >= 50). Skipping sharpening.
  -> Image Entropy (B&W check): 4.10
  -> Image already B&W-like (entropy=4.10 <= 5.0). Skipping binary.
  -> Saved enhanced to: /content/drive/MyDrive/Computer Vision Final/processed_3/01.jpg

Processing: 02.jpg
  -> Checking for rotation/skew...
  -> Rotated by -2.00° (reshape=False, mode=nearest; size preserved).
  -> Laplacian Variance (noise sharpness): 1952.76

# STEP 2: Detect board frame

In [None]:
import cv2
import numpy as np
import os
from pathlib import Path


def process_sudoku_image(input_path, output_folder, save_intermediates=False):
    """
    Process a Sudoku image: detect grid, apply perspective transform, and save result.

    Parameters:
    -----------
    input_path : str
        Path to input Sudoku image
    output_folder : str
        Folder to save processed images
    save_intermediates : bool
        Whether to save intermediate processing steps (default: False)

    Returns:
    --------
    bool : True if successful, False otherwise
    """

    # Create output folder if it doesn't exist
    os.makedirs(output_folder, exist_ok=True)

    # Get input filename
    filename = os.path.basename(input_path)
    name_without_ext = os.path.splitext(filename)[0]

    print(f"\n{'='*60}")
    print(f"Processing: {filename}")
    print(f"{'='*60}")

    # -------------------------
    # 1. LOAD + PREPROCESSING
    # -------------------------
    img = cv2.imread(input_path)
    if img is None:
        print(f"ERROR: Could not read image from {input_path}")
        return False

    original = img.copy()
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    blurred = cv2.GaussianBlur(gray, (5, 5), 0)

    # Adaptive threshold
    thresh = cv2.adaptiveThreshold(
        blurred, 255,
        cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
        cv2.THRESH_BINARY_INV,
        11, 2
    )

    if save_intermediates:
        cv2.imwrite(os.path.join(output_folder, f"{name_without_ext}_01_threshold.jpg"), thresh)

    # -------------------------
    # 2. MORPHOLOGICAL OPS
    # -------------------------

    # Extract horizontal lines
    horizontal = thresh.copy()
    h_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (40, 1))
    horizontal = cv2.erode(horizontal, h_kernel, iterations=1)
    horizontal = cv2.dilate(horizontal, h_kernel, iterations=1)

    # Extract vertical lines
    vertical = thresh.copy()
    v_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 40))
    vertical = cv2.erode(vertical, v_kernel, iterations=1)
    vertical = cv2.dilate(vertical, v_kernel, iterations=1)

    # Combine them to get grid mask
    grid_mask = cv2.addWeighted(horizontal, 0.5, vertical, 0.5, 0)

    if save_intermediates:
        cv2.imwrite(os.path.join(output_folder, f"{name_without_ext}_02_horizontal.jpg"), horizontal)
        cv2.imwrite(os.path.join(output_folder, f"{name_without_ext}_03_vertical.jpg"), vertical)
        cv2.imwrite(os.path.join(output_folder, f"{name_without_ext}_04_grid_mask.jpg"), grid_mask)

    # -------------------------
    # 3. HOUGH TRANSFORM
    # -------------------------

    lines = cv2.HoughLinesP(
        grid_mask,
        rho=1,
        theta=np.pi/180,
        threshold=150,
        minLineLength=100,
        maxLineGap=10
    )

    if lines is None:
        print(f"WARNING: No lines detected for {filename}")
        return False

    print(f"Detected {len(lines)} line segments")

    # ALWAYS save Hough lines visualization
    hough_img = original.copy()
    for line in lines:
        x1, y1, x2, y2 = line[0]
        cv2.line(hough_img, (x1, y1), (x2, y2), (0, 255, 0), 2)

    hough_output_path = os.path.join(output_folder, f"{name_without_ext}_hough_lines.jpg")
    cv2.imwrite(hough_output_path, hough_img)
    print(f"✓ Saved Hough lines: {hough_output_path}")

    # -------------------------
    # 4. HELPER FUNCTIONS
    # -------------------------

    def cluster_lines(lines_list, threshold=20):
        """Cluster similar lines together"""
        if len(lines_list) == 0:
            return []

        lines_list = sorted(lines_list)
        clusters = []
        current_cluster = [lines_list[0]]

        for line in lines_list[1:]:
            if line - current_cluster[-1] < threshold:
                current_cluster.append(line)
            else:
                clusters.append(int(np.mean(current_cluster)))
                current_cluster = [line]
        clusters.append(int(np.mean(current_cluster)))
        return clusters

    def find_grid_corners_from_hough(lines, img_shape):
        """Find the four corners of Sudoku grid from Hough lines"""
        h_lines = []
        v_lines = []

        for line in lines:
            x1, y1, x2, y2 = line[0]
            angle = np.abs(np.arctan2(y2 - y1, x2 - x1) * 180 / np.pi)

            # Horizontal lines
            if angle < 10 or angle > 170:
                y_avg = (y1 + y2) // 2
                h_lines.append(y_avg)
            # Vertical lines
            elif 80 < angle < 100:
                x_avg = (x1 + x2) // 2
                v_lines.append(x_avg)

        # Cluster lines to get main grid lines
        h_clusters = cluster_lines(h_lines, threshold=20)
        v_clusters = cluster_lines(v_lines, threshold=20)

        if len(h_clusters) < 2 or len(v_clusters) < 2:
            return None

        # Get outer boundaries (first and last lines)
        top = h_clusters[0]
        bottom = h_clusters[-1]
        left = v_clusters[0]
        right = v_clusters[-1]

        # Create corner points
        corners = np.array([
            [left, top],
            [right, top],
            [right, bottom],
            [left, bottom]
        ], dtype=np.float32)

        return corners

    def order_points(pts):
        """Order points in clockwise order starting from top-left"""
        rect = np.zeros((4, 2), dtype="float32")

        s = pts.sum(axis=1)
        rect[0] = pts[np.argmin(s)]
        rect[2] = pts[np.argmax(s)]

        diff = np.diff(pts, axis=1)
        rect[1] = pts[np.argmin(diff)]
        rect[3] = pts[np.argmax(diff)]

        return rect

    def perspective_transform(image, corners):
        """Apply perspective transform to get top-down view"""
        rect = order_points(corners)
        (tl, tr, br, bl) = rect

        widthA = np.sqrt(((br[0] - bl[0]) ** 2) + ((br[1] - bl[1]) ** 2))
        widthB = np.sqrt(((tr[0] - tl[0]) ** 2) + ((tr[1] - tl[1]) ** 2))
        maxWidth = max(int(widthA), int(widthB))

        heightA = np.sqrt(((tr[0] - br[0]) ** 2) + ((tr[1] - br[1]) ** 2))
        heightB = np.sqrt(((tl[0] - bl[0]) ** 2) + ((tl[1] - bl[1]) ** 2))
        maxHeight = max(int(heightA), int(heightB))

        size = max(maxWidth, maxHeight)

        dst = np.array([
            [0, 0],
            [size - 1, 0],
            [size - 1, size - 1],
            [0, size - 1]
        ], dtype="float32")

        M = cv2.getPerspectiveTransform(rect, dst)
        warped = cv2.warpPerspective(image, M, (size, size))

        return warped, M

    # -------------------------
    # 5. FIND CORNERS
    # -------------------------

    sudoku_corners = find_grid_corners_from_hough(lines, original.shape)

    # Fallback: Try contour detection
    if sudoku_corners is None:
        print("Hough method failed, trying contour detection...")
        contours, _ = cv2.findContours(
            grid_mask,
            cv2.RETR_EXTERNAL,
            cv2.CHAIN_APPROX_SIMPLE
        )

        contours = sorted(contours, key=cv2.contourArea, reverse=True)

        for cnt in contours:
            perimeter = cv2.arcLength(cnt, True)
            approx = cv2.approxPolyDP(cnt, 0.02 * perimeter, True)

            if len(approx) == 4:
                sudoku_corners = approx.reshape(4, 2).astype(np.float32)
                print("Using contour corners as fallback")
                break

    if sudoku_corners is None:
        print(f"ERROR: Could not detect grid corners for {filename}")
        return False

    print(f"Corners detected: {sudoku_corners.tolist()}")

    # ALWAYS save corners visualization
    corners_img = original.copy()
    ordered = order_points(sudoku_corners)

    # Draw corner points with numbers
    for i, corner in enumerate(ordered):
        cv2.circle(corners_img, tuple(corner.astype(int)), 15, (0, 0, 255), -1)
        cv2.putText(corners_img, str(i), tuple(corner.astype(int)),
                   cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 0), 2)

    # Draw lines connecting corners (bounding box)
    for i in range(4):
        pt1 = tuple(ordered[i].astype(int))
        pt2 = tuple(ordered[(i+1)%4].astype(int))
        cv2.line(corners_img, pt1, pt2, (0, 255, 0), 3)

    corners_output_path = os.path.join(output_folder, f"{name_without_ext}_corners.jpg")
    cv2.imwrite(corners_output_path, corners_img)
    print(f"✓ Saved detected corners: {corners_output_path}")

    # -------------------------
    # 6. PERSPECTIVE TRANSFORM
    # -------------------------

    try:
        warped, transform_matrix = perspective_transform(original, sudoku_corners)
        warped_gray, _ = perspective_transform(gray, sudoku_corners)
        print(f"✓ Perspective transform successful! Output size: {warped.shape[:2]}")
    except Exception as e:
        print(f"ERROR: Perspective transform failed: {e}")
        return False

    # Save other intermediate steps if requested
    if save_intermediates:
        # Additional intermediate visualizations already saved above
        pass

    print(f"✓ Processing complete for {filename}\n")
    return True


def process_sudoku_batch(input_source, output_folder, save_intermediates=False,
                         exts=(".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff")):
    """
    Process multiple Sudoku images in batch.

    Parameters
    ----------
    input_source : list[str] or str
        - If list: list of input image paths
        - If str and is a file: single image path
        - If str and is a directory: all images in that directory will be processed
    output_folder : str
        Folder to save all processed images
    save_intermediates : bool
        Whether to save intermediate processing steps

    Returns
    -------
    dict : Summary of results with success/failure counts
    """

    # Normalize input_source into a list of paths
    if isinstance(input_source, str):
        if os.path.isdir(input_source):  # directory case [web:28][web:29]
            all_files = os.listdir(input_source)
            input_paths = [
                os.path.join(input_source, f)
                for f in all_files
                if f.lower().endswith(exts)
            ]
            input_paths.sort()
            print(f"Found {len(input_paths)} images in directory: {input_source}")
        elif os.path.isfile(input_source):
            input_paths = [input_source]
        else:
            print(f"ERROR: {input_source} is neither a file nor a directory")
            return {
                'successful': [],
                'failed': [],
                'total': 0
            }
    else:
        # Assume it's already an iterable of paths (original behavior)
        input_paths = list(input_source)

    results = {
        'successful': [],
        'failed': [],
        'total': len(input_paths)
    }

    for input_path in input_paths:
        if not os.path.exists(input_path):
            print(f"WARNING: File not found: {input_path}")
            results['failed'].append(input_path)
            continue

        success = process_sudoku_image(input_path, output_folder, save_intermediates)

        if success:
            results['successful'].append(input_path)
        else:
            results['failed'].append(input_path)

    # Print summary
    print(f"\n{'='*60}")
    print(f"BATCH PROCESSING COMPLETE")
    print(f"{'='*60}")
    print(f"Total images: {results['total']}")
    print(f"Successful: {len(results['successful'])}")
    print(f"Failed: {len(results['failed'])}")

    if results['failed']:
        print("\nFailed images:")
        for path in results['failed']:
            print(f"  - {os.path.basename(path)}")

    return results


# -------------------------
# USAGE EXAMPLES
# -------------------------



input_dir = "/content/drive/MyDrive/Computer Vision Final/processed_3"
output_directory = "/content/drive/MyDrive/Computer Vision Final/warped_batch/"

results = process_sudoku_batch(input_dir, output_directory, save_intermediates=False)



Found 17 images in directory: /content/drive/MyDrive/Computer Vision Final/processed_3

Processing: 01.jpg
Detected 157 line segments
✓ Saved Hough lines: /content/drive/MyDrive/Computer Vision Final/warped_batch/01_hough_lines.jpg
Corners detected: [[37.0, 45.0], [960.0, 45.0], [960.0, 952.0], [37.0, 952.0]]
✓ Saved detected corners: /content/drive/MyDrive/Computer Vision Final/warped_batch/01_corners.jpg
✓ Perspective transform successful! Output size: (923, 923)
✓ Processing complete for 01.jpg


Processing: 02.jpg
Detected 181 line segments
✓ Saved Hough lines: /content/drive/MyDrive/Computer Vision Final/warped_batch/02_hough_lines.jpg
Corners detected: [[34.0, 55.0], [959.0, 55.0], [959.0, 948.0], [34.0, 948.0]]
✓ Saved detected corners: /content/drive/MyDrive/Computer Vision Final/warped_batch/02_corners.jpg
✓ Perspective transform successful! Output size: (925, 925)
✓ Processing complete for 02.jpg


Processing: 03.jpg
Detected 126 line segments
✓ Saved Hough lines: /content/dr

# STEP 3: Reduce to a squre


In [None]:

import cv2
import numpy as np
import os
from pathlib import Path

REFINE_INPUT_DIR = '/content/drive/MyDrive/Computer Vision Final/processed_3'
REFINE_OUTPUT_DIR = '/content/drive/MyDrive/Computer Vision Final/final_square'
PREVIOUS_OUTPUT_DIR = '/content/drive/MyDrive/Computer Vision Final/warped_batch'
os.makedirs(REFINE_OUTPUT_DIR, exist_ok=True)

TARGET_SIZE = 450
MIN_LINE_LENGTH_FACTOR = 0.4
GRID_EXPECTED_LINES = 10
ROTATION_MIN_CORRECTION = 2.0
ROTATION_MAX_CORRECTION = 12.0
VERTICAL_MIN_COUNT = 6
VERTICAL_STD_MAX = 2.5
UNIFORMITY_RATIO_LIMIT = 2.2  # slightly stricter than spacing_uniform fallback

# -----------------------------
# DETECT GRID LINES (morphology + Hough)
# -----------------------------
def detect_grid_lines(gray):
    blur = cv2.GaussianBlur(gray, (5, 5), 0)
    thresh = cv2.adaptiveThreshold(blur, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 11, 2)
    h = thresh.copy(); v = thresh.copy()
    h_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (40, 1))
    v_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 40))
    h = cv2.erode(h, h_kernel, iterations=1); h = cv2.dilate(h, h_kernel, iterations=1)
    v = cv2.erode(v, v_kernel, iterations=1); v = cv2.dilate(v, v_kernel, iterations=1)
    grid_mask = cv2.addWeighted(h, 0.5, v, 0.5, 0)
    h_val, w_val = gray.shape
    min_line_len = int(min(h_val, w_val) * MIN_LINE_LENGTH_FACTOR)
    lines = cv2.HoughLinesP(grid_mask, 1, np.pi/180, 120, minLineLength=min_line_len, maxLineGap=15)
    if lines is None:
        lines = cv2.HoughLinesP(thresh, 1, np.pi/180, 150, minLineLength=min_line_len, maxLineGap=20)
    return lines, grid_mask, thresh

# -----------------------------
# CLASSIFY LINES
# -----------------------------
def classify_lines(lines):
    horizontal_positions, vertical_positions = [], []
    if lines is None: return horizontal_positions, vertical_positions, []
    vertical_angles = []
    for seg in lines:
        x1,y1,x2,y2 = seg[0]
        dx, dy = x2 - x1, y2 - y1
        ang = np.degrees(np.arctan2(dy, dx))
        abs_ang = abs(ang)
        if abs_ang < 10 or abs_ang > 170:
            horizontal_positions.append((y1 + y2)//2)
        elif 80 < abs_ang < 100:
            vertical_positions.append((x1 + x2)//2)
            if ang < 0: ang += 180
            vertical_angles.append(ang)
    return horizontal_positions, vertical_positions, vertical_angles

# -----------------------------
# ROTATION ESTIMATION
# -----------------------------
def estimate_rotation(vertical_angles):
    if len(vertical_angles) < VERTICAL_MIN_COUNT:
        return 0.0, False
    arr = np.array(vertical_angles)
    mean_ang = arr.mean(); std_ang = arr.std()
    correction = mean_ang - 90.0
    reliable = (std_ang < VERTICAL_STD_MAX and ROTATION_MIN_CORRECTION < abs(correction) < ROTATION_MAX_CORRECTION)
    return correction, reliable

# -----------------------------
# APPLY ROTATION
# -----------------------------
def apply_rotation(image, angle):
    h, w = image.shape[:2]
    M = cv2.getRotationMatrix2D((w/2, h/2), angle, 1.0)
    return cv2.warpAffine(image, M, (w, h), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REPLICATE)

# -----------------------------
# CLUSTER POSITIONS
# -----------------------------
def cluster_positions(positions, threshold):
    if not positions: return []
    positions = sorted(positions)
    clusters, current = [], [positions[0]]
    for p in positions[1:]:
        if p - current[-1] <= threshold:
            current.append(p)
        else:
            clusters.append(int(np.mean(current)))
            current = [p]
    clusters.append(int(np.mean(current)))
    return clusters

# -----------------------------
# SELECT BEST 10 UNIFORMLY SPACED LINES (avoid margins)
# -----------------------------
def select_best_uniform_10(sorted_positions):
    # Requires >=10 positions
    if len(sorted_positions) < GRID_EXPECTED_LINES:
        return None
    best = None
    best_score = 1e9
    for i in range(len(sorted_positions) - GRID_EXPECTED_LINES + 1):
        window = sorted_positions[i:i+GRID_EXPECTED_LINES]
        diffs = np.diff(window)
        if len(diffs) == 0: continue
        max_d = np.max(diffs); min_d = np.min(diffs)
        if min_d == 0: continue
        ratio = max_d / min_d
        # Score: combination of ratio and std for stability
        std = np.std(diffs)
        score = ratio * (1 + std / (np.mean(diffs)+1e-6))
        if score < best_score:
            best_score = score
            best = window
    # Sanity: if chosen ratio still poor, return None
    if best is not None:
        diffs = np.diff(best)
        max_d = np.max(diffs); min_d = np.min(diffs)
        if min_d == 0 or (max_d / min_d) > UNIFORMITY_RATIO_LIMIT:
            return None
    return best

# -----------------------------
# COMPLETE LINES TO EXACT 10
# -----------------------------
def complete_lines(line_positions, image_extent):
    unique = sorted(set(line_positions))
    if len(unique) == GRID_EXPECTED_LINES: return unique
    if len(unique) < 2:
        return [int(round(i * image_extent / (GRID_EXPECTED_LINES - 1))) for i in range(GRID_EXPECTED_LINES)]
    first, last = unique[0], unique[-1]
    span = last - first
    if span < image_extent * 0.5:
        first, last = 0, image_extent - 1
        span = last - first
    step = span / (GRID_EXPECTED_LINES - 1)
    return [int(round(first + i * step)) for i in range(GRID_EXPECTED_LINES)]

# -----------------------------
# SPACING VALIDATION
# -----------------------------
def spacing_uniform(line_positions):
    diffs = np.diff(sorted(line_positions))
    if len(diffs) == 0: return False
    max_d, min_d = np.max(diffs), np.min(diffs)
    if min_d == 0: return False
    return (max_d / min_d) < 2.5

# -----------------------------
# Fallbacks: contour & bounding box
# -----------------------------
def contour_outer_corners(thresh):
    cnts = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    cnts = cnts[0] if len(cnts) == 2 else cnts[1]
    cnts = sorted(cnts, key=cv2.contourArea, reverse=True)
    for c in cnts:
        peri = cv2.arcLength(c, True)
        approx = cv2.approxPolyDP(c, 0.02 * peri, True)
        if len(approx) == 4:
            return approx.reshape(4,2).astype(np.float32)
    return None


def bounding_box_corners(mask):
    ys, xs = np.where(mask > 0)
    if len(xs) == 0 or len(ys) == 0: return None
    left, right = int(xs.min()), int(xs.max())
    top, bottom = int(ys.min()), int(ys.max())
    return np.array([[left, top],[right, top],[right, bottom],[left, bottom]], dtype=np.float32)

# -----------------------------
# DERIVE CORNERS FROM COMPLETED LINES
# -----------------------------
def derive_corners(h_lines, v_lines):
    return np.array([[v_lines[0], h_lines[0]],[v_lines[-1], h_lines[0]],[v_lines[-1], h_lines[-1]],[v_lines[0], h_lines[-1]]], dtype=np.float32)

# -----------------------------
# ORDER POINTS & WARP TO SQUARE
# -----------------------------
def order_points(pts):
    rect = np.zeros((4, 2), dtype="float32")
    s = pts.sum(axis=1)
    rect[0] = pts[np.argmin(s)]
    rect[2] = pts[np.argmax(s)]
    diff = np.diff(pts, axis=1)
    rect[1] = pts[np.argmin(diff)]
    rect[3] = pts[np.argmax(diff)]
    return rect


def warp_to_square(image, corners, size=TARGET_SIZE):
    rect = order_points(corners)
    dst = np.array([[0,0],[size-1,0],[size-1,size-1],[0,size-1]], dtype=np.float32)
    M = cv2.getPerspectiveTransform(rect, dst)
    warped = cv2.warpPerspective(image, M, (size, size))
    return warped, M

# -----------------------------
# ENHANCE WARPED IMAGE
# -----------------------------
def enhance_warped(warped_gray):
    clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8,8))
    eq = clahe.apply(warped_gray)
    edges = cv2.Canny(eq, 40, 120)
    edges = cv2.dilate(edges, np.ones((2,2), np.uint8), iterations=1)
    reinforced = cv2.addWeighted(eq, 1.0, edges, 0.15, 0)
    blur = cv2.GaussianBlur(reinforced, (3,3), 0)
    sharpened = cv2.addWeighted(reinforced, 1.25, blur, -0.25, 0)
    return sharpened

# -----------------------------
# EXTRACT CORNERS FROM PREVIOUS CORNERS IMAGE (RED CIRCLES)
# -----------------------------
def extract_corners_from_previous(stem):
    corners_img_path = os.path.join(PREVIOUS_OUTPUT_DIR, f"{stem}_corners.jpg")
    if not os.path.exists(corners_img_path): return None
    img = cv2.imread(corners_img_path)
    if img is None: return None
    hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
    lower1 = np.array([0, 70, 50]); upper1 = np.array([10, 255, 255])
    lower2 = np.array([160, 70, 50]); upper2 = np.array([180, 255, 255])
    mask = cv2.bitwise_or(cv2.inRange(hsv, lower1, upper1), cv2.inRange(hsv, lower2, upper2))
    mask = cv2.medianBlur(mask, 5)
    cnts = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    cnts = cnts[0] if len(cnts) == 2 else cnts[1]
    centers = []
    for c in cnts:
        area = cv2.contourArea(c)
        if 150 < area < 1500:
            M = cv2.moments(c)
            if M['m00'] > 0:
                centers.append([int(M['m10']/M['m00']), int(M['m01']/M['m00'])])
    if len(centers) != 4: return None
    return np.array(centers, dtype=np.float32)

# -----------------------------
# PARSE HOUGH OVERLAY (GREEN LINES)
# -----------------------------
def extract_lines_from_hough_overlay(stem):
    path = os.path.join(PREVIOUS_OUTPUT_DIR, f"{stem}_hough_lines.jpg")
    if not os.path.exists(path): return None
    img = cv2.imread(path)
    if img is None: return None
    hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
    lower = np.array([35, 60, 40]); upper = np.array([85, 255, 255])
    mask = cv2.inRange(hsv, lower, upper)
    mask = cv2.dilate(mask, np.ones((3,3), np.uint8), iterations=1)
    lines = cv2.HoughLinesP(mask, 1, np.pi/180, 60, minLineLength=40, maxLineGap=15)
    return lines

# -----------------------------
# HOUGH OVERLAY CORNERS (with LS outer fits)
# -----------------------------
def corners_from_hough_overlay(lines, img_shape):
    if lines is None or len(lines) < 4: return None
    h_pos = []; v_pos = []; h_segments = []; v_segments = []
    for seg in lines:
        x1,y1,x2,y2 = seg[0]; dx, dy = x2 - x1, y2 - y1
        ang = abs(np.degrees(np.arctan2(dy, dx)))
        if ang < 12 or ang > 168:
            h_pos.append((y1+y2)/2.0); h_segments.append([(x1,y1),(x2,y2)])
        elif 78 < ang < 102:
            v_pos.append((x1+x2)/2.0); v_segments.append([(x1,y1),(x2,y2)])
    if len(h_pos) < 2 or len(v_pos) < 2: return None
    h_thresh = max(10, img_shape[0]//80); v_thresh = max(10, img_shape[1]//80)
    h_clusters = cluster_positions(h_pos, h_thresh)
    v_clusters = cluster_positions(v_pos, v_thresh)
    # Inner selection if many clusters
    if len(h_clusters) > GRID_EXPECTED_LINES:
        sel = select_best_uniform_10(h_clusters)
        if sel is not None: h_clusters = sel
    if len(v_clusters) > GRID_EXPECTED_LINES:
        sel = select_best_uniform_10(v_clusters)
        if sel is not None: v_clusters = sel
    if len(h_clusters) < 2 or len(v_clusters) < 2: return None
    top_y, bottom_y = h_clusters[0], h_clusters[-1]
    left_x, right_x = v_clusters[0], v_clusters[-1]
    def collect(seg_list, target, axis='y', tol=15):
        pts = []
        for seg in seg_list:
            for (x,y) in seg:
                val = y if axis=='y' else x
                if abs(val - target) <= tol: pts.append((x,y))
        return pts
    top_pts = collect(h_segments, top_y, 'y'); bottom_pts = collect(h_segments, bottom_y, 'y')
    left_pts = collect(v_segments, left_x, 'x'); right_pts = collect(v_segments, right_x, 'x')
    def fit_line(points, mode='horizontal'):
        if len(points) < 2: return None
        pts = np.array(points, dtype=np.float32); x = pts[:,0]; y = pts[:,1]
        if mode=='horizontal':
            A = np.vstack([x, np.ones_like(x)]).T; m, b = np.linalg.lstsq(A, y, rcond=None)[0]; return ('h', m, b)
        else:
            A = np.vstack([y, np.ones_like(y)]).T; m, b = np.linalg.lstsq(A, x, rcond=None)[0]; return ('v', m, b)
    top_line = fit_line(top_pts, 'horizontal'); bottom_line = fit_line(bottom_pts, 'horizontal')
    left_line = fit_line(left_pts, 'vertical'); right_line = fit_line(right_pts, 'vertical')
    if None in (top_line,bottom_line,left_line,right_line): return None
    def line_to_abc(line):
        kind,m,b = line
        if kind=='h': return m, -1.0, b
        return 1.0, -m, -b
    def intersect(l1,l2):
        a1,b1,c1 = line_to_abc(l1); a2,b2,c2 = line_to_abc(l2); det = a1*b2 - a2*b1
        if abs(det)<1e-8: return None
        x = (-c1*b2 + c2*b1)/det; y = (-a1*c2 + a2*c1)/det
        return np.array([x,y], dtype=np.float32)
    tl = intersect(top_line,left_line); tr = intersect(top_line,right_line)
    br = intersect(bottom_line,right_line); bl = intersect(bottom_line,left_line)
    if None in (tl,tr,br,bl): return None
    corners = np.array([tl,tr,br,bl], dtype=np.float32)
    xs = corners[:,0]; ys = corners[:,1]
    if not (0 <= xs.min() < img_shape[1] and 0 <= xs.max() <= img_shape[1] and 0 <= ys.min() < img_shape[0] and 0 <= ys.max() <= img_shape[0]):
        return None
    return corners

# -----------------------------
# MAIN REFINEMENT (INCLUDES INNER SELECTION)
# -----------------------------
def refine_sudoku_image(input_path, output_dir):
    base_name = os.path.basename(input_path); stem = os.path.splitext(base_name)[0]
    img = cv2.imread(input_path)
    if img is None:
        print(f"[FAIL] Cannot read {input_path}"); return False
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    reuse_success = False; hough_success = False
    # 1. Reuse corners (red circles)
    reused_corners = extract_corners_from_previous(stem)
    if reused_corners is not None:
        xs = reused_corners[:,0]; ys = reused_corners[:,1]
        w_span = xs.max()-xs.min(); h_span = ys.max()-ys.min()
        area_ratio = (w_span*h_span)/(gray.shape[0]*gray.shape[1]); aspect = w_span/max(1,h_span)
        if area_ratio > 0.25 and 0.6 < aspect < 1.6:
            corners = reused_corners; reuse_success = True
            print(f"[REUSE] {base_name}: accepted (areaRatio={area_ratio:.2f}, aspect={aspect:.2f})")
        else:
            print(f"[REUSE-REJECT] {base_name}: sanity failed")
    # 2. Hough overlay parse
    if not reuse_success:
        hough_lines = extract_lines_from_hough_overlay(stem)
        if hough_lines is not None:
            overlay_corners = corners_from_hough_overlay(hough_lines, gray.shape)
            if overlay_corners is not None:
                xs = overlay_corners[:,0]; ys = overlay_corners[:,1]
                w_span = xs.max()-xs.min(); h_span = ys.max()-ys.min()
                area_ratio = (w_span*h_span)/(gray.shape[0]*gray.shape[1]); aspect = w_span/max(1,h_span)
                if area_ratio > 0.25 and 0.6 < aspect < 1.6:
                    corners = overlay_corners; hough_success = True
                    print(f"[HOUGH-OVERLAY] {base_name}: accepted (areaRatio={area_ratio:.2f}, aspect={aspect:.2f})")
                else:
                    print(f"[HOUGH-REJECT] {base_name}: sanity failed")
            else:
                print(f"[HOUGH-NO-CORNERS] {base_name}")
        else:
            print(f"[HOUGH-NO-LINES] {base_name}")
    # 3. Fallback detection path
    if not (reuse_success or hough_success):
        lines_pass1, grid_mask1, thresh1 = detect_grid_lines(gray)
        h_pos1, v_pos1, v_angles1 = classify_lines(lines_pass1)
        correction, reliable = estimate_rotation(v_angles1)
        if reliable:
            img = apply_rotation(img, -correction); gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
            print(f"[ROTATE] {base_name}: {-correction:.2f}° applied")
            lines, grid_mask, thresh = detect_grid_lines(gray)
        else:
            lines, grid_mask, thresh = lines_pass1, grid_mask1, thresh1
            print(f"[NO-ROTATE] {base_name}: correction={correction:.2f} reliable={reliable} count={len(v_angles1)}")
        horiz, vert, _ = classify_lines(lines)
        h_thresh = max(10, gray.shape[0]//80); v_thresh = max(10, gray.shape[1]//80)
        h_clusters = cluster_positions(horiz, h_thresh); v_clusters = cluster_positions(vert, v_thresh)
        # Inner selection before completion
        if len(h_clusters) > GRID_EXPECTED_LINES:
            sel = select_best_uniform_10(h_clusters)
            if sel is not None: h_clusters = sel
        if len(v_clusters) > GRID_EXPECTED_LINES:
            sel = select_best_uniform_10(v_clusters)
            if sel is not None: v_clusters = sel
        h_completed = h_clusters if len(h_clusters)==GRID_EXPECTED_LINES else complete_lines(h_clusters, gray.shape[0])
        v_completed = v_clusters if len(v_clusters)==GRID_EXPECTED_LINES else complete_lines(v_clusters, gray.shape[1])
        spacing_ok = spacing_uniform(h_completed) and spacing_uniform(v_completed)
        corners = derive_corners(h_completed, v_completed)
        area_lines = (max(v_completed)-min(v_completed))*(max(h_completed)-min(h_completed))
        full_area = gray.shape[0]*gray.shape[1]; area_ratio = area_lines/full_area
        aspect = (max(v_completed)-min(v_completed))/max(1,(max(h_completed)-min(h_completed)))
        unreliable = (area_ratio < 0.35) or (aspect < 0.75 or aspect > 1.35) or (not spacing_ok)
        if unreliable:
            contour_c = contour_outer_corners(thresh)
            if contour_c is not None:
                corners = contour_c; print(f"[FALLBACK-CONTOUR] {base_name}")
            else:
                bbox_c = bounding_box_corners(grid_mask)
                if bbox_c is not None:
                    corners = bbox_c; print(f"[FALLBACK-BOX] {base_name}")
                else:
                    print(f"[FAIL] {base_name}: no reliable corners"); return False
        else:
            print(f"[LINES] {base_name}: accepted (areaRatio={area_ratio:.2f}, aspect={aspect:.2f}, spacing_ok={spacing_ok})")
    # Warp
    warped_color, M = warp_to_square(img, corners, size=TARGET_SIZE)
    warped_gray = cv2.cvtColor(warped_color, cv2.COLOR_BGR2GRAY)
    enhanced = enhance_warped(warped_gray)
    # Visualization
    grid_vis = img.copy()
    if 'h_completed' in locals() and not (reuse_success or hough_success):
        for y in h_completed: cv2.line(grid_vis, (0,y), (grid_vis.shape[1]-1,y), (0,255,0), 2)
        for x in v_completed: cv2.line(grid_vis, (x,0), (grid_vis.shape[0]-1,x), (255,0,0), 2)
    for c in corners: cv2.circle(grid_vis, tuple(c.astype(int)), 10, (0,0,255), -1)
    cv2.imwrite(os.path.join(output_dir, f"{stem}_square_gray.jpg"), warped_gray)
    print(f"[OK] {base_name}: reuse={reuse_success} hough_used={hough_success} warped={warped_color.shape[:2]}")
    return True

# -----------------------------
# BATCH REFINEMENT
# -----------------------------
def refine_batch(input_dir, output_dir):
    paths = [p for p in Path(input_dir).glob('*.jpg')]
    success = fail = 0
    for p in paths:
        if refine_sudoku_image(str(p), output_dir): success += 1
        else: fail += 1
    print(f"\nBatch refinement complete: success={success}, failed={fail}, total={len(paths)}")

if __name__ == '__main__':
    refine_batch(REFINE_INPUT_DIR, REFINE_OUTPUT_DIR)


[REUSE] 01.jpg: accepted (areaRatio=0.85, aspect=1.02)
[OK] 01.jpg: reuse=True hough_used=False warped=(450, 450)
[REUSE] 02.jpg: accepted (areaRatio=0.84, aspect=1.03)
[OK] 02.jpg: reuse=True hough_used=False warped=(450, 450)
[REUSE] 03.jpg: accepted (areaRatio=0.86, aspect=0.99)
[OK] 03.jpg: reuse=True hough_used=False warped=(450, 450)
[REUSE] 04.jpg: accepted (areaRatio=0.87, aspect=1.00)
[OK] 04.jpg: reuse=True hough_used=False warped=(450, 450)
[REUSE] 05.jpg: accepted (areaRatio=0.77, aspect=1.01)
[OK] 05.jpg: reuse=True hough_used=False warped=(450, 450)
[REUSE] 06.jpg: accepted (areaRatio=0.75, aspect=1.04)
[OK] 06.jpg: reuse=True hough_used=False warped=(450, 450)
[REUSE] 07.jpg: accepted (areaRatio=0.80, aspect=0.96)
[OK] 07.jpg: reuse=True hough_used=False warped=(450, 450)
[REUSE] 08.jpg: accepted (areaRatio=0.84, aspect=0.98)
[OK] 08.jpg: reuse=True hough_used=False warped=(450, 450)
[REUSE] 09.jpg: accepted (areaRatio=0.84, aspect=1.03)
[OK] 09.jpg: reuse=True hough_use

# Step 4: Consolidated Grid - Draw Black Lines on Original

In [None]:
import cv2
import numpy as np
import os
import matplotlib.pyplot as plt

# Configuration
INPUT_DIR = '/content/drive/MyDrive/Computer Vision Final/final_square'
OUTPUT_DIR = '/content/drive/MyDrive/Computer Vision Final/consolidated_grid'
os.makedirs(OUTPUT_DIR, exist_ok=True)


# -----------------------------
# LINE DETECTION (SAME AS STEP 2)
# -----------------------------
def detect_and_draw_grid_lines(image):
    """
    Apply line detection as in Step 2 and draw black lines on original image
    """
    # Convert to grayscale
    if len(image.shape) == 3:
        gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    else:
        gray = image.copy()

    # Gaussian blur
    blurred = cv2.GaussianBlur(gray, (5, 5), 0)

    # Adaptive threshold - inverted (lines are white on black)
    thresh = cv2.adaptiveThreshold(
        blurred, 255,
        cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
        cv2.THRESH_BINARY_INV,
        11, 2
    )

    # Extract horizontal lines
    horizontal = thresh.copy()
    h_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (40, 1))
    horizontal = cv2.erode(horizontal, h_kernel, iterations=1)
    horizontal = cv2.dilate(horizontal, h_kernel, iterations=1)

    # Extract vertical lines
    vertical = thresh.copy()
    v_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 40))
    vertical = cv2.erode(vertical, v_kernel, iterations=1)
    vertical = cv2.dilate(vertical, v_kernel, iterations=1)

    # Combine them to get grid mask
    grid_mask = cv2.addWeighted(horizontal, 0.5, vertical, 0.5, 0)

    # Hough Line Transform
    lines = cv2.HoughLinesP(
        grid_mask,
        rho=1,
        theta=np.pi/180,
        threshold=150,
        minLineLength=100,
        maxLineGap=10
    )

    # Draw black lines on original image
    result = image.copy()
    if len(result.shape) == 2:
        result = cv2.cvtColor(result, cv2.COLOR_GRAY2BGR)

    line_count = 0
    if lines is not None:
        for line in lines:
            x1, y1, x2, y2 = line[0]
            cv2.line(result, (x1, y1), (x2, y2), (0, 0, 0), 2)  # Black lines
            line_count += 1

    return result, line_count


# -----------------------------
# BATCH PROCESSING
# -----------------------------
def process_consolidated_grid(input_dir, output_dir):
    """
    Process all images: detect lines and draw black lines on original
    """
    input_files = sorted([f for f in os.listdir(input_dir) if f.endswith('.jpg')])

    if not input_files:
        print(f"No images found in {input_dir}")
        return

    print(f"Processing {len(input_files)} images for consolidated grid...")
    print("=" * 70)

    results = {}

    for idx, filename in enumerate(input_files, 1):
        print(f"\n[{idx}/{len(input_files)}] Processing: {filename}")

        # Load image
        img_path = os.path.join(input_dir, filename)
        image = cv2.imread(img_path)

        if image is None:
            print(f"  ❌ Failed to load image")
            continue

        # Detect and draw grid lines
        print(f"  → Detecting and drawing grid lines...")
        result, line_count = detect_and_draw_grid_lines(image)
        print(f"     Drew {line_count} line segments")

        # Save result
        base_name = os.path.splitext(filename)[0]
        output_path = os.path.join(output_dir, f"{base_name}_consolidated.jpg")
        cv2.imwrite(output_path, result)

        print(f"  ✓ Saved: {output_path}")

        results[filename] = {
            'line_count': line_count,
            'output': output_path
        }

    print("\n" + "=" * 70)
    print("✓ Consolidated grid processing complete!")
    return results



# -----------------------------
# EXECUTE PIPELINE
# -----------------------------
print("Starting Step 4: Consolidated Grid - Draw Black Lines on Original")
print("=" * 70)

# Process all images
results = process_consolidated_grid(INPUT_DIR, OUTPUT_DIR)


Starting Step 4: Consolidated Grid - Draw Black Lines on Original
Processing 17 images for consolidated grid...

[1/17] Processing: 01_square_gray.jpg
  → Detecting and drawing grid lines...
     Drew 61 line segments
  ✓ Saved: /content/drive/MyDrive/Computer Vision Final/consolidated_grid/01_square_gray_consolidated.jpg

[2/17] Processing: 02_square_gray.jpg
  → Detecting and drawing grid lines...
     Drew 62 line segments
  ✓ Saved: /content/drive/MyDrive/Computer Vision Final/consolidated_grid/02_square_gray_consolidated.jpg

[3/17] Processing: 03_square_gray.jpg
  → Detecting and drawing grid lines...
     Drew 58 line segments
  ✓ Saved: /content/drive/MyDrive/Computer Vision Final/consolidated_grid/03_square_gray_consolidated.jpg

[4/17] Processing: 04_square_gray.jpg
  → Detecting and drawing grid lines...
     Drew 41 line segments
  ✓ Saved: /content/drive/MyDrive/Computer Vision Final/consolidated_grid/04_square_gray_consolidated.jpg

[5/17] Processing: 05_square_gray.jpg
 

# Step 5: Complete Grid Lines
Read from Step 4 (consolidated_grid) and complete lines to form perfect grid

In [None]:
import cv2
import numpy as np
import os

# Preprocess a Sudoku cell: 1) Black digit on white background. 2) Remove tiny noise and ignore grid lines.
# 3. Keep components that are likely digits.

def preprocess_sudoku_cell(cell_path, min_area=50, max_area=1000):
    cell_img = cv2.imread(cell_path, cv2.IMREAD_GRAYSCALE)
    if cell_img is None:
        raise ValueError(f"Cannot read image: {cell_path}")

    # Threshold to make digits white and background black
    _, thresh = cv2.threshold(cell_img, 180, 255, cv2.THRESH_BINARY_INV)

    # Find connected components
    num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(thresh, connectivity=8)

    cleaned = np.zeros_like(thresh)

    # Keep components within reasonable size
    for i in range(1, num_labels):
        area = stats[i, cv2.CC_STAT_AREA]
        if min_area <= area <= max_area:
            cleaned[labels == i] = 255

    # Smooth / fill holes
    kernel = np.ones((2,2), np.uint8)
    cleaned = cv2.dilate(cleaned, kernel, iterations=1)

    # Invert back: black digit on white
    final_cell = cv2.bitwise_not(cleaned)

    return final_cell

# Process the images and save outputs in output directory
input_dir = '/content/drive/MyDrive/Computer Vision Final/consolidated_grid'
output_dir = '/content/drive/MyDrive/Computer Vision Final/sudoku_preprocessed_final'
os.makedirs(output_dir, exist_ok=True)

image_files = [f for f in os.listdir(input_dir) if f.endswith('_square_gray_consolidated.jpg')]

for filename in image_files:
    img_path = os.path.join(input_dir, filename)
    processed = preprocess_sudoku_cell(img_path)
    base = os.path.splitext(filename)[0]
    cv2.imwrite(os.path.join(output_dir, f"{base}_clean.png"), processed)
    print(f"✓ Processed: {filename}")

✓ Processed: 01_square_gray_consolidated.jpg
✓ Processed: 02_square_gray_consolidated.jpg
✓ Processed: 03_square_gray_consolidated.jpg
✓ Processed: 04_square_gray_consolidated.jpg
✓ Processed: 05_square_gray_consolidated.jpg
✓ Processed: 06_square_gray_consolidated.jpg
✓ Processed: 07_square_gray_consolidated.jpg
✓ Processed: 08_square_gray_consolidated.jpg
✓ Processed: 09_square_gray_consolidated.jpg
✓ Processed: 10_square_gray_consolidated.jpg
✓ Processed: 11_square_gray_consolidated.jpg
✓ Processed: 12_square_gray_consolidated.jpg
✓ Processed: 13_square_gray_consolidated.jpg
✓ Processed: 14_square_gray_consolidated.jpg
✓ Processed: 15 copy_square_gray_consolidated.jpg
✓ Processed: 15_square_gray_consolidated.jpg
✓ Processed: 16_square_gray_consolidated.jpg


# Step 5.5: Define Sudoku Solver

In [None]:
# -----------------------------
# ADDED: Sudoku solver (appended to preserve original file)
# -----------------------------
def find_empty(grid):
    for r in range(9):
        for c in range(9):
            if grid[r][c] == 0:
                return r, c
    return None

def valid(grid, r, c, val):
    # row
    if val in grid[r]:
        return False
    # col
    for i in range(9):
        if grid[i][c] == val:
            return False
    # box
    br = (r//3)*3
    bc = (c//3)*3
    for i in range(br, br+3):
        for j in range(bc, bc+3):
            if grid[i][j] == val:
                return False
    return True

def solve_sudoku(grid):
    pos = find_empty(grid)
    if not pos:
        return True
    r,c = pos
    for v in range(1,10):
        if valid(grid, r, c, v):
            grid[r][c] = v
            if solve_sudoku(grid):
                return True
            grid[r][c] = 0
    return False

def is_puzzle_legal(grid):
    """
    Quick validation: check for conflicts in rows, columns, and 3x3 boxes.
    Returns True if no conflicts found (puzzle is legal).
    """
    # Check rows
    for r in range(9):
        row_vals = [x for x in grid[r] if x != 0]
        if len(row_vals) != len(set(row_vals)):
            print(f"  ✗ Illegal puzzle: duplicate {row_vals[0]} in row {r}")
            return False

    # Check columns
    for c in range(9):
        col_vals = [grid[r][c] for r in range(9) if grid[r][c] != 0]
        if len(col_vals) != len(set(col_vals)):
            print(f"  ✗ Illegal puzzle: duplicate {col_vals[0]} in column {c}")
            return False

    # Check 3x3 boxes
    for box_row in range(3):
        for box_col in range(3):
            box_vals = []
            for i in range(3):
                for j in range(3):
                    val = grid[box_row*3 + i][box_col*3 + j]
                    if val != 0:
                        box_vals.append(val)
            if len(box_vals) != len(set(box_vals)):
                print(f"  ✗ Illegal puzzle: duplicate {box_vals[0]} in box ({box_row},{box_col})")
                return False

    return True

def try_solve_grid(sudoku, out_path=None):
    """
    sudoku: 2D Python list (9x9), 0 = empty.
    out_path: optional path to save JSON; if None, nothing is written.
    """
    # ensure proper shape and ints
    grid = [[int(x) for x in row] for row in sudoku]
    grid_copy = [row[:] for row in grid]

    # ---------- NEW: Legal check before solving ----------
    print("  → Checking puzzle legality...")
    if not is_puzzle_legal(grid_copy):
        print("  ✗ Skipping solver: puzzle has conflicts")
        return False, None

    num_empty = sum(1 for row in grid_copy for cell in row if cell == 0)
    print(f"  ✓ Puzzle is legal ({81-num_empty}/81 digits filled)")

    solvable = solve_sudoku(grid_copy)
    if solvable:
        if out_path is not None:
            import json, os
            os.makedirs(os.path.dirname(out_path), exist_ok=True)
            with open(out_path, "w") as f:
                json.dump(grid_copy, f, indent=2)
            print(f"Solver: solved and wrote to {out_path}")
        else:
            print("Solver: solved (no output path provided, not saved)")
        return True, grid_copy
    else:
        print("Solver: puzzle not solvable or incomplete detections.")
        return False, None


# Step 6: Draw Clean Sudoku Grid
Read from Step 5 output and create a clean 9x9 Sudoku grid visualization

In [None]:
import cv2
import numpy as np
import os
from matplotlib import pyplot as plt

# Complete Sudoku OCR with Cell Extraction and CV-Based Recognition


def load_templates_from_directory(templates_dir):
    """
    Load and preprocess templates with SAME pipeline as cell ROIs:
    1. Load template
    2. Convert to binary (threshold)
    3. Extract main component (ROI)
    4. Center the content
    5. Resize to standard size
    6. Keep as white background (will be inverted during matching)
    """
    templates = {}

    for digit in range(1, 10):
        template_path = os.path.join(templates_dir, f"tem_{digit}_centered_resized.png")

        if os.path.exists(template_path):
            template = cv2.imread(template_path, cv2.IMREAD_GRAYSCALE)

            if template is not None:
                # Apply SAME preprocessing as cell ROIs
                # 1. Threshold to binary (invert to get white digits on black background)
                _, binary = cv2.threshold(template, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)

                # 2. Clean
                kernel = np.ones((2, 2), np.uint8)
                cleaned = cv2.morphologyEx(binary, cv2.MORPH_OPEN, kernel, iterations=1)
                kernel2 = np.ones((3, 3), np.uint8)
                cleaned = cv2.morphologyEx(cleaned, cv2.MORPH_CLOSE, kernel2, iterations=1)

                # 3. Extract main component
                contours, _ = cv2.findContours(cleaned, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
                if contours:
                    main_contour = max(contours, key=cv2.contourArea)
                    mask = np.zeros_like(cleaned)
                    cv2.drawContours(mask, [main_contour], -1, 255, -1)

                    # Get bounding box with padding
                    x, y, w, h = cv2.boundingRect(main_contour)
                    pad = max(2, min(w, h) // 10)
                    x = max(0, x - pad)
                    y = max(0, y - pad)
                    w = min(mask.shape[1] - x, w + 2*pad)
                    h = min(mask.shape[0] - y, h + 2*pad)
                    roi = mask[y:y+h, x:x+w]

                    # 4. Center the content
                    coords = cv2.findNonZero(roi)
                    if coords is not None:
                        rx, ry, rw, rh = cv2.boundingRect(coords)
                        padding = max(2, min(rw, rh) // 8)
                        x_padded = max(0, rx - padding)
                        y_padded = max(0, ry - padding)
                        w_padded = min(roi.shape[1] - x_padded, rw + 2 * padding)
                        h_padded = min(roi.shape[0] - y_padded, rh + 2 * padding)

                        roi_h, roi_w = roi.shape
                        content_center_x = x_padded + w_padded // 2
                        content_center_y = y_padded + h_padded // 2
                        roi_center_x = roi_w // 2
                        roi_center_y = roi_h // 2

                        shift_x = roi_center_x - content_center_x
                        shift_y = roi_center_y - content_center_y

                        translation_matrix = np.float32([[1, 0, shift_x], [0, 1, shift_y]])
                        centered_roi = cv2.warpAffine(roi, translation_matrix, (roi_w, roi_h),
                                                       borderMode=cv2.BORDER_CONSTANT,
                                                       borderValue=0)

                        # 5. Resize to standard template size (200x200)
                        resized = cv2.resize(centered_roi, (200, 200), interpolation=cv2.INTER_CUBIC)

                        # Store as-is (white digits on black background, will be inverted during matching)
                        templates[digit] = resized
                        print(f"  ✓ Loaded & preprocessed template {digit}")
                    else:
                        print(f"  ✗ Failed to find content in template {digit}")
                else:
                    print(f"  ✗ No contours found in template {digit}")
            else:
                print(f"  ✗ Failed to load template {digit}")
        else:
            print(f"  ✗ Template file not found: {template_path}")

    return templates


def extract_and_save_cells(image, output_cells_dir, image_name):
    """
    Extract all 81 cells, enlarge them 2x, and save as individual images
    Uses smaller margins to ensure full digit capture before centering
    Returns: dictionary of enlarged cell images by position
    """
    os.makedirs(output_cells_dir, exist_ok=True)

    h, w = image.shape[:2]
    cell_h = h // 9
    cell_w = w // 9

    cells = {}

    # Use very small margins to capture maximum cell content
    # This ensures digits near edges are fully included
    margin_h = max(2, cell_h // 35)  # Even smaller: 1/35 of cell size
    margin_w = max(2, cell_w // 35)  # Even smaller: 1/35 of cell size

    for row in range(9):
        for col in range(9):
            y1 = row * cell_h + margin_h
            y2 = (row + 1) * cell_h - margin_h
            x1 = col * cell_w + margin_w
            x2 = (col + 1) * cell_w - margin_w

            cell = image[y1:y2, x1:x2]

            # Enlarge cell by 2x for better recognition
            enlarged_cell = cv2.resize(cell, None, fx=4.0, fy=4.0, interpolation=cv2.INTER_CUBIC)

            # Save enlarged cell image
            cell_filename = f"cell_r{row}_c{col}.png"
            cell_path = os.path.join(output_cells_dir, cell_filename)
            cv2.imwrite(cell_path, enlarged_cell)

            cells[(row, col)] = enlarged_cell

    print(f"  → Saved 81 cells to: {output_cells_dir}")
    return cells


def preprocess_for_digit_detection(cell):
    """
    Preprocess cell for digit detection
    """
    if cell is None or cell.size == 0:
        return None, None

    # Convert to grayscale
    if len(cell.shape) == 3:
        gray = cv2.cvtColor(cell, cv2.COLOR_BGR2GRAY)
    else:
        gray = cell.copy()

    # Denoise
    denoised = cv2.fastNlMeansDenoising(gray, None, h=10, templateWindowSize=7, searchWindowSize=21)

    # CLAHE for contrast
    clahe = cv2.createCLAHE(clipLimit=2.5, tileGridSize=(4, 4))
    enhanced = clahe.apply(denoised)

    # Try multiple thresholding methods
    binary_candidates = []

    # Otsu
    _, bin1 = cv2.threshold(enhanced, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
    binary_candidates.append(bin1)

    # Adaptive Gaussian
    bin2 = cv2.adaptiveThreshold(enhanced, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
                                  cv2.THRESH_BINARY_INV, 13, 3)
    binary_candidates.append(bin2)

    # Adaptive Mean
    bin3 = cv2.adaptiveThreshold(enhanced, 255, cv2.ADAPTIVE_THRESH_MEAN_C,
                                  cv2.THRESH_BINARY_INV, 13, 3)
    binary_candidates.append(bin3)

    # Select best based on density
    best_binary = None
    best_score = float('inf')

    for binary in binary_candidates:
        density = np.sum(binary > 0) / binary.size
        if 0.06 < density < 0.40:
            score = abs(density - 0.18)
            if score < best_score:
                best_score = score
                best_binary = binary

    if best_binary is None:
        best_binary = binary_candidates[0]

    return enhanced, best_binary


def clean_binary(binary):
    """
    Clean binary image with minimal border removal since we use smaller margins
    """
    # Remove noise
    kernel = np.ones((2, 2), np.uint8)
    cleaned = cv2.morphologyEx(binary, cv2.MORPH_OPEN, kernel, iterations=1)

    # Fill gaps
    kernel2 = np.ones((3, 3), np.uint8)
    cleaned = cv2.morphologyEx(cleaned, cv2.MORPH_CLOSE, kernel2, iterations=1)

    # Clear borders very conservatively to preserve edge digits
    h, w = cleaned.shape
    border = max(1, min(h, w) // 40)  # Very minimal: 1/40 of cell size
    cleaned[:border, :] = 0
    cleaned[-border:, :] = 0
    cleaned[:, :border] = 0
    cleaned[:, -border:] = 0

    return cleaned


def is_cell_empty(binary):
    """
    Check if cell is empty
    """
    density = np.sum(binary > 0) / binary.size

    if density < 0.015 or density > 0.65:
        return True

    # Check connected components
    num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(binary, connectivity=8)

    if num_labels <= 1:
        return True

    areas = stats[1:, cv2.CC_STAT_AREA]
    if len(areas) == 0:
        return True

    largest = np.max(areas)
    if largest < binary.size * 0.008 or largest > binary.size * 0.90:
        return True

    return False


def find_content_bounds(binary):
    """
    Find the bounding box of content in the cell
    Returns: (x, y, width, height) or None if no content
    """
    # Find all non-zero pixels
    coords = cv2.findNonZero(binary)

    if coords is None or len(coords) == 0:
        return None

    # Get bounding rectangle
    x, y, w, h = cv2.boundingRect(coords)
    return (x, y, w, h)


def center_content(cell, binary):
    """
    Center the content in the cell by finding content bounds and creating
    a centered version with safety padding. Returns centered cell (grayscale).

    If cell is empty, returns None.
    """
    if cell is None or cell.size == 0:
        return None

    # Check if empty
    if is_cell_empty(binary):
        return None

    # Find content bounds
    bounds = find_content_bounds(binary)
    if bounds is None:
        return None

    x, y, w, h = bounds

    # Get cell dimensions
    cell_h, cell_w = cell.shape[:2]

    # Add safety padding to content bounds to ensure nothing is cut off
    padding = max(2, min(w, h) // 8)
    x_padded = max(0, x - padding)
    y_padded = max(0, y - padding)
    w_padded = min(cell_w - x_padded, w + 2 * padding)
    h_padded = min(cell_h - y_padded, h + 2 * padding)

    # Calculate shifts needed to center the padded content
    content_center_x = x_padded + w_padded // 2
    content_center_y = y_padded + h_padded // 2
    cell_center_x = cell_w // 2
    cell_center_y = cell_h // 2

    shift_x = cell_center_x - content_center_x
    shift_y = cell_center_y - content_center_y

    # Create translation matrix
    translation_matrix = np.float32([[1, 0, shift_x], [0, 1, shift_y]])

    # Apply translation to center the content
    if len(cell.shape) == 3:
        gray = cv2.cvtColor(cell, cv2.COLOR_BGR2GRAY)
    else:
        gray = cell.copy()

    centered = cv2.warpAffine(gray, translation_matrix, (cell_w, cell_h),
                              borderMode=cv2.BORDER_CONSTANT,
                              borderValue=255)

    return centered


def extract_digit_component(binary):
    """
    Extract the main digit component
    """
    contours, hierarchy = cv2.findContours(binary, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)

    if not contours:
        return None, 0

    # Find valid contours
    valid_contours = []
    cell_area = binary.size

    for i, cnt in enumerate(contours):
        area = cv2.contourArea(cnt)
        if cell_area * 0.01 < area < cell_area * 0.85:
            x, y, w, h = cv2.boundingRect(cnt)
            ar = float(w) / h if h > 0 else 0
            if w > 4 and h > 6 and 0.1 < ar < 2.0:
                valid_contours.append((cnt, area, i))

    if not valid_contours:
        return None, 0

    # Get largest valid contour
    main_contour, _, main_idx = max(valid_contours, key=lambda x: x[1])

    # Create mask
    mask = np.zeros_like(binary)
    cv2.drawContours(mask, [main_contour], -1, 255, -1)

    # Get bounding box
    x, y, w, h = cv2.boundingRect(main_contour)
    pad = max(2, min(w, h) // 10)
    x = max(0, x - pad)
    y = max(0, y - pad)
    w = min(mask.shape[1] - x, w + 2*pad)
    h = min(mask.shape[0] - y, h + 2*pad)

    roi = mask[y:y+h, x:x+w]

    # Count holes
    holes = 0
    if hierarchy is not None:
        for i, h_info in enumerate(hierarchy[0]):
            if h_info[3] == main_idx:
                hole_area = cv2.contourArea(contours[i])
                if hole_area > cell_area * 0.002:
                    holes += 1

    return roi, holes


def compute_digit_features(roi):
    """
    Extract features from digit ROI
    """
    if roi is None or roi.size == 0:
        return None

    h, w = roi.shape
    aspect_ratio = float(w) / h if h > 0 else 0

    # Divide into regions
    h_third = max(1, h // 3)
    w_third = max(1, w // 3)

    top = roi[:h_third, :]
    middle = roi[h_third:2*h_third, :]
    bottom = roi[2*h_third:, :]
    left = roi[:, :w_third]
    center = roi[:, w_third:2*w_third]
    right = roi[:, 2*w_third:]

    def density(region):
        return np.sum(region > 0) / region.size if region.size > 0 else 0

    return {
        'ar': aspect_ratio,
        'top': density(top),
        'mid': density(middle),
        'bot': density(bottom),
        'left': density(left),
        'center': density(center),
        'right': density(right)
    }


def match_template(roi, templates):
    """
    Match ROI with templates - centers content before comparison
    """
    if roi is None or roi.size == 0:
        return {}

    # Center the ROI content before resizing to match template format
    # Find content bounds in ROI
    coords = cv2.findNonZero(roi)
    if coords is None or len(coords) == 0:
        return {}

    x, y, w, h = cv2.boundingRect(coords)

    # Extract content with padding
    padding = max(2, min(w, h) // 8)
    x_padded = max(0, x - padding)
    y_padded = max(0, y - padding)
    w_padded = min(roi.shape[1] - x_padded, w + 2 * padding)
    h_padded = min(roi.shape[0] - y_padded, h + 2 * padding)

    # Create centered version
    roi_h, roi_w = roi.shape
    content_center_x = x_padded + w_padded // 2
    content_center_y = y_padded + h_padded // 2
    roi_center_x = roi_w // 2
    roi_center_y = roi_h // 2

    shift_x = roi_center_x - content_center_x
    shift_y = roi_center_y - content_center_y

    # Create translation matrix
    translation_matrix = np.float32([[1, 0, shift_x], [0, 1, shift_y]])

    # Apply translation to center the content
    centered_roi = cv2.warpAffine(roi, translation_matrix, (roi_w, roi_h),
                                   borderMode=cv2.BORDER_CONSTANT,
                                   borderValue=0)

    # Now resize the centered ROI to match template size
    # Get template size from first template
    template_size = list(templates.values())[0].shape
    resized = cv2.resize(centered_roi, (template_size[1], template_size[0]),
                        interpolation=cv2.INTER_CUBIC)

    # Both ROI and templates are now preprocessed identically:
    # - Binary (white digits on black background)
    # - Centered
    # - Same size
    # NO NEED to invert since both have the same format!

    scores = {}
    for digit, template in templates.items():
        result = cv2.matchTemplate(resized, template, cv2.TM_CCOEFF_NORMED)
        scores[digit] = result[0][0]

    return scores


def classify_by_rules(features, holes):
    """
    Rule-based classification
    """
    if features is None:
        return 0, 0.0

    ar = features['ar']
    top = features['top']
    mid = features['mid']
    bot = features['bot']
    left = features['left']
    right = features['right']

    # Holes give strong clues
    if holes == 2:
        return 8, 0.98

    elif holes == 1:
        if ar < 0.55:  # Narrow
            if top > bot * 1.1:
                return 9, 0.92
            else:
                return 6, 0.92
        else:  # Wider
            if mid > (top + bot) / 2 * 1.1:
                return 4, 0.88
            else:
                return 0, 0.85

    # No holes - improved rules for 5, 6, 7, 9
    else:
        # Check for 7 first - it has very distinctive top-heavy distribution
        if top > bot * 1.8:  # Strong top-heavy pattern
            return 7, 0.92

        if ar < 0.35:
            return 1, 0.95

        elif ar < 0.60:
            if top > bot * 1.3:  # Relaxed threshold for 7
                return 7, 0.87
            else:
                return 1, 0.85

        else:  # Wider digits (ar >= 0.60)
            # Check for 7 with looser AR constraint
            if top > bot * 1.5 and mid < (top + bot) / 2:
                return 7, 0.85
            # Check for 5 - top heavy, more left-heavy
            elif top > bot * 1.1 and left > right * 1.1:
                return 5, 0.82
            # Check for 6 - should have been caught by holes, but backup
            elif bot > top * 1.1 and left > right:
                return 6, 0.78
            # Check for 3 - right heavy
            elif right > left * 1.15:
                return 3, 0.78
            # Check for 2 - bottom heavy
            elif bot > top * 1.2:
                return 2, 0.78
            else:
                return 3, 0.65


def recognize_digit_cv(cell, templates):
    """
    Recognize digit using CV algorithms only
    Improved fusion between template matching and rule-based classification
    """
    if cell is None or cell.size == 0:
        return 0

    # Initial preprocess to check if empty
    enhanced, binary = preprocess_for_digit_detection(cell)
    if binary is None:
        return 0

    # Clean
    binary = clean_binary(binary)

    # Check if empty
    if is_cell_empty(binary):
        return 0

    # Center the content in the cell (only if not empty)
    centered_cell = center_content(cell, binary)
    if centered_cell is None:
        return 0

    # Re-preprocess the centered cell
    enhanced, binary = preprocess_for_digit_detection(centered_cell)
    if binary is None:
        return 0

    # Clean again
    binary = clean_binary(binary)

    # Extract digit component from centered image
    roi, holes = extract_digit_component(binary)
    if roi is None:
        return 0

    # Get features
    features = compute_digit_features(roi)

    # Template matching
    template_scores = match_template(roi, templates)

    # Rule-based classification
    rule_digit, rule_conf = classify_by_rules(features, holes)

    # Improved fusion logic - prioritize template matching more strongly
    if template_scores:
        # Get top 3 template scores
        sorted_scores = sorted(template_scores.items(), key=lambda x: x[1], reverse=True)
        best_template = sorted_scores[0][0]
        best_score = sorted_scores[0][1]
        second_best = sorted_scores[1] if len(sorted_scores) > 1 else (0, 0)
        third_best = sorted_scores[2] if len(sorted_scores) > 2 else (0, 0)

        # Very strong template match - almost always trust it
        if best_score > 0.60:
            return best_template

        # Strong template match
        elif best_score > 0.50:
            # If much better than second best, trust it
            if best_score - second_best[1] > 0.10:
                return best_template
            # If close to second best but agrees with rules, trust it
            elif best_template == rule_digit:
                return best_template
            # Otherwise still prefer template unless rules are very confident
            elif rule_conf > 0.90:
                return rule_digit
            else:
                return best_template

        # Good template match
        elif best_score > 0.40:
            # Clear winner and decent confidence
            if best_score - second_best[1] > 0.12:
                return best_template
            # Agreement with rules
            elif best_template == rule_digit:
                return best_template
            # Check if second or third best agrees with rules
            elif second_best[0] == rule_digit and second_best[1] > 0.35:
                return rule_digit
            elif third_best[0] == rule_digit and third_best[1] > 0.32 and rule_conf > 0.85:
                return rule_digit
            # Template score is better than rule confidence
            elif best_score > rule_conf + 0.10:
                return best_template
            # High rule confidence
            elif rule_conf > 0.85:
                return rule_digit
            else:
                return best_template

        # Moderate template match
        elif best_score > 0.30:
            # High rule confidence, go with rules
            if rule_conf > 0.80:
                return rule_digit
            # Check if any top 3 templates agree with rules
            elif rule_digit in [sorted_scores[i][0] for i in range(min(3, len(sorted_scores)))]:
                # Find the matching score
                for digit, score in sorted_scores[:3]:
                    if digit == rule_digit:
                        # If it's reasonably high, trust the agreement
                        if score > 0.25:
                            return rule_digit
                        break
                return best_template
            # Template is clearly better
            elif best_score > rule_conf + 0.15:
                return best_template
            else:
                return rule_digit

        # Weak template match - trust rules if confident
        else:
            if rule_conf > 0.70:
                return rule_digit
            elif best_score > 0.20:
                return best_template
            else:
                return rule_digit

    return rule_digit


def save_templates_as_images(templates, output_dir):
    """
    Save all digit templates as images for visualization
    """
    os.makedirs(output_dir, exist_ok=True)

    for digit, template in templates.items():
        # Enlarge template for better visibility
        enlarged = cv2.resize(template, (112, 112), interpolation=cv2.INTER_NEAREST)
        template_path = os.path.join(output_dir, f"template_{digit}.png")
        cv2.imwrite(template_path, enlarged)

    print(f"  ✓ Saved {len(templates)} templates to: {output_dir}")


def process_sudoku_image(image_path, output_base_dir, image_index, templates):
    """
    Process a single Sudoku image:
    1. Extract and save all 81 cells
    2. Recognize digits using CV
    3. Create and save the Sudoku array
    """
    # Create output directories
    image_name = os.path.splitext(os.path.basename(image_path))[0]
    image_output_dir = os.path.join(output_base_dir, f"image_{image_index:02d}_{image_name}")
    cells_dir = os.path.join(image_output_dir, "cells")

    os.makedirs(image_output_dir, exist_ok=True)
    os.makedirs(cells_dir, exist_ok=True)

    # Load image
    image = cv2.imread(image_path)
    if image is None:
        print(f"  ✗ Failed to load image")
        return None

    print(f"\n[Image {image_index}] {image_name}")

    # Extract and save all cells
    cells = extract_and_save_cells(image, cells_dir, image_name)

    # Recognize digits
    print(f"  → Recognizing digits using CV algorithms...")
    sudoku_array = np.zeros((9, 9), dtype=int)

    for row in range(9):
        for col in range(9):
            cell = cells[(row, col)]
            digit = recognize_digit_cv(cell, templates)
            sudoku_array[row, col] = digit

    num_digits = np.sum(sudoku_array != 0)
    print(f"  → Detected {num_digits} digits")

    # Print array
    print(f"  → Sudoku Grid:")
    for row in sudoku_array:
        print(f"     {' '.join(str(d) if d != 0 else '.' for d in row)}")

    # Save array as text file
    array_path = os.path.join(image_output_dir, "sudoku_array.txt")
    with open(array_path, 'w') as f:
        f.write(f"Sudoku Grid from: {image_name}\n")
        f.write(f"Detected {num_digits} digits\n")
        f.write("=" * 50 + "\n\n")
        f.write("Grid (. = empty):\n")
        for row in sudoku_array:
            f.write(' '.join(str(d) if d != 0 else '.' for d in row) + '\n')
        f.write("\n" + "=" * 50 + "\n\n")
        f.write("Python array:\n")
        f.write(str(sudoku_array.tolist()) + "\n")

    print(f"  ✓ Saved array: {array_path}")

    # Draw and save grid with lines
    h, w = image.shape[:2]
    cell_size = 60
    size = cell_size * 9

    if len(image.shape) == 2:
        resized = cv2.resize(image, (size, size))
        grid_img = cv2.cvtColor(resized, cv2.COLOR_GRAY2BGR)
    else:
        grid_img = cv2.resize(image, (size, size))

    # Draw grid lines
    for i in range(10):
        thickness = 3 if i % 3 == 0 else 1
        color = (0, 0, 0) if i % 3 == 0 else (100, 100, 100)
        y = i * cell_size
        cv2.line(grid_img, (0, y), (size, y), color, thickness)
        x = i * cell_size
        cv2.line(grid_img, (x, 0), (x, size), color, thickness)

    grid_path = os.path.join(image_output_dir, "sudoku_grid.jpg")
    cv2.imwrite(grid_path, grid_img)
    print(f"  ✓ Saved grid: {grid_path}")

    # Create empty grid template
    empty_grid = np.ones((size, size, 3), dtype=np.uint8) * 255
    for i in range(10):
        thickness = 3 if i % 3 == 0 else 1
        color = (0, 0, 0) if i % 3 == 0 else (150, 150, 150)
        y = i * cell_size
        cv2.line(empty_grid, (0, y), (size, y), color, thickness)
        x = i * cell_size
        cv2.line(empty_grid, (x, 0), (x, size), color, thickness)

    empty_path = os.path.join(image_output_dir, "empty_grid.jpg")
    cv2.imwrite(empty_path, empty_grid)
    print(f"  ✓ Saved empty grid: {empty_path}")

    print("  → Solving detected Sudoku...")
    solved_json_path = os.path.join(image_output_dir, "sudoku_solved.json")
    solved_ok, solved_grid = try_solve_grid(sudoku_array.tolist(), solved_json_path)
    if solved_ok:
        print("  ✓ Sudoku solved during pipeline")
    else:
        print("  ✗ Sudoku could not be solved")

    return {
        'image_name': image_name,
        'output_dir': image_output_dir,
        'cells_dir': cells_dir,
        'array': sudoku_array,
        'num_digits': num_digits,
        'array_path': array_path,
        'grid_path': grid_path
    }


def process_all_sudoku_images(input_dir, output_dir, templates):
    """
    Process all Sudoku images in the input directory
    """

    # Get all images
    image_files = sorted([f for f in os.listdir(input_dir) if f.endswith('_clean.png')])

    if not image_files:
        print(f"No images found in {input_dir}")
        return []

    print("=" * 70)
    print("SUDOKU OCR WITH CELL EXTRACTION")
    print("=" * 70)
    print(f"Input: {input_dir}")
    print(f"Output: {output_dir}")
    print(f"Found {len(image_files)} images to process")
    print("=" * 70)

    results = []

    for idx, filename in enumerate(image_files, 1):
        image_path = os.path.join(input_dir, filename)
        result = process_sudoku_image(image_path, output_dir, idx, templates)

        if result:
            results.append(result)

    print("\n" + "=" * 70)
    print(f"✓ PROCESSING COMPLETE")
    print(f"  Processed: {len(results)}/{len(image_files)} images")
    print(f"  Output directory: {output_dir}")
    print("=" * 70)

    # Print summary
    print("\nSUMMARY OF ALL GRIDS:")
    print("=" * 70)
    for idx, result in enumerate(results, 1):
        print(f"\n[{idx}] {result['image_name']}")
        print(f"    Digits detected: {result['num_digits']}/81")
        print(f"    Output: {result['output_dir']}")
        print(f"    Cells: {result['cells_dir']}")
        print(f"    Grid:")
        for row in result['array']:
            print(f"      {' '.join(str(d) if d != 0 else '.' for d in row)}")

    return results


# Load templates from template_centered_resized directory
templates_dir = '/content/drive/MyDrive/Computer Vision Final/templates_centered_resized'

print("=" * 70)
print("Loading templates from template_centered_resized directory...")
print("=" * 70)

templates = load_templates_from_directory(templates_dir)

if not templates:
    print("\n✗ No templates loaded! Make sure to run the template creation cell first.")
else:
    print(f"\n✓ Successfully loaded {len(templates)} templates")
    print("=" * 70)


    # Execute the complete processing
    input_dir = '/content/drive/MyDrive/Computer Vision Final/sudoku_preprocessed_final'
    output_dir = '/content/drive/MyDrive/Computer Vision Final/sudoku_ocr_results'

    results = process_all_sudoku_images(input_dir, output_dir, templates)

Loading templates from template_centered_resized directory...
  ✓ Loaded & preprocessed template 1
  ✓ Loaded & preprocessed template 2
  ✓ Loaded & preprocessed template 3
  ✓ Loaded & preprocessed template 4
  ✓ Loaded & preprocessed template 5
  ✓ Loaded & preprocessed template 6
  ✓ Loaded & preprocessed template 7
  ✓ Loaded & preprocessed template 8
  ✓ Loaded & preprocessed template 9

✓ Successfully loaded 9 templates
SUDOKU OCR WITH CELL EXTRACTION
Input: /content/drive/MyDrive/Computer Vision Final/sudoku_preprocessed_final
Output: /content/drive/MyDrive/Computer Vision Final/sudoku_ocr_results
Found 17 images to process

[Image 1] 01_square_gray_consolidated_clean
  → Saved 81 cells to: /content/drive/MyDrive/Computer Vision Final/sudoku_ocr_results/image_01_01_square_gray_consolidated_clean/cells
  → Recognizing digits using CV algorithms...
  → Detected 26 digits
  → Sudoku Grid:
     8 . . . . . . 4 .
     . . 3 6 . . . . .
     . 7 . . 9 . 2 . 3
     . 5 . . . 7 . . .
  