Skip to content

Commit

Permalink
improve test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
alex-oleshkevich committed Mar 20, 2024
1 parent 3ef98d2 commit da2a076
Show file tree
Hide file tree
Showing 15 changed files with 414 additions and 46 deletions.
3 changes: 2 additions & 1 deletion ohmyadmin/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from starlette_babel.locale import get_language
from starlette_flash import flash

from ohmyadmin.authentication import AnonymousAuthPolicy, AuthPolicy
from ohmyadmin.authentication import AnonymousAuthPolicy, AuthPolicy, LoginRequiredMiddleware
from ohmyadmin.pages import Page
from ohmyadmin.templating import media_url, static_url, url_matches
from ohmyadmin.theme import Theme
Expand Down Expand Up @@ -54,6 +54,7 @@ def __init__(
AuthenticationMiddleware,
backend=self.auth_policy.get_authentication_backend(),
),
Middleware(LoginRequiredMiddleware, exclude_paths=["/login", "/static", "/site.webmanifest"]),
],
)

Expand Down
26 changes: 26 additions & 0 deletions ohmyadmin/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
import wtforms
from starlette.authentication import AuthCredentials, AuthenticationBackend, BaseUser, UnauthenticatedUser
from starlette.requests import HTTPConnection, Request
from starlette.responses import RedirectResponse
from starlette.types import ASGIApp, Receive, Scope, Send
from starlette_babel import gettext_lazy as _
from starlette_flash import flash

SESSION_KEY = "_ohmyadmin_user_id_"

Expand Down Expand Up @@ -71,3 +74,26 @@ async def authenticate(self, request: Request, identity: str, password: str) ->

async def load_user(self, conn: HTTPConnection, user_id: str) -> BaseUser | None:
return AnonymousUser()


class LoginRequiredMiddleware:
def __init__(self, app: ASGIApp, exclude_paths: list[str]) -> None:
self.app = app
self.exclude_paths = exclude_paths

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] not in ["http", "websocket"]: # pragma: no cover
await self.app(scope, receive, send)
return

conn = HTTPConnection(scope)
if conn.user.is_authenticated or any([excluded_path in conn.url.path for excluded_path in self.exclude_paths]):
await self.app(scope, receive, send)
return

if scope["type"] == "http":
flash(Request(scope)).error(_("You need to be logged in to access this page.", domain="ohmyadmin"))

redirect_to = conn.url_for("ohmyadmin.login").include_query_params(next=conn.url.path)
response = RedirectResponse(url=redirect_to, status_code=302)
await response(scope, receive, send)
6 changes: 3 additions & 3 deletions ohmyadmin/htmx.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def retarget(response: R, target: str) -> R:
return response


def trigger(response: R, event: str, data: typing.Any = None, stage: TriggerStage = "immediate") -> R:
def trigger(response: R, event: str, data: typing.Any = None, *, stage: TriggerStage = "immediate") -> R:
hx_event = {
"immediate": "hx-trigger",
"after-swap": "hx-trigger-after-swap",
Expand All @@ -91,7 +91,7 @@ def refresh(response: R) -> R:
return trigger(response, "refresh")


def toast(response: R, message: str, category: ToastCategory = "success", stage: TriggerStage = "immediate") -> R:
def toast(response: R, message: str, category: ToastCategory = "success", *, stage: TriggerStage = "immediate") -> R:
return trigger(response, "toast", {"message": message, "category": category}, stage=stage)


Expand Down Expand Up @@ -129,7 +129,7 @@ def reswap(self, target: SwapTarget) -> typing.Self:
return reswap(self, target)

def trigger(self, event: str, data: typing.Any = None, stage: TriggerStage = "immediate") -> typing.Self:
return trigger(self, event, data, stage)
return trigger(self, event, data, stage=stage)


def response(
Expand Down
6 changes: 3 additions & 3 deletions ohmyadmin/templating.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
from starlette.requests import Request


def static_url(request: Request, path: str) -> str:
def static_url(request: Request, path: str) -> URL:
if path.startswith("http://") or path.startswith("https://"):
return path
return URL(path)

url = request.url_for("ohmyadmin.static", path=path)
if request.app.debug:
url = url.include_query_params(_ts=time.time())
return str(url)
return url


def media_url(request: Request, path: str) -> URL:
Expand Down
3 changes: 1 addition & 2 deletions ohmyadmin/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ def get_attribute(self, selector: str, attribute: str, default: None = None) ->

def get_attribute(self, selector: str, attribute: str, default: str | None = None) -> str | None:
value = self.find_node_or_raise(selector).get(attribute, default)
if not isinstance(value, str):
raise TypeError("Attribute must be a string.")
assert not isinstance(value, list)
return value

def has_attribute(self, selector: str, attribute: str) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion tests/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@ async def authenticate(self, request: Request, identity: str, password: str) ->
return None

async def load_user(self, conn: HTTPConnection, user_id: str) -> BaseUser | None:
if str(self.user.id) == user_id:
if str(self.user.id) == str(user_id):
return self.user
return None
16 changes: 8 additions & 8 deletions tests/components/_test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,26 +22,26 @@ def compose(self, request: Request) -> Component:
return HelloComponent(child=WorldComponent())


def test_component(template_dir: pathlib.Path, http_get: Request) -> None:
def test_component(template_dir: pathlib.Path, request: Request) -> None:
(template_dir / "world.html").write_text("world")
component = WorldComponent()
assert component.render(http_get) == "world"
assert component.render(request) == "world"


def test_compose_component(template_dir: pathlib.Path, http_get: Request) -> None:
def test_compose_component(template_dir: pathlib.Path, request: Request) -> None:
(template_dir / "world.html").write_text("world")
(template_dir / "hello.html").write_text(
"{%- import 'ohmyadmin/components.html' as components -%}"
"hello {{ components.render_component(request, component.child) }}"
)
component = TwoComponents()
assert str(component.render(http_get)) == "hello world"
assert str(component.render(request)) == "hello world"


def test_builder(template_dir: pathlib.Path, http_get: Request) -> None:
def test_builder(template_dir: pathlib.Path, request: Request) -> None:
(template_dir / "world.html").write_text("world")
component = Builder(builder=lambda: WorldComponent())
assert component.render(http_get) == "world"
assert component.render(request) == "world"


@pytest.mark.parametrize(
Expand All @@ -51,7 +51,7 @@ def test_builder(template_dir: pathlib.Path, http_get: Request) -> None:
(False, "hello world"),
),
)
def test_when(template_dir: pathlib.Path, http_get: Request, expression: bool, expected: str) -> None:
def test_when(template_dir: pathlib.Path, request: Request, expression: bool, expected: str) -> None:
(template_dir / "world.html").write_text("world")
(template_dir / "hello.html").write_text(
"{%- import 'ohmyadmin/components.html' as components -%}"
Expand All @@ -62,4 +62,4 @@ def test_when(template_dir: pathlib.Path, http_get: Request, expression: bool, e
when_true=WorldComponent(),
when_false=TwoComponents(),
)
assert component.render(http_get) == expected
assert component.render(request) == expected
8 changes: 4 additions & 4 deletions tests/components/_test_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,18 @@ def render(self, request: Request) -> str:
return "CHILD"


def test_column(http_get: Request) -> None:
def test_column(request: Request) -> None:
component = layout.Column(children=[Child()], gap=5, colspan=5)
html = component.render(http_get)
html = component.render(request)
selector = MarkupSelector(html)
assert selector.find_node(".column-layout")
assert selector.has_class(".column-layout", "gap-5", "col-span-5")
assert selector.has_text(".column-layout", "CHILD")


def test_grid(http_get: Request) -> None:
def test_grid(request: Request) -> None:
component = layout.Grid(children=[Child()], columns=5, gap=5, colspan=5)
html = component.render(http_get)
html = component.render(request)
selector = MarkupSelector(html)
assert selector.find_node(".grid-layout")
assert selector.has_class(".grid-layout", "gap-5", "col-span-5", "grid-cols-5")
Expand Down
44 changes: 38 additions & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,33 +73,65 @@ def app(app_f: AppFactory, ohmyadmin: OhMyAdmin) -> Starlette:


class RequestFactory(typing.Protocol): # pragma: no cover:
def __call__(self, method: str = "get", type: str = "http") -> Request: ...
def __call__(
self,
method: str = "get",
path: str = "/",
*,
headers: typing.Sequence[tuple[bytes, bytes]] = tuple(),
type: str = "http",
state: dict[str, typing.Any] | None = None,
session: dict[str, typing.Any] | None = None,
) -> Request: ...


@pytest.fixture
def request_f(ohmyadmin: OhMyAdmin) -> RequestFactory:
def request_f(ohmyadmin: OhMyAdmin, app: Starlette) -> RequestFactory:
def factory(
method: str = "get",
path: str = "/",
*,
headers: typing.Sequence[tuple[bytes, bytes]] = tuple(),
type: str = "http",
state: dict[str, typing.Any] | None = None,
session: dict[str, typing.Any] | None = None,
) -> Request:
state = state or {}
state.update(
{
"ohmyadmin": ohmyadmin,
}
)
scope = {
"app": app,
"path": path,
"type": type,
"method": method,
"state": {
"ohmyadmin": ohmyadmin,
},
"state": state,
"headers": headers,
"router": app.router,
}
if session:
scope["session"] = session
return Request(scope)

return factory


@pytest.fixture
def http_get(request_f: RequestFactory) -> Request:
def http_request(request_f: RequestFactory) -> Request:
return request_f(method="get")


@pytest.fixture
def client(app: Starlette) -> typing.Generator[TestClient, None, None]:
with TestClient(app, follow_redirects=False) as client:
yield client


@pytest.fixture
def auth_client(client: TestClient, user: User) -> typing.Generator[TestClient, None, None]:
response = client.post("/admin/login", data={"identity": user.email, "password": "password"})
assert response.status_code == 302
assert "location" in response.headers
yield client
18 changes: 11 additions & 7 deletions tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,18 @@

@dataclasses.dataclass
class User(BaseUser):
id: int
first_name: str
last_name: str
email: str
is_active: bool
birthdate: datetime.date
password: str
id: int = 1
first_name: str = "John"
last_name: str = "Doe"
email: str = "john.doe@localhost"
is_active: bool = True
birthdate: datetime.date = datetime.date(1990, 1, 1)
password: str = "password"

@property
def identity(self) -> str:
return self.id

@property
def is_authenticated(self) -> bool:
return True
14 changes: 10 additions & 4 deletions tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from tests.conftest import AppFactory


def test_welcome_view(client: TestClient) -> None:
response = client.get("/admin/")
def test_welcome_view(auth_client: TestClient) -> None:
response = auth_client.get("/admin/")
assert response.status_code == 200


Expand Down Expand Up @@ -103,8 +103,14 @@ def test_static_files(client: TestClient) -> None:
assert response.headers["content-type"] == "image/png"


async def test_media_files(client: TestClient, file_storage: FileStorage) -> None:
async def test_media_files(auth_client: TestClient, file_storage: FileStorage) -> None:
await file_storage.write("text.txt", b"")
response = client.get("/admin/media/text.txt")
response = auth_client.get("/admin/media/text.txt")
assert response.status_code == 200
assert response.headers["content-type"] == "text/plain; charset=utf-8"


async def test_requires_auth(client: TestClient) -> None:
response = client.get("/admin/")
assert response.status_code == 302
assert response.headers["location"] == "http://testserver/admin/login?next=%2Fadmin%2F"
49 changes: 42 additions & 7 deletions tests/test_authentication.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from starlette.requests import Request
from starlette.authentication import AuthenticationBackend, BaseUser
from starlette.requests import HTTPConnection, Request
from starlette.testclient import TestClient

from ohmyadmin.authentication import AnonymousAuthPolicy
from ohmyadmin.app import OhMyAdmin
from ohmyadmin.authentication import AnonymousAuthPolicy, AuthPolicy, SESSION_KEY, SessionAuthBackend
from tests.auth import AuthTestPolicy
from tests.conftest import RequestFactory
from tests.models import User
Expand All @@ -13,9 +15,9 @@ def test_login_page() -> None: ...
def test_logout_page() -> None: ...


async def test_auth_policy(http_get: Request, user: User) -> None:
async def test_auth_policy(request: Request, user: User) -> None:
policy = AuthTestPolicy(user)
assert await policy.load_user(http_get, str(user.id)) == user
assert await policy.load_user(request, str(user.id)) == user


async def test_anonymous_auth_policy(request_f: RequestFactory) -> None:
Expand Down Expand Up @@ -53,10 +55,43 @@ async def test_login_redirects_to_next_url(client: TestClient, user: User) -> No
assert response.headers["location"] == "/admin"


async def test_logout(client: TestClient, user: User) -> None:
response = client.get("/admin/logout")
async def test_logout(auth_client: TestClient, user: User) -> None:
response = auth_client.get("/admin/logout")
assert response.status_code == 405

response = client.post("/admin/logout")
response = auth_client.post("/admin/logout")
assert response.status_code == 302
assert response.headers["location"] == "http://testserver/admin/login"


async def test_session_backend(ohmyadmin: OhMyAdmin, request_f: RequestFactory) -> None:
class TestPolicy(AuthPolicy):
async def authenticate(
self, request: Request, identity: str, password: str
) -> BaseUser | None: # pragma: no cover
return None

async def load_user(self, conn: HTTPConnection, user_id: str) -> BaseUser | None: # pragma: no cover
if user_id == "1":
return User(id=1)
return None

def get_authentication_backend(
self,
) -> AuthenticationBackend: # pragma: no cover
return SessionAuthBackend()

ohmyadmin.auth_policy = TestPolicy()
request = request_f(state={"app": ohmyadmin}, session={SESSION_KEY: "2"})
backend = SessionAuthBackend()
result = await backend.authenticate(request)
assert result
creds, user = result
assert not user.is_authenticated

request = request_f(state={"app": ohmyadmin}, session={SESSION_KEY: "1"})
backend = SessionAuthBackend()
result = await backend.authenticate(request)
assert result
creds, user = result
assert user.is_authenticated

0 comments on commit da2a076

Please sign in to comment.