-
Notifications
You must be signed in to change notification settings - Fork 0
/
agent.py
executable file
Β·197 lines (173 loc) Β· 6.93 KB
/
agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
import numpy as np
from pysc2.agents import base_agent
from pysc2.lib import actions
from pysc2.lib.features import SCREEN_FEATURES
_PLAYER_RELATIVE = SCREEN_FEATURES.player_relative.index
_UNIT_TYPE = SCREEN_FEATURES.unit_type.index
_SELECTED = SCREEN_FEATURES.selected.index
_UNIT_HIT_POINTS = SCREEN_FEATURES.unit_hit_points.index
_UNIT_ENERGY = SCREEN_FEATURES.unit_energy.index
_PLAYER_FRIENDLY = 1
_PLAYER_NEUTRAL = 3 # beacon/minerals
_PLAYER_HOSTILE = 4
_NO_OP = actions.FUNCTIONS.no_op.id
_MOVE_SCREEN = actions.FUNCTIONS.Move_screen.id
_ATTACK_SCREEN = actions.FUNCTIONS.Attack_screen.id
_SELECT_ARMY = actions.FUNCTIONS.select_army.id
_SELECT_UNIT = actions.FUNCTIONS.select_unit.id
_SELECT_CONTROL_GROUP = actions.FUNCTIONS.select_control_group.id
_NOT_QUEUED = [0]
_SELECT_ALL = [0]
def make_default_args(arg_names):
default_args = []
spatial_seen = False
spatial_arguments = ["screen", "minimap", "screen2"]
for k in arg_names:
if k in spatial_arguments:
spatial_seen = True
continue
else:
assert not spatial_seen, "got %s argument after spatial argument" % k
default_args.append([0])
return tuple(default_args), spatial_seen
def convert_point_to_rectangle(point, delta, dim):
def l(x):
return max(0, min(x, dim - 1))
p1 = [l(k - delta) for k in point]
p2 = [l(k + delta) for k in point]
return p1, p2
def arg_names():
x = [[a.name for a in k.args] for k in actions.FUNCTIONS]
assert all("minimap2" not in k for k in x)
return x
def find_rect_function_id():
"""
this is just a change-safe way to return 3
"""
x = [k.id for k, names in zip(actions.FUNCTIONS, arg_names()) if "screen2" in names]
assert len(x) == 1
return x[0]
class MADDPG(base_agent.BaseAgent):
"""An agent specifically for solving the DefeatRoaches map."""
def __init__(self):
super(MADDPG, self).__init__()
self.default_args, is_spatial = zip(*[make_default_args(k) for k in arg_names()])
self.is_spatial = np.array(is_spatial)
self.rect_select_action_id = find_rect_function_id()
self.rect_delta = 5
self.dim = 40
self.distance=[0,0,0,0,0]
def selected_units(self, obs):
selected = obs.observation["screen"][_SELECTED]
self.selected_units_y, self.selected_units_x = selected.nonzero()
self.num_units = len(self.selected_units_y)
def select_unit(self, i, obs, env):
if _SELECT_CONTROL_GROUP in obs.observation["available_actions"]:
action = [actions.FunctionCall(_SELECT_CONTROL_GROUP, [[0], [i]])]
obs = env.step(action)[0]
if obs.last():
return False
return obs
def build_group(self, obs, env):
obs_shape_n = []
group = []
"""
player_relative = obs.observation["screen"][_PLAYER_RELATIVE]
army_y, army_x = (player_relative == _PLAYER_HOSTILE).nonzero()
"""
for i in range(self.num_units):
if _SELECT_UNIT in obs.observation["available_actions"]:
action = [actions.FunctionCall(_SELECT_UNIT, [[0], [i]])]
obs = env.step(action)[0]
#self.changeDistanceMatrix(obs,i)
if _SELECT_CONTROL_GROUP in obs.observation["available_actions"]:
action = [actions.FunctionCall(_SELECT_CONTROL_GROUP, [[2], [i]])]
obs = env.step(action)[0]
group.append(True)
if _SELECT_ARMY in obs.observation["available_actions"]:
action = [actions.FunctionCall(_SELECT_ARMY, [_SELECT_ALL])]
obs = env.step(action)[0]
obs_shape_n.append((20,))
self.group = group
'''
if _SELECT_UNIT in obs.observation["available_actions"]:
action = [actions.FunctionCall(_SELECT_UNIT, [[2], [0]])]
obs = env.step(action)[0]
if _ATTACK_SCREEN in obs.observation["available_actions"]:
index = np.argmax(army_y)
target = [army_x[index], army_y[index]]
action = [actions.FunctionCall(_ATTACK_SCREEN, [_NOT_QUEUED, target])]
obs = env.step(action)[0]
'''
if len(group) < self.num_units:
self.selected_units(obs)
self.build_group(obs, env)
return obs_shape_n, obs
def get_obs(self, obs, env):
player_relative = obs.observation["screen"][_PLAYER_RELATIVE]
reaper_y, reaper_x =(player_relative == _PLAYER_FRIENDLY).nonzero()
if len(reaper_x) == 5 and self.group.count(True) == 0:
self.num_units = 5
obs_shape_n, obs = self.build_group(obs, env)
unit_y, unit_x = player_relative.nonzero()
num_units = len(unit_x)
obs_n = []
unit = 0
rew_d=[]
for i, alive_reaper in enumerate(self.group):
obs_info = []
if alive_reaper and _SELECT_CONTROL_GROUP in obs.observation["available_actions"]:
action = [actions.FunctionCall(_SELECT_CONTROL_GROUP, [[0], [unit]])]
obs = env.step(action)[0]
unit += 1
self.selected_units(obs)
selected_reapers =i
if self.num_units == 1:
for i in range(min(num_units, 10)):
obs_info.append(unit_y[i] - self.selected_units_y[0])
obs_info.append(unit_x[i] - self.selected_units_x[0])
rew_d.append(-self.changeDistanceMatrix(obs,selected_reapers)/50)
for i in range(num_units, 10):
obs_info.append(0)
obs_info.append(0)
else:
self.group[i] = False
obs_info = [0] * 20
rew_d.append(0)
elif not alive_reaper:
obs_info = [0] * 20
obs_n.append(obs_info)
return obs_n, obs
'''
def make_one_action(self, action_id, spatial_coordinates):
args = list(self.default_args[action_id])
assert all(s < self.dim for s in spatial_coordinates)
if action_id == self.rect_select_action_id:
args.extend(convert_point_to_rectangle(spatial_coordinates, self.rect_delta, self.dim))
elif self.is_spatial[action_id]:
# NOTE: in pysc2 v 1.2 the action space (x,y) is flipped. Handling that conversion here
# in all other places we operate with the "non-flipped" coordinates
args.append(spatial_coordinates[::-1])
return actions.FunctionCall(action_id, args)
'''
def step(self, obs):
super(MADDPG, self).step(obs)
if self.action_id in obs.observation["available_actions"]:
return actions.FunctionCall(self.action_id, self.args)
# reward_distance
def changeDistanceMatrix(self,obs,reaper_i):
distanceChanged=self.distance[reaper_i]
self.distance[reaper_i]=0
selected = obs.observation["screen"][_SELECTED]
selected_y, selected_x = (selected == 1).nonzero()
player_relative = obs.observation["screen"][_PLAYER_RELATIVE]
army_y, army_x = (player_relative == _PLAYER_HOSTILE).nonzero()
for i in range(len(army_x)):
y=abs(selected_y[0] - army_y[i])
x=abs(selected_x[0] - army_x[i])
if x*x+y*y<=25:
self.distance[reaper_i] += 0
else:
self.distance[reaper_i] += (x*x+y*y)**0.5-5
distanceChanged=self.distance[reaper_i]-distanceChanged
return distanceChanged