Skip to content

Commit

Permalink
ruff: Isort and more (#143)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Borda and pre-commit-ci[bot] authored Jun 28, 2023
1 parent 086de04 commit de1efa9
Show file tree
Hide file tree
Showing 14 changed files with 75 additions and 48 deletions.
14 changes: 4 additions & 10 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,6 @@ repos:
args: [--py37-plus]
name: Upgrade code

- repo: https://github.com/PyCQA/isort
rev: 5.12.0
hooks:
- id: isort

- repo: https://github.com/psf/black
rev: 23.3.0
hooks:
Expand All @@ -54,15 +49,14 @@ repos:
hooks:
- id: yesqa
additional_dependencies:
- flake8-docstrings
- pep8-naming
- flake8-comprehensions
- flake8-pytest-style
- flake8-return
- flake8-simplify
- flake8-bandit
- flake8-builtins
- flake8-bugbear

- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.0.270
rev: v0.0.272
hooks:
- id: ruff
args: ["--fix"]
31 changes: 30 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,28 @@ ignore_missing_imports = true
select = [
"E", "W", # see: https://pypi.org/project/pycodestyle
"F", # see: https://pypi.org/project/pyflakes
"I", #see: https://pypi.org/project/isort/
"D", # see: https://pypi.org/project/pydocstyle
"N", # see: https://pypi.org/project/pep8-naming
"S", # see: https://pypi.org/project/flake8-bandit
]
extend-select = [
"A", # see: https://pypi.org/project/flake8-builtins
"B", # see: https://pypi.org/project/flake8-bugbear
"C4", # see: https://pypi.org/project/flake8-comprehensions
"PT", # see: https://pypi.org/project/flake8-pytest-style
"RET", # see: https://pypi.org/project/flake8-return
"SIM", # see: https://pypi.org/project/flake8-simplify
"YTT", # see: https://pypi.org/project/flake8-2020
"ANN", # see: https://pypi.org/project/flake8-annotations
"TID", # see: https://pypi.org/project/flake8-tidy-imports/
"T10", # see: https://pypi.org/project/flake8-debugger
"Q", # see: https://pypi.org/project/flake8-quotes
"RUF", # Ruff-specific rules
"EXE", # see: https://pypi.org/project/flake8-executable
"ISC", # see: https://pypi.org/project/flake8-implicit-str-concat
"PIE", # see: https://pypi.org/project/flake8-pie
"PLE", # see: https://pypi.org/project/pylint/
]
ignore = [
"E731",
Expand All @@ -100,20 +114,35 @@ extend-select = [
ignore-init-module-imports = true

[tool.ruff.per-file-ignores]
"setup.py" = ["D100", "SIM115"]
"setup.py" = ["ANN202", "D100", "SIM115"]
"__about__.py" = ["D100"]
"__init__.py" = ["D100"]
"src/**" = [
"ANN101", # Missing type annotation for `self` in method
"ANN102", # Missing type annotation for `cls` in classmethod
"ANN401", # Dynamically typed expressions (typing.Any)
"B905", # `zip()` without an explicit `strict=` parameter
"D100", # Missing docstring in public module
"D107", # Missing docstring in `__init__`
]
"tests/**" = [
"ANN001", # Missing type annotation for function argument
"ANN101", # Missing type annotation for `self` in method
"ANN201", # Missing return type annotation for public function
"ANN202", # Missing return type annotation for private function
"ANN204", # Missing return type annotation for special method
"ANN401", # Dynamically typed expressions (typing.Any)
"B905", # `zip()` without an explicit `strict=` parameter
"D100", # Missing docstring in public module
"D101", # Missing docstring in public class
"D102", # Missing docstring in public method
"D103", # Missing docstring in public function
"D104", # Missing docstring in public package
"D105", # Missing docstring in magic method
"D107", # Missing docstring in `__init__`
"S101", # Use of `assert` detected
"S311", # Standard pseudo-random generators are not suitable for cryptographic purposes
"B028", # No explicit `stacklevel` keyword argument found
]

[tool.ruff.pydocstyle]
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
_PATH_REQUIRE = os.path.join(_PATH_ROOT, "requirements")


def _load_py_module(fname, pkg="lightning_utilities"):
def _load_py_module(fname: str, pkg: str = "lightning_utilities"):
spec = spec_from_file_location(os.path.join(pkg, fname), os.path.join(_PATH_SOURCE, pkg, fname))
py = module_from_spec(spec)
spec.loader.exec_module(py)
Expand Down
2 changes: 1 addition & 1 deletion src/lightning_utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import os

from lightning_utilities.__about__ import * # noqa: F401, F403
from lightning_utilities.__about__ import * # noqa: F403
from lightning_utilities.core.apply_func import apply_to_collection
from lightning_utilities.core.enums import StrEnum
from lightning_utilities.core.imports import compare_version, module_available
Expand Down
5 changes: 3 additions & 2 deletions src/lightning_utilities/core/apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# http://www.apache.org/licenses/LICENSE-2.0
#
import dataclasses
from collections import defaultdict, OrderedDict
from collections import OrderedDict, defaultdict
from copy import deepcopy
from typing import Any, Callable, List, Mapping, Optional, Sequence, Tuple, Union

Expand Down Expand Up @@ -185,7 +185,8 @@ def apply_to_collections(
is_namedtuple_ = is_namedtuple(data1)
is_sequence = isinstance(data1, Sequence) and not isinstance(data1, str)
if (is_namedtuple_ or is_sequence) and data2 is not None:
assert len(data1) == len(data2), "Sequence collections have different sizes."
if len(data1) != len(data2):
raise ValueError("Sequence collections have different sizes.")
out = [
apply_to_collections(v1, v2, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs)
for v1, v2 in zip(data1, data2)
Expand Down
2 changes: 1 addition & 1 deletion src/lightning_utilities/core/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def try_from_str(cls, value: str, source: Literal["key", "value", "any"] = "key"
try:
return cls.from_str(value, source)
except ValueError:
warnings.warn(
warnings.warn( # noqa: B028
UserWarning(f"Invalid string: expected one of {cls._allowed_matches(source)}, but got {value}.")
)
return None
Expand Down
2 changes: 1 addition & 1 deletion src/lightning_utilities/install/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Generic Installation tools."""

from lightning_utilities.install.requirements import load_requirements, Requirement
from lightning_utilities.install.requirements import Requirement, load_requirements

__all__ = ["load_requirements", "Requirement"]
9 changes: 6 additions & 3 deletions src/lightning_utilities/install/requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ class _RequirementWithComment(Requirement):
def __init__(self, *args: Any, comment: str = "", pip_argument: Optional[str] = None, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.comment = comment
assert pip_argument is None or pip_argument # sanity check that it's not an empty str
if not (pip_argument is None or pip_argument): # sanity check that it's not an empty str
raise RuntimeError(f"wrong pip argument: {pip_argument}")
self.pip_argument = pip_argument
self.strict = self.strict_string in comment.lower()

Expand Down Expand Up @@ -109,8 +110,10 @@ def load_requirements(path_dir: str, file_name: str = "base.txt", unfreeze: str
>>> load_requirements(path_req, "docs.txt", unfreeze="major") # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
['sphinx<6.0,>=4.0', ...]
"""
assert unfreeze in {"none", "major", "all"}
if unfreeze not in {"none", "major", "all"}:
raise ValueError(f'unsupported option of "{unfreeze}"')
path = Path(path_dir) / file_name
assert path.exists(), (path_dir, file_name, path)
if not path.exists():
raise FileNotFoundError(f"missing file for {(path_dir, file_name, path)}")
text = path.read_text()
return [req.adjust(unfreeze) for req in _parse_requirements(text)]
31 changes: 17 additions & 14 deletions tests/unittests/core/test_apply_func.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import dataclasses
import numbers
from collections import defaultdict, namedtuple, OrderedDict
from collections import OrderedDict, defaultdict, namedtuple
from dataclasses import InitVar
from typing import Any, ClassVar, List, Optional

import pytest
from lightning_utilities.core.apply_func import apply_to_collection, apply_to_collections
from unittests.mocks import torch

from lightning_utilities.core.apply_func import apply_to_collection, apply_to_collections
_TENSOR_0 = torch.tensor(0)
_TENSOR_1 = torch.tensor(1)


@dataclasses.dataclass
Expand All @@ -29,10 +31,10 @@ class ModelExample:
label: torch.Tensor
some_constant: int = dataclasses.field(init=False)

def __post_init__(self): # noqa: D105
def __post_init__(self):
self.some_constant = 7

def __eq__(self, o: object) -> bool: # noqa: D105
def __eq__(self, o: object) -> bool:
if not isinstance(o, ModelExample):
return NotImplemented

Expand Down Expand Up @@ -64,11 +66,11 @@ class WithInitVar:
dummy: Any
override: InitVar[Optional[Any]] = None

def __post_init__(self, override: Optional[Any]): # noqa: D105
def __post_init__(self, override: Optional[Any]):
if override is not None:
self.dummy = override

def __eq__(self, o: object) -> bool: # noqa: D105
def __eq__(self, o: object) -> bool:
if not isinstance(o, WithInitVar):
return NotImplemented
if isinstance(self.dummy, torch.Tensor):
Expand All @@ -79,15 +81,16 @@ def __eq__(self, o: object) -> bool: # noqa: D105

@dataclasses.dataclass
class WithClassAndInitVar:
class_var: ClassVar[torch.Tensor] = torch.tensor(0)
class_var: ClassVar[torch.Tensor] = _TENSOR_0
dummy: Any
override: InitVar[Optional[Any]] = torch.tensor(1)
override: InitVar[Optional[Any]] = _TENSOR_1

def __post_init__(self, override: Optional[Any]): # noqa: D105
def __post_init__(self, override: Optional[Any]):
if override is not None:
self.dummy = override

def __eq__(self, o: object) -> bool: # noqa: D105
def __eq__(self, o: object) -> bool:
"""Equal."""
if not isinstance(o, WithClassAndInitVar):
return NotImplemented
if isinstance(self.dummy, torch.Tensor):
Expand Down Expand Up @@ -206,7 +209,7 @@ def _assert_dataclass_reduction(actual, expected, dataclass_type: str = ""):

# custom mappings
class _CustomCollection(dict):
def __init__(self, initial_dict):
def __init__(self, initial_dict) -> None:
super().__init__(initial_dict)

to_reduce = _CustomCollection({"a": 1, "b": 2, "c": 3})
Expand Down Expand Up @@ -262,7 +265,7 @@ def fn(a, b):
assert reduced == [1, 2, 3, 4]

# different sizes
with pytest.raises(AssertionError, match="Sequence collections have different sizes"):
with pytest.raises(ValueError, match="Sequence collections have different sizes"):
apply_to_collections([[1, 2], [3]], [4], int, fn)

def fn(a, b):
Expand Down Expand Up @@ -323,7 +326,7 @@ def fn(a, b):
def test_apply_to_collection_frozen_dataclass():
@dataclasses.dataclass(frozen=True)
class Foo:
input: int
var: int

foo = Foo(0)
with pytest.raises(ValueError, match="frozen dataclass was passed"):
Expand All @@ -333,7 +336,7 @@ class Foo:
def test_apply_to_collection_allow_frozen_dataclass():
@dataclasses.dataclass(frozen=True)
class Foo:
input: int
var: int

foo = Foo(0)
result = apply_to_collection(foo, int, lambda x: x + 1, allow_frozen=True)
Expand Down
5 changes: 2 additions & 3 deletions tests/unittests/core/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@
import re

import pytest

from lightning_utilities.core.imports import (
ModuleAvailableCache,
RequirementCache,
compare_version,
get_dependency_min_version_spec,
lazy_import,
module_available,
ModuleAvailableCache,
RequirementCache,
requires,
)

Expand Down
6 changes: 3 additions & 3 deletions tests/unittests/core/test_overrides.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from functools import partial, wraps
from typing import Any, Callable
from unittest.mock import Mock

import pytest

from lightning_utilities.core.overrides import is_overridden


Expand Down Expand Up @@ -36,14 +36,14 @@ def bar(self):
assert is_overridden("training_step", LightningModule(), parent=BoringModel)

class WrappedModel(TestModel):
def __new__(cls, *args, **kwargs):
def __new__(cls, *args: Any, **kwargs: Any):
obj = super().__new__(cls)
obj.foo = cls.wrap(obj.foo)
obj.bar = cls.wrap(obj.bar)
return obj

@staticmethod
def wrap(fn):
def wrap(fn) -> Callable:
@wraps(fn)
def wrapper():
fn()
Expand Down
1 change: 0 additions & 1 deletion tests/unittests/core/test_rank_zero.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import pytest

from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only


Expand Down
12 changes: 6 additions & 6 deletions tests/unittests/mocks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Iterable
from typing import Any, Iterable

from lightning_utilities.core.imports import package_available

Expand All @@ -7,7 +7,7 @@
else:
# minimal torch implementation to avoid installing torch in testing CI
class TensorMock:
def __init__(self, data):
def __init__(self, data) -> None:
self.data = data

def __add__(self, other):
Expand All @@ -32,7 +32,7 @@ def __iter__(self):
"""Iterate."""
return iter(self.data)

def __repr__(self):
def __repr__(self) -> str:
"""Return object representation."""
return repr(self.data)

Expand All @@ -44,15 +44,15 @@ class TorchMock:
Tensor = TensorMock

@staticmethod
def tensor(data):
def tensor(data: Any) -> TensorMock:
return TensorMock(data)

@staticmethod
def equal(a, b):
def equal(a: Any, b: Any) -> bool:
return a == b

@staticmethod
def arange(*args):
def arange(*args: Any) -> TensorMock:
return TensorMock(list(range(*args)))

torch = TorchMock()
1 change: 0 additions & 1 deletion tests/unittests/test/test_warnings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from re import escape

import pytest

from lightning_utilities.test.warning import no_warning_call


Expand Down

0 comments on commit de1efa9

Please sign in to comment.