diff --git a/.env b/.env new file mode 100644 index 00000000..d847fa48 --- /dev/null +++ b/.env @@ -0,0 +1,3 @@ +SECRET_KEY=aawinjwol;egbfnjek bnl +DATABASE_URL=sqlite:///test.db +DEBUG=False \ No newline at end of file diff --git a/.gitignore b/.gitignore index 50d5b466..de284e96 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,7 @@ htmlcov/ .mypy_cache/ coverage.xml examples/ +example/ +.idea/ +build/ +site/ \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 200b51fe..3e3bdfe9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,3 +19,8 @@ mkdocstrings==0.17.0 # Packaging twine==3.7.1 wheel==0.37.0 +pydantic[dotenv]==1.9.0 +python-jose==3.3.0 +anyio==3.5.0 +wtforms==3.0.1 +typer==0.4.0 \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index 43a2c04e..b47a334b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -32,6 +32,9 @@ install_requires = sqlalchemy >=1.4, <1.5 wtforms >=3, <4 python-multipart + typer + pydantic[dotenv] + python-jose [options.package_data] sqladmin = py.typed @@ -66,3 +69,7 @@ exclude_lines = pragma: no cover pragma: nocover if TYPE_CHECKING: + +[options.entry_points] +console_scripts = + sqladmin = sqladmin_cli.__init__:main \ No newline at end of file diff --git a/sqladmin/application.py b/sqladmin/application.py index 71041703..dc3ea81b 100644 --- a/sqladmin/application.py +++ b/sqladmin/application.py @@ -1,20 +1,29 @@ from typing import TYPE_CHECKING, List, Type, Union +import anyio from jinja2 import ChoiceLoader, FileSystemLoader, PackageLoader +from sqlalchemy import select from sqlalchemy.engine import Engine -from sqlalchemy.ext.asyncio import AsyncEngine +from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession +from sqlalchemy.orm import Session, sessionmaker from starlette.applications import Starlette +from starlette.authentication import requires from starlette.exceptions import HTTPException +from starlette.middleware.authentication import AuthenticationMiddleware from starlette.requests import Request from starlette.responses import RedirectResponse, Response from starlette.routing import Mount, Route, Router from starlette.staticfiles import StaticFiles from starlette.templating import Jinja2Templates +from sqladmin.auth.hashers import make_password +from sqladmin.auth.middlewares import BasicAuthBackend +from sqladmin.auth.models import User +from sqladmin.auth.utils.token import create_access_token + if TYPE_CHECKING: from sqladmin.models import ModelAdmin - __all__ = [ "Admin", ] @@ -34,6 +43,7 @@ def __init__( base_url: str = "/admin", title: str = "Admin", logo_url: str = None, + language: str = None, ) -> None: self.app = app self.engine = engine @@ -41,6 +51,7 @@ def __init__( self._model_admins: List[Type["ModelAdmin"]] = [] self.templates = Jinja2Templates("templates") + self.templates.env.add_extension("jinja2.ext.i18n") self.templates.env.loader = ChoiceLoader( [ FileSystemLoader("templates"), @@ -130,6 +141,7 @@ def __init__( base_url: str = "/admin", title: str = "Admin", logo_url: str = None, + language: str = None, ) -> None: """ Args: @@ -137,20 +149,36 @@ def __init__( engine: SQLAlchemy engine instance. base_url: Base URL for Admin interface. title: Admin title. - logo: URL of logo to be displayed instead of title. + logo_url: URL of logo to be displayed instead of title. + language: Now it can write "zh_CN" or None. """ assert isinstance(engine, (Engine, AsyncEngine)) super().__init__( - app=app, engine=engine, base_url=base_url, title=title, logo_url=logo_url + app=app, + engine=engine, + base_url=base_url, + title=title, + logo_url=logo_url, + language=language, + ) + if isinstance(engine, Engine): + LocalSession = sessionmaker(bind=self.engine, class_=Session) + self.session = LocalSession() + self._sync = True + else: + LocalSession = sessionmaker(bind=self.engine, class_=AsyncSession) + self.session = LocalSession() + self._sync = False + app.add_middleware( + AuthenticationMiddleware, backend=BasicAuthBackend(self.session, self._sync) ) - statics = StaticFiles(packages=["sqladmin"]) router = Router( routes=[ Mount("/statics", app=statics, name="statics"), - Route("/", endpoint=self.index, name="index"), + Route("/", endpoint=self.index, name="index", methods=["GET", "POST"]), Route("/{identity}/list", endpoint=self.list, name="list"), Route("/{identity}/detail/{pk}", endpoint=self.detail, name="detail"), Route( @@ -165,20 +193,26 @@ def __init__( name="create", methods=["GET", "POST"], ), + Route( + "/login", + endpoint=self.login, + name="login", + methods=["GET", "POST"], + ), ] ) self.app.mount(base_url, app=router, name="admin") self.templates.env.globals["model_admins"] = self.model_admins + @requires("authenticated", redirect="admin:login") async def index(self, request: Request) -> Response: """Index route which can be overriden to create dashboards.""" - return self.templates.TemplateResponse("index.html", {"request": request}) + @requires("authenticated", redirect="admin:login") async def list(self, request: Request) -> Response: """List route to display paginated Model instances.""" - model_admin = self._find_model_admin(request.path_params["identity"]) page = int(request.query_params.get("page", 1)) @@ -202,6 +236,7 @@ async def list(self, request: Request) -> Response: return self.templates.TemplateResponse("list.html", context) + @requires("authenticated", redirect="admin:login") async def detail(self, request: Request) -> Response: """Detail route.""" @@ -222,9 +257,9 @@ async def detail(self, request: Request) -> Response: return self.templates.TemplateResponse("detail.html", context) + @requires("authenticated", redirect="admin:login") async def delete(self, request: Request) -> Response: """Delete route.""" - identity = request.path_params["identity"] model_admin = self._find_model_admin(identity) if not model_admin.can_delete: @@ -238,9 +273,9 @@ async def delete(self, request: Request) -> Response: return Response(content=request.url_for("admin:list", identity=identity)) + @requires("authenticated", redirect="admin:login") async def create(self, request: Request) -> Response: """Create model endpoint.""" - identity = request.path_params["identity"] model_admin = self._find_model_admin(identity) if not model_admin.can_create: @@ -272,3 +307,53 @@ async def create(self, request: Request) -> Response: request.url_for("admin:list", identity=identity), status_code=302, ) + + async def login(self, request: Request) -> Response: + context = { + "request": request, + "errinfo": "", + "username_err": False, + "password_err": False, + } + + if request.method == "GET": + return self.templates.TemplateResponse("login.html", context) + form = await request.form() + username = form.get("username") + raw_password = form.get("password") + + if not username: + context["username_err"] = True + return self.templates.TemplateResponse("login.html", context) + if not raw_password: + context["password_err"] = True + return self.templates.TemplateResponse("login.html", context) + if self._sync: + res = await anyio.to_thread.run_sync( + self.session.execute, + select(User.password) + .where(User.username == username, User.is_active == True) # noqa + .limit(1), + ) + else: + res = await self.session.execute( + select(User.password) + .where(User.username == username, User.is_active == True) # noqa + .limit(1) + ) + password = res.scalar_one_or_none() + if password is not None: + if make_password(raw_password) == password: + request.cookies.setdefault( + "access_token", + ) + res = RedirectResponse( + request.url_for( + "admin:index", + ), + ) + access_token = create_access_token({"username": username}) + res.set_cookie("access_token", access_token) + return res + context["errinfo"] = "e" + return self.templates.TemplateResponse("login.html", context) diff --git a/sqladmin/auth/__init__.py b/sqladmin/auth/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/sqladmin/auth/hashers.py b/sqladmin/auth/hashers.py new file mode 100644 index 00000000..98165b3d --- /dev/null +++ b/sqladmin/auth/hashers.py @@ -0,0 +1,24 @@ +import binascii +import hashlib + +from sqladmin.conf import settings + +SECRET_KEY = settings.SECRET_KEY + + +def make_password(raw_password: str) -> str: + password = hashlib.pbkdf2_hmac( + "sha256", raw_password.encode("utf-8"), SECRET_KEY.encode("utf-8"), 16 + ) + return binascii.hexlify(password).decode() + + +def verify_password(raw_password: str, password: str) -> bool: + random_salt = SECRET_KEY.encode("utf-8") + raw_password_bytes = hashlib.pbkdf2_hmac( + "sha256", raw_password.encode("utf-8"), random_salt, 16 + ) + if binascii.hexlify(raw_password_bytes).decode() == password: + return True + else: + return False diff --git a/sqladmin/auth/middlewares.py b/sqladmin/auth/middlewares.py new file mode 100644 index 00000000..321487b6 --- /dev/null +++ b/sqladmin/auth/middlewares.py @@ -0,0 +1,48 @@ +import typing + +import anyio +from sqlalchemy import select +from sqlalchemy.orm import sessionmaker +from starlette.authentication import AuthCredentials, AuthenticationBackend, SimpleUser +from starlette.requests import HTTPConnection + +from sqladmin.auth.models import User +from sqladmin.auth.utils.token import decode_access_token + + +class BasicAuthBackend(AuthenticationBackend): + def __init__(self, session: sessionmaker, _sync: bool): + self.session = session + self._sync = _sync + + async def authenticate( + self, conn: HTTPConnection + ) -> typing.Optional[typing.Tuple[AuthCredentials, SimpleUser]]: + access_token = conn.cookies.get("access_token") + if access_token: + try: + data = decode_access_token(access_token) + username = data["username"] + if self._sync: + res = await anyio.to_thread.run_sync( + self.session.execute, + select(User.username) + .where( + User.username == username, User.is_active == True # noqa + ) + .limit(1), + ) + else: + res = await self.session.execute( + select(User.username) + .where( + User.username == username, User.is_active == True # noqa + ) # noqa + .limit(1) + ) + if not res.scalar_one_or_none(): + return None + return AuthCredentials(["authenticated"]), SimpleUser(data["username"]) + except Exception as e: # noqa + print(e) + return None diff --git a/sqladmin/auth/models.py b/sqladmin/auth/models.py new file mode 100644 index 00000000..fa2b1ab1 --- /dev/null +++ b/sqladmin/auth/models.py @@ -0,0 +1,15 @@ +from sqlalchemy import Boolean, Column, Integer, String +from sqlalchemy.ext.declarative import declarative_base + +Base = declarative_base() + + +class User(Base): # type: ignore + __tablename__ = "auth_users" + + id = Column(Integer, primary_key=True) + username = Column(String(length=128), unique=True) + email = Column(String(length=128)) + password = Column(String(length=128)) + is_active = Column(Boolean, default=True) + # is_superuser = Column(Boolean) diff --git a/sqladmin/auth/utils/__init__.py b/sqladmin/auth/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/sqladmin/auth/utils/password.py b/sqladmin/auth/utils/password.py new file mode 100644 index 00000000..e69de29b diff --git a/sqladmin/auth/utils/token.py b/sqladmin/auth/utils/token.py new file mode 100644 index 00000000..7bd123b8 --- /dev/null +++ b/sqladmin/auth/utils/token.py @@ -0,0 +1,23 @@ +from datetime import datetime, timedelta + +from jose import jwt + +from sqladmin.conf import settings + +SECRET_KEY = settings.SECRET_KEY +ALGORITHM = settings.ALGORITHM +EXPIRES_DELTA = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) + + +def create_access_token(data: dict) -> str: + to_encode = data.copy() + expire = datetime.utcnow() + EXPIRES_DELTA + to_encode.update({"exp": expire}) + encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) + return encoded_jwt + + +def decode_access_token( + token: str, +) -> dict: + return jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) diff --git a/sqladmin/conf.py b/sqladmin/conf.py new file mode 100644 index 00000000..a0ff51fe --- /dev/null +++ b/sqladmin/conf.py @@ -0,0 +1,46 @@ +from pydantic import BaseSettings, validator + + +class Settings(BaseSettings): + SECRET_KEY: str + ALGORITHM: str = "HS256" + ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 * 24 + DEBUG: bool = True + DATABASE_URL: str # db connect url + + @validator("DEBUG", pre=True) + def get_debug(cls, v: str) -> bool: + if isinstance(v, str): + if v != "True": + return False + return True + + # sentry's config + + # SENTRY_DSN: Optional[HttpUrl] = None + # SENTRY_ENVIROMENT: str = "development" + # + # @validator("SENTRY_DSN", pre=True) + # def sentry_dsn_can_be_blank(cls, v: str) -> Optional[str]: + # if v and len(v) > 0: + # return v + # return None + + class Config: + case_sensitive = True + env_file = ".env" # default env file + + # init sentry + # def __init__(self): + # super(Settings, self).__init__() + # + # if self.SENTRY_DSN: + # import sentry_sdk + # + # sentry_sdk.init( + # dsn=self.SENTRY_DSN, + # environment=self.SENTRY_ENVIROMENT, + # ) + + +settings = Settings() diff --git a/sqladmin/statics/img/logo.svg b/sqladmin/statics/img/logo.svg new file mode 100644 index 00000000..109341aa --- /dev/null +++ b/sqladmin/statics/img/logo.svg @@ -0,0 +1,4 @@ + + + + diff --git a/sqladmin/statics/js/logout.js b/sqladmin/statics/js/logout.js new file mode 100644 index 00000000..8e3d1233 --- /dev/null +++ b/sqladmin/statics/js/logout.js @@ -0,0 +1,12 @@ +//设置cookie +function setCookie(cname, cvalue, exdays) { + var d = new Date(); + d.setTime(d.getTime() + (exdays * 24 * 60 * 60 * 1000)); + var expires = "expires=" + d.toUTCString(); + document.cookie = cname + "=" + cvalue + "; " + expires; +} + +function logout() { + setCookie("access_token", "", -1); + self.location = "/admin/login"; +} \ No newline at end of file diff --git a/sqladmin/templates/layout.html b/sqladmin/templates/layout.html index 8dc34c2d..317fb465 100644 --- a/sqladmin/templates/layout.html +++ b/sqladmin/templates/layout.html @@ -1,55 +1,61 @@ {% extends "base.html" %} {% block body %} -
- +
+
+ +
+
+
+
+ {% block content %} {% endblock %} +
+
+
+ {% block footer %} + {% endblock %}
-
- {% block footer %} - {% endblock %} - - {% endblock %} +{% block head %} + +{% endblock %} \ No newline at end of file diff --git a/sqladmin/templates/login.html b/sqladmin/templates/login.html new file mode 100644 index 00000000..77f70891 --- /dev/null +++ b/sqladmin/templates/login.html @@ -0,0 +1,68 @@ +{% extends "base.html" %} +{% block body %} +
+
+
+ +
+
+
+

Login to your account

+
+ + {% if username_err==true %} + +
Username can not be null.
+ {% else %} + + {% endif %} +
+
+ +
+ {% if password_err==true %} + +
Password can not be null.
+ {% else %} + + {% endif %} + + + + + +
+
+
+ +
+
+ {% if errinfo!="" %} + Username or password is error! + {% endif %} +
+ +
+
+
+
+ +{% endblock %} diff --git a/sqladmin_cli/__init__.py b/sqladmin_cli/__init__.py new file mode 100644 index 00000000..94477862 --- /dev/null +++ b/sqladmin_cli/__init__.py @@ -0,0 +1,33 @@ +import typer +from sqlalchemy import create_engine +from sqlalchemy.orm import Session + +from sqladmin.auth.hashers import make_password +from sqladmin.auth.models import User +from sqladmin.conf import settings + +app = typer.Typer() + + +@app.command() +def createmanager(username: str, password: str) -> None: + """ + create manager + """ + engine = create_engine(settings.DATABASE_URL, echo=settings.DEBUG) + user = User( + username=username, is_active=True, + password=make_password(password) + ) + session = Session(engine) + session.add(user) + session.commit() + print(f"create {username} success.") + + +def main() -> None: + app() + + +if __name__ == "__main__": + main() diff --git a/tests/__init__.py b/tests/__init__.py index e69de29b..f07776eb 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -0,0 +1,5 @@ +from sqladmin.auth.utils.token import create_access_token + + +def get_test_token(username: str) -> str: + return create_access_token({"username": username}) diff --git a/tests/test_admin_async.py b/tests/test_admin_async.py index 778bd6d7..e9f0a7e1 100644 --- a/tests/test_admin_async.py +++ b/tests/test_admin_async.py @@ -9,9 +9,14 @@ from starlette.testclient import TestClient from sqladmin import Admin, ModelAdmin +from sqladmin.auth.hashers import make_password +from sqladmin.auth.models import Base as AdminBase, User as AdminUser +from sqladmin.conf import settings +from tests import get_test_token from tests.common import TEST_DATABASE_URI_ASYNC pytestmark = pytest.mark.anyio +settings.DATABASE_URL = TEST_DATABASE_URI_ASYNC Base = declarative_base() # type: Any @@ -57,7 +62,17 @@ def __str__(self) -> str: async def prepare_database() -> AsyncGenerator[None, None]: async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) - + await conn.run_sync(AdminBase.metadata.create_all) + + res = await session.execute( + select(AdminUser).where(AdminUser.username == "root_async").limit(1) + ) + if not res.scalar_one_or_none(): + user = AdminUser( + username="root_async", is_active=True, password=make_password("root") + ) + session.add(user) + await session.commit() yield async with engine.begin() as conn: @@ -83,6 +98,7 @@ class AddressAdmin(ModelAdmin, model=Address): async def test_root_view() -> None: with TestClient(app) as client: + client.cookies.setdefault("access_token", get_test_token("root_async")) response = client.get("/admin") assert response.status_code == 200 @@ -92,6 +108,7 @@ async def test_root_view() -> None: async def test_invalid_list_page() -> None: with TestClient(app) as client: + client.cookies.setdefault("access_token", get_test_token("root_async")) response = client.get("/admin/example/list") assert response.status_code == 404 @@ -104,6 +121,7 @@ async def test_list_view_single_page() -> None: await session.commit() with TestClient(app) as client: + client.cookies.setdefault("access_token", get_test_token("root_async")) response = client.get("/admin/user/list") assert response.status_code == 200 @@ -129,6 +147,7 @@ async def test_list_view_multi_page() -> None: await session.commit() with TestClient(app) as client: + client.cookies.setdefault("access_token", get_test_token("root_async")) response = client.get("/admin/user/list") assert response.status_code == 200 @@ -142,6 +161,7 @@ async def test_list_view_multi_page() -> None: assert response.text.count('
  • ') == 1 with TestClient(app) as client: + client.cookies.setdefault("access_token", get_test_token("root_async")) response = client.get("/admin/user/list?page=3") assert response.status_code == 200 @@ -152,6 +172,7 @@ async def test_list_view_multi_page() -> None: assert response.text.count('
  • ') == 2 with TestClient(app) as client: + client.cookies.setdefault("access_token", get_test_token("root_async")) response = client.get("/admin/user/list?page=5") assert response.status_code == 200 @@ -177,6 +198,7 @@ async def test_list_page_permission_actions() -> None: await session.commit() with TestClient(app) as client: + client.cookies.setdefault("access_token", get_test_token("root_async")) response = client.get("/admin/user/list") assert response.status_code == 200 @@ -184,6 +206,7 @@ async def test_list_page_permission_actions() -> None: assert response.text.count('') == 10 with TestClient(app) as client: + client.cookies.setdefault("access_token", get_test_token("root_async")) response = client.get("/admin/address/list") assert response.status_code == 200 @@ -194,6 +217,7 @@ async def test_list_page_permission_actions() -> None: async def test_unauthorized_detail_page() -> None: with TestClient(app) as client: + client.cookies.setdefault("access_token", get_test_token("root_async")) response = client.get("/admin/address/detail/1") assert response.status_code == 401 @@ -201,6 +225,7 @@ async def test_unauthorized_detail_page() -> None: async def test_not_found_detail_page() -> None: with TestClient(app) as client: + client.cookies.setdefault("access_token", get_test_token("root_async")) response = client.get("/admin/user/detail/1") assert response.status_code == 404 @@ -217,6 +242,7 @@ async def test_detail_page() -> None: await session.commit() with TestClient(app) as client: + client.cookies.setdefault("access_token", get_test_token("root_async")) response = client.get("/admin/user/detail/1") assert response.status_code == 200 @@ -244,12 +270,14 @@ async def test_column_labels() -> None: await session.commit() with TestClient(app) as client: + client.cookies.setdefault("access_token", get_test_token("root_async")) response = client.get("/admin/user/list") assert response.status_code == 200 assert response.text.count("Email") == 1 with TestClient(app) as client: + client.cookies.setdefault("access_token", get_test_token("root_async")) response = client.get("/admin/user/detail/1") assert response.status_code == 200 @@ -258,6 +286,7 @@ async def test_column_labels() -> None: async def test_delete_endpoint_unauthorized_response() -> None: with TestClient(app) as client: + client.cookies.setdefault("access_token", get_test_token("root_async")) response = client.delete("/admin/address/delete/1") assert response.status_code == 401 @@ -265,6 +294,7 @@ async def test_delete_endpoint_unauthorized_response() -> None: async def test_delete_endpoint_not_found_response() -> None: with TestClient(app) as client: + client.cookies.setdefault("access_token", get_test_token("root_async")) response = client.delete("/admin/user/delete/1") assert response.status_code == 404 @@ -285,6 +315,7 @@ async def test_delete_endpoint() -> None: assert result.scalar_one() == 1 with TestClient(app) as client: + client.cookies.setdefault("access_token", get_test_token("root_async")) response = client.delete("/admin/user/delete/1") assert response.status_code == 200 @@ -297,6 +328,7 @@ async def test_create_endpoint_unauthorized_response() -> None: admin._model_admins[1].can_create = False with TestClient(app) as client: + client.cookies.setdefault("access_token", get_test_token("root_async")) response = client.get("/admin/address/create") assert response.status_code == 401 @@ -306,6 +338,7 @@ async def test_create_endpoint_unauthorized_response() -> None: async def test_create_endpoint_get_form() -> None: with TestClient(app) as client: + client.cookies.setdefault("access_token", get_test_token("root_async")) response = client.get("/admin/user/create") assert response.status_code == 200 @@ -326,6 +359,7 @@ async def test_create_endpoint_get_form() -> None: async def test_create_endpoint_post_form() -> None: data: dict = {"date_of_birth": "Wrong Date Format"} with TestClient(app) as client: + client.cookies.setdefault("access_token", get_test_token("root_async")) response = client.post("/admin/user/create", data=data) assert response.status_code == 400 @@ -335,6 +369,7 @@ async def test_create_endpoint_post_form() -> None: data = {"name": "SQLAlchemy"} with TestClient(app) as client: + client.cookies.setdefault("access_token", get_test_token("root_async")) response = client.post("/admin/user/create", data=data) stmt = select(func.count(User.id)) @@ -351,6 +386,7 @@ async def test_create_endpoint_post_form() -> None: data = {"user": user.id} with TestClient(app) as client: + client.cookies.setdefault("access_token", get_test_token("root_async")) response = client.post("/admin/address/create", data=data) stmt = select(func.count(Address.id)) @@ -366,6 +402,7 @@ async def test_create_endpoint_post_form() -> None: data = {"name": "SQLAdmin", "addresses": [address.id]} with TestClient(app) as client: + client.cookies.setdefault("access_token", get_test_token("root_async")) response = client.post("/admin/user/create", data=data) stmt = select(func.count(User.id)) @@ -378,3 +415,75 @@ async def test_create_endpoint_post_form() -> None: user = result.scalar_one() assert user.name == "SQLAdmin" assert user.addresses == [address] + + +async def test_login() -> None: + with TestClient(app) as client: + response = client.post( + "/admin/login", data={"username": "root_async", "password": "root"} + ) + + assert response.status_code == 307 + assert len(response.cookies.get("access_token")) > 0 + with TestClient(app) as client: + response = client.post( + "/admin/login", data={"username": "root", "password": "root2"} + ) + + assert response.status_code == 200 + assert ( + response.text.count( + 'Username or password is error!' + ) + == 1 + ) + with TestClient(app) as client: + response = client.post( + "/admin/login", data={"username": "root", "password": ""} + ) + + assert response.status_code == 200 + assert response.text.count("Password can not be null.") == 1 + with TestClient(app) as client: + response = client.post( + "/admin/login", data={"username": "", "password": "root"} + ) + + assert response.status_code == 200 + assert response.text.count("Username can not be null.") == 1 + + +def check_redirect(url: str, method: str = "GET") -> None: + with TestClient(app) as client: + response = client.get(url) + assert response.status_code == 200 or response.status_code == 307 + assert ( + response.text.count( + '

    Login to your account

    ' + ) + == 1 + ) + + +def test_redirect() -> None: + check_redirect("/admin") + check_redirect("/admin/address/detail/1") + check_redirect("/admin/address/list") + check_redirect("/admin/address/create") + with TestClient(app) as client: + response = client.delete("/admin/address/delete/1") + assert response.status_code == 303 + + +async def test_expire_time() -> None: + token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VybmFtZSI6InJvb3QiLCJleHAiOjE2NDMyNzYyMDl9._iFrRq5TqlomtM0bCr-p0L-VSst-bP9rZDf4s9Zc-1c" # noqa + with TestClient(app) as client: + client.cookies.setdefault("access_token", token) + response = client.get("/admin") + assert response.status_code == 200 or response.status_code == 307 + assert ( + response.text.count( + '

    Login to your account

    ' + ) + == 1 + ) diff --git a/tests/test_admin_sync.py b/tests/test_admin_sync.py index 7c0807ee..8555b9d4 100644 --- a/tests/test_admin_sync.py +++ b/tests/test_admin_sync.py @@ -17,10 +17,14 @@ from starlette.testclient import TestClient from sqladmin import Admin, ModelAdmin +from sqladmin.auth.hashers import make_password, verify_password +from sqladmin.auth.models import Base as AdminBase, User as AdminUser +from sqladmin.conf import settings +from tests import get_test_token from tests.common import TEST_DATABASE_URI_SYNC Base = declarative_base() # type: Any - +settings.DATABASE_URL = TEST_DATABASE_URI_SYNC engine = create_engine( TEST_DATABASE_URI_SYNC, connect_args={"check_same_thread": False} ) @@ -62,8 +66,19 @@ def __str__(self) -> str: @pytest.fixture(autouse=True, scope="function") def prepare_database() -> Generator[None, None, None]: Base.metadata.create_all(engine) + AdminBase.metadata.create_all(engine) + res = session.execute( + select(AdminUser).where(AdminUser.username == "root_sync").limit(1) + ) + if not res.scalar_one_or_none(): + user = AdminUser( + username="root_sync", is_active=True, password=make_password("root") + ) + session.add(user) + session.commit() yield Base.metadata.drop_all(engine) + AdminBase.metadata.drop_all(engine) class UserAdmin(ModelAdmin, model=User): @@ -85,6 +100,7 @@ class AddressAdmin(ModelAdmin, model=Address): def test_root_view() -> None: with TestClient(app) as client: + client.cookies.setdefault("access_token", get_test_token("root_sync")) response = client.get("/admin") assert response.status_code == 200 @@ -94,6 +110,8 @@ def test_root_view() -> None: def test_invalid_list_page() -> None: with TestClient(app) as client: + client.cookies.setdefault("access_token", get_test_token("root_sync")) + client.cookies.setdefault("access_token", get_test_token("root_sync")) response = client.get("/admin/example/list") assert response.status_code == 404 @@ -106,6 +124,7 @@ def test_list_view_single_page() -> None: session.commit() with TestClient(app) as client: + client.cookies.setdefault("access_token", get_test_token("root_sync")) response = client.get("/admin/user/list") assert response.status_code == 200 @@ -131,6 +150,7 @@ def test_list_view_multi_page() -> None: session.commit() with TestClient(app) as client: + client.cookies.setdefault("access_token", get_test_token("root_sync")) response = client.get("/admin/user/list") assert response.status_code == 200 @@ -144,6 +164,7 @@ def test_list_view_multi_page() -> None: assert response.text.count('
  • ') == 1 with TestClient(app) as client: + client.cookies.setdefault("access_token", get_test_token("root_sync")) response = client.get("/admin/user/list?page=3") assert response.status_code == 200 @@ -154,6 +175,7 @@ def test_list_view_multi_page() -> None: assert response.text.count('
  • ') == 2 with TestClient(app) as client: + client.cookies.setdefault("access_token", get_test_token("root_sync")) response = client.get("/admin/user/list?page=5") assert response.status_code == 200 @@ -179,6 +201,7 @@ def test_list_page_permission_actions() -> None: session.commit() with TestClient(app) as client: + client.cookies.setdefault("access_token", get_test_token("root_sync")) response = client.get("/admin/user/list") assert response.status_code == 200 @@ -186,6 +209,7 @@ def test_list_page_permission_actions() -> None: assert response.text.count('') == 10 with TestClient(app) as client: + client.cookies.setdefault("access_token", get_test_token("root_sync")) response = client.get("/admin/address/list") assert response.status_code == 200 @@ -196,6 +220,7 @@ def test_list_page_permission_actions() -> None: def test_unauthorized_detail_page() -> None: with TestClient(app) as client: + client.cookies.setdefault("access_token", get_test_token("root_sync")) response = client.get("/admin/address/detail/1") assert response.status_code == 401 @@ -203,6 +228,7 @@ def test_unauthorized_detail_page() -> None: def test_not_found_detail_page() -> None: with TestClient(app) as client: + client.cookies.setdefault("access_token", get_test_token("root_sync")) response = client.get("/admin/user/detail/1") assert response.status_code == 404 @@ -219,6 +245,7 @@ def test_detail_page() -> None: session.commit() with TestClient(app) as client: + client.cookies.setdefault("access_token", get_test_token("root_sync")) response = client.get("/admin/user/detail/1") assert response.status_code == 200 @@ -246,12 +273,14 @@ def test_column_labels() -> None: session.commit() with TestClient(app) as client: + client.cookies.setdefault("access_token", get_test_token("root_sync")) response = client.get("/admin/user/list") assert response.status_code == 200 assert response.text.count("Email") == 1 with TestClient(app) as client: + client.cookies.setdefault("access_token", get_test_token("root_sync")) response = client.get("/admin/user/detail/1") assert response.status_code == 200 @@ -260,6 +289,7 @@ def test_column_labels() -> None: def test_delete_endpoint_unauthorized_response() -> None: with TestClient(app) as client: + client.cookies.setdefault("access_token", get_test_token("root_sync")) response = client.delete("/admin/address/delete/1") assert response.status_code == 401 @@ -267,6 +297,7 @@ def test_delete_endpoint_unauthorized_response() -> None: def test_delete_endpoint_not_found_response() -> None: with TestClient(app) as client: + client.cookies.setdefault("access_token", get_test_token("root_sync")) response = client.delete("/admin/user/delete/1") assert response.status_code == 404 @@ -281,6 +312,7 @@ def test_delete_endpoint() -> None: assert session.query(User).count() == 1 with TestClient(app) as client: + client.cookies.setdefault("access_token", get_test_token("root_sync")) response = client.delete("/admin/user/delete/1") assert response.status_code == 200 @@ -291,6 +323,7 @@ def test_create_endpoint_unauthorized_response() -> None: admin._model_admins[1].can_create = False with TestClient(app) as client: + client.cookies.setdefault("access_token", get_test_token("root_sync")) response = client.get("/admin/address/create") assert response.status_code == 401 @@ -300,6 +333,7 @@ def test_create_endpoint_unauthorized_response() -> None: def test_create_endpoint_get_form() -> None: with TestClient(app) as client: + client.cookies.setdefault("access_token", get_test_token("root_sync")) response = client.get("/admin/user/create") assert response.status_code == 200 @@ -320,6 +354,7 @@ def test_create_endpoint_get_form() -> None: def test_create_endpoint_post_form() -> None: data: dict = {"birthdate": "Wrong Date Format"} with TestClient(app) as client: + client.cookies.setdefault("access_token", get_test_token("root_sync")) response = client.post("/admin/user/create", data=data) assert response.status_code == 400 @@ -329,6 +364,7 @@ def test_create_endpoint_post_form() -> None: data = {"name": "SQLAlchemy"} with TestClient(app) as client: + client.cookies.setdefault("access_token", get_test_token("root_sync")) response = client.post("/admin/user/create", data=data) stmt = select(func.count(User.id)) @@ -343,6 +379,7 @@ def test_create_endpoint_post_form() -> None: data = {"user": user.id} with TestClient(app) as client: + client.cookies.setdefault("access_token", get_test_token("root_sync")) response = client.post("/admin/address/create", data=data) stmt = select(func.count(Address.id)) @@ -356,6 +393,7 @@ def test_create_endpoint_post_form() -> None: data = {"name": "SQLAdmin", "addresses": [address.id]} with TestClient(app) as client: + client.cookies.setdefault("access_token", get_test_token("root_sync")) response = client.post("/admin/user/create", data=data) stmt = select(func.count(User.id)) @@ -366,3 +404,88 @@ def test_create_endpoint_post_form() -> None: user = session.execute(stmt).scalar_one() assert user.name == "SQLAdmin" assert user.addresses == [address] + + +def test_login() -> None: + user = AdminUser( + username="root", + is_active=True, + email="test@email.com", + password=make_password("root"), + ) + session.add(user) + session.commit() + assert not verify_password("root1", "root2") + pd = make_password("root1") + assert verify_password("root1", pd) + with TestClient(app) as client: + response = client.post( + "/admin/login", data={"username": "root", "password": "root"} + ) + + assert response.status_code == 307 + assert len(response.cookies.get("access_token")) > 0 + with TestClient(app) as client: + response = client.post( + "/admin/login", data={"username": "root", "password": "root1"} + ) + + assert response.status_code == 200 + assert ( + response.text.count( + 'Username or password is error!' + ) + == 1 + ) + with TestClient(app) as client: + response = client.post( + "/admin/login", data={"username": "root", "password": ""} + ) + + assert response.status_code == 200 + assert response.text.count("Password can not be null.") == 1 + with TestClient(app) as client: + response = client.post( + "/admin/login", data={"username": "", "password": "root"} + ) + + assert response.status_code == 200 + assert response.text.count("Username can not be null.") == 1 + + +def check_redirect( + url: str, +) -> None: + with TestClient(app) as client: + response = client.get(url) + assert response.status_code == 200 or response.status_code == 307 + assert ( + response.text.count( + '

    Login to your account

    ' + ) + == 1 + ) + + +def test_redirect() -> None: + check_redirect("/admin") + check_redirect("/admin/address/detail/1") + check_redirect("/admin/address/list") + check_redirect("/admin/address/create") + with TestClient(app) as client: + response = client.delete("/admin/address/delete/1") + assert response.status_code == 303 + + +def test_expire_time() -> None: + token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VybmFtZSI6InJvb3QiLCJleHAiOjE2NDMyNzYyMDl9._iFrRq5TqlomtM0bCr-p0L-VSst-bP9rZDf4s9Zc-1c" # noqa + with TestClient(app) as client: + client.cookies.setdefault("access_token", token) + response = client.get("/admin") + assert response.status_code == 200 or response.status_code == 307 + assert ( + response.text.count( + '

    Login to your account

    ' + ) + == 1 + ) diff --git a/tests/test_application.py b/tests/test_application.py index e7765e77..70e9f885 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -1,17 +1,23 @@ -from typing import Any +from typing import Any, Generator -from sqlalchemy import create_engine +import pytest +from sqlalchemy import create_engine, select from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import Session, sessionmaker from starlette.applications import Starlette from starlette.testclient import TestClient from sqladmin import Admin +from sqladmin.auth.hashers import make_password +from sqladmin.auth.models import Base as AdminBase, User as AdminUser +from tests import get_test_token from tests.common import TEST_DATABASE_URI_SYNC Base = declarative_base() # type: Any -engine = create_engine(TEST_DATABASE_URI_SYNC) +engine = create_engine( + TEST_DATABASE_URI_SYNC, connect_args={"check_same_thread": False} +) LocalSession = sessionmaker(bind=engine) @@ -20,10 +26,30 @@ app = Starlette() +@pytest.fixture(autouse=True, scope="function") +def prepare_database() -> Generator[None, None, None]: + Base.metadata.create_all(engine) + AdminBase.metadata.create_all(engine) + res = session.execute( + select(AdminUser).where(AdminUser.username == "root_sync").limit(1) + ) + if not res.scalar_one_or_none(): + user = AdminUser( + username="root_sync", is_active=True, password=make_password("root") + ) + session.add(user) + session.commit() + yield + Base.metadata.drop_all(engine) + AdminBase.metadata.drop_all(engine) + + def test_application_title() -> None: Admin(app=app, engine=engine) with TestClient(app) as client: + client.cookies.setdefault("access_token", get_test_token("root_sync")) + response = client.get("/admin") assert response.status_code == 200 @@ -39,6 +65,7 @@ def test_application_logo() -> None: ) with TestClient(app) as client: + client.cookies.setdefault("access_token", get_test_token("root_sync")) response = client.get("/dashboard") assert response.status_code == 200