In [1]:
# -*- coding: utf-8 -*-
# Use convolutional neural network to classify tile image
import os

import cv2
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms

# CNN输出(int)与牌名(str)的对应关系
classes = {
    0: '1m',
    1: '2m',
    2: '3m',
    3: '4m',
    4: '5m',
    5: '6m',
    6: '7m',
    7: '8m',
    8: '9m',
    9: '1p',
    10: '2p',
    11: '3p',
    12: '4p',
    13: '5p',
    14: '6p',
    15: '7p',
    16: '8p',
    17: '9p',
    18: '1s',
    19: '2s',
    20: '3s',
    21: '4s',
    22: '5s',
    23: '6s',
    24: '7s',
    25: '8s',
    26: '9s',
    27: '1z',
    28: '2z',
    29: '3z',
    30: '4z',
    31: '5z',
    32: '6z',
    33: '7z',
    34: '0m',
    35: '0p',
    36: '0s',
    37: 'back'   # 牌背面
}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def CV2PIL(img):
    return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))


transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])


class TileNet(nn.Module):
    def __init__(self):
        super(TileNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 10, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(10, 26, 5)
        self.fc1 = nn.Linear(26 * 5 * 5, 300)
        self.fc2 = nn.Linear(300, 124)
        self.fc3 = nn.Linear(124, 38)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 26 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


class Classify:

    def __init__(self):
        self.model = model = TileNet()
        path = os.path.join(os.path.dirname('__file__'), '../ModelTrain/recogition/tile.model')
        # 如果模型是在 GPU 上训练的，但在 CPU 上运行，需要映射到 CPU
        self.map_location = device
        self.model.load_state_dict(torch.load(path, map_location=self.map_location))
        # self.model.load_state_dict(torch.load(path))
        self.model.to(device)
        self.__call__(np.ones((32, 32, 3), dtype=np.uint8))  # load cache

    def __call__(self, img: np.ndarray):
        img = transform(CV2PIL(img))
        c, n, m = img.shape
        img = img.view(1, c, n, m).to(device)
        with torch.no_grad():
            _, predicted = torch.max(self.model(img), 1)
            TileID = predicted[0]
            TileName = classes[TileID.item()]
        return TileName

In [2]:
print(device)

cpu


In [4]:
# 识别
classify = Classify()
img = cv2.imread('D:\Project\SoulPlay\Data\\recogition\data0\\7m\\be035502-c84f-11ec-9335-e0d55e4c11ff.png')

print(classify(img))

7m


In [18]:
import os
import cv2
import numpy as np

class MahjongProcessor:
    def __init__(self):
        self.classifier = Classify()  # 初始化分类器

    def get_mahjongs_position(self, img):
        """分割麻将牌区域（基于您提供的代码修改）"""
        gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        thresh = np.zeros_like(gray)
        
        # 颜色阈值调整（根据实际画面优化）
        for i in range(19, 25, 2):
            thresh += cv2.inRange(img, (i*10, i*10, i*10), (i*10+20, i*10+20, i*10+20))
        
        # 形态学操作优化分割
        kernel = np.ones((3, 3), np.uint8)
        thresh = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, kernel, iterations=2)
        
        # 查找轮廓
        contours, hierarchy = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
        boxes = []
        for i in range(len(contours)):
            area = cv2.contourArea(contours[i])
            if area < 300 or hierarchy[0][i][3] != -1:
                continue
            rect = cv2.minAreaRect(contours[i])
            box = cv2.boxPoints(rect)
            boxes.append(np.int64(box))
        return boxes

    def process_image(self, img_path):
        """主处理流程"""
        img = cv2.imread(img_path)
        if img is None:
            print("无法读取图像")
            return
        
        # 步骤1：分割麻将牌区域
        boxes = self.get_mahjongs_position(img)
        
        # 步骤2：提取并分类每个牌
        padding = 1
        results = []
        for box in boxes:
            # 提取单个牌图像
            min_x = min(p[1] for p in box)
            max_x = max(p[1] for p in box)
            min_y = min(p[0] for p in box)
            max_y = max(p[0] for p in box)
            
            # 截取并预处理
            tile_img = img[min_x+padding:max_x-padding, min_y+padding:max_y-padding]
            if tile_img.size == 0:
                continue
            
            # 调整尺寸为分类器输入要求
            tile_img = cv2.resize(tile_img, (40, 80))
            
            # 分类
            tile_name = self.classifier(tile_img)
            results.append(((min_y, min_x), tile_name))
        
        # 步骤3：标注结果
        for pos, name in results:
            cv2.putText(img, name, pos, cv2.FONT_HERSHEY_SIMPLEX,
                        0.7, (0, 0, 255), 2)
        
        # 显示结果
        cv2.imshow("Result", img)
        cv2.waitKey(0)
        cv2.destroyAllWindows()


In [19]:
if __name__ == "__main__":
    processor = MahjongProcessor()
    processor.process_image("./test/spilt1.png")

In [None]:
# 输出图片分割结果
