Skip to content

Commit

Permalink
Relax logger check for Windows (#1615)
Browse files Browse the repository at this point in the history
* Relax logger check for Windows

* Update tests
  • Loading branch information
araffin committed Jul 21, 2023
1 parent 61e1060 commit a730b9b
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 6 deletions.
3 changes: 2 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Changelog
==========


Release 2.1.0a1 (WIP)
Release 2.1.0a2 (WIP)
--------------------------

Breaking Changes:
Expand All @@ -25,6 +25,7 @@ New Features:

Bug Fixes:
^^^^^^^^^^
- Relaxed check in logger, that was causing issue on Windows with colorama

Deprecations:
^^^^^^^^^^^^^
Expand Down
6 changes: 4 additions & 2 deletions stable_baselines3/common/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,10 @@ def __init__(self, filename_or_file: Union[str, TextIO], max_length: int = 36):
if isinstance(filename_or_file, str):
self.file = open(filename_or_file, "w")
self.own_file = True
elif isinstance(filename_or_file, TextIOBase):
self.file = filename_or_file
elif isinstance(filename_or_file, TextIOBase) or hasattr(filename_or_file, "write"):
# Note: in theory `TextIOBase` check should be sufficient,
# in practice, libraries don't always inherit from it, see GH#1598
self.file = filename_or_file # type: ignore[assignment]
self.own_file = False
else:
raise ValueError(f"Expected file or str, got {filename_or_file}")
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.1.0a1
2.1.0a2
5 changes: 3 additions & 2 deletions tests/test_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,8 +437,9 @@ def test_ep_buffers_stats_window_size(algo, stats_window_size):
assert model.ep_success_buffer.maxlen == stats_window_size


def test_human_output_format_custom_test_io():
class DummyTextIO(TextIOBase):
@pytest.mark.parametrize("base_class", [object, TextIOBase])
def test_human_output_format_custom_test_io(base_class):
class DummyTextIO(base_class):
def __init__(self) -> None:
super().__init__()
self.lines = [[]]
Expand Down

0 comments on commit a730b9b

Please sign in to comment.