diff --git a/bananas/admin/api/schemas/yasg.py b/bananas/admin/api/schemas/yasg.py index c22504a..7214bb6 100644 --- a/bananas/admin/api/schemas/yasg.py +++ b/bananas/admin/api/schemas/yasg.py @@ -2,20 +2,38 @@ from django.urls.exceptions import NoReverseMatch from django.utils.translation import ugettext as _ from drf_yasg import openapi -from drf_yasg.generators import OpenAPISchemaGenerator +from drf_yasg.generators import EndpointEnumerator, OpenAPISchemaGenerator from drf_yasg.inspectors.view import SwaggerAutoSchema from drf_yasg.views import get_schema_view from rest_framework import permissions, viewsets from rest_framework.authentication import SessionAuthentication from rest_framework.routers import SimpleRouter from rest_framework.schemas.generators import is_custom_action -from rest_framework.versioning import URLPathVersioning from ..versioning import BananasVersioning from .base import BananasBaseRouter +class BananasEndpointEnumerator(EndpointEnumerator): + def should_include_endpoint( + self, path, callback, app_name="", namespace="", url_name=None + ): + # Fall back to check namespace on the resolver match + request = self.request + if ( + not namespace + and getattr(request, "version", None) + and getattr(request, "resolver_match", None) + ): + namespace = request.resolver_match.namespace or "" + return super().should_include_endpoint( + path, callback, app_name, namespace, url_name + ) + + class BananasOpenAPISchemaGenerator(OpenAPISchemaGenerator): + endpoint_enumerator_class = BananasEndpointEnumerator + def get_schema(self, *args, **kwargs): schema = super().get_schema(*args, **kwargs) api_settings = getattr(settings, "ADMIN", {}).get("API", {}) @@ -120,5 +138,6 @@ def get_schema_view(self): permission_classes=(permissions.AllowAny,), patterns=self.urls, ) - view.versioning_class = URLPathVersioning + view.versioning_class = BananasVersioning + return view diff --git a/tests/separate_api.py b/tests/separate_api.py new file mode 100644 index 0000000..0544ab0 --- /dev/null +++ b/tests/separate_api.py @@ -0,0 +1,14 @@ +from django.conf.urls import include, url +from rest_framework.routers import DefaultRouter +from rest_framework.viewsets import ViewSet + + +class SomeThingAPI(ViewSet): + def list(self, request): + pass + + +separate_router = DefaultRouter() +separate_router.register(r"some-thing", SomeThingAPI, "some-thing") + +urlpatterns = [url("", include(separate_router.urls))] diff --git a/tests/settings.py b/tests/settings.py index 2611988..9dade82 100644 --- a/tests/settings.py +++ b/tests/settings.py @@ -1,51 +1,52 @@ -SECRET_KEY = 'bananas' -LANGUAGE_CODE = 'en' +SECRET_KEY = "bananas" +LANGUAGE_CODE = "en" -DATABASES = { - 'default': { - 'ENGINE': 'django.db.backends.sqlite3', - 'NAME': ':memory:', - } -} +DATABASES = {"default": {"ENGINE": "django.db.backends.sqlite3", "NAME": ":memory:"}} INSTALLED_APPS = [ - 'bananas', - 'django.contrib.auth', - 'django.contrib.contenttypes', - 'django.contrib.sessions', - 'django.contrib.messages', - 'django.contrib.staticfiles', - 'django.contrib.admin', - 'tests' + "bananas", + "django.contrib.auth", + "django.contrib.contenttypes", + "django.contrib.sessions", + "django.contrib.messages", + "django.contrib.staticfiles", + "django.contrib.admin", + "tests", ] -ROOT_URLCONF = 'tests.urls' +ROOT_URLCONF = "tests.urls" MIDDLEWARE = MIDDLEWARE_CLASSES = [ - 'django.middleware.security.SecurityMiddleware', - 'django.contrib.sessions.middleware.SessionMiddleware', - 'django.middleware.common.CommonMiddleware', - 'django.middleware.csrf.CsrfViewMiddleware', - 'django.contrib.auth.middleware.AuthenticationMiddleware', - 'django.contrib.messages.middleware.MessageMiddleware', - 'django.middleware.clickjacking.XFrameOptionsMiddleware', + "django.middleware.security.SecurityMiddleware", + "django.contrib.sessions.middleware.SessionMiddleware", + "django.middleware.common.CommonMiddleware", + "django.middleware.csrf.CsrfViewMiddleware", + "django.contrib.auth.middleware.AuthenticationMiddleware", + "django.contrib.messages.middleware.MessageMiddleware", + "django.middleware.clickjacking.XFrameOptionsMiddleware", ] -STATIC_URL = '/static/' -MEDIA_URL = '/media/' +STATIC_URL = "/static/" +MEDIA_URL = "/media/" TEMPLATES = [ { - 'BACKEND': 'django.template.backends.django.DjangoTemplates', - 'DIRS': [], - 'APP_DIRS': True, - 'OPTIONS': { - 'context_processors': [ - 'django.template.context_processors.debug', - 'django.template.context_processors.request', - 'django.contrib.auth.context_processors.auth', - 'django.contrib.messages.context_processors.messages', - ], + "BACKEND": "django.template.backends.django.DjangoTemplates", + "DIRS": [], + "APP_DIRS": True, + "OPTIONS": { + "context_processors": [ + "django.template.context_processors.debug", + "django.template.context_processors.request", + "django.contrib.auth.context_processors.auth", + "django.contrib.messages.context_processors.messages", + ] }, - }, + } ] + +REST_FRAMEWORK = { + "DEFAULT_VERSIONING_CLASS": "rest_framework.versioning.AcceptHeaderVersioning", + "DEFAULT_VERSION": 1.0, + "ALLOWED_VERSIONS": [1.0], +} diff --git a/tests/test_admin.py b/tests/test_admin.py index fcea678..b0c9637 100644 --- a/tests/test_admin.py +++ b/tests/test_admin.py @@ -5,7 +5,7 @@ import django from django.contrib.auth.models import AnonymousUser, Group, Permission, User from django.core.management import call_command -from django.test import TestCase +from django.test import TestCase, override_settings from bananas import admin, compat @@ -224,7 +224,9 @@ def test_autorized_schema(self): response = self.client.get(url) self.assertEqual(response.status_code, 200) - data = response.json() + self.check_valid_schema(response.json()) + + def check_valid_schema(self, data): self.assertNotIn("/bananas/login/", data["paths"]) self.assertIn("/bananas/logout/", data["paths"]) action = data["paths"]["/bananas/logout/"]["post"] diff --git a/tests/test_commands.py b/tests/test_commands.py index b8631d3..230468e 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -14,11 +14,11 @@ def test_show_urls(self): admin_api_url_count = 13 if django.VERSION < (1, 9): - n_urls = 23 + admin_api_url_count + n_urls = 27 + admin_api_url_count elif django.VERSION < (2, 0): - n_urls = 25 + admin_api_url_count + n_urls = 29 + admin_api_url_count else: - n_urls = 27 + admin_api_url_count + n_urls = 31 + admin_api_url_count self.assertEqual(len(urls), n_urls) diff --git a/tests/urls.py b/tests/urls.py index 8379aa0..bc3838f 100644 --- a/tests/urls.py +++ b/tests/urls.py @@ -8,8 +8,12 @@ if django.VERSION >= (1, 10): from bananas.admin import api from .admin_api import FooAPI, HamAPI + from . import separate_api api.register(FooAPI) api.register(HamAPI) - urlpatterns += [url(r"^api/", include("bananas.admin.api.urls"))] + urlpatterns += [ + url(r"^api/bananas", include("bananas.admin.api.urls")), + url(r"^api/separate", include(separate_api)), + ]