In [27]:
from dataclasses import dataclass

import cv2
import numpy as np
from PIL import Image
from paddleocr import PaddleOCR


In [28]:
dir = './pictures'
fps = [
    # 'demo.png',
    'paper.jpg',
    # 'board.png',
    # 'sudoku.png',
    # 'sudoku_num.png',
    # 'sudoku_num_dark.png',
    # 'sudoku_big_dark.png'
]

In [29]:
Pixels = int

@dataclass(frozen=True)
class SudokuBounds:
    top_left_x: Pixels
    top_left_y: Pixels
    width: Pixels
    cell_width: Pixels


@dataclass(frozen=True)
class CellCoordinates:
    x0: Pixels
    y0: Pixels
    x1: Pixels
    y1: Pixels

In [30]:
def show(*args):
    for i, j in enumerate(args):
        cv2.imshow(str(i), j)
    cv2.waitKey(0)
    cv2.destroyAllWindows()

In [31]:
def transform(pts, img):  # TODO: Spline transform, remove this
    pts = np.float32(pts)
    top_l, top_r, bot_l, bot_r = pts[0], pts[1], pts[2], pts[3]

    def pythagoras(pt1, pt2):
        return np.sqrt((pt2[0] - pt1[0]) ** 2 + (pt2[1] - pt1[1]) ** 2)

    width = int(max(pythagoras(bot_r, bot_l), pythagoras(top_r, top_l)))
    height = int(max(pythagoras(top_r, bot_r), pythagoras(top_l, bot_l)))
    square = max(width, height) // 9 * 9  # Making the image dimensions divisible by 9

    dim = np.array(([0, 0], [square - 1, 0], [square - 1, square - 1], [0, square - 1]), dtype='float32')
    matrix = cv2.getPerspectiveTransform(pts, dim)
    warped = cv2.warpPerspective(img, matrix, (square, square))
    show(warped)
    return warped

In [32]:
def find_sudoku_bounds(image):
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    thresh = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 11, 2)
    
    thresh = cv2.bitwise_not(thresh)
    
    kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (2, 2))
    dilated = cv2.dilate(thresh, kernel, iterations=1)
    
    # edges = cv2.Canny(dilated, 50, 100)
    contours, hierarchy = cv2.findContours(dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    
    # contours_hierarchy = list(filter(lambda c_h: (print(c_h[1]), c_h[1][3] == -1), zip(contours, hierarchy[0])))
    # contours = [item[0] for item in contours_hierarchy]
    contours = sorted(contours, key=cv2.contourArea, reverse=True)[:9]
    
    if len(contours) > 1:  # TODO：分写处理文件
        contour_area0 = cv2.contourArea(contours[0])
        contour_area1 = cv2.contourArea(contours[1])
        if contour_area1 < 0.25 * contour_area0:
            contours = [contours[0]]
            largest_contour = np.squeeze(contours[0])

            sums = [sum(i) for i in largest_contour]
            differences = [i[0] - i[1] for i in largest_contour]
        
            top_left = np.argmin(sums)
            top_right = np.argmax(differences)
            bottom_left = np.argmax(sums)
            bottom_right = np.argmin(differences)
        
            corners = [largest_contour[top_left], largest_contour[top_right], largest_contour[bottom_left],
                       largest_contour[bottom_right]]
            return transform(corners, image)

    color = (0, 0, 255)  # (B, G, R)
    
    # 绘制轮廓
    # contours_n = len(contours)
    for i, contour in enumerate(contours):
        res = cv2.drawContours(image.copy(), [contour], -1, color, 3)
        # cv2.imshow(f'Contour {i+1}/{contours_n}', res)
        # cv2.waitKey(0)

    min_x = min_y = 10000
    max_x = max_y = 0
    
    for contour in contours:
        perimeter = cv2.arcLength(contour, True)
        approx = cv2.approxPolyDP(contour, 0.005 * perimeter, True)
        
        for a in approx:
            x, y = a[0][0], a[0][1]
            if x < min_x:
                min_x = x
            elif x > max_x:
                max_x = x
                
            if y < min_y:
                min_y = y
            elif y > max_y:
                max_y = y
                

    top_left_x = min_x
    top_left_y = min_y
    square_width = max(max_x-min_x, max_y-min_y)
    
    cropped_image = np.array(image)[top_left_y:top_left_y + square_width, top_left_x:top_left_x + square_width]
    return cropped_image
    

In [33]:
def crop_to_sudoku(screenshot):
    handle_img = find_sudoku_bounds(screenshot)
    gray = cv2.cvtColor(handle_img, cv2.COLOR_BGR2GRAY)
    thresh = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 11, 2)
    
    show(handle_img, gray, thresh)
    
    return Image.fromarray(thresh)

In [34]:
def get_sudoku_cells_coordinates(cell_width):
    border_offset: Pixels = 5
    cells = []

    for row in range(9):
        for col in range(9):
            x0 = col * cell_width + border_offset
            y0 = row * cell_width + border_offset
            x1 = x0 + cell_width - border_offset
            y1 = y0 + cell_width - border_offset
            cells.append(CellCoordinates(x0, y0, x1, y1))

    return cells

In [35]:
OCR = PaddleOCR(lang='en', use_gpu=True)

[2023/12/18 13:27:38] ppocr DEBUG: Namespace(alpha=1.0, alphacolor=(255, 255, 255), benchmark=False, beta=1.0, binarize=False, cls_batch_num=6, cls_image_shape='3, 48, 192', cls_model_dir='C:\\Users\\Fadegentle/.paddleocr/whl\\cls\\ch_ppocr_mobile_v2.0_cls_infer', cls_thresh=0.9, cpu_threads=10, crop_res_save_dir='./output', det=True, det_algorithm='DB', det_box_type='quad', det_db_box_thresh=0.6, det_db_score_mode='fast', det_db_thresh=0.3, det_db_unclip_ratio=1.5, det_east_cover_thresh=0.1, det_east_nms_thresh=0.2, det_east_score_thresh=0.8, det_limit_side_len=960, det_limit_type='max', det_model_dir='C:\\Users\\Fadegentle/.paddleocr/whl\\det\\en\\en_PP-OCRv3_det_infer', det_pse_box_thresh=0.85, det_pse_min_area=16, det_pse_scale=1, det_pse_thresh=0, det_sast_nms_thresh=0.2, det_sast_score_thresh=0.5, draw_img_save_dir='./inference_results', drop_score=0.5, e2e_algorithm='PGNet', e2e_char_dict_path='./ppocr/utils/ic15_dict.txt', e2e_limit_side_len=768, e2e_limit_type='max', e2e_model

In [36]:
def recognize_digit(cell_image):
    recognition_data = OCR.ocr(np.array(cell_image), det=False, rec=True, cls=False)
    char = (recognition_data[0][0][0] if recognition_data else '').strip().replace(':', '8')
    
    if char.isdigit():
        print(recognition_data, char)
    else:
        return ''

    return char
    
    

In [37]:
def recognize_sudoku(path):
    # full_image = capture_full_screen()
    image = cv2.imread(path)
    img = crop_to_sudoku(image)
    cell_width = img.size[0] // 9
    cells_coordinates = get_sudoku_cells_coordinates(cell_width)

    digits = []
    for cell in cells_coordinates:
        cell_img = img.crop((cell.x0, cell.y0, cell.x1, cell.y1))
        digit = recognize_digit(cell_img)
        digits.append(digit)

    grid = []
    for i in range(0, len(digits), 9):
        grid.append(digits[i:i + 9])

    return grid

In [38]:
def prettier(puzzle):
    print('*-----------*')
    for r, i in enumerate(puzzle):
        if r in [3, 6, 9]:
            print('|---+---+---|')
        for c, j in enumerate(i):
            if c in [0, 3, 6]:
                print('|', end='')
            print(j if j else '.', end='')
        print('|')
    print('*-----------*')

In [39]:
for fp in fps:
    puzzle = recognize_sudoku(dir+fp)
    print(f'{fp} 识别结果为：')
    # print(puzzle)
    prettier(puzzle)


[[('6', 0.9410324692726135)]] 6
[[('4', 0.9235805869102478)]] 4
[[('7', 0.6330481171607971)]] 7
[[('7', 0.9717089533805847)]] 7
[[('6', 0.8688789010047913)]] 6
[[('9', 0.9238210320472717)]] 9
[[('5', 0.96817946434021)]] 5
[[('8', 0.908908486366272)]] 8
[[('7', 0.9093402028083801)]] 7
[[('2', 0.9338260293006897)]] 2
[[('9', 0.8680936098098755)]] 9
[[('3', 0.9880919456481934)]] 3
[[('8', 0.9728391170501709)]] 8
[[('5', 0.9840741157531738)]] 5
[[('4', 0.9827737808227539)]] 4
[[('3', 0.9924730062484741)]] 3
[[('1', 0.5723848342895508)]] 1
[[('7', 0.8570005893707275)]] 7
[[('5', 0.9940098524093628)]] 5
[[('2', 0.9972352385520935)]] 2
[[('3', 0.9947984218597412)]] 3
[[('2', 0.9907593727111816)]] 2
[[('8', 0.9599323868751526)]] 8
[[('2', 0.9890355467796326)]] 2
[[('3', 0.9817011952400208)]] 3
[[('1', 0.36261889338493347)]] 1
paper.jpg 识别结果为：
*-----------*
|...|6.4|7..|
|7.6|...|..9|
|...|..5|.8.|
|---+---+---|
|.7.|.2.|.93|
|8..|...|..5|
|43.|.1.|.7.|
|---+---+---|
|.5.|2..|...|
|3..|...|2.8|