Skip to content

Commit

Permalink
Emit warning when reserved fields specified
Browse files Browse the repository at this point in the history
  • Loading branch information
avylove committed Jul 25, 2020
1 parent 7431b33 commit 27cdf89
Show file tree
Hide file tree
Showing 10 changed files with 120 additions and 35 deletions.
5 changes: 3 additions & 2 deletions enlighten/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@

from enlighten.counter import Counter, StatusBar, SubCounter
from enlighten._manager import Manager, get_manager
from enlighten._util import Justify
from enlighten._util import EnlightenWarning, Justify


__version__ = '1.6.0'
__all__ = ('Counter', 'Justify', 'Manager', 'StatusBar', 'SubCounter', 'get_manager')
__all__ = ('Counter', 'EnlightenWarning', 'Justify', 'Manager',
'StatusBar', 'SubCounter', 'get_manager')
25 changes: 15 additions & 10 deletions enlighten/_basecounter.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,19 @@ def __repr__(self):

return '%s(%s)' % (self.__class__.__name__, ', '.join(params))

def __init__(self, **kwargs):
def __init__(self, keywords=None, **kwargs):

self.count = self.start_count = kwargs.get('count', 0)
if keywords is not None:
kwargs = keywords

self.count = self.start_count = kwargs.pop('count', 0)
self._color = None

self.manager = kwargs.get('manager', None)
self.manager = kwargs.pop('manager', None)
if self.manager is None:
raise TypeError('manager must be specified')

self.color = kwargs.get('color', None)
self.color = kwargs.pop('color', None)

@property
def color(self):
Expand Down Expand Up @@ -131,15 +134,17 @@ class PrintableCounter(BaseCounter):

__slots__ = ('enabled', '_fill', 'last_update', 'leave', 'min_delta', '_pinned', 'start')

def __init__(self, **kwargs):
def __init__(self, keywords=None, **kwargs):

super(PrintableCounter, self).__init__(**kwargs)
if keywords is not None: # pragma: no branch
kwargs = keywords
super(PrintableCounter, self).__init__(keywords=kwargs)

self.enabled = kwargs.get('enabled', True)
self.enabled = kwargs.pop('enabled', True)
self._fill = u' '
self.fill = kwargs.get('fill', u' ')
self.leave = kwargs.get('leave', True)
self.min_delta = kwargs.get('min_delta', 0.1)
self.fill = kwargs.pop('fill', u' ')
self.leave = kwargs.pop('leave', True)
self.min_delta = kwargs.pop('min_delta', 0.1)
self._pinned = False
self.last_update = self.start = time.time()

Expand Down
23 changes: 21 additions & 2 deletions enlighten/_counter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
"""

import platform
import re
import sys
import time
import warnings

from enlighten._basecounter import BaseCounter, PrintableCounter
from enlighten._util import format_time
from enlighten._util import EnlightenWarning, format_time

COUNTER_FMT = u'{desc}{desc_pad}{count:d} {unit}{unit_pad}' + \
u'[{elapsed}, {rate:.2f}{unit_pad}{unit}/s]{fill}'
Expand All @@ -40,6 +42,11 @@
except (AttributeError, TypeError): # pragma: no cover(Non-standard Terminal)
pass

# Reserved fields
COUNTER_FIELDS = {'count', 'desc', 'desc_pad', 'elapsed', 'rate', 'unit', 'unit_pad',
'bar', 'eta', 'len_total', 'percentage', 'total', 'fill'}
RE_SUBCOUNTER_FIELDS = re.compile(r'(?:count|percentage|eta|rate)_\d+')


class SubCounter(BaseCounter):
"""
Expand Down Expand Up @@ -391,7 +398,7 @@ class can be called directly. The output stream will default to :py:data:`sys.st
# pylint: disable=too-many-arguments
def __init__(self, **kwargs):

super(Counter, self).__init__(**kwargs)
super(Counter, self).__init__(keywords=kwargs)

# Accept additional_fields for backwards compatibility
self.fields = kwargs.pop('fields', kwargs.pop('additional_fields', {}))
Expand Down Expand Up @@ -502,6 +509,18 @@ def format(self, width=None, elapsed=None):

fields = self.fields.copy()
fields.update(self._fields)

# Warn on reserved fields
reserved_fields = (set(fields) & COUNTER_FIELDS) | set(
match.group() for match in (RE_SUBCOUNTER_FIELDS.match(key) for key in fields) if match
)
if reserved_fields:
warnings.warn(
'Ignoring reserved fields specified as user-defined fields: %s' %
', '.join(reserved_fields),
EnlightenWarning, stacklevel=2
)

fields.update({'bar': u'{0}',
'count': self.count,
'desc': self.desc or u'',
Expand Down
3 changes: 2 additions & 1 deletion enlighten/_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ def _add_counter(self, counter_class, *args, **kwargs):
"""

position = kwargs.pop('position', None)
autorefresh = kwargs.pop('autorefresh', False)

# List of counters to refresh due to new position
toRefresh = []
Expand All @@ -211,7 +212,7 @@ def _add_counter(self, counter_class, *args, **kwargs):

# Create counter
new = counter_class(*args, **kwargs)
if kwargs.pop('autorefresh', False):
if autorefresh:
self.autorefresh.append(new)

# Get pinned counters
Expand Down
19 changes: 17 additions & 2 deletions enlighten/_statusbar.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,13 @@
"""

import time
import warnings

from enlighten._basecounter import PrintableCounter
from enlighten._util import format_time, Justify
from enlighten._util import EnlightenWarning, format_time, Justify


STATUS_FIELDS = {'elapsed', 'fill'}


class StatusBar(PrintableCounter):
Expand Down Expand Up @@ -132,7 +136,8 @@ class StatusBar(PrintableCounter):
__slots__ = ('fields', '_justify', 'status_format', '_static', '_fields')

def __init__(self, *args, **kwargs):
super(StatusBar, self).__init__(**kwargs)

super(StatusBar, self).__init__(keywords=kwargs)

self.fields = kwargs.pop('fields', {})
self._justify = None
Expand Down Expand Up @@ -185,6 +190,16 @@ def format(self, width=None, elapsed=None):
else:
fields = self.fields.copy()
fields.update(self._fields)

# Warn on reserved fields
reserved_fields = (set(fields) & STATUS_FIELDS)
if reserved_fields:
warnings.warn(
'Ignoring reserved fields specified as user-defined fields: %s' %
', '.join(reserved_fields),
EnlightenWarning, stacklevel=2
)

elapsed = elapsed if elapsed is not None else self.elapsed
fields['elapsed'] = format_time(elapsed)
fields['fill'] = u'{0}'
Expand Down
6 changes: 6 additions & 0 deletions enlighten/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@
BASESTRING = str


class EnlightenWarning(Warning):
"""
Generic warning class for Enlighten
"""


def format_time(seconds):
"""
Args:
Expand Down
5 changes: 3 additions & 2 deletions enlighten/counter.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@ class Counter(_Counter): # pylint: disable=missing-docstring
def __init__(self, **kwargs):

manager = kwargs.get('manager', None)
stream = kwargs.pop('stream', sys.stdout)

if manager is None:
manager = get_manager(stream=kwargs.get('stream', sys.stdout),
counter_class=self.__class__, set_scroll=False)
manager = get_manager(stream=stream, counter_class=self.__class__, set_scroll=False)
manager.counters[self] = 1
kwargs['manager'] = manager

Expand Down
2 changes: 2 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@

if sys.version_info[0] < 3:
from StringIO import StringIO
PY2 = True
else:
from io import StringIO
PY2 = False

# pylint: enable=import-error

Expand Down
48 changes: 35 additions & 13 deletions tests/test_counter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@

import time

from enlighten import Counter, Manager
from enlighten import Counter, EnlightenWarning, Manager
import enlighten._counter
from enlighten._manager import NEEDS_UNICODE_HELP

from tests import TestCase, mock, MockManager, MockTTY, MockCounter
from tests import TestCase, mock, MockManager, MockTTY, MockCounter, PY2, unittest


# pylint: disable=missing-docstring, protected-access, too-many-public-methods
Expand Down Expand Up @@ -396,15 +396,15 @@ def test_auto_offset(self):
blueBarFormat = self.manager.term.blue(barFormat)
self.assertNotEqual(len(barFormat), len(blueBarFormat))

ctr = self.manager.counter(stream=self.tty.stdout, total=10, desc='Test',
unit='ticks', count=10, bar_format=barFormat)
ctr = self.manager.counter(total=10, desc='Test', unit='ticks',
count=10, bar_format=barFormat)
formatted1 = ctr.format(width=80)
self.assertEqual(len(formatted1), 80)
barLen1 = formatted1.count(BLOCK)

offset = len(self.manager.term.blue(''))
ctr = self.manager.counter(stream=self.tty.stdout, total=10, desc='Test',
unit='ticks', count=10, bar_format=blueBarFormat)
ctr = self.manager.counter(total=10, desc='Test', unit='ticks',
count=10, bar_format=blueBarFormat)
formatted2 = ctr.format(width=80)
self.assertEqual(len(formatted2), 80 + offset)
barLen2 = formatted2.count(BLOCK)
Expand All @@ -420,27 +420,27 @@ def test_offset(self):
u'[{elapsed}<{eta}, {rate:.2f}{unit_pad}{unit}/s]'
barFormat = self.manager.term.blue(barFormat)

ctr = self.manager.counter(stream=self.tty.stdout, total=10, desc='Test',
unit='ticks', count=10, bar_format=barFormat, offset=0)
ctr = self.manager.counter(total=10, desc='Test', unit='ticks',
count=10, bar_format=barFormat, offset=0)
formatted1 = ctr.format(width=80)
self.assertEqual(len(formatted1), 80)
barLen1 = formatted1.count(BLOCK)

offset = len(self.manager.term.blue(''))
ctr = self.manager.counter(stream=self.tty.stdout, total=10, desc='Test',
unit='ticks', count=10, bar_format=barFormat, offset=offset)
ctr = self.manager.counter(total=10, desc='Test', unit='ticks',
count=10, bar_format=barFormat, offset=offset)
formatted2 = ctr.format(width=80)
self.assertEqual(len(formatted2), 80 + offset)
barLen2 = formatted2.count(BLOCK)

self.assertTrue(barLen2 == barLen1 + offset)

# Test in counter format
ctr = self.manager.counter(stream=self.tty.stdout, total=10, count=50, offset=0)
ctr = self.manager.counter(total=10, count=50, offset=0)
formatted = ctr.format(width=80)
self.assertEqual(len(formatted), 80)

ctr = self.manager.counter(stream=self.tty.stdout, total=10, count=50, offset=10)
ctr = self.manager.counter(total=10, count=50, offset=10)
formatted = ctr.format(width=80)
self.assertEqual(len(formatted), 90)

Expand Down Expand Up @@ -692,7 +692,7 @@ def test_additional_fields_no_overwrite(self):
"""

bar_format = ctr_format = u'{arg1:s} {count:d}'
additional_fields = {'arg1': 'hello', 'count': 100000}
additional_fields = {'arg1': 'hello'}

ctr = Counter(stream=self.tty.stdout, total=10, count=1, bar_format=bar_format,
fields=additional_fields)
Expand Down Expand Up @@ -762,3 +762,25 @@ def test_fill(self):
ctr_format = u'{fill}HI{fill}'
ctr = Counter(stream=self.tty.stdout, count=1, counter_format=ctr_format, fill=u'-')
self.assertEqual(ctr.format(), u'-' * 39 + 'HI' + u'-' * 39)

@unittest.skipIf(PY2, 'Skip warnings tests in Python 2')
def test_reserve_fields(self):
"""
When reserved fields are used, a warning is raised
"""

ctr = Counter(stream=self.tty.stdout, total=10, count=1, fields={'elapsed': 'reserved'})
with self.assertWarnsRegex(EnlightenWarning, 'Ignoring reserved fields'):
ctr.format()

ctr = Counter(stream=self.tty.stdout, total=10, fields={'elapsed': 'reserved'})
with self.assertWarnsRegex(EnlightenWarning, 'Ignoring reserved fields'):
ctr.format()

ctr = Counter(stream=self.tty.stdout, total=10, count=1, elapsed='reserved')
with self.assertWarnsRegex(EnlightenWarning, 'Ignoring reserved fields'):
ctr.format()

ctr = Counter(stream=self.tty.stdout, total=10, elapsed='reserved')
with self.assertWarnsRegex(EnlightenWarning, 'Ignoring reserved fields'):
ctr.format()
19 changes: 16 additions & 3 deletions tests/test_statusbar.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
Test module for enlighten._statusbar
"""

from enlighten import Justify
from enlighten import EnlightenWarning, Justify

from tests import TestCase, MockManager, MockTTY, MockStatusBar
from tests import TestCase, MockManager, MockTTY, MockStatusBar, PY2, unittest


class TestStatusBar(TestCase):
Expand Down Expand Up @@ -145,9 +145,22 @@ def test_fill_uneven(self):
Extra fill should be equal
"""

print(self.manager.term.width)
sbar = self.manager.status_bar(
status_format=u'{fill}Helloooo!{fill}Woooorld!{fill}', fill='-'
)
self.assertEqual(sbar.format(),
u'-' * 20 + 'Helloooo!' + u'-' * 21 + 'Woooorld!' + u'-' * 21)

@unittest.skipIf(PY2, 'Skip warnings tests in Python 2')
def test_reserve_fields(self):
"""
When reserved fields are used, a warning is raised
"""

with self.assertWarnsRegex(EnlightenWarning, 'Ignoring reserved fields'):
self.manager.status_bar(status_format=u'Stage: {stage}, Fill: {fill}', stage=1,
fields={'fill': 'Reserved field'})

with self.assertWarnsRegex(EnlightenWarning, 'Ignoring reserved fields'):
self.manager.status_bar(status_format=u'Stage: {stage}, elapsed: {elapsed}', stage=1,
elapsed='Reserved field')

0 comments on commit 27cdf89

Please sign in to comment.