Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement 553 Version Not Supported #420

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 14 additions & 0 deletions optimade/server/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,17 @@ def __init__(
) -> None:
super().__init__(status_code=status_code, detail=detail, headers=headers)
self.title = title


class VersionNotSupported(HTTPException):
"""553 Version Not Supported"""

def __init__(
self,
status_code: int = 553,
detail: str = None,
headers: dict = None,
title: str = "Version Not Supported",
) -> None:
super().__init__(status_code=status_code, detail=detail, headers=headers)
self.title = title
3 changes: 2 additions & 1 deletion optimade/server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from .entry_collections import MongoCollection
from .config import CONFIG
from .middleware import EnsureQueryParamIntegrity
from .middleware import EnsureQueryParamIntegrity, CheckWronglyVersionedBaseUrls
from .routers import info, links, references, structures, landing, versions
from .routers.utils import get_providers, BASE_URL_PREFIXES

Expand Down Expand Up @@ -58,6 +58,7 @@ def load_entries(endpoint_name: str, endpoint_collection: MongoCollection):
# Add various middleware
app.add_middleware(CORSMiddleware, allow_origins=["*"])
app.add_middleware(EnsureQueryParamIntegrity)
app.add_middleware(CheckWronglyVersionedBaseUrls)


# Add various exception handlers
Expand Down
12 changes: 8 additions & 4 deletions optimade/server/main_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@
from optimade import __api_version__, __version__
import optimade.server.exception_handlers as exc_handlers

from .config import CONFIG
from .middleware import EnsureQueryParamIntegrity
from .routers import index_info, links, versions
from .routers.utils import BASE_URL_PREFIXES
from optimade.server.config import CONFIG
from optimade.server.middleware import (
EnsureQueryParamIntegrity,
CheckWronglyVersionedBaseUrls,
)
from optimade.server.routers import index_info, links, versions
from optimade.server.routers.utils import BASE_URL_PREFIXES


if CONFIG.debug: # pragma: no cover
Expand Down Expand Up @@ -61,6 +64,7 @@
# Add various middleware
app.add_middleware(CORSMiddleware, allow_origins=["*"])
app.add_middleware(EnsureQueryParamIntegrity)
app.add_middleware(CheckWronglyVersionedBaseUrls)


# Add various exception handlers
Expand Down
40 changes: 38 additions & 2 deletions optimade/server/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request

from optimade.server.exceptions import BadRequest


class EnsureQueryParamIntegrity(BaseHTTPMiddleware):
"""Ensure all query parameters are followed by an equal sign (`=`)"""

@staticmethod
def check_url(url_query: str):
"""Check parsed URL query part for parameters not followed by `=`"""
from optimade.server.exceptions import BadRequest

queries_amp = set(url_query.split("&"))
queries = set()
for query in queries_amp:
Expand All @@ -29,3 +29,39 @@ async def dispatch(self, request: Request, call_next):
self.check_url(parsed_url.query)
response = await call_next(request)
return response


class CheckWronglyVersionedBaseUrls(BaseHTTPMiddleware):
"""If a non-supported versioned base URL is supplied return `553 Version Not Supported`"""

@staticmethod
def check_url(parsed_url: urllib.parse.ParseResult):
"""Check URL path for versioned part"""
import re

from optimade.server.exceptions import VersionNotSupported
from optimade.server.routers.utils import get_base_url, BASE_URL_PREFIXES

base_url = get_base_url(parsed_url)
optimade_path = f"{parsed_url.scheme}://{parsed_url.netloc}{parsed_url.path}"[
len(base_url) :
]
if re.match(r"^/v[0-9]+", optimade_path):
for version_prefix in BASE_URL_PREFIXES.values():
if optimade_path.startswith(f"{version_prefix}/"):
break
else:
version_prefix = re.findall(r"(/v[0-9]+(\.[0-9]+){0,2})", optimade_path)
raise VersionNotSupported(
detail=(
f"The parsed versioned base URL {version_prefix[0][0]!r} from {urllib.parse.urlunparse(parsed_url)!r} is not supported by this implementation. "
f"Supported versioned base URLs are: {', '.join(BASE_URL_PREFIXES.values())}"
)
)

async def dispatch(self, request: Request, call_next):
parsed_url = urllib.parse.urlparse(str(request.url))
if parsed_url.path:
self.check_url(parsed_url)
response = await call_next(request)
return response
27 changes: 17 additions & 10 deletions tests/server/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
from typing import Union

from optimade.server.config import CONFIG
import pytest


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -59,6 +59,7 @@ def inner(request: str, server: str = "regular") -> dict:
def check_response(get_good_response):
"""Fixture to check "good" response"""
from typing import List
from optimade.server.config import CONFIG

def inner(
request: str,
Expand Down Expand Up @@ -91,23 +92,29 @@ def inner(
@pytest.fixture
def check_error_response(client, index_client):
"""General method for testing expected erroneous response"""
from .utils import OptimadeTestClient

def inner(
request: str,
expected_status: int = None,
expected_title: str = None,
expected_detail: str = None,
server: str = "regular",
server: Union[str, OptimadeTestClient] = "regular",
):
response = None
if server == "regular":
used_client = client
elif server == "index":
used_client = index_client
if isinstance(server, str):
if server == "regular":
used_client = client
elif server == "index":
used_client = index_client
else:
pytest.fail(
f"Wrong value for 'server': {server}. It must be either 'regular' or 'index'."
)
elif isinstance(server, OptimadeTestClient):
used_client = server
else:
pytest.fail(
f"Wrong value for 'server': {server}. It must be either 'regular' or 'index'."
)
pytest.fail("'server' must be either a string or an OptimadeTestClient.")

try:
response = used_client.get(request)
Expand Down
78 changes: 75 additions & 3 deletions tests/server/test_middleware.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
import urllib

from optimade.server.exceptions import BadRequest
from optimade.server.exceptions import BadRequest, VersionNotSupported


def test_regular_CORS_request(both_clients):
Expand All @@ -25,7 +26,7 @@ def test_preflight_CORS_request(both_clients):
), f"{response_header} header not found in response headers: {response.headers}"


def test_wrong_html_form(check_error_response):
def test_wrong_html_form(check_error_response, both_clients):
"""Using a parameter without equality sign `=` or values should result in a `400 Bad Request` response"""
from optimade.server.query_params import EntryListingQueryParams

Expand All @@ -37,10 +38,11 @@ def test_wrong_html_form(check_error_response):
expected_status=400,
expected_title="Bad Request",
expected_detail="A query parameter without an equal sign (=) is not supported by this server",
server=both_clients,
)


def test_wrong_html_form_one_wrong(check_error_response):
def test_wrong_html_form_one_wrong(check_error_response, both_clients):
"""Using a parameter without equality sign `=` or values should result in a `400 Bad Request` response

This should hold true, no matter the chosen (valid) parameter separator (either & or ;).
Expand All @@ -52,6 +54,7 @@ def test_wrong_html_form_one_wrong(check_error_response):
expected_status=400,
expected_title="Bad Request",
expected_detail="A query parameter without an equal sign (=) is not supported by this server",
server=both_clients,
)


Expand Down Expand Up @@ -79,3 +82,72 @@ def test_empty_parameters(both_clients):
query_part
)
assert expected_result == parsed_set_of_queries


def test_wrong_version(both_clients):
"""If a non-supported versioned base URL is passed, `553 Version Not Supported` should be returned"""
from optimade.server.config import CONFIG
from optimade.server.middleware import CheckWronglyVersionedBaseUrls

version = "/v0"
urls = (
f"{CONFIG.base_url}{version}/info",
f"{CONFIG.base_url}{version}",
)

for url in urls:
with pytest.raises(VersionNotSupported):
CheckWronglyVersionedBaseUrls(both_clients.app).check_url(
urllib.parse.urlparse(url)
)


def test_wrong_version_json_response(check_error_response, both_clients):
"""If a non-supported versioned base URL is passed, `553 Version Not Supported` should be returned

A specific JSON response should also occur.
"""
from optimade.server.routers.utils import BASE_URL_PREFIXES

version = "/v0"
request = f"{version}/info"
with pytest.raises(VersionNotSupported):
check_error_response(
request,
expected_status=553,
expected_title="Version Not Supported",
expected_detail=(
f"The parsed versioned base URL {version!r} from '{both_clients.base_url}{request}' is not supported by this implementation. "
f"Supported versioned base URLs are: {', '.join(BASE_URL_PREFIXES.values())}"
),
server=both_clients,
)


def test_multiple_versions_in_path(both_clients):
"""If another version is buried in the URL path, only the OPTIMADE versioned URL path part should be recognized."""
from optimade.server.config import CONFIG
from optimade.server.middleware import CheckWronglyVersionedBaseUrls
from optimade.server.routers.utils import BASE_URL_PREFIXES

non_valid_version = "/v0.5"
org_base_url = CONFIG.base_url

try:
CONFIG.base_url = f"https://example.org{non_valid_version}/my_database/optimade"

for valid_version_prefix in BASE_URL_PREFIXES.values():
url = f"{CONFIG.base_url}{valid_version_prefix}/info"
CheckWronglyVersionedBaseUrls(both_clients.app).check_url(
urllib.parse.urlparse(url)
)

# Test also that the a non-valid OPTIMADE version raises
url = f"{CONFIG.base_url}/v0/info"
with pytest.raises(VersionNotSupported):
CheckWronglyVersionedBaseUrls(both_clients.app).check_url(
urllib.parse.urlparse(url)
)
finally:
if org_base_url:
CONFIG.base_url = org_base_url
12 changes: 6 additions & 6 deletions tests/server/test_server_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,12 @@ def test_as_type_with_validator(client):
import unittest

test_urls = {
f"{client.base_url}structures": "structures",
f"{client.base_url}structures/mpf_1": "structure",
f"{client.base_url}references": "references",
f"{client.base_url}references/dijkstra1968": "reference",
f"{client.base_url}info": "info",
f"{client.base_url}links": "links",
f"{client.base_url}/structures": "structures",
f"{client.base_url}/structures/mpf_1": "structure",
f"{client.base_url}/references": "references",
f"{client.base_url}/references/dijkstra1968": "reference",
f"{client.base_url}/info": "info",
f"{client.base_url}/links": "links",
}
with unittest.mock.patch(
"requests.get", unittest.mock.Mock(side_effect=client.get)
Expand Down
12 changes: 6 additions & 6 deletions tests/server/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,12 @@ def __init__(
root_path=root_path,
)
if version:
if not version.startswith("v"):
if not version.startswith("v") and not version.startswith("/v"):
version = f"/v{version}"
if re.match(r"v[0-9](.[0-9]){0,2}", version) is None:
if re.match(r"/v[0-9](.[0-9]){0,2}", version) is None:
warnings.warn(
f"Invalid version passed to client: '{version}'. "
f"Will use the default: 'v{__api_version__.split('.')[0]}'"
f"Invalid version passed to client: {version!r}. "
f"Will use the default: '/v{__api_version__.split('.')[0]}'"
)
version = f"/v{__api_version__.split('.')[0]}"
self.version = version
Expand Down Expand Up @@ -213,9 +213,9 @@ def inner(version: str = None, server: str = "regular") -> OptimadeTestClient:

if version:
return OptimadeTestClient(
app, base_url="http://example.org/", version=version
app, base_url="http://example.org", version=version
)
return OptimadeTestClient(app, base_url="http://example.org/")
return OptimadeTestClient(app, base_url="http://example.org")

return inner

Expand Down
2 changes: 1 addition & 1 deletion tests/test_config.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"debug": false,
"default_db": "test_server",
"base_url": "http://localhost:5000",
"base_url": "http://example.org",
"implementation": {
"name": "Example implementation",
"source_url": "https://github.com/Materials-Consortia/optimade-python-tools",
Expand Down