Skip to content

Commit

Permalink
Fix Colab logger error (#1484)
Browse files Browse the repository at this point in the history
* fix HumanOutputFormat

* update version

* update changelog

* TextIO annotation, TextIOBase isinstance

* update changelog

* test for HumanOutputFormat with custom TextIO

* rm extra test line

* Update tests/test_logger.py

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>

---------

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
  • Loading branch information
qgallouedec and araffin committed May 5, 2023
1 parent 63a0bb9 commit 9cebedc
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 4 deletions.
3 changes: 2 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Changelog
==========

Release 2.0.0a7 (WIP)
Release 2.0.0a8 (WIP)
--------------------------

**Gymnasium support**
Expand All @@ -21,6 +21,7 @@ Breaking Changes:
- Removed deprecated ``stack_observation_space`` method of ``StackedObservations``
- Renamed environment output observations in ``evaluate_policy`` to prevent shadowing the input observations during callbacks (@npit)
- Upgraded wrappers and custom environment to Gymnasium
- Refined the ``HumanOutputFormat`` file check: now it verifies if the object is an instance of ``io.TextIOBase`` instead of only checking for the presence of a ``write`` method.

New Features:
^^^^^^^^^^^^^
Expand Down
4 changes: 2 additions & 2 deletions stable_baselines3/common/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import tempfile
import warnings
from collections import defaultdict
from io import TextIOWrapper
from io import TextIOBase
from typing import Any, Dict, List, Mapping, Optional, Sequence, TextIO, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -164,7 +164,7 @@ def __init__(self, filename_or_file: Union[str, TextIO], max_length: int = 36):
if isinstance(filename_or_file, str):
self.file = open(filename_or_file, "w")
self.own_file = True
elif isinstance(filename_or_file, TextIOWrapper): # equivalent to `isinstance(..., TextIO)` (not supported)
elif isinstance(filename_or_file, TextIOBase):
self.file = filename_or_file
self.own_file = False
else:
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.0.0a7
2.0.0a8
39 changes: 39 additions & 0 deletions tests/test_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import sys
import time
from io import TextIOBase
from typing import Sequence
from unittest import mock

Expand Down Expand Up @@ -434,3 +435,41 @@ def test_ep_buffers_stats_window_size(algo, stats_window_size):
model.learn(total_timesteps=10)
assert model.ep_info_buffer.maxlen == stats_window_size
assert model.ep_success_buffer.maxlen == stats_window_size


def test_human_output_format_custom_test_io():
class DummyTextIO(TextIOBase):
def __init__(self) -> None:
super().__init__()
self.lines = [[]]

def write(self, t: str) -> int:
self.lines[-1].append(t)

def flush(self) -> None:
self.lines.append([])

def close(self) -> None:
pass

def get_printed(self) -> str:
return "\n".join(["".join(line) for line in self.lines])

dummy_text_io = DummyTextIO()
output = HumanOutputFormat(dummy_text_io)
output.write({"key1": "value1", "key2": 42}, {"key1": None, "key2": None})
output.write({"key1": "value2", "key2": 43}, {"key1": None, "key2": None})
printed = dummy_text_io.get_printed()
desired_printed = """-----------------
| key1 | value1 |
| key2 | 42 |
-----------------
-----------------
| key1 | value2 |
| key2 | 43 |
-----------------
"""

assert printed == desired_printed

0 comments on commit 9cebedc

Please sign in to comment.