Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 19 additions & 23 deletions python/unitytrainers/bc/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,10 @@ def __init__(self, sess, env, brain_name, trainer_parameters, training, seed):
self.training_buffer = Buffer()
self.is_continuous_action = (env.brains[brain_name].vector_action_space_type == "continuous")
self.is_continuous_observation = (env.brains[brain_name].vector_observation_space_type == "continuous")
self.use_observations = (env.brains[brain_name].number_visual_observations > 0)
if self.use_observations:
self.use_visual_observations = (env.brains[brain_name].number_visual_observations > 0)
if self.use_visual_observations:
logger.info('Cannot use observations with imitation learning')
self.use_states = (env.brains[brain_name].vector_observation_space_size > 0)
self.use_vector_observations = (env.brains[brain_name].vector_observation_space_size > 0)
self.summary_path = trainer_parameters['summary_path']
if not os.path.exists(self.summary_path):
os.makedirs(self.summary_path)
Expand Down Expand Up @@ -144,16 +144,15 @@ def take_action(self, all_brain_info: AllBrainInfo):
agent_brain = all_brain_info[self.brain_name]
feed_dict = {self.model.dropout_rate: 1.0, self.model.sequence_length: 1}

if self.use_observations:
if self.use_visual_observations:
for i, _ in enumerate(agent_brain.visual_observations):
feed_dict[self.model.visual_in[i]] = agent_brain.visual_observations[i]
if self.use_states:
if self.use_vector_observations:
feed_dict[self.model.vector_in] = agent_brain.vector_observations
if self.use_recurrent:
if agent_brain.memories.shape[1] == 0:
agent_brain.memories = np.zeros((len(agent_brain.agents), self.m_size))
feed_dict[self.model.memory_in] = agent_brain.memories
if self.use_recurrent:
agent_action, memories = self.sess.run(self.inference_run_list, feed_dict)
return agent_action, memories, None, None
else:
Expand Down Expand Up @@ -192,11 +191,11 @@ def add_experiences(self, curr_info: AllBrainInfo, next_info: AllBrainInfo, take
info_teacher_record, next_info_teacher_record = "true", "true"
if info_teacher_record == "true" and next_info_teacher_record == "true":
if not stored_info_teacher.local_done[idx]:
if self.use_observations:
if self.use_visual_observations:
for i, _ in enumerate(stored_info_teacher.visual_observations):
self.training_buffer[agent_id]['visual_observations%d' % i]\
.append(stored_info_teacher.visual_observations[i][idx])
if self.use_states:
if self.use_vector_observations:
self.training_buffer[agent_id]['vector_observations']\
.append(stored_info_teacher.vector_observations[idx])
if self.use_recurrent:
Expand Down Expand Up @@ -276,41 +275,38 @@ def update_model(self):
"""
Uses training_buffer to update model.
"""

self.training_buffer.update_buffer.shuffle()
batch_losses = []
for j in range(
min(len(self.training_buffer.update_buffer['actions']) // self.n_sequences, self.batches_per_epoch)):
_buffer = self.training_buffer.update_buffer
start = j * self.n_sequences
end = (j + 1) * self.n_sequences
batch_states = np.array(_buffer['vector_observations'][start:end])
batch_actions = np.array(_buffer['actions'][start:end])

feed_dict = {self.model.dropout_rate: 0.5,
self.model.batch_size: self.n_sequences,
self.model.sequence_length: self.sequence_length}
if self.is_continuous_action:
feed_dict[self.model.true_action] = batch_actions.reshape([-1, self.brain.vector_action_space_size])
else:
feed_dict[self.model.true_action] = batch_actions.reshape([-1])
if not self.is_continuous_observation:
feed_dict[self.model.vector_in] = batch_states.reshape([-1, self.brain.num_stacked_vector_observations])
feed_dict[self.model.true_action] = np.array(_buffer['actions'][start:end]).\
reshape([-1, self.brain.vector_action_space_size])
else:
feed_dict[self.model.vector_in] = batch_states.reshape([-1, self.brain.vector_observation_space_size *
self.brain.num_stacked_vector_observations])
if self.use_observations:
feed_dict[self.model.true_action] = np.array(_buffer['actions'][start:end]).reshape([-1])
if self.use_vector_observations:
if not self.is_continuous_observation:
feed_dict[self.model.vector_in] = np.array(_buffer['vector_observations'][start:end])\
.reshape([-1, self.brain.num_stacked_vector_observations])
else:
feed_dict[self.model.vector_in] = np.array(_buffer['vector_observations'][start:end])\
.reshape([-1, self.brain.vector_observation_space_size * self.brain.num_stacked_vector_observations])
if self.use_visual_observations:
for i, _ in enumerate(self.model.visual_in):
_obs = np.array(_buffer['visual_observations%d' % i][start:end])
(_batch, _seq, _w, _h, _c) = _obs.shape
feed_dict[self.model.visual_in[i]] = _obs.reshape([-1, _w, _h, _c])
feed_dict[self.model.visual_in[i]] = _obs
if self.use_recurrent:
feed_dict[self.model.memory_in] = np.zeros([self.n_sequences, self.m_size])

loss, _ = self.sess.run([self.model.loss, self.model.update], feed_dict=feed_dict)
batch_losses.append(loss)
if len(batch_losses) > 0:
self.stats['losses'].append(np.mean(batch_losses))
else:
self.stats['losses'].append(0)