Skip to content

Commit

Permalink
Improving a bit the doc + cosmetic modif in Toy_env.py
Browse files Browse the repository at this point in the history
  • Loading branch information
VinF committed Feb 26, 2017
1 parent cf386d1 commit fb718a0
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 10 deletions.
34 changes: 32 additions & 2 deletions deer/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ class NeuralAgent(object):
replay_memory_size : int
Size of the replay memory. Default : 1000000
replay_start_size : int
Number of observations (=number of time steps taken) in the replay memory before starting learning. Default: minimum possible according to environment.inputDimensions().
Number of observations (=number of time steps taken) in the replay memory before starting learning.
Default: minimum possible according to environment.inputDimensions().
batch_size : int
Number of tuples taken into account for each iteration of gradient descent. Default : 32
random_state : numpy random number generator
Expand All @@ -44,7 +45,8 @@ class NeuralAgent(object):
test_policy : object from class Policy
Policy followed when in other modes than training (validation and test modes)
only_full_history : boolean
Whether we wish to train the neural network only on full histories or we wish to fill with zeroes the observations before the beginning of the episode
Whether we wish to train the neural network only on full histories or we wish to fill with zeroes the
observations before the beginning of the episode
"""

def __init__(self, environment, q_network, replay_memory_size=1000000, replay_start_size=None, batch_size=32, random_state=np.random.RandomState(), exp_priority=0, train_policy=None, test_policy=None, only_full_history=True):
Expand Down Expand Up @@ -180,6 +182,10 @@ def summarizeTestPerformance(self):
self._environment.summarizePerformance(self._tmp_dataset)

def train(self):
"""
This function selects a random batch of data (with self._dataset.randomBatch) and performs a
Q-learning iteration (with self._network.train).
"""
# We make sure that the number of elements in the replay memory
# is strictly superior to self._replay_start_size before taking
# a random batch and perform training
Expand Down Expand Up @@ -244,6 +250,20 @@ def setNetwork(self, fname, nEpoch=-1):
self._network.setAllParams(all_params)

def run(self, n_epochs, epoch_length):
"""
This function encapsulates the whole process of the learning.
It starts by calling the controllers method "onStart",
Then it runs a given number of epochs where an epoch is made up of one or many episodes (called with
agent._runEpisode) and where an epoch ends up after the number of steps reaches the argument "epoch_length".
It ends up by calling the controllers method "end".
Parameters
-----------
n_epochs : number of epochs
int
epoch_length : maximum number of steps for a given epoch
int
"""
for c in self._controllers: c.onStart(self)
i = 0
while i < n_epochs or self._mode_epochs_length > 0:
Expand All @@ -265,6 +285,15 @@ def run(self, n_epochs, epoch_length):
for c in self._controllers: c.onEnd(self)

def _runEpisode(self, maxSteps):
"""
This function runs an episode of learning. An episode ends up when the environment method "inTerminalState"
returns True (or when the number of steps reaches the argument "maxSteps")
Parameters
-----------
maxSteps : maximum number of steps before automatically ending the episode
int
"""
self._in_episode = True
initState = self._environment.reset(self._mode)
inputDims = self._environment.inputDimensions()
Expand All @@ -283,6 +312,7 @@ def _runEpisode(self, maxSteps):
self._state[i][-1] = obs[i]

V, action, reward = self._step()

self._Vs_on_last_episode.append(V)
if self._mode != -1:
self._total_mode_reward += reward
Expand Down
14 changes: 7 additions & 7 deletions deer/base_classes/Environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class Environment(object):
"""

def reset(self, mode):
"""Reset the environment and put it in mode [mode].
"""Resets the environment and put it in mode [mode]. This function is called when beginning every new episode.
The [mode] can be used to discriminate for instance between an agent which is training or trying to get a
validation or generalization score. The mode the environment is in should always be redefined by resetting the
Expand All @@ -37,7 +37,7 @@ def reset(self, mode):
raise NotImplementedError()

def act(self, action):
"""Apply the agent action [action] on the environment.
"""Applies the agent action [action] on the environment.
Parameters
-----------
Expand All @@ -49,7 +49,7 @@ def act(self, action):
raise NotImplementedError()

def inputDimensions(self):
"""Get the shape of the input space for this environment.
"""Gets the shape of the input space for this environment.
This returns a list whose length is the number of subjects observed on the environment. Each element of the
list is a tuple whose content and size depends on the type of data observed: the first integer is always the
Expand All @@ -65,14 +65,14 @@ def inputDimensions(self):
raise NotImplementedError()

def nActions(self):
"""Get the number of different actions that can be taken on this environment.
"""Gets the number of different actions that can be taken on this environment.
It can be either an integer in the case of a finite discrete number of actions
or it can be a list of couples [min_action_value,max_action_value] for a continuous action space"""

raise NotImplementedError()

def inTerminalState(self):
"""Tell whether the environment reached a terminal state after the last transition (i.e. the last transition
"""Tells whether the environment reached a terminal state after the last transition (i.e. the last transition
that occured was terminal).
As the majority of control tasks considered have no end (a continuous control should be operated), by default
Expand All @@ -88,7 +88,7 @@ def inTerminalState(self):
return False

def observe(self):
"""Get a list of punctual observations on all subjects composing this environment.
"""Gets a list of punctual observations on all subjects composing this environment.
This returns a list where element i is a punctual observation on subject i. You will notice that the history
of observations on this subject is not returned; only the very last observation. Each element is thus either
Expand All @@ -115,7 +115,7 @@ def summarizePerformance(self, test_data_set):
pass

def observationType(self, subject):
"""Get the most inner type (np.uint8, np.float32, ...) of [subject].
"""Gets the most inner type (np.uint8, np.float32, ...) of [subject].
Parameters
-----------
Expand Down
2 changes: 1 addition & 1 deletion deer/experiment/base_controllers.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def onEpochEnd(self, agent):


class TrainerController(Controller):
"""A controller that make the agent train on its current database periodically.
"""A controller that makes the agent train on its current database periodically.
Parameters
----------
Expand Down

0 comments on commit fb718a0

Please sign in to comment.