Skip to content
This repository has been archived by the owner on Dec 15, 2020. It is now read-only.

Commit

Permalink
Fix all missing mypy linting warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
Jackevansevo committed Apr 28, 2017
1 parent b74ada7 commit dd98e77
Show file tree
Hide file tree
Showing 13 changed files with 118 additions and 117 deletions.
1 change: 1 addition & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[flake8]
1 change: 0 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
language: python
python:
- "3.4"
- "3.5"
- "3.5-dev" # 3.5 development branch
- "3.6"
Expand Down
18 changes: 8 additions & 10 deletions basic_utils/core.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import os
from collections import defaultdict
from functools import partial, reduce
from functools import reduce
from itertools import chain
from operator import attrgetter
from typing import List

from basic_utils.primitives import sentinel
from typing import Any, List, Sequence, Tuple

__all__ = [
'clear', 'getattrs', 'map_getattr', 'recursive_default_dict', 'rgetattr',
Expand Down Expand Up @@ -39,29 +37,29 @@ def to_string(objects: List[object], sep: str = ", ") -> str:
return sep.join(map(str, objects))


def getattrs(obj, keys):
def getattrs(obj: object, keys: Sequence[str]) -> Tuple[Any, ...]:
"""Supports getting multiple attributes from a model at once"""
return tuple(map(partial(getattr, obj), keys))
return tuple(getattr(obj, key) for key in keys)


def map_getattr(attr, object_seq):
def map_getattr(attr: str, object_seq: Sequence[object]) -> Tuple[Any, ...]:
"""
Returns a map to retrieve a single attribute from a sequence of objects
"""
return tuple(map(attrgetter(attr), object_seq))


def recursive_default_dict():
def recursive_default_dict() -> defaultdict:
"""Returns a default dict that points to itself"""
return defaultdict(recursive_default_dict)


def rgetattr(obj: object, attrs: str, default=sentinel):
def rgetattr(obj: object, attrs: str) -> Any:
"""Get a nested attribute within an object"""
return reduce(getattr, chain([obj], attrs.split('.')))


def rsetattr(obj, attr, val):
def rsetattr(obj: object, attr: str, val: Any) -> None:
"""Sets a nested attribute within an object"""
pre, _, post = attr.rpartition('.')
return setattr(rgetattr(obj, pre) if pre else obj, post, val)
43 changes: 23 additions & 20 deletions basic_utils/datastructures.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import weakref

from collections import namedtuple
from typing import Any, Callable, Iterable, List, Optional

Branch = namedtuple('Branch', 'obj, value')

Expand All @@ -17,18 +17,18 @@ class DepthFirstIterator(object):
Derived from Python Cookbook 3rd Edition by David Beazley
"""

def __init__(self, start_node):
def __init__(self, start_node: Any) -> None:
self._node = start_node
self._children_iter = None
self._child_iter = None

def __iter__(self):
def __iter__(self) -> Any:
return self

def __next__(self):
def __next__(self) -> Any:
# Return myself if just started; create an iterator for children
if self._children_iter is None:
self._children_iter = iter(self._node)
self._children_iter = iter(self._node) # type: ignore
return self._node
# If processing a child, return its next item,
elif self._child_iter:
Expand All @@ -44,10 +44,10 @@ def __next__(self):
return next(self)


def format_tree(node, prefix='', key=None):
def format_tree(node: 'Node', prefix: str='', key: Callable=None) -> Iterable:
children = node.children

def make_branch(s, v):
def make_branch(s: str, v: str) -> str:
return ''.join([prefix, s, SIDE * 2, ' ', v])

if children:
Expand All @@ -68,40 +68,43 @@ class Node:
Derived from Python Cookbook 3rd Edition by David Beazley
"""

def __init__(self, value):
def __init__(self, value: Any) -> None:
self.value = value
self._parent = None
self._children = []
self._children = [] # type: List[Node]

def __repr__(self):
def __repr__(self) -> str:
return 'Node({!r:})'.format(self.value)

@property
def children(self):
def children(self) -> List['Node']:
return self._children

@property
def parent(self):
return self._parent if self._parent is None else self._parent()
def parent(self) -> Optional['Node']:
if self._parent is None:
return self._parent
else:
return self._parent()

@parent.setter
def parent(self, node):
self._parent = weakref.ref(node)
def parent(self, node: 'Node') -> None:
self._parent = weakref.ref(node) # type: ignore

def add_child(self, child):
def add_child(self, child: 'Node') -> None:
self._children.append(child)
child.parent = self

def add_children(self, children):
def add_children(self, children: Iterable['Node']) -> None:
for child in children:
self._children.append(child)
child.parent = self

def __iter__(self):
def __iter__(self) -> Iterable['Node']:
return iter(self._children)

def build_tree(self):
def build_tree(self) -> str:
return "\n".join([self.value, "\n".join(format_tree(self))])

def depth_first(self):
def depth_first(self) -> 'DepthFirstIterator':
return DepthFirstIterator(self)
6 changes: 2 additions & 4 deletions basic_utils/dates.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
from datetime import datetime, timedelta
from itertools import starmap
from typing import Iterator, NamedTuple, NewType, Tuple
from typing import Iterator, NamedTuple, Tuple

__all__ = ['dates_between', 'date_ranges_overlap']


DatePair = NewType('DatePair', Tuple[datetime, datetime])
DatePair = Tuple[datetime, datetime]

DateRange = NamedTuple('DateRange', [('start', datetime), ('end', datetime)])

# [TODO] Document functions


def dates_between(start: datetime, end: datetime) -> Iterator[datetime]:
"""Returns lazy sequence of dates between a start/end point"""
Expand Down
13 changes: 7 additions & 6 deletions basic_utils/helpers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from functools import reduce
from itertools import chain
from operator import getitem
from typing import Any, Callable, Iterable, Sequence
from typing import Any, Callable, Iterable, Sequence, Tuple

__all__ = [
'butlast', 'concat', 'cons', 'dedupe', 'dict_subset', 'first', 'flatten',
Expand Down Expand Up @@ -104,7 +104,7 @@ def partial_flatten(seq: Iterable) -> Iterable:
return chain.from_iterable(seq)


def dedupe(seq: Sequence, key=None):
def dedupe(seq: Sequence, key: Callable=None) -> Iterable:
"""
Removes duplicates from a sequence while maintaining order
Expand All @@ -119,7 +119,7 @@ def dedupe(seq: Sequence, key=None):
seen.add(val)


def get_keys(obj, keys, default=None):
def get_keys(d: dict, keys: Sequence[str], default: Callable=None) -> Tuple:
"""
Returns multiple values for keys in a dictionary
Expand All @@ -129,10 +129,11 @@ def get_keys(obj, keys, default=None):
>>> get_keys(d, ('x', 'y', 'z'))
(24, 25, None)
"""
return tuple(obj.get(key, default) for key in keys)
return tuple(d.get(key, default) for key in keys)


def dict_subset(d: dict, keys, prune=False, default=None):
def dict_subset(d, keys, prune=False, default=None):
# type: (dict, Sequence[str], bool, Callable) -> dict
"""
Returns a new dictionary with a subset of key value pairs from the original
Expand All @@ -157,7 +158,7 @@ def get_in_dict(d: dict, keys: Sequence[str]) -> Any:
return reduce(getitem, keys, d)


def set_in_dict(d: dict, keys: Sequence[str], value) -> None:
def set_in_dict(d: dict, keys: Sequence[str], value: Any) -> None:
"""
Sets a value inside a nested dictionary
Expand Down
2 changes: 1 addition & 1 deletion basic_utils/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def identity(x: Any) -> Any:
return x


def comp(*funcs):
def comp(*funcs: Callable) -> Callable:
"""
Takes a set of functions and returns a fn that is the composition
of those functions
Expand Down
3 changes: 3 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
[mypy]
ignore-missing-imports=True
strict_optional=True
disallow_untyped_defs=True
check_untyped_defs=True
2 changes: 1 addition & 1 deletion scripts/lint
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@ set -x

${PREFIX}flake8 basic_utils tests
${PREFIX}isort basic_utils tests --recursive --check-only
${PREFIX}mypy basic_utils tests --ignore-missing-imports
${PREFIX}mypy basic_utils tests
${PREFIX}python -m doctest basic_utils/[a-zA-Z]*.py
50 changes: 22 additions & 28 deletions tests/test_core.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from unittest.mock import MagicMock, mock_open, patch
from unittest.mock import MagicMock, Mock, mock_open, patch

import pytest # type: ignore

Expand All @@ -14,7 +14,7 @@
)


def test_slurp():
def test_slurp() -> None:
"""Tests that slurp reads in contents of a file as a string"""
data = "In the face of ambiguity, refuse the temptation to guess."
with patch("builtins.open", mock_open(read_data=data)) as mock_file:
Expand All @@ -27,7 +27,7 @@ def test_slurp():
('posix', 'clear'),
('nt', 'cls')
])
def test_clear(platform, expected):
def test_clear(platform: str, expected: str) -> None:
"""
Tests that os.system is called with the correct string corresponding to
the host OS name
Expand All @@ -38,32 +38,27 @@ def test_clear(platform, expected):
mock_os.system.assert_called_once_with(expected)


def test_to_string():
def test_to_string() -> None:
# Create two mock class instances which implement __str__
objectX, objectY = MagicMock(), MagicMock()
objectX.__str__.return_value = "Homer"
objectY.__str__.return_value = "Bart"
objectX = MagicMock(**{'__str__.return_value': 'Homer'})
objectY = MagicMock(**{'__str__.return_value': 'Bart'})
assert to_string([objectX, objectY]) == "Homer, Bart"
assert to_string([1, 2, 3]) == "1, 2, 3"


def test_getattrs():
# Create two mock class instances with a sample attribute
mock_obj = MagicMock()
mock_obj.name = 'Homer'
mock_obj.age = 39
assert getattrs(mock_obj, ('name', 'age')) == ('Homer', 39)
def test_getattrs() -> None:
# Create a single mock class instance with two sample attributes
mock_obj = Mock(forename='Homer', age=39)
assert getattrs(mock_obj, ('forename', 'age')) == ('Homer', 39)


def test_map_getattr():
def test_map_getattr() -> None:
# Create two mock class instances with a sample attribute
objectX, objectY = MagicMock(), MagicMock()
objectX.name = 'Homer'
objectY.name = 'Bart'
assert map_getattr('name', (objectX, objectY)) == ('Homer', 'Bart')
objectX, objectY = Mock(forename='Homer'), Mock(forename='Bart')
assert map_getattr('forename', (objectX, objectY)) == ('Homer', 'Bart')


def test_recursive_default_dict():
def test_recursive_default_dict() -> None:
"""
Tests that recursive data structure points to itself
"""
Expand All @@ -74,20 +69,19 @@ def test_recursive_default_dict():
class TestRecursiveGettersAndSetters:

@classmethod
def setup_class(cls):
cls.homer, cls.child = MagicMock(), MagicMock()
cls.homer.child = cls.child
cls.child.name = 'Bart'
def setup_class(cls) -> None:
cls.child = MagicMock(forename='Bart') # type: ignore
cls.homer = MagicMock(child=cls.child) # type: ignore

def test_rgetattr(self):
def test_rgetattr(self) -> None:
"""
Tests that rgetattr returns returns nested values within objects
"""
assert rgetattr(self.homer, 'child.name') == 'Bart'
assert rgetattr(self.homer, 'child.forename') == 'Bart' # type: ignore

def test_rsetattr(self):
def test_rsetattr(self) -> None:
"""
Tests that rsetattr sets the value of a nested attribute
"""
rsetattr(self.homer, 'child.name', 'Lisa')
assert self.child.name == 'Lisa'
rsetattr(self.homer, 'child.name', 'Lisa') # type: ignore
assert self.child.name == 'Lisa' # type: ignore
4 changes: 2 additions & 2 deletions tests/test_dates.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
from basic_utils.dates import date_ranges_overlap, dates_between


def test_dates_between():
def test_dates_between() -> None:
start = datetime(2016, 1, 1)
end = datetime(2016, 1, 3)
expected = starmap(datetime, ((2016, 1, 1), (2016, 1, 2), (2016, 1, 3)))
assert tuple(dates_between(start, end)) == tuple(expected)


def test_ranges_overlap():
def test_ranges_overlap() -> None:
rangeX = (datetime(2012, 1, 15), datetime(2012, 5, 10))
rangeY = (datetime(2012, 3, 20), datetime(2012, 9, 15))
assert date_ranges_overlap(rangeX, rangeY)
Expand Down

0 comments on commit dd98e77

Please sign in to comment.