In [1]:
# -*- coding: utf-8 -*-
# Use convolutional neural network to classify tile image
import os

import cv2
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms

# CNN输出(int)与牌名(str)的对应关系
classes = {
    0: '1m',
    1: '2m',
    2: '3m',
    3: '4m',
    4: '5m',
    5: '6m',
    6: '7m',
    7: '8m',
    8: '9m',
    9: '1p',
    10: '2p',
    11: '3p',
    12: '4p',
    13: '5p',
    14: '6p',
    15: '7p',
    16: '8p',
    17: '9p',
    18: '1s',
    19: '2s',
    20: '3s',
    21: '4s',
    22: '5s',
    23: '6s',
    24: '7s',
    25: '8s',
    26: '9s',
    27: '1z',
    28: '2z',
    29: '3z',
    30: '4z',
    31: '5z',
    32: '6z',
    33: '7z',
    34: '0m',
    35: '0p',
    36: '0s',
    37: 'back'   # 牌背面
}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def CV2PIL(img):
    return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))


transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])


class TileNet(nn.Module):
    def __init__(self):
        super(TileNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 10, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(10, 26, 5)
        self.fc1 = nn.Linear(26 * 5 * 5, 300)
        self.fc2 = nn.Linear(300, 124)
        self.fc3 = nn.Linear(124, 38)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 26 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


class Classify:

    def __init__(self):
        self.model = model = TileNet()
        path = os.path.join(os.path.dirname('__file__'), '../ModelTrain/recogition/tile.model')
        # 如果模型是在 GPU 上训练的，但在 CPU 上运行，需要映射到 CPU
        self.map_location = device
        self.model.load_state_dict(torch.load(path, map_location=self.map_location))
        # self.model.load_state_dict(torch.load(path))
        self.model.to(device)
        self.__call__(np.ones((32, 32, 3), dtype=np.uint8))  # load cache

    def __call__(self, img: np.ndarray):
        img = transform(CV2PIL(img))
        c, n, m = img.shape
        img = img.view(1, c, n, m).to(device)
        with torch.no_grad():
            _, predicted = torch.max(self.model(img), 1)
            TileID = predicted[0]
            TileName = classes[TileID.item()]
        return TileName

In [2]:
print(device)

cpu


In [None]:
# 识别
classify = Classify()
img = cv2.imread('D:\Project\SoulPlay\Data\\recogition\output_final\self_discard_Onphone\\0.png')

print(classify(img))

1p


In [13]:
import os
import cv2
import numpy as np
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor

class BatchClassifier:
    def __init__(self):
        self.classifier = Classify()  # 初始化分类器

    def process_single_image(self, img_path):
        """处理单张图片（线程安全）"""
        filename = os.path.basename(img_path)
        try:
            img = cv2.imread(img_path)
            if img is None:
                return filename, "error: 无法读取图像"
            
            # 分类识别
            tile_name = self.classifier(img)
            return filename, tile_name
        except Exception as e:
            return filename, f"error: {str(e)}"

    def process_folder(self, input_folder, output_file="results.csv", max_workers=4):
        """
        多线程处理整个文件夹
        :param input_folder: 输入文件夹路径
        :param output_file:  输出结果文件路径
        :param max_workers:  最大线程数（建议设置为CPU核心数的2-4倍）
        """
        # 获取所有图片文件
        image_exts = ('.jpg', '.jpeg', '.png', '.bmp', '.webp')
        image_files = [
            os.path.join(input_folder, f) 
            for f in os.listdir(input_folder)
            if f.lower().endswith(image_exts)
        ]
        
        # 创建线程池并处理
        with ThreadPoolExecutor(max_workers=max_workers) as executor:
            # 提交所有任务
            futures = [executor.submit(self.process_single_image, path) for path in image_files]
            
            # 使用进度条监控处理进度
            results = []
            for future in tqdm(
                futures, 
                total=len(image_files), 
                desc="Processing Images", 
                unit="image"
            ):
                try:
                    result = future.result()
                    results.append(result)
                except Exception as e:
                    print(f"处理任务时发生未捕获的异常: {str(e)}")

        # 写入结果（按文件名排序）
        results.sort(key=lambda x: x[0])  # 按文件名排序
        with open(output_file, 'w', encoding='utf-8') as f:
            f.write("filename,tile_name\n")
            for filename, tile_name in results:
                f.write(f"{filename},{tile_name}\n")
        
        print(f"处理完成！结果已保存至 {output_file}")



In [22]:
# 使用示例
if __name__ == "__main__":
    batch_classifier = BatchClassifier()
    batch_classifier.process_folder(
        input_folder="D:\Project\SoulPlay\Data\\recogition\output_final\hand_tiles_normal",
        output_file="./results.csv",
        max_workers=4  # 根据设备性能调整
    )

Processing Images: 100%|██████████| 11/11 [00:00<00:00, 524.19image/s]

处理完成！结果已保存至 ./results.csv





In [32]:
import os
import json
import cv2
import numpy as np
from tqdm import tqdm
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor
from typing import List, Dict

class GameStateGenerator(BatchClassifier):
    def __init__(self):
        super().__init__()
        self.hand_tiles_dir = r"../Data/recogition/output_final/phone1_hand_tiles"
        self.dora_indicator_dir = r"../Data/recogition/output_final/phone1_dora_indicator"
        
    def get_sorted_hand_tiles(self) -> List[str]:
        """获取按文件名排序的手牌图片路径"""
        return sorted(
            [str(p) for p in Path(self.hand_tiles_dir).glob("*") if p.is_file()],
            key=lambda x: os.path.basename(x))

    def process_hand_tiles(self) -> List[str]:
        """多线程处理手牌图片"""
        hand_images = self.get_sorted_hand_tiles()
        valid_tiles = []
        
        with ThreadPoolExecutor(max_workers=4) as executor:
            futures = [executor.submit(self.process_single_image, path) for path in hand_images]
            
            for future in tqdm(futures, desc="识别手牌", unit="张"):
                filename, tile_name = future.result()
                if "error" not in tile_name and tile_name != "back":
                    valid_tiles.append(tile_name)
        
        return valid_tiles

    def get_dora_indicator_path(self) -> str:
        """获取最新的宝牌指示牌路径"""
        dora_files = sorted(
            Path(self.dora_indicator_dir).glob("*"),
            key=lambda x: x.stat().st_mtime,  # 按修改时间排序
            reverse=True  # 取最新文件
        )
        return str(dora_files[0]) if dora_files else ""

    def calculate_real_dora(self, indicator_tile: str) -> str:
        """计算真正的宝牌（考虑风牌和三元牌顺序）"""
        if not indicator_tile or indicator_tile == "back":
            return "unknown"
        
        # 分离数字和类型
        num_str = indicator_tile[:-1]
        tile_type = indicator_tile[-1]
        
        try:
            if tile_type in ["m", "p", "s"]:  # 数牌
                num = int(num_str)
                real_num = (num % 9) + 1
                return f"{real_num}{tile_type}"
            elif tile_type == "z":  # 字牌
                z_num = int(num_str)
                # 风牌循环顺序：东(1z)->南(2z)->西(3z)->北(4z)->东
                wind_order = {1: 2, 2: 3, 3: 4, 4: 1}
                # 三元牌循环顺序：白(5z)->发(6z)->中(7z)->白
                dragon_order = {5: 6, 6: 7, 7: 5}
                
                if 1 <= z_num <= 4:  # 风牌
                    return f"{wind_order[z_num]}z"
                elif 5 <= z_num <= 7:  # 三元牌
                    return f"{dragon_order[z_num]}z"
                else:
                    return "unknown"
        except (ValueError, KeyError):
            pass
        return "unknown"

    def recognize_dora(self) -> List[str]:
        """识别宝牌指示牌并计算真实宝牌"""
        dora_path = self.get_dora_indicator_path()
        if not dora_path:
            return []
        
        try:
            # 识别指示牌
            img = cv2.imread(dora_path)
            if img is None:
                return []
            
            indicator_tile = self.classifier(img)
            real_dora = self.calculate_real_dora(indicator_tile)
            return [real_dora] if real_dora != "unknown" else []
        except Exception as e:
            print(f"宝牌识别失败: {str(e)}")
            return []

    def generate_game_state(self) -> Dict:
        """生成游戏状态JSON结构"""
        return {
            "id": -1,
            "state": "GameStart",
            "seatList": [1, 2, 3, 17457800],  # 需根据实际游戏数据修改
            "tiles": self.process_hand_tiles(),
            "doras": self.recognize_dora()  # 使用真实宝牌
        }

    def save_game_state(self, output_path: str):
        """保存游戏状态到JSON文件"""
        game_state = self.generate_game_state()
        
        with open(output_path, 'w', encoding='utf-8') as f:
            json.dump(game_state, f, indent=2, ensure_ascii=False)
            
        print(f"游戏状态已保存至：{os.path.abspath(output_path)}")

# 使用示例
if __name__ == "__main__":
    generator = GameStateGenerator()
    
    # 测试宝牌计算逻辑
    print(generator.calculate_real_dora("9m"))  # 应输出 1m
    print(generator.calculate_real_dora("5s"))  # 应输出 6s
    print(generator.calculate_real_dora("4z"))  # 应输出 1z
    
    # 生成并保存游戏状态
    generator.save_game_state(
        output_path=r"game_state.json"
    )

1m
6s
1z


识别手牌: 100%|██████████| 10/10 [00:00<00:00, 7259.09张/s]

游戏状态已保存至：d:\Project\SoulPlay\IMGProcess\game_state.json



