Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add the login function #29

Closed
wants to merge 12 commits into from
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions .env
@@ -0,0 +1,2 @@
SECRET_KEY=aawinjwol;egbfnjek bnl
DATABASE_URL=sqlite:///example.db
1 change: 1 addition & 0 deletions requirements.txt
Expand Up @@ -19,3 +19,4 @@ mkdocstrings==0.17.0
# Packaging
twine==3.7.1
wheel==0.37.0
pydantic==1.9.0
137 changes: 128 additions & 9 deletions sqladmin/application.py
@@ -1,6 +1,10 @@
import gettext
import os
from typing import TYPE_CHECKING, List, Type, Union

import anyio
from jinja2 import ChoiceLoader, FileSystemLoader, PackageLoader
from sqlalchemy import select
from sqlalchemy.engine import Engine
from sqlalchemy.ext.asyncio import AsyncEngine
from starlette.applications import Starlette
Expand All @@ -11,10 +15,13 @@
from starlette.staticfiles import StaticFiles
from starlette.templating import Jinja2Templates

from sqladmin.auth.hashers import make_password
from sqladmin.auth.models import User
from sqladmin.auth.utils.token import create_access_token, decode_access_token

if TYPE_CHECKING:
from sqladmin.models import ModelAdmin


__all__ = [
"Admin",
]
Expand All @@ -34,13 +41,26 @@ def __init__(
base_url: str = "/admin",
title: str = "Admin",
logo_url: str = None,
language: str = None,
) -> None:
self.app = app
self.engine = engine
self.base_url = base_url
self._model_admins: List[Type["ModelAdmin"]] = []

self.templates = Jinja2Templates("templates")
self.templates.env.add_extension("jinja2.ext.i18n")
if language:
translation = gettext.translation(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather do this in a separate changes to keep small PR and make review easier. What do you think?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll pay attention next time.

If you think I submit too many contents this time, I can submit them separately.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that would be great. If you could focus this PR on the login functionality and do other PRs for translation and other things I'd really appreciate that.
PRs are always appreciated :)

"lang",
os.path.dirname(__file__) + "/translations",
languages=[language],
)
self.templates.env.install_gettext_translations( # type: ignore
translation, newstyle=True
) # type: ignore
else:
self.templates.env.install_null_translations(newstyle=True) # type: ignore
self.templates.env.loader = ChoiceLoader(
[
FileSystemLoader("templates"),
Expand Down Expand Up @@ -100,6 +120,17 @@ class UserAdmin(ModelAdmin, model=User):
self._model_admins.append(model)


def check_token(request: Request) -> bool:
token = request.cookies.get("access_token")
if token:
try:
decode_access_token(token)
return True
except: # noqa
pass
return False


class Admin(BaseAdmin):
"""Main entrypoint to admin interface.

Expand Down Expand Up @@ -130,27 +161,34 @@ def __init__(
base_url: str = "/admin",
title: str = "Admin",
logo_url: str = None,
language: str = None,
) -> None:
"""
Args:
app: Starlette or FastAPI application.
engine: SQLAlchemy engine instance.
base_url: Base URL for Admin interface.
title: Admin title.
logo: URL of logo to be displayed instead of title.
logo_url: URL of logo to be displayed instead of title.
language: Now it can write "zh_CN" or None.
"""

assert isinstance(engine, (Engine, AsyncEngine))
super().__init__(
app=app, engine=engine, base_url=base_url, title=title, logo_url=logo_url
app=app,
engine=engine,
base_url=base_url,
title=title,
logo_url=logo_url,
language=language,
)

statics = StaticFiles(packages=["sqladmin"])

router = Router(
routes=[
Mount("/statics", app=statics, name="statics"),
Route("/", endpoint=self.index, name="index"),
Route("/", endpoint=self.index, name="index", methods=["GET", "POST"]),
Route("/{identity}/list", endpoint=self.list, name="list"),
Route("/{identity}/detail/{pk}", endpoint=self.detail, name="detail"),
Route(
Expand All @@ -165,6 +203,12 @@ def __init__(
name="create",
methods=["GET", "POST"],
),
Route(
"/login",
endpoint=self.login,
name="login",
methods=["GET", "POST"],
),
]
)
self.app.mount(base_url, app=router, name="admin")
Expand All @@ -173,12 +217,22 @@ def __init__(

async def index(self, request: Request) -> Response:
"""Index route which can be overriden to create dashboards."""

if not check_token(request):
return RedirectResponse(
url=request.url_for(
"admin:login",
),
)
return self.templates.TemplateResponse("index.html", {"request": request})

async def list(self, request: Request) -> Response:
"""List route to display paginated Model instances."""

if not check_token(request):
return RedirectResponse(
request.url_for(
"admin:login",
),
)
model_admin = self._find_model_admin(request.path_params["identity"])

page = int(request.query_params.get("page", 1))
Expand All @@ -204,7 +258,12 @@ async def list(self, request: Request) -> Response:

async def detail(self, request: Request) -> Response:
"""Detail route."""

if not check_token(request):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A @login_required would be easier here.

return RedirectResponse(
request.url_for(
"admin:login",
),
)
model_admin = self._find_model_admin(request.path_params["identity"])
if not model_admin.can_view_details:
return self._unathorized_response(request)
Expand All @@ -224,7 +283,12 @@ async def detail(self, request: Request) -> Response:

async def delete(self, request: Request) -> Response:
"""Delete route."""

if not check_token(request):
return RedirectResponse(
request.url_for(
"admin:login",
),
)
identity = request.path_params["identity"]
model_admin = self._find_model_admin(identity)
if not model_admin.can_delete:
Expand All @@ -240,7 +304,12 @@ async def delete(self, request: Request) -> Response:

async def create(self, request: Request) -> Response:
"""Create model endpoint."""

if not check_token(request):
return RedirectResponse(
request.url_for(
"admin:login",
),
)
identity = request.path_params["identity"]
model_admin = self._find_model_admin(identity)
if not model_admin.can_create:
Expand Down Expand Up @@ -272,3 +341,53 @@ async def create(self, request: Request) -> Response:
request.url_for("admin:list", identity=identity),
status_code=302,
)

async def login(self, request: Request) -> Response:
context = {
"request": request,
"errinfo": "",
"username_err": False,
"password_err": False,
}

if request.method == "GET":
return self.templates.TemplateResponse("login.html", context)
form = await request.form()
username = form.get("username")
raw_password = form.get("password")

if not username:
context["username_err"] = True
return self.templates.TemplateResponse("login.html", context)
if not raw_password:
context["password_err"] = True
return self.templates.TemplateResponse("login.html", context)
if isinstance(self.engine, Engine):
res = await anyio.to_thread.run_sync(
self.engine.execute,
select(User.password)
.where(User.username == username, User.is_active == True) # noqa
.limit(1),
)
else:
res = await self.engine.execute(
select(User.password)
.where(User.username == username, User.is_active == True) # noqa
.limit(1)
)
password = res.scalar_one_or_none()
if password is not None:
if make_password(raw_password) == password:
request.cookies.setdefault(
"access_token",
)
res = RedirectResponse(
request.url_for(
"admin:index",
),
)
access_token = create_access_token({"username": username})
res.set_cookie("access_token", access_token)
return res
context["errinfo"] = "e"
return self.templates.TemplateResponse("login.html", context)
Empty file added sqladmin/auth/__init__.py
Empty file.
24 changes: 24 additions & 0 deletions sqladmin/auth/hashers.py
@@ -0,0 +1,24 @@
import binascii
import hashlib

from sqladmin.conf import settings

SECRET_KEY = settings.SECRET_KEY


def make_password(raw_password: str) -> str:
password = hashlib.pbkdf2_hmac(
"sha256", raw_password.encode("utf-8"), SECRET_KEY.encode("utf-8"), 16
)
return binascii.hexlify(password).decode()


def verify_password(raw_password: str, password: str) -> bool:
random_salt = SECRET_KEY.encode("utf-8")
raw_password_bytes = hashlib.pbkdf2_hmac(
"sha256", raw_password.encode("utf-8"), random_salt, 16
)
if binascii.hexlify(raw_password_bytes).decode() == password:
return True
else:
return False
29 changes: 29 additions & 0 deletions sqladmin/auth/models.py
@@ -0,0 +1,29 @@
from sqlalchemy import Boolean, Column, Integer, String
from sqlalchemy.ext.declarative import declarative_base

from sqladmin.auth.hashers import make_password, verify_password

Base = declarative_base()


class User(Base): # type: ignore
aminalaee marked this conversation as resolved.
Show resolved Hide resolved
__tablename__ = "auth_users"

id = Column(Integer, primary_key=True)
username = Column(String(length=128), unique=True)
email = Column(String(length=128))
password = Column(String(length=128))
is_active = Column(Boolean, default=True)

# is_superuser = Column(Boolean)

def set_password(self, raw_password: str) -> None:
self.password = make_password(
raw_password,
)

def verify_password(self, raw_password: str) -> bool:
return verify_password(
raw_password,
self.password,
)
Empty file added sqladmin/auth/utils/__init__.py
Empty file.
23 changes: 23 additions & 0 deletions sqladmin/auth/utils/token.py
@@ -0,0 +1,23 @@
from datetime import datetime, timedelta

from jose import jwt

from sqladmin.conf import settings

SECRET_KEY = settings.SECRET_KEY
ALGORITHM = settings.ALGORITHM
EXPIRES_DELTA = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)


def create_access_token(data: dict) -> str:
to_encode = data.copy()
expire = datetime.utcnow() + EXPIRES_DELTA
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt


def decode_access_token(
token: str,
) -> dict:
return jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
46 changes: 46 additions & 0 deletions sqladmin/conf.py
@@ -0,0 +1,46 @@
from pydantic import BaseSettings, validator


class Settings(BaseSettings):
SECRET_KEY: str
ALGORITHM: str = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 * 24
DEBUG: bool = True
DATABASE_URL: str # db connect url

@validator("DEBUG", pre=True)
def get_debug(cls, v: str) -> bool:
if isinstance(v, str):
if v != "True":
return False
return True

# sentry's config

# SENTRY_DSN: Optional[HttpUrl] = None
# SENTRY_ENVIROMENT: str = "development"
#
# @validator("SENTRY_DSN", pre=True)
# def sentry_dsn_can_be_blank(cls, v: str) -> Optional[str]:
# if v and len(v) > 0:
# return v
# return None

class Config:
case_sensitive = True
env_file = ".env" # default env file

# init sentry
# def __init__(self):
# super(Settings, self).__init__()
#
# if self.SENTRY_DSN:
# import sentry_sdk
#
# sentry_sdk.init(
# dsn=self.SENTRY_DSN,
# environment=self.SENTRY_ENVIROMENT,
# )


settings = Settings()
4 changes: 4 additions & 0 deletions sqladmin/statics/img/logo.svg
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.