Skip to content

Commit

Permalink
Allow logging lists (#355)
Browse files Browse the repository at this point in the history
* allow logging lists

* Update docs/misc/changelog.rst

Co-Authored-By: Antonin RAFFIN <antonin.raffin@ensta.org>

* Move imports to top level

* Add return type
  • Loading branch information
dwiel authored and araffin committed Jul 18, 2019
1 parent 2bc3c87 commit dc31d83
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 17 deletions.
2 changes: 1 addition & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ New Features:
- Add support for continuous action spaces to `action_probability`, computing the PDF of a Gaussian
policy in addition to the existing support for categorical stochastic policies.
- Add flag to `action_probability` to return log-probabilities.
- Added support for python lists and numpy arrays in ``logger.writekvs``. (@dwiel)

Bug Fixes:
^^^^^^^^^^
Expand Down Expand Up @@ -50,7 +51,6 @@ Breaking Changes:

- **breaking change** removed ``stable_baselines.ddpg.memory`` in favor of ``stable_baselines.deepq.replay_buffer`` (see fix below)


**Breaking Change:** DDPG replay buffer was unified with DQN/SAC replay buffer. As a result,
when loading a DDPG model trained with stable_baselines<2.6.0, it throws an import error.
You can fix that using:
Expand Down
51 changes: 36 additions & 15 deletions stable_baselines/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@
import warnings
from collections import defaultdict

import tensorflow as tf
from tensorflow.python import pywrap_tensorflow
from tensorflow.core.util import event_pb2
from tensorflow.python.util import compat

DEBUG = 10
INFO = 20
WARN = 30
Expand Down Expand Up @@ -124,8 +129,12 @@ def __init__(self, filename):
def writekvs(self, kvs):
for key, value in sorted(kvs.items()):
if hasattr(value, 'dtype'):
value = value.tolist()
kvs[key] = float(value)
if value.shape == () or len(value) == 1:
# if value is a dimensionless numpy array or of length 1, serialize as a float
kvs[key] = float(value)
else:
# otherwise, a value is a numpy array, serialize as a list or nested lists
kvs[key] = value.tolist()
self.file.write(json.dumps(kvs) + '\n')
self.file.flush()

Expand Down Expand Up @@ -180,6 +189,29 @@ def close(self):
self.file.close()


def summary_val(key, value):
"""
:param key: (str)
:param value: (float)
"""
kwargs = {'tag': key, 'simple_value': float(value)}
return tf.Summary.Value(**kwargs)


def valid_float_value(value):
"""
Returns True if the value can be successfully cast into a float
:param value: (Any) the value to check
:return: (bool)
"""
try:
float(value)
return True
except TypeError:
return False


class TensorBoardOutputFormat(KVWriter):
def __init__(self, folder):
"""
Expand All @@ -192,22 +224,11 @@ def __init__(self, folder):
self.step = 1
prefix = 'events'
path = os.path.join(os.path.abspath(folder), prefix)
import tensorflow as tf
from tensorflow.python import pywrap_tensorflow
from tensorflow.core.util import event_pb2
from tensorflow.python.util import compat
self._tf = tf
self.event_pb2 = event_pb2
self.pywrap_tensorflow = pywrap_tensorflow
self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path))

def writekvs(self, kvs):
def summary_val(key, value):
kwargs = {'tag': key, 'simple_value': float(value)}
return self._tf.Summary.Value(**kwargs)

summary = self._tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()])
event = self.event_pb2.Event(wall_time=time.time(), summary=summary)
summary = tf.Summary(value=[summary_val(k, v) for k, v in kvs.items() if valid_float_value(v)])
event = event_pb2.Event(wall_time=time.time(), summary=summary)
event.step = self.step # is there any reason why you'd want to specify the step?
self.writer.WriteEvent(event)
self.writer.Flush()
Expand Down
11 changes: 10 additions & 1 deletion tests/test_logger.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
import pytest
import numpy as np

from stable_baselines.logger import make_output_format, read_tb, read_csv, read_json, _demo


KEY_VALUES = {'test': 1, 'b': -3.14, '8': 9.9}
KEY_VALUES = {
"test": 1,
"b": -3.14,
"8": 9.9,
"l": [1, 2],
"a": np.array([1, 2, 3]),
"f": np.array(1),
"g": np.array([[[1]]]),
}
LOG_DIR = '/tmp/openai_baselines/'


Expand Down

0 comments on commit dc31d83

Please sign in to comment.