Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion autowsgr/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""AutoWSGR - 战舰少女R 自动化框架(v2)"""

__version__ = '2.1.9.post5'
__version__ = '2.1.9.post6'
28 changes: 22 additions & 6 deletions autowsgr/ui/choose_ship_page.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from __future__ import annotations

import re
import time
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -56,6 +57,7 @@
_SCROLL_FROM_Y: float = 0.55
_SCROLL_TO_Y: float = 0.30
_OCR_MAX_ATTEMPTS: int = 3
_SHIP_ALIAS_SUFFIX_RE = re.compile(r'\s*[((][^()()]*[))]\s*$')

PAGE_SIGNATURE = PixelSignature(
name='choose_ship_page',
Expand Down Expand Up @@ -329,14 +331,16 @@ def _click_ship_in_list(
Parameters
----------
name:
目标舰船名 (精确名称)。
目标舰船名。
匹配时会先做舰名归一化(如去除“·改”与尾部括号别名)后再比较。

Returns
-------
str | None
匹配并点击成功时返回舰船名;失败返回 ``None``。
"""
assert self._ctx.ocr is not None
normalized_target = self._normalize_ship_name(name)

for attempt in range(_OCR_MAX_ATTEMPTS):
screen = self._ctrl.screenshot()
Expand All @@ -359,20 +363,25 @@ def _click_ship_in_list(
raw_levels = []

hits = [self._normalize_hit_entry(hit) for hit in raw_hits]
level_map: dict[float, list[int | None]] = {}
level_map: dict[float, dict[str, list[int | None]]] = {}
for entry in raw_levels:
_, level, row_key = self._normalize_level_entry(entry)
level_map.setdefault(row_key, []).append(level)
level_name, level, row_key = self._normalize_level_entry(entry)
normalized_level_name = self._normalize_ship_name(level_name)
row_levels = level_map.setdefault(row_key, {})
row_levels.setdefault(normalized_level_name, []).append(level)

for matched, cx, cy, row_key in hits:
if matched != name:
normalized_matched = self._normalize_ship_name(matched)
if normalized_matched != normalized_target:
continue

level = None
if use_level_filter:
row_levels = level_map.get(row_key)
if row_levels:
level = row_levels.pop(0)
name_levels = row_levels.get(normalized_matched)
if name_levels:
level = name_levels.pop(0)
if not self._is_level_in_range(level, min_level, max_level):
_log.warning(
"[UI] 命中 '{}', 但等级 {} 不满足范围 [{}, {}]",
Expand Down Expand Up @@ -406,3 +415,10 @@ def _click_ship_in_list(
time.sleep(0.5)

return None

@staticmethod
def _normalize_ship_name(name: str) -> str:
normalized = name.strip()
normalized = normalized.removesuffix('·改')
normalized = _SHIP_ALIAS_SUFFIX_RE.sub('', normalized)
return normalized.strip()
190 changes: 169 additions & 21 deletions autowsgr/ui/utils/ship_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@
#: Legacy 选船列表左侧裁剪宽度 (px@1280)
LEGACY_LIST_WIDTH: int = 1048

_LEVEL_PATTERN = re.compile(r'[Ll][Vv]\.?\s*(\d+)')
_LEVEL_PATTERN = re.compile(r'[Ll][Vv]\.?\s*([0-9ILilOo]{1,6})')
_LEVEL_NOISY_PATTERN = re.compile(r'(?:[LlIi1O0][VvYy])[\.:]?\s*([0-9ILilOo]{1,6})')
_MAX_LEVEL_VALUE = 200


def to_legacy_format(screen: np.ndarray) -> tuple[np.ndarray, float, float]:
Expand Down Expand Up @@ -169,15 +171,124 @@ def recognize_ships_in_list(

def _parse_level(text: str) -> int | None:
"""从 OCR 文本中提取 ``Lv.XX`` 格式等级数字。"""
m = _LEVEL_PATTERN.search(text)
compact = text.strip().replace(' ', '')

m = _LEVEL_PATTERN.search(compact)
if m:
try:
return int(m.group(1))
except ValueError:
return None
level = _coerce_level_digits(m.group(1))
if level is not None:
return level

m2 = _LEVEL_NOISY_PATTERN.search(compact)
if m2:
level = _coerce_level_digits(m2.group(1))
if level is not None:
return level

return None
Comment on lines 172 to 188
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

新增的等级解析/噪声矫正逻辑(_parse_level() + _coerce_level_digits(),包括 i -> 1、前 3 位截断、_LEVEL_NOISY_PATTERN 等)目前没有对应的单元测试覆盖;仓库里已有 testing/vision/test_ocr.py 这类 OCR 相关单测框架。建议补充一些纯字符串输入的单测用例,覆盖典型 OCR 噪声(如 Lv.i5Iv:O3Lv.051Lv.110544)以及超界值被拒绝的情况,避免后续调整 regex/映射时回归。

Copilot uses AI. Check for mistakes.


def _coerce_level_digits(raw_digits: str) -> int | None:
"""将 OCR 提取出的数字串映射为合法等级值。"""
trans = str.maketrans(
{
'I': '1',
'i': '1',
'l': '1',
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_LEVEL_PATTERN/_LEVEL_NOISY_PATTERN 的捕获组允许包含大写 L(常见把数字 1 识别为 L),但 _coerce_level_digits()trans 只映射了 I/i/l -> 1,未处理 L。这会导致像 Lv.L5 这样的输入在 translate 后变成 L5,最终只提取到 5 或直接返回 None,降低等级解析命中率。建议把 'L': '1' 也加入映射表以与正则允许字符保持一致。

Suggested change
'l': '1',
'l': '1',
'L': '1',

Copilot uses AI. Check for mistakes.
'O': '0',
'o': '0',
}
)
normalized = raw_digits.translate(trans)
digits = ''.join(ch for ch in normalized if ch.isdigit())
if not digits:
return None

candidates: list[int] = []

# 先尝试前 3 位(常见误读: 1046 -> 104, 110544 -> 110)
if len(digits) >= 3:
candidates.append(int(digits[:3]))
if len(digits) >= 2:
candidates.append(int(digits[:2]))
candidates.append(int(digits[:1]))

# 兼容前导 0 的场景(如 051 -> 51)
if digits.startswith('0') and len(digits) >= 3:
candidates.insert(0, int(digits[1:3]))

seen_vals: set[int] = set()
for value in candidates:
if value in seen_vals:
continue
seen_vals.add(value)
if 1 <= value <= _MAX_LEVEL_VALUE:
return value

return None


def _center_x(bbox: tuple[int, int, int, int] | None, width: int) -> float:
if bbox is None:
return width / 2
x1, _, x2, _ = bbox
return (x1 + x2) / 2


def _probe_level_near_name(
ocr: OCREngine,
screen: np.ndarray,
*,
y_start: int,
y_end: int,
name_x: float,
max_x: int,
) -> int | None:
"""在同一 y 行按舰名 x 位置裁剪区域,二次识别等级。"""
h, w = screen.shape[:2]
row_h = max(1, y_end - y_start)

x_pad = max(70, int(w * 0.045))
x0 = max(0, int(name_x - x_pad))
x1 = min(max_x, int(name_x + x_pad))

y0 = max(0, y_start - int(row_h * 1.6))
y1 = min(h, y_end + int(row_h * 0.4))

if x1 <= x0 or y1 <= y0:
return None

roi = screen[y0:y1, x0:x1]
if roi.size == 0:
return None

parsed_levels: list[int] = []

def collect_levels(img: np.ndarray) -> None:
results = ocr.recognize(img, allowlist='LlVvIiYy0Oo1.:-/0123456789')
for r in results:
text = r.text.strip()
if not text:
continue
level = _parse_level(text)
if level is not None:
parsed_levels.append(level)

collect_levels(roi)

gray = cv2.cvtColor(roi, cv2.COLOR_RGB2GRAY)
up = cv2.resize(gray, None, fx=3, fy=3, interpolation=cv2.INTER_CUBIC)
norm = cv2.normalize(up, None, 0, 255, cv2.NORM_MINMAX)
binary = cv2.threshold(norm, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)[1]
binary_rgb = cv2.cvtColor(binary, cv2.COLOR_GRAY2RGB)
collect_levels(binary_rgb)

if not parsed_levels:
return None

return max(parsed_levels)


def read_ship_levels(
ocr: OCREngine,
screen: np.ndarray,
Expand Down Expand Up @@ -230,43 +341,80 @@ def read_ship_levels(
for y_start_720, y_end_720 in rows:
y_start = max(0, int((y_start_720 - 1) * scale_y))
y_end = min(h, int((y_end_720 + 1) * scale_y))
row_key = round((y_start + y_end) / 2 / h, 4)

row_img = list_area_native[y_start:y_end]
results = ocr.recognize(row_img)

row_name: str | None = None
row_level: int | None = None
name_hits: list[tuple[str, float]] = []
local_level_hits: list[tuple[int, float]] = []

for r in results:
text = r.text.strip()
if not text:
continue

# 尝试匹配等级
if row_level is None:
level = _parse_level(text)
if level is not None:
row_level = level
x_center = _center_x(r.bbox, row_img.shape[1])

# 尝试匹配舰船名
if row_name is None:
name = _fuzzy_match(text, SHIPNAMES)
if name is not None and name not in seen:
row_name = name
level = _parse_level(text)
if level is not None:
local_level_hits.append((level, x_center))

if row_name is not None:
name = _fuzzy_match(text, SHIPNAMES)
if name is not None:
name_hits.append((name, x_center))

if not name_hits:
continue

name_hits.sort(key=lambda item: item[1])
local_level_hits.sort(key=lambda item: item[1])
max_pair_dist = max(80.0, row_img.shape[1] * 0.12)

for row_name, name_x in name_hits:
if deduplicate_by_name and row_name in seen:
continue

row_level: int | None = None

best_level: int | None = None
best_dist = float('inf')
for candidate_level, candidate_x in local_level_hits:
dist = abs(candidate_x - name_x)
if dist < best_dist:
best_dist = dist
best_level = candidate_level

if best_level is not None and best_dist <= max_pair_dist:
row_level = best_level

if row_level is None:
probe_level = _probe_level_near_name(
ocr,
screen,
y_start=y_start,
y_end=y_end,
name_x=name_x,
max_x=list_w_native,
)
if probe_level is not None:
row_level = probe_level

if deduplicate_by_name:
seen.add(row_name)
row_key = round((y_start + y_end) / 2 / h, 4)
_log.debug(
'[选船列表] 等级识别命中: name={} level={} row_key={}',
row_name,
row_level if row_level is not None else 'None',
row_key,
)
if include_row_key:
found.append((row_name, row_level, row_key))
else:
found.append((row_name, row_level))

_log.debug(
'[选船列表] 等级识别: {}',
[(n, lv) for n, lv in found],
[(entry[0], entry[1]) for entry in found],
)
return found
Loading