Skip to content

Commit

Permalink
feat(single-mode): recognize training type
Browse files Browse the repository at this point in the history
  • Loading branch information
NateScarlet committed Jun 24, 2021
1 parent de12cba commit eca43ae
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 4 deletions.
12 changes: 12 additions & 0 deletions auto_derby/mathtools.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

from typing import Tuple

import cast_unknown as cast
import numpy as np


def linear_interpolate(a: float, b: float, pos: float) -> float:
return a + (b - a) * pos
Expand Down Expand Up @@ -57,3 +60,12 @@ def vector4(
) -> Tuple[int, int, int, int]:
l, t, r, b = (self.vector(i, from_) for i in rect)
return l, t, r, b




def distance(a: Tuple[int, ...], b: Tuple[int, ...]) -> float:
assert len(a) == len(b), f"length must be same, got len(a)={len(a)} len(b)={len(b)}"
return cast.instance(
np.sqrt(np.sum((np.array(a) - np.array(b)) ** 2, axis=0)), float
)
8 changes: 8 additions & 0 deletions auto_derby/mathtools_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,11 @@ def test_resize_proxy_issue71():

res = rp.vector(-50, 466)
assert res == -116, res


def test_distance():
assert mathtools.distance((0,), (1,)) == 1
assert mathtools.distance((0,), (-1,)) == 1
assert mathtools.distance((0, 0), (1, 1)) == 1.4142135623730951
assert mathtools.distance((0, 0), (-1, -1)) == 1.4142135623730951
assert mathtools.distance((0, 0, 0), (1, 1, 1)) == 1.7320508075688772
60 changes: 56 additions & 4 deletions auto_derby/single_mode/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

import os
from typing import Tuple, Type
from typing import Dict, Tuple, Type

import cv2
import numpy as np
Expand All @@ -18,6 +18,7 @@

class g:
training_class: Type[Training]
target_levels: Dict[int, int] = {}


def _gradient(colors: Tuple[Tuple[Tuple[int, int, int], int], ...]) -> np.ndarray:
Expand Down Expand Up @@ -125,12 +126,27 @@ def _recognize_level(rgb_color: Tuple[int, ...]) -> int:


class Training:
TYPE_SPEED: int = 1
TYPE_STAMINA: int = 2
TYPE_POWER: int = 3
TYPE_GUTS: int = 4
TYPE_WISDOM: int = 5

ALL_TYPES = (
TYPE_SPEED,
TYPE_STAMINA,
TYPE_POWER,
TYPE_GUTS,
TYPE_WISDOM,
)

@staticmethod
def new() -> Training:
return g.training_class()

def __init__(self):
self.level = 0
self.type = 0

self.speed: int = 0
self.stamina: int = 0
Expand All @@ -144,6 +160,8 @@ def __init__(self):

@classmethod
def from_training_scene(cls, img: Image) -> Training:
rp = mathtools.ResizeProxy(img.width)

self = cls.new()
self.confirm_position = next(
template.match(
Expand All @@ -153,8 +171,24 @@ def from_training_scene(cls, img: Image) -> Training:
),
)
)[1]

rp = mathtools.ResizeProxy(img.width)
radius = rp.vector(30, 540)
for t, center in zip(
Training.ALL_TYPES,
(
rp.vector2((78, 850), 540),
rp.vector2((171, 850), 540),
rp.vector2((268, 850), 540),
rp.vector2((367, 850), 540),
rp.vector2((461, 850), 540),
),
):
if mathtools.distance(self.confirm_position, center) < radius:
self.type = t
break
else:
raise ValueError(
"unknown type for confirm position: %s" % self.confirm_position
)

self.level = _recognize_level(
tuple(cast.list_(img.getpixel(rp.vector2((10, 200), 540)), int))
Expand Down Expand Up @@ -264,7 +298,25 @@ def score(self, ctx: Context) -> float:
(7000, 1.0),
),
)
return (spd + sta + pow + per + int_ + skill) * success_rate

target_level = g.target_levels.get(self.type, self.level)
target_level_score = 0
if self.level < target_level:
target_level_score += mathtools.interpolate(
ctx.turn_count(),
(
(0, 5),
(24, 3),
(48, 2),
(72, 0),
),
)
elif self.level > target_level:
target_level_score -= (self.level - target_level) * 5

return (
spd + sta + pow + per + int_ + skill + target_level_score
) * success_rate


g.training_class = Training
9 changes: 9 additions & 0 deletions auto_derby/single_mode/training_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def test_update_by_training_scene():
)

training = Training.from_training_scene(img)
assert training.type == training.TYPE_SPEED
assert training.level == 5
assert training.speed == 26
assert training.stamina == 0
Expand All @@ -33,6 +34,7 @@ def test_update_by_training_scene_2():
)

training = Training.from_training_scene(img)
assert training.type == training.TYPE_WISDOM
assert training.level == 3
assert training.speed == 6
assert training.stamina == 0
Expand All @@ -50,6 +52,7 @@ def test_update_by_training_scene_3():
)

training = Training.from_training_scene(img)
assert training.type == training.TYPE_GUTS
assert training.level == 5
assert training.speed == 6
assert training.stamina == 0
Expand All @@ -67,6 +70,7 @@ def test_update_by_training_scene_4():
)

training = Training.from_training_scene(img)
assert training.type == training.TYPE_GUTS
assert training.level == 5
assert training.speed == 7
assert training.stamina == 0
Expand All @@ -80,6 +84,7 @@ def test_update_by_training_scene_5():
with _test.screenshot("training_scene_5.png") as img:

training = Training.from_training_scene(img)
assert training.type == training.TYPE_WISDOM
assert training.level == 2
assert training.speed == 2
assert training.stamina == 0
Expand All @@ -97,6 +102,7 @@ def test_update_by_training_scene_issue9():
)

training = Training.from_training_scene(img)
assert training.type == training.TYPE_SPEED
assert training.level == 1
assert training.speed == 12
assert training.stamina == 0
Expand All @@ -114,6 +120,7 @@ def test_update_by_training_scene_issue24():
)

training = Training.from_training_scene(img)
assert training.type == training.TYPE_STAMINA
assert training.level == 1
assert training.speed == 0
assert training.stamina == 9
Expand All @@ -131,6 +138,7 @@ def test_update_by_training_scene_issue51():
)

training = Training.from_training_scene(img)
assert training.type == training.TYPE_SPEED
assert training.level == 5
assert training.speed == 21
assert training.stamina == 0
Expand All @@ -144,6 +152,7 @@ def test_update_by_training_scene_issue55():
img = PIL.Image.open(_TEST_DATA_PATH / "training_scene_issue55.png").convert("RGB")

training = Training.from_training_scene(img)
assert training.type == training.TYPE_SPEED
assert training.level == 5
assert training.speed == 30
assert training.stamina == 0
Expand Down

0 comments on commit eca43ae

Please sign in to comment.