Skip to content

Commit

Permalink
Merge pull request #218 from opensafely-core/check-version-number
Browse files Browse the repository at this point in the history
Check the results version was generated with a supported version of ACRO
  • Loading branch information
ghickman committed Jul 14, 2023
2 parents 3a1108f + b861cd5 commit 018413e
Show file tree
Hide file tree
Showing 8 changed files with 212 additions and 12 deletions.
17 changes: 17 additions & 0 deletions sacro/middleware.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from django.conf import settings
from django.http import HttpResponseForbidden

from sacro.errors import error
from sacro.versioning import IncorrectVersionError


class AppTokenMiddleware:
def __init__(self, get_response):
Expand All @@ -13,3 +16,17 @@ def __call__(self, request):

response = self.get_response(request)
return response


class ErrorHandlerMiddleware:
def __init__(self, get_response):
self.get_response = get_response

def __call__(self, request):
return self.get_response(request)

def process_exception(self, request, exception):
if not isinstance(exception, IncorrectVersionError):
raise

return error(request, message=str(exception))
5 changes: 5 additions & 0 deletions sacro/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@

MIDDLEWARE = [
"sacro.middleware.AppTokenMiddleware",
"sacro.middleware.ErrorHandlerMiddleware",
"django.middleware.security.SecurityMiddleware",
"whitenoise.middleware.WhiteNoiseMiddleware",
"django.contrib.sessions.middleware.SessionMiddleware",
Expand Down Expand Up @@ -171,3 +172,7 @@

# Insert Whitenoise Middleware.
STATICFILES_STORAGE = "whitenoise.storage.CompressedStaticFilesStorage"


# PROJECT SETTINGS
ACRO_SUPPORTED_VERSION = "0.4.x"
71 changes: 71 additions & 0 deletions sacro/versioning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import functools

from django.conf import settings


class IncorrectVersionError(Exception):
def __init__(self, *args, used, supported, **kwargs):
super().__init__(*args, **kwargs)
self.used = used
self.supported = supported

def __str__(self):
return f"Unsupported ACRO output. This viewer supports ACRO version {self.supported}, but your results were generated with version {self.used}."


class UnsupportedVersionFormatError(Exception):
def __init__(self, *args, version, **kwargs):
super().__init__(*args, **kwargs)
self.version = version


@functools.total_ordering
class Version:
"""Utility class to parse and compare version strings"""

def __init__(self, version: str) -> None:
try:
major, minor, *_ = version.split(".")

# check major and minor are valid numbers
int(major)
int(minor)

self.major = major
self.minor = minor
except ValueError:
msg = f"Expected version to be in format 1.2.3, got {version}"
raise UnsupportedVersionFormatError(msg, version=version)

self.original = version

def __eq__(self, other: "Version") -> bool:
return self.major == other.major and self.minor == other.minor

def __gt__(self, other: "Version") -> bool:
if self.major > other.major:
return True

if self.major == other.major and self.minor > other.minor:
return True

return False

def __repr__(self):
return f"Version: {self.original}"

def __str__(self):
return self.original


def check_version(version: str) -> None:
"""
Check the given version against the supported version in settings
We don't care about bugfix versions so the Version class ignores them.
"""
supported = Version(settings.ACRO_SUPPORTED_VERSION)
used = Version(version)

if used < supported:
raise IncorrectVersionError(used=used, supported=supported)
7 changes: 6 additions & 1 deletion sacro/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from django.views.decorators.http import require_GET, require_POST

from sacro.adapters import local_audit, zipfile
from sacro.versioning import check_version


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -110,7 +111,11 @@ def get_outputs(data):
if not path.exists(): # pragma: no cover
raise Http404

return Outputs(path)
outputs = Outputs(path)

check_version(outputs.version)

return outputs


@require_GET
Expand Down
17 changes: 17 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import shutil
from pathlib import Path

import pytest

from sacro import views


@pytest.fixture
def TEST_PATH():
return Path("outputs/results.json")


@pytest.fixture
def test_outputs(tmp_path, TEST_PATH):
shutil.copytree(TEST_PATH.parent, tmp_path, dirs_exist_ok=True)
return views.Outputs(tmp_path / TEST_PATH.name)
21 changes: 21 additions & 0 deletions tests/test_middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import json
import shutil

from django.test import Client


def test_error_handling_middleware(client, tmp_path, TEST_PATH):
shutil.copytree(TEST_PATH.parent, tmp_path, dirs_exist_ok=True)
path = tmp_path / TEST_PATH.name

# change the version number
data = json.load(path.open())
data["version"] = "0.3.0"
json.dump(data, path.open("w"))

response = Client().get(f"/?path={path}")
assert response.status_code == 500
assert (
"Unsupported ACRO output. This viewer supports ACRO version 0.4.x, but your results were generated with version 0.3.0."
in response.rendered_content
)
74 changes: 74 additions & 0 deletions tests/test_versioning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import pytest
from django.test import override_settings

from sacro.versioning import (
IncorrectVersionError,
UnsupportedVersionFormatError,
Version,
check_version,
)


@override_settings(ACRO_SUPPORTED_VERSION="0.4.0")
def test_check_version(monkeypatch):
assert check_version("0.4.0") is None
assert check_version("0.4.2") is None

with pytest.raises(IncorrectVersionError):
check_version("0.3.0")


@pytest.mark.parametrize("version", ["test", "v0.4.0"])
def test_version_init_with_unexpected_format(version):
with pytest.raises(UnsupportedVersionFormatError):
Version(version)


@pytest.mark.parametrize("version", ["0.4.0", "0.4"])
def test_version_init_success(version):
Version(version)


def test_version_rich_comparison_eq():
assert Version("0.4.0") == Version("0.4.0")

# check bugfix numbers are ignored
assert Version("0.4.0") == Version("0.4.2")


def test_version_rich_comparison_ge():
assert Version("0.4.0") >= Version("0.3.0")
assert Version("0.4.0") >= Version("0.4.0")
assert Version("1.0.0") >= Version("0.3.0")


def test_version_rich_comparison_gt():
assert Version("0.4.0") > Version("0.3.0")
assert Version("1.0.0") > Version("0.3.0")


def test_version_rich_comparison_le():
assert Version("0.3.0") <= Version("0.4.0")
assert Version("0.4.0") <= Version("0.4.0")
assert Version("0.3.0") <= Version("1.0.0")


def test_version_rich_comparison_lt():
assert Version("0.3.0") < Version("0.4.0")
assert Version("0.3.0") < Version("1.0.0")


def test_version_rich_comparison_ne():
assert Version("1.4.0") != Version("0.4.0")
assert Version("0.3.0") != Version("0.4.0")

# check bugfix numbers are ignored
assert Version("0.3.2") != Version("0.4.2")


def test_version_repr():
assert repr(Version("0.7.0")) == "Version: 0.7.0"


def test_version_str():
assert str(Version("0.7.0")) == "0.7.0"
12 changes: 1 addition & 11 deletions tests/test_views.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import io
import json
import shutil
import zipfile
from pathlib import Path
from urllib.parse import urlencode
Expand All @@ -13,15 +12,6 @@
from sacro import views


TEST_PATH = Path("outputs/results.json")


@pytest.fixture
def test_outputs(tmp_path):
shutil.copytree(TEST_PATH.parent, tmp_path, dirs_exist_ok=True)
return views.Outputs(tmp_path / TEST_PATH.name)


def test_outputs_annotation(test_outputs):
assert test_outputs.version == "0.4.0"
for metadata in test_outputs.values():
Expand Down Expand Up @@ -56,7 +46,7 @@ def test_index(test_outputs):


@override_settings(DEBUG=True)
def test_index_no_path():
def test_index_no_path(TEST_PATH):
request = RequestFactory().get(path="/")

response = views.index(request)
Expand Down

0 comments on commit 018413e

Please sign in to comment.