Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement app typing (strict) #513

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
66 changes: 33 additions & 33 deletions aiohttp_jinja2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Awaitable,
Callable,
Dict,
Final,
Iterable,
Mapping,
Optional,
Expand All @@ -22,35 +23,35 @@
from aiohttp.abc import AbstractView

if sys.version_info >= (3, 8):
from typing import Protocol
from typing import Protocol, TypedDict
else:
from typing_extensions import Protocol
from typing_extensions import Protocol, TypedDict

from .helpers import GLOBAL_HELPERS
from .typedefs import Filters
from .typedefs import AppState as AppState, ContextProcessor, Filters

__version__ = "1.5"

__all__ = ("setup", "get_env", "render_template", "render_string", "template")


APP_CONTEXT_PROCESSORS_KEY = "aiohttp_jinja2_context_processors"
APP_KEY = "aiohttp_jinja2_environment"
REQUEST_CONTEXT_KEY = "aiohttp_jinja2_context"
APP_KEY: Final = "_aiohttp_jinja2_environment"
APP_CONTEXT_PROCESSORS_KEY: Final = "_aiohttp_jinja2_context_processors"
REQUEST_CONTEXT_KEY: Final = "_aiohttp_jinja2_context"

_TemplateReturnType = Awaitable[Union[web.StreamResponse, Mapping[str, Any]]]
_SimpleTemplateHandler = Callable[[web.Request], _TemplateReturnType]
_ContextProcessor = Callable[[web.Request], Awaitable[Dict[str, Any]]]
_SimpleTemplateHandler = Callable[[web.Request[_T]], _TemplateReturnType]

_T = TypeVar("_T")
_AbstractView = TypeVar("_AbstractView", bound=AbstractView)
_U = TypeVar("_U")
_AbstractView = TypeVar("_AbstractView", bound=AbstractView[Any])


class _TemplateWrapper(Protocol):
@overload
def __call__(
self, func: _SimpleTemplateHandler
) -> Callable[[web.Request], Awaitable[web.StreamResponse]]:
self, func: _SimpleTemplateHandler[_T]
) -> Callable[[web.Request[_T]], Awaitable[web.StreamResponse]]:
...

@overload
Expand All @@ -61,16 +62,16 @@ def __call__(

@overload
def __call__(
self, func: Callable[[_T, web.Request], _TemplateReturnType]
) -> Callable[[_T, web.Request], Awaitable[web.StreamResponse]]:
self, func: Callable[[_T, web.Request[_U]], _TemplateReturnType]
) -> Callable[[_T, web.Request[_U]], Awaitable[web.StreamResponse]]:
...


def setup(
app: web.Application,
app: web.Application[AppState],
*args: Any,
app_key: str = APP_KEY,
context_processors: Iterable[_ContextProcessor] = (),
context_processors: Iterable[ContextProcessor] = (),
filters: Optional[Filters] = None,
default_helpers: bool = True,
**kwargs: Any,
Expand All @@ -81,23 +82,23 @@ def setup(
env.globals.update(GLOBAL_HELPERS)
if filters is not None:
env.filters.update(filters)
app[app_key] = env
app.state["_aiohttp_jinja2_environment"] = env
if context_processors:
app[APP_CONTEXT_PROCESSORS_KEY] = context_processors
app.state["_aiohttp_jinja2_context_processors"] = context_processors
app.middlewares.append(context_processors_middleware)

env.globals["app"] = app

return env


def get_env(app: web.Application, *, app_key: str = APP_KEY) -> jinja2.Environment:
return cast(jinja2.Environment, app.get(app_key))
def get_env(app: web.Application[AppState]) -> jinja2.Environment:
return app.state["_aiohttp_jinja2_environment"]


def _render_string(
template_name: str,
request: web.Request,
request: web.Request[AppState],
context: Mapping[str, Any],
app_key: str,
) -> Tuple[jinja2.Template, Mapping[str, Any]]:
Expand Down Expand Up @@ -128,7 +129,7 @@ def _render_string(

def render_string(
template_name: str,
request: web.Request,
request: web.Request[AppState],
context: Mapping[str, Any],
*,
app_key: str = APP_KEY,
Expand All @@ -139,7 +140,7 @@ def render_string(

async def render_string_async(
template_name: str,
request: web.Request,
request: web.Request[AppState],
context: Mapping[str, Any],
*,
app_key: str = APP_KEY,
Expand All @@ -163,7 +164,7 @@ def _render_template(

def render_template(
template_name: str,
request: web.Request,
request: web.Request[AppState],
context: Optional[Mapping[str, Any]],
*,
app_key: str = APP_KEY,
Expand All @@ -177,7 +178,7 @@ def render_template(

async def render_template_async(
template_name: str,
request: web.Request,
request: web.Request[AppState],
context: Optional[Mapping[str, Any]],
*,
app_key: str = APP_KEY,
Expand All @@ -200,8 +201,8 @@ def template(
) -> _TemplateWrapper:
@overload
def wrapper(
func: _SimpleTemplateHandler,
) -> Callable[[web.Request], Awaitable[web.StreamResponse]]:
func: _SimpleTemplateHandler[_T],
) -> Callable[[web.Request[_T]], Awaitable[web.StreamResponse]]:
...

@overload
Expand All @@ -212,8 +213,8 @@ def wrapper(

@overload
def wrapper(
func: Callable[[_T, web.Request], _TemplateReturnType]
) -> Callable[[_T, web.Request], Awaitable[web.StreamResponse]]:
func: Callable[[_T, web.Request[_U]], _TemplateReturnType]
) -> Callable[[_T, web.Request[_U]], Awaitable[web.StreamResponse]]:
...

def wrapper(
Expand Down Expand Up @@ -257,17 +258,16 @@ async def wrapped(*args: Any) -> web.StreamResponse: # type: ignore[misc]


@web.middleware
async def context_processors_middleware(
request: web.Request,
handler: Callable[[web.Request], Awaitable[web.StreamResponse]],
async def context_processors_middleware( # type: ignore[misc]
request: web.Request[AppState],
handler: Callable[[web.Request[Any]], Awaitable[web.StreamResponse]],
) -> web.StreamResponse:

if REQUEST_CONTEXT_KEY not in request:
request[REQUEST_CONTEXT_KEY] = {}
for processor in request.config_dict[APP_CONTEXT_PROCESSORS_KEY]:
request[REQUEST_CONTEXT_KEY].update(await processor(request))
return await handler(request)


async def request_processor(request: web.Request) -> Dict[str, web.Request]:
async def request_processor(request: web.Request[_T]) -> Dict[str, web.Request[_T]]:
return {"request": request}
6 changes: 4 additions & 2 deletions aiohttp_jinja2/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@
from aiohttp import web
from yarl import URL

from .typedefs import AppState

if sys.version_info >= (3, 8):
from typing import TypedDict

class _Context(TypedDict, total=False):
app: web.Application
app: web.Application[AppState]


else:
Expand Down Expand Up @@ -70,7 +72,7 @@ def static_url(context: _Context, static_file_path: str) -> str:
"""
app = context["app"]
try:
static_url = app["static_root_url"]
static_url = app.state["static_root_url"]
except KeyError:
raise RuntimeError(
"app does not define a static root url "
Expand Down
24 changes: 23 additions & 1 deletion aiohttp_jinja2/typedefs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,26 @@
from typing import Callable, Iterable, Mapping, Tuple, Union
from typing import (
Any,
Awaitable,
Callable,
Dict,
Iterable,
Mapping,
Tuple,
TypedDict,
Union,
)

import jinja2
from aiohttp import web

ContextProcessor = Callable[[web.Request[Any]], Awaitable[Dict[str, Any]]]
Filter = Callable[..., str]
Filters = Union[Iterable[Tuple[str, Filter]], Mapping[str, Filter]]


class AppState(TypedDict, total=False):
"""App config used by aiohttp-jinja2."""

_aiohttp_jinja2_context_processors: Iterable[ContextProcessor]
_aiohttp_jinja2_environment: jinja2.Environment
static_root_url: str
6 changes: 6 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from aiohttp.web import Application, Request

from aiohttp_jinja2 import AppState

_App = Application[AppState]
_Request = Request[AppState]
28 changes: 18 additions & 10 deletions tests/test_context_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,28 @@

import aiohttp_jinja2

from .conftest import _App, _Request


async def test_context_processors(aiohttp_client):
@aiohttp_jinja2.template("tmpl.jinja2")
async def func(request):
return {"bar": 2}

app = web.Application(middlewares=[aiohttp_jinja2.context_processors_middleware])
app: _App = web.Application(
middlewares=[aiohttp_jinja2.context_processors_middleware]
)
aiohttp_jinja2.setup(
app,
loader=jinja2.DictLoader(
{"tmpl.jinja2": "foo: {{ foo }}, bar: {{ bar }}, path: {{ request.path }}"}
),
)

async def processor(request: web.Request) -> Dict[str, Union[str, int]]:
async def processor(request: _Request) -> Dict[str, Union[str, int]]:
return {"foo": 1, "bar": "should be overwriten"}

app["aiohttp_jinja2_context_processors"] = (
app.state[aiohttp_jinja2.APP_CONTEXT_PROCESSORS_KEY] = (
aiohttp_jinja2.request_processor,
processor,
)
Expand All @@ -42,7 +46,9 @@ async def test_nested_context_processors(aiohttp_client):
async def func(request):
return {"bar": 2}

subapp = web.Application(middlewares=[aiohttp_jinja2.context_processors_middleware])
subapp: _App = web.Application(
middlewares=[aiohttp_jinja2.context_processors_middleware]
)
aiohttp_jinja2.setup(
subapp,
loader=jinja2.DictLoader(
Expand All @@ -56,20 +62,22 @@ async def func(request):
async def subprocessor(request):
return {"foo": 1, "bar": "should be overwriten"}

subapp["aiohttp_jinja2_context_processors"] = (
subapp.state[aiohttp_jinja2.APP_CONTEXT_PROCESSORS_KEY] = (
aiohttp_jinja2.request_processor,
subprocessor,
)

subapp.router.add_get("/", func)

app = web.Application(middlewares=[aiohttp_jinja2.context_processors_middleware])
app: _App = web.Application(
middlewares=[aiohttp_jinja2.context_processors_middleware]
)
aiohttp_jinja2.setup(app, loader=jinja2.DictLoader({}))

async def processor(request):
return {"baz": 5}

app["aiohttp_jinja2_context_processors"] = (
app.state[aiohttp_jinja2.APP_CONTEXT_PROCESSORS_KEY] = (
aiohttp_jinja2.request_processor,
processor,
)
Expand All @@ -89,7 +97,7 @@ async def test_context_is_response(aiohttp_client):
async def func(request):
raise web.HTTPForbidden()

app = web.Application()
app: _App = web.Application()
aiohttp_jinja2.setup(app, loader=jinja2.DictLoader({"tmpl.jinja2": "template"}))

app.router.add_route("GET", "/", func)
Expand All @@ -107,7 +115,7 @@ async def func(request):
async def processor(request):
return {"foo": 1, "bar": "should be overwriten"}

app = web.Application()
app: _App = web.Application()
aiohttp_jinja2.setup(
app,
loader=jinja2.DictLoader(
Expand Down Expand Up @@ -139,7 +147,7 @@ async def func(request):
async def processor(request):
return {"foo": 1}

app = web.Application()
app: _App = web.Application()
aiohttp_jinja2.setup(
app,
loader=jinja2.DictLoader({"tmpl.jinja2": "foo: {{ foo }}"}),
Expand Down
4 changes: 3 additions & 1 deletion tests/test_jinja_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

import aiohttp_jinja2

from .conftest import _App


async def test_jinja_filters(aiohttp_client):
@aiohttp_jinja2.template("tmpl.jinja2")
Expand All @@ -12,7 +14,7 @@ async def index(request):
def add_2(value):
return value + 2

app = web.Application()
app: _App = web.Application()
aiohttp_jinja2.setup(
app,
loader=jinja2.DictLoader({"tmpl.jinja2": "{{ 5|add_2 }}"}),
Expand Down