Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remote ajax refs #292

Merged
merged 12 commits into from Sep 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
91 changes: 91 additions & 0 deletions sqladmin/ajax.py
@@ -0,0 +1,91 @@
from typing import TYPE_CHECKING, Any, Dict, List

from sqlalchemy import String, cast, inspect, or_, select, text

from sqladmin.helpers import get_primary_key

if TYPE_CHECKING:
from sqladmin.models import ModelView


DEFAULT_PAGE_SIZE = 10


class QueryAjaxModelLoader:
def __init__(
self,
name: str,
model: type,
model_admin: "ModelView",
**options: Any,
):
self.name = name
self.model = model
self.model_admin = model_admin
self.fields = options.get("fields", {})
self.order_by = options.get("order_by")

if not self.fields:
raise ValueError(
"AJAX loading requires `fields` to be specified for "
f"{self.model}.{self.name}"
)

self._cached_fields = self._process_fields()
self.pk = get_primary_key(self.model)

def _process_fields(self) -> list:
remote_fields = []

for field in self.fields:
if isinstance(field, str):
attr = getattr(self.model, field, None)

if not attr:
raise ValueError(f"{self.model}.{field} does not exist.")

remote_fields.append(attr)
else:
remote_fields.append(field)

return remote_fields

def format(self, model: type) -> Dict[str, Any]:
if not model:
return {}

return {"id": getattr(model, self.pk.name), "text": str(model)}

async def get_list(self, term: str, limit: int = DEFAULT_PAGE_SIZE) -> List[Any]:
stmt = select(self.model)

# no type casting to string if a ColumnAssociationProxyInstance is given
filters = [
cast(field, String).ilike("%%%s%%" % term) for field in self._cached_fields
]

stmt = stmt.filter(or_(*filters))

if self.order_by:
stmt = stmt.order_by(self.order_by)

stmt = stmt.limit(limit)
result = await self.model_admin._run_query(stmt)
return result


def create_ajax_loader(
*,
model_admin: "ModelView",
name: str,
options: dict,
) -> QueryAjaxModelLoader:
mapper = inspect(model_admin.model)

try:
attr = mapper.relationships[name]
except KeyError:
raise ValueError(f"{model_admin.model}.{name} is not a relation.")

remote_model = attr.mapper.class_
return QueryAjaxModelLoader(name, remote_model, model_admin, **options)
34 changes: 28 additions & 6 deletions sqladmin/application.py
Expand Up @@ -9,12 +9,13 @@
from starlette.exceptions import HTTPException
from starlette.middleware import Middleware
from starlette.requests import Request
from starlette.responses import RedirectResponse, Response
from starlette.responses import JSONResponse, RedirectResponse, Response
from starlette.routing import Mount, Route
from starlette.staticfiles import StaticFiles
from starlette.templating import Jinja2Templates

from sqladmin._types import ENGINE_TYPE
from sqladmin.ajax import QueryAjaxModelLoader
from sqladmin.authentication import AuthenticationBackend, login_required
from sqladmin.models import BaseView, ModelView

Expand Down Expand Up @@ -120,6 +121,7 @@ class UserAdmin(ModelView, model=User):

# Set database engine from Admin instance
view.engine = self.engine
view.ajax_lookup_url = f"{self.base_url}/{view.identity}/ajax/lookup"

if isinstance(view.engine, Engine):
view.sessionmaker = sessionmaker(
Expand Down Expand Up @@ -321,10 +323,10 @@ def http_exception(request: Request, exc: Exception) -> Response:
methods=["GET", "POST"],
),
Route(
"/{identity}/export/{export_type}",
endpoint=self.export,
name="export",
methods=["GET"],
"/{identity}/export/{export_type}", endpoint=self.export, name="export"
),
Route(
"/{identity}/ajax/lookup", endpoint=self.ajax_lookup, name="ajax_lookup"
),
Route("/login", endpoint=self.login, name="login", methods=["GET", "POST"]),
Route("/logout", endpoint=self.logout, name="logout", methods=["GET"]),
Expand Down Expand Up @@ -514,13 +516,33 @@ async def logout(self, request: Request) -> Response:
await self.authentication_backend.logout(request)
return RedirectResponse(request.url_for("admin:index"), status_code=302)

async def ajax_lookup(self, request: Request) -> Response:
"""Ajax lookup route."""

identity = request.path_params["identity"]
model_view = self._find_model_view(identity)

name = request.query_params.get("name")
term = request.query_params.get("term")

if not name or not term:
raise HTTPException(status_code=400)

try:
loader: QueryAjaxModelLoader = model_view._form_ajax_refs[name]
except KeyError:
raise HTTPException(status_code=400)

data = [loader.format(m) for m in await loader.get_list(term)]
return JSONResponse({"results": data})


def expose(
path: str,
*,
methods: List[str] = ["GET"],
identity: str = None,
include_in_schema: bool = True
include_in_schema: bool = True,
) -> Callable[..., Any]:
"""Expose View with information."""

Expand Down
90 changes: 88 additions & 2 deletions sqladmin/fields.py
@@ -1,13 +1,16 @@
import json
import operator
from typing import Any, Callable, Generator, List, Optional, Tuple, Union
from typing import Any, Callable, Generator, List, Optional, Set, Tuple, Union

from sqlalchemy import inspect
from wtforms import Form, ValidationError, fields, widgets
from wtforms import Form, SelectFieldBase, ValidationError, fields, widgets

from sqladmin import widgets as sqladmin_widgets
from sqladmin.ajax import QueryAjaxModelLoader

__all__ = [
"AjaxSelectField",
"AjaxSelectMultipleField",
"DateField",
"DateTimeField",
"JSONField",
Expand Down Expand Up @@ -263,3 +266,86 @@ def pre_validate(self, form: Form) -> None:
for v in self.data:
if v not in pk_list: # pragma: no cover
raise ValidationError(self.gettext("Not a valid choice"))


class AjaxSelectField(SelectFieldBase):
widget = sqladmin_widgets.AjaxSelect2Widget()
separator = ","

def __init__(
self,
loader: QueryAjaxModelLoader,
label: str = None,
validators: list = None,
allow_blank: bool = False,
**kwargs: Any,
) -> None:
kwargs.pop("data", None) # Handled by JS side
self.loader = loader
self.allow_blank = allow_blank
super().__init__(label, validators, **kwargs)

@property
def data(self) -> Any:
if self._formdata:
self.data = self._formdata

return self._data

@data.setter
def data(self, data: Any) -> None:
self._data = data
self._formdata = None

def process_formdata(self, valuelist: list) -> None:
if valuelist:
if self.allow_blank and valuelist[0] == "__None":
self.data = None
else:
self._data = None
self._formdata = valuelist[0]

def pre_validate(self, form: Form) -> None:
if not self.allow_blank and self.data is None:
raise ValidationError("Not a valid choice")


class AjaxSelectMultipleField(SelectFieldBase):
widget = sqladmin_widgets.AjaxSelect2Widget(multiple=True)
separator = ","

def __init__(
self,
loader: QueryAjaxModelLoader,
label: str = None,
validators: list = None,
default: list = None,
allow_blank: bool = False,
**kwargs: Any,
) -> None:
kwargs.pop("data", None) # Handled by JS side
self.loader = loader
self.allow_blank = allow_blank
default = default or []
self._formdata: Set[Any] = set()

super().__init__(label, validators, default=default, **kwargs)

@property
def data(self) -> Any:
if self._formdata:
self.data = self._formdata

return self._data

@data.setter
def data(self, data: Any) -> None:
self._data = data
self._formdata = set()

def process_formdata(self, valuelist: list) -> None:
self._formdata = set()

for field in valuelist:
for n in field.split(self.separator):
self._formdata.add(n)
22 changes: 21 additions & 1 deletion sqladmin/forms.py
Expand Up @@ -35,8 +35,11 @@

from sqladmin._types import ENGINE_TYPE, MODEL_ATTR_TYPE
from sqladmin._validators import CurrencyValidator, TimezoneValidator
from sqladmin.ajax import QueryAjaxModelLoader
from sqladmin.exceptions import NoConverterFound
from sqladmin.fields import (
AjaxSelectField,
AjaxSelectMultipleField,
DateField,
DateTimeField,
JSONField,
Expand All @@ -45,7 +48,7 @@
SelectField,
TimeField,
)
from sqladmin.helpers import get_direction, get_primary_key
from sqladmin.helpers import get_direction, get_primary_key, is_relationship

if sys.version_info >= (3, 8):
from typing import Protocol
Expand Down Expand Up @@ -256,6 +259,7 @@ async def convert(
form_include_pk: bool,
label: Optional[str] = None,
override: Optional[Type[Field]] = None,
form_ajax_refs: Dict[str, QueryAjaxModelLoader] = {},
) -> Optional[UnboundField]:

kwargs = await self._prepare_kwargs(
Expand All @@ -274,6 +278,19 @@ async def convert(
assert issubclass(override, Field)
return override(**kwargs)

loader = form_ajax_refs.get(prop.key)
multiple = (
is_relationship(prop)
and prop.direction.name in ("ONETOMANY", "MANYTOMANY")
and prop.uselist
)

if loader:
if multiple:
return AjaxSelectMultipleField(loader, **kwargs)
else:
return AjaxSelectField(loader, **kwargs)

converter = self.get_converter(prop=prop)
return converter(model=model, prop=prop, kwargs=kwargs)

Expand Down Expand Up @@ -477,6 +494,7 @@ async def get_model_form(
form_widget_args: Dict[str, Dict[str, Any]] = None,
form_class: Type[Form] = Form,
form_overrides: Dict[str, Type[Field]] = None,
form_ajax_refs: Dict[str, QueryAjaxModelLoader] = None,
form_include_pk: bool = False,
) -> Type[Form]:
type_name = model.__name__ + "Form"
Expand All @@ -486,6 +504,7 @@ async def get_model_form(
form_widget_args = form_widget_args or {}
column_labels = column_labels or {}
form_overrides = form_overrides or {}
form_ajax_refs = form_ajax_refs or {}

attributes = []
names = only or mapper.attrs.keys()
Expand All @@ -509,6 +528,7 @@ async def get_model_form(
label=label,
override=override,
form_include_pk=form_include_pk,
form_ajax_refs=form_ajax_refs,
)
if field is not None:
field_dict[name] = field
Expand Down
4 changes: 4 additions & 0 deletions sqladmin/helpers.py
Expand Up @@ -140,3 +140,7 @@ def get_column_python_type(column: Column) -> type:
return column.type.python_type
except NotImplementedError:
return str


def is_relationship(attr: MODEL_ATTR_TYPE) -> bool:
return isinstance(attr, RelationshipProperty)