Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type hints #70

Merged
merged 5 commits into from
Oct 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ coverage-xml:

.PHONY: type-check
type-check:
mypy
mypy $(type-check)

.PHONY: install
install:
Expand Down
30 changes: 4 additions & 26 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -119,41 +119,19 @@ plugins =
[mypy.plugins.django-stubs]
django_settings_module = tests.settings

# Initially ignore errors by default, and only fail for the modules marked
# below. Once we cover a significant portion of the codebase we should switch
# strategy to raise errors by default and mark modules that we haven't covered
# yet.
[mypy-bananas.*]
[mypy-bananas.admin.*]
ignore_errors = True

[mypy-bananas.query.*]
ignore_errors = True

[mypy-tests.*]
disallow_untyped_defs = False
ignore_errors = True

[mypy-bananas.url]
ignore_errors = False

[mypy-tests.test_db_url]
ignore_errors = False

[mypy-bananas.secrets]
ignore_errors = False

[mypy-tests.test_secrets]
ignore_errors = False

[mypy-bananas.environment]
ignore_errors = False

[mypy-test.support]
ignore_missing_imports = True

[mypy-bananas.drf.*]
ignore_errors = False

[mypy-tests.drf.fenced_api]
ignore_errors = False

[mypy-drf_yasg.*]
ignore_missing_imports = True

Expand Down
16 changes: 7 additions & 9 deletions src/bananas/__init__.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,23 @@
VERSION = (2, 0, 0, "final", 0)


def get_version(version=None):
def get_version() -> str:
"""Derives a PEP386-compliant version number from VERSION."""
if version is None:
version = VERSION
assert len(version) == 5
assert version[3] in ("alpha", "beta", "rc", "final")
assert len(VERSION) == 5
assert VERSION[3] in ("alpha", "beta", "rc", "final")

# Now build the two parts of the version number:
# main = X.Y[.Z]
# sub = .devN - for pre-alpha releases
# | {a|b|c}N - for alpha, beta and rc releases

parts = 2 if version[2] == 0 else 3
main = ".".join(str(x) for x in version[:parts])
parts = 2 if VERSION[2] == 0 else 3
main = ".".join(str(x) for x in VERSION[:parts])

sub = ""
if version[3] != "final":
if VERSION[3] != "final":
mapping = {"alpha": "a", "beta": "b", "rc": "c"}
sub = mapping[version[3]] + str(version[4])
sub = mapping[VERSION[3]] + str(VERSION[4])

return main + sub

Expand Down
15 changes: 10 additions & 5 deletions src/bananas/management/commands/show_urls.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import sys
from collections import OrderedDict
from typing import Callable, Dict, List, Optional, Tuple, Union

from django.core.management import BaseCommand
from django.urls import URLPattern, URLResolver, get_resolver


def collect_urls(urls=None, namespace=None, prefix=None):
def collect_urls(
urls: Union[URLResolver, URLPattern, Tuple[str, Callable], None] = None,
namespace: Optional[str] = None,
prefix: Optional[list] = None,
) -> List["OrderedDict"]:
if urls is None:
urls = get_resolver(urlconf=None)
prefix = prefix or []
Expand All @@ -27,18 +32,18 @@ def collect_urls(urls=None, namespace=None, prefix=None):
("name", urls.name),
("pattern", prefix + [pattern]),
("lookup_str", lookup_str),
("default_args", dict(urls.default_args)),
("default_args", dict(urls.default_args or {})),
]
)
]
else: # pragma: no cover
raise NotImplementedError(repr(urls))


def show_urls():
def show_urls() -> None:
all_urls = collect_urls()

max_lengths = {}
max_lengths: Dict[str, int] = {}
for u in all_urls:
for k in ["pattern", "default_args"]:
u[k] = str(u[k])
Expand All @@ -59,5 +64,5 @@ def show_urls():


class Command(BaseCommand):
def handle(self, *args, **kwargs):
def handle(self, *args: object, **kwargs: object) -> None:
show_urls()
5 changes: 3 additions & 2 deletions src/bananas/management/commands/syncpermissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ class Command(BaseCommand):

help = "Create admin permissions"

def handle(self, *args, **options):
def handle(self, *args: object, **options: object) -> None:
if args: # pragma: no cover
raise CommandError("Command doesn't accept any arguments")
return self.handle_noargs(**options)

def handle_noargs(self, *args, **options):
def handle_noargs(self, *args: object, **options: object) -> None:
from django.contrib import admin as django_admin
from django.contrib.contenttypes.models import ContentType

Expand All @@ -22,6 +22,7 @@ def handle_noargs(self, *args, **options):
for model, _ in admin.site._registry.items():
if issubclass(getattr(model, "View", object), admin.AdminView):
meta = model._meta
assert isinstance(meta.object_name, str)

ct, created = ContentType.objects.get_or_create(
app_label=meta.app_label, model=meta.object_name.lower()
Expand Down
75 changes: 48 additions & 27 deletions src/bananas/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,26 @@
import os
import uuid
from itertools import chain
from typing import Any, Dict, Mapping, Optional, Sized

from django.core.exceptions import ValidationError
from django.db import models
from django.utils.translation import gettext_lazy as _
from typing_extensions import Final

MISSING = object()

class Missing:
...

class ModelDict(dict):

_nested = None
MISSING: Final = Missing()

def __getattr__(self, item):

class ModelDict(Dict[str, Any]):

_nested: Optional[Dict[str, "ModelDict"]] = None

def __getattr__(self, item: str) -> Any:
"""
Try to to get attribute as key item.
Fallback on prefixed nested keys.
Expand All @@ -30,7 +37,7 @@ def __getattr__(self, item):
except KeyError:
return self.__getattribute__(item)

def __getnested__(self, item):
def __getnested__(self, item: str) -> "ModelDict":
"""
Find existing items prefixed with given item
and return a new ModelDict containing matched keys,
Expand All @@ -41,12 +48,12 @@ def __getnested__(self, item):
"""
# Ensure _nested cache
if self._nested is None:
self._nested = {}
self._nested: Dict[str, ModelDict] = {}

# Try to get previously accessed/cached nested item
value = self._nested.get(item, MISSING)

if value is not MISSING:
if not isinstance(value, Missing):
# Return previously accessed nested item
return value

Expand All @@ -66,7 +73,7 @@ def __getnested__(self, item):
# Item not a nested key, raise
raise KeyError(item)

def expand(self):
def expand(self) -> "ModelDict":
keys = list(self)
for key in keys:
field, __, nested_key = key.partition("__")
Expand All @@ -80,7 +87,8 @@ def expand(self):
return ModelDict(self)

@classmethod
def from_model(cls, model, *fields, **named_fields):
# Ignore types until no longer work-in-progress.
def from_model(cls, model, *fields, **named_fields): # type: ignore[no-untyped-def]
"""
Work-in-progress constructor,
consuming fields and values from django model instance.
Expand All @@ -89,7 +97,9 @@ def from_model(cls, model, *fields, **named_fields):

if not (fields or named_fields):
# Default to all fields
fields = [f.attname for f in model._meta.concrete_fields]
fields = [ # type: ignore[assignment]
f.attname for f in model._meta.concrete_fields
]

not_found = object()

Expand Down Expand Up @@ -148,7 +158,7 @@ class TimeStampedModel(models.Model):
class Meta:
abstract = True

def save(self, *args, **kwargs):
def save(self, *args: Any, **kwargs: Any) -> None:
if "update_fields" in kwargs and "date_modified" not in kwargs["update_fields"]:
update_fields = list(kwargs["update_fields"])
update_fields.append("date_modified")
Expand All @@ -170,7 +180,7 @@ class Meta:
class SecretField(models.CharField):
description = _("Generates and stores a random key.")

default_error_messages = {
default_error_messages: Mapping[str, str] = {
"random-is-none": _("%(cls)s.get_random_bytes returned None"),
"random-too-short": _(
"Too few random bytes received from "
Expand All @@ -181,23 +191,34 @@ class SecretField(models.CharField):
}

def __init__(
self, verbose_name=None, num_bytes=32, min_bytes=32, auto=True, **kwargs
self,
verbose_name: Optional[str] = None,
num_bytes: int = 32,
min_bytes: int = 32,
auto: bool = True,
**kwargs: Any,
):
self.num_bytes, self.auto, self.min_length = num_bytes, auto, min_bytes

field_length = self.get_field_length(self.num_bytes)

defaults = {"max_length": field_length}
defaults.update(kwargs)

if self.auto:
defaults["editable"] = False
defaults["blank"] = True
defaults: Mapping[str, object] = {
"max_length": field_length,
**kwargs,
**(
{
"editable": False,
"blank": True,
}
if self.auto
else {}
),
}

super().__init__(verbose_name, **defaults)

@staticmethod
def get_field_length(num_bytes):
def get_field_length(num_bytes: int) -> int:
"""
Return the length of hexadecimal byte representation of ``n`` bytes.

Expand All @@ -206,20 +227,20 @@ def get_field_length(num_bytes):
"""
return num_bytes * 2

def pre_save(self, model_instance, add):
def pre_save(self, model_instance: models.Model, add: bool) -> Any:
if self.auto and add:
value = self.get_random_str()
setattr(model_instance, self.attname, value)
return value
else:
return super().pre_save(model_instance, add)

def get_random_str(self):
def get_random_str(self) -> str:
random = self.get_random_bytes()
self._check_random_bytes(random)
return binascii.hexlify(random).decode("utf8")

def _check_random_bytes(self, random):
def _check_random_bytes(self, random: Optional[Sized]) -> None:
if random is None:
raise ValidationError(
self.error_messages["random-is-none"],
Expand All @@ -234,13 +255,13 @@ def _check_random_bytes(self, random):
params={"num_bytes": len(random), "min_length": self.min_length},
)

def get_random_bytes(self):
def get_random_bytes(self) -> bytes:
return os.urandom(self.num_bytes)


class URLSecretField(SecretField):
@staticmethod
def get_field_length(num_bytes):
def get_field_length(num_bytes: int) -> int:
"""
Get the maximum possible length of a base64 encoded bytearray of
length ``length``.
Expand All @@ -251,7 +272,7 @@ def get_field_length(num_bytes):
return math.ceil(num_bytes / 3.0) * 4

@staticmethod
def y64_encode(s):
def y64_encode(s: bytes) -> bytes:
"""
Implementation of Y64 non-standard URL-safe base64 variant.

Expand All @@ -263,7 +284,7 @@ def y64_encode(s):
first_pass = base64.urlsafe_b64encode(s)
return first_pass.translate(bytes.maketrans(b"+/=", b"._-"))

def get_random_str(self):
def get_random_str(self) -> str:
random = self.get_random_bytes()
self._check_random_bytes(random)
return self.y64_encode(random).decode("utf-8")