/
agent.py
141 lines (120 loc) · 4.49 KB
/
agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
"""
Methods for agent interaction using sapai-gym and stable basline 3
"""
from sb3_contrib import MaskablePPO
from sb3_contrib.common.envs import InvalidActionEnvDiscrete
from sapai_gym import SuperAutoPetsEnv
from sb3_contrib.common.maskable.evaluation import evaluate_policy
from sb3_contrib.common.maskable.utils import get_action_masks
from sapai_gym.opponent_gen.opponent_generators import random_opp_generator, biggest_numbers_horizontal_opp_generator
from sapai_gym.ai import baselines
from sapai import *
from sapai.shop import *
from .image_detection import *
from .actions import *
import keyboard
import matplotlib.pyplot as plt
import pyautogui as gui
def pause():
while True:
if keyboard.read_key() == 'space':
# If you put 'space' key
# the program will resume.
break
def time_pause(time: int):
plt.pause(time)
def get_action_name(k: int) -> str:
name_val = list(SuperAutoPetsEnv.ACTION_BASE_NUM.items())
assert k >= 0
for (start_name, _), (end_name, end_val) in zip(name_val[:-1], name_val[1:]):
if k < end_val:
return start_name
else:
return end_name
def remove_nothing(pet_list):
pets = []
for i in pet_list:
if i != 'nothing':
pets.append(i)
return pets
def opponent_generator(num_turns):
# Returns teams to fight against in the gym
opponents = biggest_numbers_horizontal_opp_generator(25)
return opponents
def create_new_PPO():
env = InvalidActionEnvDiscrete(dim=80, n_invalid_actions=60)
model = MaskablePPO("MlpPolicy", env, gamma=0.4, seed=32, verbose=1)
model.learn(5000)
evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=90, warn=False)
model.save("ppo_sapai_070822")
def run(model_path):
interface = SuperAutoPetsMouse()
action_dict = interface.actionDict()
# create_new_PPO()
#model = MaskablePPO.load("ppo_sapai_070822")
#model = MaskablePPO.load("ppo_sapai_3")
model = MaskablePPO.load(model_path)
env = SuperAutoPetsEnv(opponent_generator, valid_actions_only=True)
obs = env.reset()
while True:
time_pause(0.5)
pets, _ = find_the_animals(directory=os.path.join(os.path.dirname(os.path.abspath(__file__)), "SAP_res\\"))
pets = remove_nothing(pets)
print(pets)
env.player.shop = Shop(pets)
if env.player.lives <= 3:
env.player.lives += 3
action_masks = get_action_masks(env)
obs = env._encode_state()
action, _states = model.predict(obs, action_masks=action_masks, deterministic=True)
s = env._avail_actions()
# print(s[action][1:])
time_pause(0.5)
print("Action")
print(action)
print(get_action_name(action))
print(s[action][0])
print(s[action][1:])
if env._is_valid_action(action):
if get_action_name(action) == 'buy_food':
num_pets = 0
num_food = 0
for shop_slot in env.player.shop:
if shop_slot.slot_type == "pet":
num_pets += 1
if shop_slot.slot_type == "food":
num_food += 1
action_dict[get_action_name(action)](s[action][1:], num_pets - num_food % 2)
else:
if get_action_name(action) == 'roll':
action_dict[get_action_name(action)]()
else:
action_dict[get_action_name(action)](s[action][1:])
obs, reward, done, info = env.step(action)
if get_action_name(action) == 'end_turn':
# time_pause(1.5)
# when end turn is pressed, I want it to spam clicking until it sees end turn button again (game is over).
time_pause(3.0)
battle_finished = False
while not battle_finished:
# click event
print("click event occured")
gui.click(1780, 200)
# check if battle is done
if find_paw():
print("battle is done!")
battle_finished = True
else:
# check if game is over
if find_arena():
print("Game is over! Start new game 8)")
gui.click(600, 400)
gui.click(1780, 200)
# if done:
# obs = env.reset()
# break
print(s[action][0])
env.close()
if __name__ == "__main__":
pause()
run()