In [None]:
from time import sleep
import sys
import cv2
import win32gui
import csv
import os
import glob
import re
import mysql.connector
from mysql.connector import Error
import difflib
from datetime import datetime
from PIL import ImageGrab, ImageEnhance
from paddleocr import PaddleOCR

ocr = PaddleOCR(
    use_doc_orientation_classify=False,
    use_doc_unwarping=False,
    use_textline_orientation=False,
    text_det_box_thresh=0.6,
    lang='ch',
    # 使用轻量模型
    text_detection_model_name='PP-OCRv4_mobile_det',  # 检测模型
    text_recognition_model_name='PP-OCRv4_mobile_rec',  # 识别模型
    text_detection_model_dir='C:\\Users\\16528\\.paddlex\\official_models\\PP-OCRv4_mobile_det',  # 检测模型
    text_recognition_model_dir='C:\\Users\\16528\\.paddlex\\official_models\\PP-OCRv4_mobile_rec',  # 识别模型
)

custom_dict_path = "C:\\Users\\16528\\Desktop\\custom_words.txt"  # 自定义词库路径
# ===== 配置 =====
TARGET_TITLE = "短线精灵"  # 目标窗口标题
SAVE_FILE = "C:\\Users\\16528\\PycharmProjects\\PythonProject\\save_file"

INTERVAL = 5  # 秒
DEBUG_DIR = "debug_images"
GRID_ROWS = 1  # 行数 (1行)
GRID_COLS = 5  # 列数 (8列)
GRID_PADDING = 0  # 区域间的像素填充
CROP_TOP = 60  # 高度裁剪顶部位置
CROP_BOTTOM = 1008  # 高度裁剪底部位置
CROP_LEFT = 0  # 宽度裁剪左侧位置
CROP_RIGHT = 1920  # 宽度裁剪右侧位置
# 数据库连接参数
host = "localhost"
database = "public"
user = "root"
password = "root"

# 获取系统日期
date_str = datetime.now().strftime("%Y%m%d")

# 新建子目录
save_path = os.path.join(SAVE_FILE, date_str)
folder_path = os.path.join(SAVE_FILE, date_str)
os.makedirs(save_path, exist_ok=True)
os.makedirs(folder_path, exist_ok=True)

# 拼接保存文件路径
TXT_FILE = os.path.join(save_path, "大笔买入.txt")
CSV_FILE = os.path.join(save_path, "大笔买入.csv")
# 输出到文件和控制台
output_path = os.path.join(folder_path, "统计结果.txt")
last_file_path=os.path.join(folder_path, "last_big_amount.txt")

def find_window_rect(title):
    hwnd = win32gui.FindWindow(None, title)
    if hwnd == 0:
        print(f"❌ 找不到窗口: {title}")
        return None
    return win32gui.GetWindowRect(hwnd)

# ====== 读取自定义词库 ======
with open(custom_dict_path, "r", encoding="utf-8") as f:
    custom_words = [w.strip() for w in f.read().splitlines() if w.strip()]

def import_csv_to_mysql(all_rows_csv, host, database, user, password):
    try:
        # 连接到MySQL数据库
        connection = mysql.connector.connect(
            host=host,
            database=database,
            user=user,
            password=password
        )

        if connection.is_connected():
            cursor = connection.cursor()
            # 准备插入语句
            insert_query = """INSERT INTO short_term_elf (
                time, stock_name, description, amount,trade_date
            ) VALUES (%s, %s, %s, %s, CURDATE() )"""

            # 插入数据
            for row in all_rows_csv:
                # 处理空值
                processed_row = []
                for value in row:
                    if value == '' or value is None:
                        processed_row.append(None)
                    else:
                        processed_row.append(value)

                cursor.execute(insert_query, processed_row)

            connection.commit()
            print(f"成功导入 {cursor.rowcount} 条记录到数据库")

    except Error as e:
        print(f"数据库错误: {e}")
    finally:
        if connection.is_connected():
            cursor.close()
            connection.close()
            print("MySQL连接已关闭")


def preprocess_image(img):
    """优化图像预处理，特别针对小字体"""
    # 转换为灰度
    img = img.convert("L")

    # 增强对比度
    enhancer = ImageEnhance.Contrast(img)
    img = enhancer.enhance(1.7)

    # 锐化图像
    enhancer = ImageEnhance.Sharpness(img)
    img = enhancer.enhance(1.3)

    return img

def load_existing_data(file_path):
    """读取已保存的行，去掉空行"""
    if not os.path.exists(file_path):
        return set()
    with open(file_path, 'r', encoding='utf-8-sig') as f:
        return set(line.strip() for line in f if line.strip())

def save_unique_txt(rows):
    """追加保存到 TXT（自动去重）"""
    existing = load_existing_data(TXT_FILE)
    new_rows = [r for r in rows if r not in existing]
    if not new_rows:
        return 0
    with open(TXT_FILE, 'a', encoding='utf-8') as f:
        for row in new_rows:
            f.write(row + "\n")
    return len(new_rows)

def save_unique_csv(rows):
    """追加保存到 CSV（自动去重）"""
    existing = load_existing_data(CSV_FILE)
    is_new_file = not os.path.exists(CSV_FILE)
    with open(CSV_FILE, 'a', newline='', encoding='utf-8-sig') as f:
        writer = csv.writer(f)
        if is_new_file:
            writer.writerow(["时间", "股票", "描述", "金额"])
        added_count = 0
        unique_rows = []  # 用于返回最终去重后的数据
        for row in rows:
            row_str = ",".join(row)
            if row_str not in existing:
                writer.writerow(row)
                unique_rows.append(row)
                added_count += 1
        if added_count > 0:
            import_csv_to_mysql(unique_rows, host, database, user, password)
    return added_count

def sort_rows_by_time(txt_rows, csv_rows):
    """
    按时间排序数据行
    :param txt_rows: ["09:31:25| 股票 | 描述 | 金额", ...]
    :param csv_rows: [["09:31:25", "股票", "描述", "金额"], ...]
    """
    # 排序 TXT
    txt_rows_sorted = sorted(
        txt_rows,
        key=lambda r: datetime.strptime(r.split("|")[0].strip(), "%H:%M:%S"),
    )

    # 排序 CSV
    csv_rows_sorted = sorted(
        csv_rows,
        key=lambda r: datetime.strptime(r[0].strip(), "%H:%M:%S")
    )

    return txt_rows_sorted, csv_rows_sorted

import akshare as ak

def get_market_cap_map(df):
    # 只取名称和流通市值两列
    df_selected = df[["名称", "流通市值"]].copy()

    # 转换为万元
    df_selected["流通市值"] = df_selected["流通市值"] / 10000

    # 生成映射字典 {股票名称: 流通市值(万元)}
    market_cap_map = dict(zip(df_selected["名称"], df_selected["流通市值"]))

    return market_cap_map



def match_ocr_lines_to_dict(ocr_lines: list, custom_words: list, cutoff=0.6):
    """
    批量模糊匹配 OCR 文本行到词库

    :param ocr_lines: OCR 识别结果列表（每行一个元素）
    :param custom_words: 自定义词库（list）
    :param cutoff: 相似度阈值（0~1），越高匹配越严格
    :return: 匹配结果列表（每个元素是匹配后的字符串）
    """
    # 确保输入是列表且非空
    if not isinstance(ocr_lines, list):
        raise ValueError("ocr_lines 必须是列表类型")

    matched_results = []
    for line in ocr_lines:
        # 获取最相似的匹配项
        matches = difflib.get_close_matches(line, custom_words, n=1, cutoff=cutoff)

        if matches:
            # 使用匹配到的词库项
            matched_results.append(matches[0])
        else:
            # 没有匹配到则保留原文本
            matched_results.append(line)

    return matched_results

def split_image_into_regions_reverse(cropped):
    width = cropped.shape[1]
    col_width = width // GRID_COLS
    found_data = False
    all_rows_txt = []
    all_rows_csv = []

    # 逆序循环
    for i in reversed(range(GRID_COLS)):
        start_col = i * col_width
        end_col = width if i == GRID_COLS - 1 else (i + 1) * col_width
        region = cropped[:, start_col:end_col]

        # 分四列
        h, w = region.shape[:2]
        col1 = region[:, 0:int(w * 0.25)]
        col2 = region[:, int(w * 0.25):int(w * 0.50)]
        col3 = region[:, int(w * 0.50):int(w * 0.75)]
        col4 = region[:, int(w * 0.75):]

        # 保存调试图
        cv2.imwrite("col1_time.png", col1)
        cv2.imwrite("col2_stock.png", col2)
        cv2.imwrite("col3_desc.png", col3)
        cv2.imwrite("col4_money.png", col4)

        image_path_time = r"C:\Users\16528\PycharmProjects\PythonProject\shortline_ths\col1_time.png"
        image_path_stock = r"C:\Users\16528\PycharmProjects\PythonProject\shortline_ths\col2_stock.png"
        image_path_desc = r"C:\Users\16528\PycharmProjects\PythonProject\shortline_ths\col3_desc.png"
        image_path_money = r"C:\Users\16528\PycharmProjects\PythonProject\shortline_ths\col4_money.png"

        # OCR
        time_text = ocr.predict(image_path_time)
        stock_text =ocr.predict(image_path_stock)
        desc_text = ocr.predict(image_path_desc)
        money_text = ocr.predict(image_path_money)

        # 按行合并
        time_lines = time_text[0]['rec_texts']
        stock_lines = stock_text[0]['rec_texts']
        stock_lines = match_ocr_lines_to_dict(stock_lines, custom_words, 0.7)
        desc_lines = desc_text[0]['rec_texts']
        money_lines = money_text[0]['rec_texts']

        if not (len(time_lines) == len(stock_lines) == len(money_lines)):
            print(f"错误：第 {i} 区域的四列的行数不一致！")
            return False

        rows = min(len(time_lines), len(stock_lines), len(money_lines))
        if rows == 0  or (len(time_lines) == 1 and time_lines[0].strip() == ""):
            continue  # 没数据 → 往前扫描
        found_data = True

        # 组合行
        for j in range(rows):
            # 判断是否为卖出
            money_val = money_lines[j]
            if "卖出" in desc_lines[j]:
                # 数字前加负号
                if not money_val.startswith("-"):
                    money_val = "-" + money_val

            txt_line = f"{time_lines[j].replace(' ', '')} | {stock_lines[j]} | {desc_lines[j].replace('人', '入')} | {money_val}"
            csv_line = [time_lines[j].replace(" ", ""), stock_lines[j], desc_lines[j].replace("人", "入"), money_val]
            all_rows_txt.append(txt_line)
            all_rows_csv.append(csv_line)

        # 检查是否有重复
        existing_txt = load_existing_data(TXT_FILE)
        if any(r in existing_txt for r in all_rows_txt):
            print(f"⚠️ 第 {i} 区域检测到重复数据 → 停止本轮扫描")
            break  # 停止本次扫描
        print(f"本轮扫描第 {i} 区域检测已完成")

    if not found_data:
        print("❌ 本轮扫描未识别到任何数据")
    else:
        # 按时间排序
        all_rows_txt, all_rows_csv = sort_rows_by_time(all_rows_txt, all_rows_csv)
        added_txt = save_unique_txt(all_rows_txt)
        added_csv = save_unique_csv(all_rows_csv)
        print(f"💾 本轮新增 TXT {added_txt} 行, CSV {added_csv} 行（已按时间排序）")

def output_big_amount_from_file(market_cap_map,file_path, threshold=3000, save_path="last_big_amount.txt"):

    """
    输出成交金额大于 threshold 万元的股票，同时和上一次结果做对比（↑ ↓）
    """
    # 1. 读取上次结果
    last_result = {}
    if os.path.exists(save_path):
        with open(save_path, "r", encoding="utf-8") as f:
            for line in f:
                parts = line.strip().split(",")
                if len(parts) == 3:
                    name, amount, pct = parts
                    try:
                        last_result[name] = float(amount)
                    except ValueError:
                        continue

    big_amount_stocks = []

    with open(file_path, "r", encoding="utf-8") as f:
        for line in f:
            # 提取成交总金额数字
            match_amount = re.search(r"成交总金额为([\d.]+)万元", line)
            if match_amount:
                total_amount = float(match_amount.group(1))
                if total_amount > threshold:
                    # 提取第一个非空的股票名称（假设股票名在开头，中文或数字代码都可）
                    match_name = re.match(r"([^\s，,]+)", line.strip())
                    if match_name:
                        stock_name = match_name.group(1)
                        pct = round(total_amount/market_cap_map[stock_name] * 100,2)
                        if pct > 0.3:
                            big_amount_stocks.append((stock_name, total_amount, pct))

    # 按金额降序
    big_amount_stocks.sort(key=lambda x: x[2], reverse=True)

    # 3. 输出
    if big_amount_stocks:
        print(f"\n💰 成交总金额 > {threshold} 万元的股票：")
        formatted = []
        printStockName =[]
        for name, amount, pct in big_amount_stocks:
            if name in last_result:
                cz =amount-last_result[name]
                if amount > last_result[name]:
                    amount = f"{amount}🔺🔺🔺{cz}"
                elif amount < last_result[name]:
                    amount = f"{amount}🔽🔽🔽{cz}"
                elif amount == last_result[name]:
                    amount = f"{amount}➖"
            else:
                amount = f"{amount}🔺"   # 金额染红
            if pct > 0.7:
                pct = f"{pct}%❤️"
            formatted.append(f"{name}：{amount}: {pct}")
            printStockName.append(f"{name}")
        print("  |" .join(formatted))
        print(f"\n 满足条件的股票名称{printStockName}")
    else:
        print(f"\n没有成交总金额大于 {threshold} 万元的股票")

    # 4. 保存本次结果，覆盖文件
    with open(save_path, "w", encoding="utf-8") as f:
        for name, amount, pct in big_amount_stocks:
            f.write(f"{name},{amount},{pct}\n")

    sys.stdout.flush()  # 双重确保刷新

def tj():

    # 匹配模式：第一列数字或冒号组合，第二列股票名，第四列金额
    # 改进后的正则表达式，支持负号金额匹配
    pattern = re.compile(
        r"^\s*([\d:]+)\s*\|\s*([\w\u4e00-\u9fa5\s-]+)\s*\|\s*([\u4e00-\u9fa5\s]+)\s*\|\s*(-?[\d\.]+)万"
    )

    stock_data = {}

    # 遍历目录下所有 txt 文件
    txt_files = glob.glob(os.path.join(folder_path, "*.txt"))

    for file_path in txt_files:
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                if "|" not in line:
                    continue
                match = pattern.match(line)
                if match:
                    time_str, stock_name, note, amount_str = match.groups()
                    # 保留正负号
                    amount = int(amount_str) if not amount_str.startswith('-') else int(amount_str)

                    if stock_name not in stock_data:
                        stock_data[stock_name] = {
                            "count": 0,
                            "total": 0,
                            "early_count": 0, "early_amount": 0,
                            "mid_count": 0, "mid_amount": 0,
                            "late_count": 0, "late_amount": 0,
                            "details": []
                        }

                    # 直接按正负累加 → 买入加、卖出减
                    stock_data[stock_name]["count"] += 1
                    stock_data[stock_name]["total"] += amount
                    stock_data[stock_name]["details"].append((time_str, amount))

                    # 时间段分类
                    try:
                        hour, minute, _ = map(int, time_str.split(":"))
                    except ValueError:
                        continue
                    time_val = hour * 60 + minute
                    if time_val < 9 * 60 + 45:
                        stock_data[stock_name]["early_count"] += 1
                        stock_data[stock_name]["early_amount"] += amount
                    elif time_val <= 14 * 60 + 45:
                        stock_data[stock_name]["mid_count"] += 1
                        stock_data[stock_name]["mid_amount"] += amount
                    else:
                        stock_data[stock_name]["late_count"] += 1
                        stock_data[stock_name]["late_amount"] += amount


    # 按出现次数降序
    sorted_data = sorted(stock_data.items(), key=lambda x: x[1]["total"], reverse=True)

    with open(output_path, 'w', encoding='utf-8') as out_file:
        for stock, info in sorted_data:
            # 买入 / 卖出总计
            buy_count = sum(1 for _, a in info["details"] if a > 0)
            buy_amount = sum(a for _, a in info["details"] if a > 0)
            sell_count = sum(1 for _, a in info["details"] if a < 0)
            sell_amount = sum(abs(a) for _, a in info["details"] if a < 0)

            # 成交总金额 = 买入金额 - 卖出金额
            total_trades = info["count"]
            total_amount = buy_amount - sell_amount  # ← 这里改成差值

            # 早盘、中盘、尾盘的买入/卖出细分
            def period_stats(start_min, end_min):
                trades = [(t, a) for t, a in info["details"]
                          if start_min <= (int(t.split(":")[0]) * 60 + int(t.split(":")[1])) <= end_min]
                period_count = len(trades)
                # 净额 = 买入 - 卖出
                period_amount = sum(a for _, a in trades)
                buy_c = sum(1 for _, a in trades if a > 0)
                buy_a = sum(a for _, a in trades if a > 0)
                sell_c = sum(1 for _, a in trades if a < 0)
                sell_a = sum(abs(a) for _, a in trades if a < 0)
                return period_count, period_amount, buy_c, buy_a, sell_c, sell_a

            early_stats = period_stats(0, 9 * 60 + 44)
            mid_stats   = period_stats(9 * 60 + 45, 14 * 60 + 45)
            late_stats  = period_stats(14 * 60 + 46, 24 * 60)

            details_str = "，".join([f"{t}{'买入' if a>0 else '卖出'}{abs(a)}万元" for t, a in info["details"]])

            out_file.write(
                f"{stock} ，成交{total_trades}次，成交总金额为{total_amount}万元；"
                f"买入{buy_count}次金额{buy_amount}万元，卖出{sell_count}次金额{sell_amount}万元\n"
                f"    早盘：成交{early_stats[0]}次，成交金额{early_stats[1]}万元；买入{early_stats[2]}次金额{early_stats[3]}万元；卖出{early_stats[4]}次金额{early_stats[5]}万元\n"
                f"    中盘：成交{mid_stats[0]}次，成交金额{mid_stats[1]}万元；买入{mid_stats[2]}次金额{mid_stats[3]}万元；卖出{mid_stats[4]}次金额{mid_stats[5]}万元\n"
                f"    尾盘：成交{late_stats[0]}次，成交金额{late_stats[1]}万元；买入{late_stats[2]}次金额{late_stats[3]}万元；卖出{late_stats[4]}次金额{late_stats[5]}万元\n"
                f"    详情：分别在{details_str}\n\n"
            )

    print(f"\n统计完成，结果已保存到: {output_path}")

def parse_text_to_csv(input_file_path, output_file_path=None):
    # 如果未指定输出文件路径，则使用输入文件同名但扩展名为.csv
    if output_file_path is None:
        output_file_path = os.path.splitext(input_file_path)[0] + '.csv'

    # 读取输入文件内容
    with open(input_file_path, 'r', encoding='utf-8') as f:
        content = f.read()

    # 分割内容为多个记录（假设每个记录由空行分隔）
    records = content.strip().split('\n\n')

    # 定义CSV表头
    headers = [
        '股票名称', '总成交次数', '总成交金额(万元)', '总买入次数', '总买入金额(万元)',
        '总卖出次数', '总卖出金额(万元)',
        '早盘成交次数', '早盘成交金额(万元)', '早盘买入次数', '早盘买入金额(万元)',
        '早盘卖出次数', '早盘卖出金额(万元)',
        '中盘成交次数', '中盘成交金额(万元)', '中盘买入次数', '中盘买入金额(万元)',
        '中盘卖出次数', '中盘卖出金额(万元)',
        '尾盘成交次数', '尾盘成交金额(万元)', '尾盘买入次数', '尾盘买入金额(万元)',
        '尾盘卖出次数', '尾盘卖出金额(万元)',
        '详情'
    ]

    # 准备所有记录的数据
    all_data = []

    for record in records:
        # 跳过空记录
        if not record.strip():
            continue

        # 提取股票名称
        stock_name_match = re.match(r'(\S+)\s*，', record)
        if not stock_name_match:
            continue

        stock_name = stock_name_match.group(1)
        data = [stock_name]

        # 提取总体数据
        # 总成交次数和金额
        total_trades_match = re.search(r'成交(\d+)次', record)
        total_amount_match = re.search(r'成交总金额为(-?\d+)万元', record)

        # 买入次数和金额（注意：可能为负数）
        buy_times_match = re.search(r'买入(\d+)次', record)
        buy_amount_match = re.search(r'买入.*?金额(-?\d+)万元', record)

        # 卖出次数和金额（注意：可能为负数）
        sell_times_match = re.search(r'卖出(\d+)次', record)
        sell_amount_match = re.search(r'卖出.*?金额(-?\d+)万元', record)

        # 添加到数据列表
        data.append(total_trades_match.group(1) if total_trades_match else '0')
        data.append(total_amount_match.group(1) if total_amount_match else '0')
        data.append(buy_times_match.group(1) if buy_times_match else '0')
        data.append(buy_amount_match.group(1) if buy_amount_match else '0')
        data.append(sell_times_match.group(1) if sell_times_match else '0')
        data.append(sell_amount_match.group(1) if sell_amount_match else '0')

        # 提取各时间段数据
        time_periods = ['早盘', '中盘', '尾盘']
        for period in time_periods:
            # 使用更精确的正则表达式提取每个时间段的数据
            period_pattern = f'{period}：.*?成交(\\d+)次，成交金额(-?\\d+)万元；买入(\\d+)次金额(-?\\d+)万元；卖出(\\d+)次金额(-?\\d+)万元'
            match = re.search(period_pattern, record)
            if match:
                data.extend(match.groups())
            else:
                # 如果没有找到匹配，尝试更宽松的匹配
                loose_pattern = f'{period}：.*?成交(\\d+)次.*?成交金额(-?\\d+)万元.*?买入(\\d+)次.*?金额(-?\\d+)万元.*?卖出(\\d+)次.*?金额(-?\\d+)万元'
                loose_match = re.search(loose_pattern, record, re.DOTALL)
                if loose_match:
                    data.extend(loose_match.groups())
                else:
                    data.extend(['0'] * 6)

        # 提取详情
        detail_match = re.search(r'详情：(.*?)(?=\n|$)', record, re.DOTALL)
        detail = detail_match.group(1).strip() if detail_match else ''
        data.append(detail)

        all_data.append(data)

    # 写入CSV文件
    with open(output_file_path, 'w', newline='', encoding='utf-8-sig') as f:
        writer = csv.writer(f)
        writer.writerow(headers)
        writer.writerows(all_data)

    print(f"成功处理 {len(all_data)} 条记录，结果已保存到 {output_file_path}")
    return output_file_path

if __name__ == "__main__":
    # 获取A股实时行情数据
    print("📡 开始获取a股实时行情数据...")
    df = ak.stock_zh_a_spot_em()
    market_cap_map = get_market_cap_map(df)
    print("📡 开始实时监控窗口内容...")
    try:
        while True:
            # 查找目标窗口
            rect = find_window_rect(TARGET_TITLE)

            if rect:
                print(f"🔍 窗口位置: {rect}")
                # 截取特定区域 (左上角100,100 到右下角500,500)
                rect = (0, 0, 1920, 1080)
                # 捕获并处理窗口图像
                img_np = ImageGrab.grab(bbox=rect)
                img_np = preprocess_image(img_np)
                # 保存原始图像
                timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
                img_np.save(f"{DEBUG_DIR}/{timestamp}_raw.png")
                # 读取原图
                img = cv2.imread(f"{DEBUG_DIR}/{timestamp}_raw.png")
                # 全图裁剪
                cropped = img[CROP_TOP:CROP_BOTTOM, CROP_LEFT:CROP_RIGHT]
                split_image_into_regions_reverse(cropped)
                tj()
                output_big_amount_from_file(market_cap_map,output_path,1500,last_file_path)
                parse_text_to_csv(output_path)
                print(f"下一次截屏采集图像时，停止{INTERVAL}秒")
                sleep(INTERVAL)

    except KeyboardInterrupt:
        print("\n🛑 程序已手动停止")
