diff --git a/.github/workflows/test-fair-database.yml b/.github/workflows/test-fair-database.yml new file mode 100644 index 000000000..9652a8377 --- /dev/null +++ b/.github/workflows/test-fair-database.yml @@ -0,0 +1,54 @@ +name: Tests + +on: + [push, pull_request] + +defaults: + run: + shell: bash + +jobs: + unit-test: + + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [macos-latest, ubuntu-latest, windows-latest] + python-version: ['3.12'] + fail-fast: false + + steps: + + - name: Obtain SasData source from git + uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + cache-dependency-path: | + **/test.yml + **/requirements*.txt + + ### Installation of build-dependencies + + - name: Install Python dependencies + run: | + python -m pip install --upgrade pip + python -m pip install wheel setuptools + python -m pip install -r requirements.txt + python -m pip install -r sasdata/fair_database/requirements.txt + + ### Build and test sasdata + + - name: Build sasdata + run: | + # BUILD SASDATA + python -m pip install -e . + + ### Build documentation (if enabled) + + - name: Test with Django tests + run: | + python sasdata/fair_database/manage.py test sasdata.fair_database diff --git a/.gitignore b/.gitignore index ff18e7a00..a24aa7e04 100644 --- a/.gitignore +++ b/.gitignore @@ -24,6 +24,7 @@ **/build /dist .mplconfig +**/db.sqlite3 # INSTALL.md recommends a venv that should not be committed venv diff --git a/sasdata/data.py b/sasdata/data.py index 788f3d7f5..cc9ca4b31 100644 --- a/sasdata/data.py +++ b/sasdata/data.py @@ -1,4 +1,6 @@ +import json import typing +from typing import Any import h5py import numpy as np @@ -125,6 +127,40 @@ def save_h5(data: dict[str, typing.Self], path: str | typing.BinaryIO): sasentry.attrs["sasview_key"] = key data._save_h5(sasentry) + @staticmethod + def deserialise(data: str) -> "SasData": + json_data = json.loads(data) + return SasData.deserialise_json(json_data) + + @staticmethod + def deserialise_json(json_data: dict) -> "SasData": + name = json_data["name"] + data_contents = {} + dataset_type = json_data["dataset_type"] # TODO: update when DatasetType is more finalized + metadata = json_data["metadata"].deserialise_json() + for quantity in json_data["data_contents"]: + data_contents[quantity["label"]] = Quantity.deserialise_json(quantity) + return SasData(name, data_contents, dataset_type, metadata) + + def serialise(self) -> str: + return json.dumps(self._serialise_json()) + + # TODO: fix serializers eventually + def _serialise_json(self) -> dict[str, Any]: + data = [] + for d in self._data_contents: + quantity = self._data_contents[d] + quantity["label"] = d + data.append(quantity) + return { + "name": self.name, + "data_contents": data, + "dataset_type": None, # TODO: update when DatasetType is more finalized + "verbose": self._verbose, + "metadata": self.metadata.serialise_json(), + "mask": {}, + "model_requirements": {} + } class SasDataEncoder(MetadataEncoder): diff --git a/sasdata/data_backing.py b/sasdata/data_backing.py index 210ac33ca..fe8cf902b 100644 --- a/sasdata/data_backing.py +++ b/sasdata/data_backing.py @@ -36,6 +36,34 @@ def summary(self, indent_amount: int = 0, indent: str = " ") -> str: return s + @staticmethod + def deserialise_json(json_data: dict) -> "Dataset": + name = json_data["name"] + data = "" # TODO: figure out QuantityType serialisation + attributes = {} + for key in json_data["attributes"]: + value = json_data["attributes"][key] + if isinstance(value, dict): + attributes[key] = Dataset.deserialise_json(value) + else: + attributes[key] = value + return Dataset(name, data, attributes) + + def serialise_json(self): + content = { + "name": self.name, + "data": "", # TODO: figure out QuantityType serialisation + "attributes": {}, + "type": "dataset" + } + for key in self.attributes: + value = self.attributes[key] + if isinstance(value, (Group, Dataset)): + content["attributes"]["key"] = value.serialise_json() + else: + content["attributes"]["key"] = value + return content + @dataclass class Group: name: str @@ -48,6 +76,27 @@ def summary(self, indent_amount: int=0, indent=" "): return s + @staticmethod + def deserialise_json(json_data: dict) -> "Group": + name = json_data["name"] + children = {} + for key in json_data["children"]: + value = json_data["children"][key] + if value["type"] == "group": + children[key] = Group.deserialise_json(value) + else: + children[key] = Dataset.deserialise_json(value) + return Group(name, children) + + def serialise_json(self): + return { + "name": self.name, + "children": { + key: self.children[key].serialise_json() for key in self.children + }, + "type": "group" + } + class Function: """ Representation of a (data driven) function, such as I vs Q """ diff --git a/sasdata/fair_database/__init__.py b/sasdata/fair_database/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/sasdata/fair_database/data/__init__.py b/sasdata/fair_database/data/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/sasdata/fair_database/data/admin.py b/sasdata/fair_database/data/admin.py new file mode 100644 index 000000000..2134875a6 --- /dev/null +++ b/sasdata/fair_database/data/admin.py @@ -0,0 +1,11 @@ +from data import models +from django.contrib import admin + +admin.site.register(models.DataFile) +admin.site.register(models.Session) +admin.site.register(models.PublishedState) +admin.site.register(models.DataSet) +admin.site.register(models.MetaData) +admin.site.register(models.Quantity) +admin.site.register(models.OperationTree) +admin.site.register(models.ReferenceQuantity) diff --git a/sasdata/fair_database/data/apps.py b/sasdata/fair_database/data/apps.py new file mode 100644 index 000000000..b882be950 --- /dev/null +++ b/sasdata/fair_database/data/apps.py @@ -0,0 +1,6 @@ +from django.apps import AppConfig + + +class DataConfig(AppConfig): + default_auto_field = "django.db.models.BigAutoField" + name = "data" diff --git a/sasdata/fair_database/data/forms.py b/sasdata/fair_database/data/forms.py new file mode 100644 index 000000000..519556e2f --- /dev/null +++ b/sasdata/fair_database/data/forms.py @@ -0,0 +1,9 @@ +from data.models import DataFile +from django import forms + + +# Create the form class. +class DataFileForm(forms.ModelForm): + class Meta: + model = DataFile + fields = ["file", "is_public"] diff --git a/sasdata/fair_database/data/migrations/0001_initial.py b/sasdata/fair_database/data/migrations/0001_initial.py new file mode 100644 index 000000000..e8f7219a6 --- /dev/null +++ b/sasdata/fair_database/data/migrations/0001_initial.py @@ -0,0 +1,332 @@ +# Generated by Django 5.1.6 on 2025-04-23 18:08 + +import data.models +import django.core.files.storage +import django.db.models.deletion +from django.conf import settings +from django.db import migrations, models + + +class Migration(migrations.Migration): + initial = True + + dependencies = [ + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ] + + operations = [ + migrations.CreateModel( + name="DataFile", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ( + "is_public", + models.BooleanField( + default=False, help_text="opt in to make your data public" + ), + ), + ( + "file_name", + models.CharField( + blank=True, + default=None, + help_text="File name", + max_length=200, + null=True, + ), + ), + ( + "file", + models.FileField( + default=None, + help_text="This is a file", + storage=django.core.files.storage.FileSystemStorage(), + upload_to="uploaded_files", + ), + ), + ( + "current_user", + models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="+", + to=settings.AUTH_USER_MODEL, + ), + ), + ( + "users", + models.ManyToManyField( + blank=True, related_name="+", to=settings.AUTH_USER_MODEL + ), + ), + ], + options={ + "abstract": False, + }, + ), + migrations.CreateModel( + name="DataSet", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ( + "is_public", + models.BooleanField( + default=False, help_text="opt in to make your data public" + ), + ), + ("name", models.CharField(max_length=200)), + ( + "current_user", + models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="+", + to=settings.AUTH_USER_MODEL, + ), + ), + ("files", models.ManyToManyField(to="data.datafile")), + ( + "users", + models.ManyToManyField( + blank=True, related_name="+", to=settings.AUTH_USER_MODEL + ), + ), + ], + options={ + "abstract": False, + }, + ), + migrations.CreateModel( + name="MetaData", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("title", models.CharField(default="Title", max_length=500)), + ("run", models.JSONField(default=data.models.empty_list)), + ("definition", models.TextField(blank=True, null=True)), + ("instrument", models.JSONField(blank=True, null=True)), + ("process", models.JSONField(default=data.models.empty_list)), + ("sample", models.JSONField(blank=True, null=True)), + ( + "dataset", + models.OneToOneField( + on_delete=django.db.models.deletion.CASCADE, + related_name="metadata", + to="data.dataset", + ), + ), + ], + ), + migrations.CreateModel( + name="Quantity", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("value", models.JSONField()), + ("variance", models.JSONField()), + ("units", models.CharField(max_length=200)), + ("hash", models.IntegerField()), + ("label", models.CharField(max_length=50)), + ( + "dataset", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name="data_contents", + to="data.dataset", + ), + ), + ], + ), + migrations.CreateModel( + name="OperationTree", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ( + "operation", + models.CharField( + choices=[ + ("zero", "0 [Add.Id.]"), + ("one", "1 [Mul.Id.]"), + ("constant", "Constant"), + ("variable", "Variable"), + ("neg", "Neg"), + ("reciprocal", "Inv"), + ("add", "Add"), + ("sub", "Sub"), + ("mul", "Mul"), + ("div", "Div"), + ("pow", "Pow"), + ("transpose", "Transpose"), + ("dot", "Dot"), + ("matmul", "MatMul"), + ("tensor_product", "TensorProduct"), + ], + max_length=20, + ), + ), + ("parameters", models.JSONField(default=data.models.empty_dict)), + ("label", models.CharField(blank=True, max_length=10, null=True)), + ( + "child_operation", + models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="parent_operations", + to="data.operationtree", + ), + ), + ( + "quantity", + models.OneToOneField( + blank=True, + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="operation_tree", + to="data.quantity", + ), + ), + ], + ), + migrations.CreateModel( + name="ReferenceQuantity", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("value", models.JSONField()), + ("variance", models.JSONField()), + ("units", models.CharField(max_length=200)), + ("hash", models.IntegerField()), + ( + "derived_quantity", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name="references", + to="data.quantity", + ), + ), + ], + ), + migrations.CreateModel( + name="Session", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ( + "is_public", + models.BooleanField( + default=False, help_text="opt in to make your data public" + ), + ), + ("title", models.CharField(max_length=200)), + ( + "current_user", + models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="+", + to=settings.AUTH_USER_MODEL, + ), + ), + ( + "users", + models.ManyToManyField( + blank=True, related_name="+", to=settings.AUTH_USER_MODEL + ), + ), + ], + options={ + "abstract": False, + }, + ), + migrations.CreateModel( + name="PublishedState", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("published", models.BooleanField(default=False)), + ("doi", models.URLField()), + ( + "session", + models.OneToOneField( + on_delete=django.db.models.deletion.CASCADE, + related_name="published_state", + to="data.session", + ), + ), + ], + ), + migrations.AddField( + model_name="dataset", + name="session", + field=models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="datasets", + to="data.session", + ), + ), + ] diff --git a/sasdata/fair_database/data/migrations/__init__.py b/sasdata/fair_database/data/migrations/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/sasdata/fair_database/data/models.py b/sasdata/fair_database/data/models.py new file mode 100644 index 000000000..3a36dc831 --- /dev/null +++ b/sasdata/fair_database/data/models.py @@ -0,0 +1,232 @@ +from django.contrib.auth.models import User +from django.core.files.storage import FileSystemStorage +from django.db import models + + +# method for empty list default value +def empty_list(): + return [] + + +# method for empty dictionary default value +def empty_dict(): + return {} + + +class Data(models.Model): + """Base model for data with access-related information.""" + + # owner of the data + current_user = models.ForeignKey( + User, blank=True, null=True, on_delete=models.CASCADE, related_name="+" + ) + + # users that have been granted view access to the data + users = models.ManyToManyField(User, blank=True, related_name="+") + + # is the data public? + is_public = models.BooleanField( + default=False, help_text="opt in to make your data public" + ) + + class Meta: + abstract = True + + +class DataFile(Data): + """Database model for file contents.""" + + # file name + file_name = models.CharField( + max_length=200, default=None, blank=True, null=True, help_text="File name" + ) + + # imported data + # user can either import a file path or actual file + file = models.FileField( + blank=False, + default=None, + help_text="This is a file", + upload_to="uploaded_files", + storage=FileSystemStorage(), + ) + + +class DataSet(Data): + """Database model for a set of data and associated metadata.""" + + # dataset name + name = models.CharField(max_length=200) + + # associated files + files = models.ManyToManyField(DataFile) + + # session the dataset is a part of, if any + session = models.ForeignKey( + "Session", + on_delete=models.CASCADE, + related_name="datasets", + blank=True, + null=True, + ) + + # TODO: update based on SasData class in data.py + # type of dataset + # dataset_type = models.JSONField() + + +class Quantity(models.Model): + """Database model for data quantities such as the ordinate and abscissae.""" + + # data value + value = models.JSONField() + + # variance of the data + variance = models.JSONField() + + # units + units = models.CharField(max_length=200) + + # hash value + hash = models.IntegerField() + + # label, e.g. Q or I(Q) + label = models.CharField(max_length=50) + + # data set the quantity is a part of + dataset = models.ForeignKey( + DataSet, on_delete=models.CASCADE, related_name="data_contents" + ) + + +class ReferenceQuantity(models.Model): + """ + Database models for quantities referenced by variables in an OperationTree. + + Corresponds to the references dictionary in the QuantityHistory class in + sasdata/quantity.py. ReferenceQuantities should be essentially the same as + Quantities but with no operations performed on them and therefore no + OperationTree. + """ + + # data value + value = models.JSONField() + + # variance of the data + variance = models.JSONField() + + # units + units = models.CharField(max_length=200) + + # hash value + hash = models.IntegerField() + + # Quantity whose OperationTree this is a reference for + derived_quantity = models.ForeignKey( + Quantity, + related_name="references", + on_delete=models.CASCADE, + ) + + +# TODO: update based on changes in sasdata/metadata.py +class MetaData(models.Model): + """Database model for scattering metadata""" + + # title + title = models.CharField(max_length=500, default="Title") + + # run + run = models.JSONField(default=empty_list) + + # definition + definition = models.TextField(blank=True, null=True) + + # instrument + instrument = models.JSONField(blank=True, null=True) + + # process + process = models.JSONField(default=empty_list) + + # sample + sample = models.JSONField(blank=True, null=True) + + # associated dataset + dataset = models.OneToOneField( + DataSet, on_delete=models.CASCADE, related_name="metadata" + ) + + +class OperationTree(models.Model): + """Database model for tree of operations performed on a DataSet.""" + + # possible operations + OPERATION_CHOICES = { + "zero": "0 [Add.Id.]", + "one": "1 [Mul.Id.]", + "constant": "Constant", + "variable": "Variable", + "neg": "Neg", + "reciprocal": "Inv", + "add": "Add", + "sub": "Sub", + "mul": "Mul", + "div": "Div", + "pow": "Pow", + "transpose": "Transpose", + "dot": "Dot", + "matmul": "MatMul", + "tensor_product": "TensorProduct", + } + + # operation + operation = models.CharField(max_length=20, choices=OPERATION_CHOICES) + + # parameters + parameters = models.JSONField(default=empty_dict) + + # label (a or b) if the operation is a parameter of a child operation + # maintains ordering of binary operation parameters + label = models.CharField(max_length=10, blank=True, null=True) + + # operation this operation is a parameter for, if any + child_operation = models.ForeignKey( + "self", + on_delete=models.CASCADE, + related_name="parent_operations", + blank=True, + null=True, + ) + + # quantity the operation produces + # only set for base of tree (the quantity's most recent operation) + quantity = models.OneToOneField( + Quantity, + on_delete=models.CASCADE, + blank=True, + null=True, + related_name="operation_tree", + ) + + +class Session(Data): + """Database model for a project save state.""" + + # title + title = models.CharField(max_length=200) + + +class PublishedState(models.Model): + """Database model for a project published state.""" + + # published + published = models.BooleanField(default=False) + + # TODO: update doi as needed when DOI generation is implemented + # doi + doi = models.URLField() + + # session + session = models.OneToOneField( + Session, on_delete=models.CASCADE, related_name="published_state" + ) diff --git a/sasdata/fair_database/data/serializers.py b/sasdata/fair_database/data/serializers.py new file mode 100644 index 000000000..dfb7ece52 --- /dev/null +++ b/sasdata/fair_database/data/serializers.py @@ -0,0 +1,529 @@ +from data import models +from django.core.exceptions import ObjectDoesNotExist +from fair_database import permissions +from rest_framework import serializers + +# TODO: more custom validation, particularly for specific nested dictionary structures +# TODO: custom update methods for nested structures + + +# Determine if an operation does not have parent operations +def constant_or_variable(operation: str): + return operation in ["zero", "one", "constant", "variable"] + + +# Determine if an operation has two parent operations +def binary(operation: str): + return operation in ["add", "sub", "mul", "div", "dot", "matmul", "tensor_product"] + + +class DataFileSerializer(serializers.ModelSerializer): + """Serialization, deserialization, and validation for the DataFile model.""" + + class Meta: + model = models.DataFile + fields = "__all__" + + # TODO: check partial updates + # Check that private data has an owner + def validate(self, data): + if not self.context["is_public"] and not data["current_user"]: + raise serializers.ValidationError("private data must have an owner") + return data + + +class AccessManagementSerializer(serializers.Serializer): + """ + Serialization, deserialization, and validation for granting and revoking + access to instances of any exposed model. + """ + + # The username of a user + username = serializers.CharField(max_length=200, required=False) + # Whether that user has access + access = serializers.BooleanField() + + +class MetaDataSerializer(serializers.ModelSerializer): + """Serialization, deserialization, and validation for the MetaData model.""" + + # associated dataset + dataset = serializers.PrimaryKeyRelatedField( + queryset=models.DataSet, required=False, allow_null=True + ) + + class Meta: + model = models.MetaData + fields = "__all__" + + # Serialize an entry in MetaData + def to_representation(self, instance): + data = super().to_representation(instance) + if "dataset" in data: + data.pop("dataset") + return data + + +class OperationTreeSerializer(serializers.ModelSerializer): + """Serialization, deserialization, and validation for the OperationTree model.""" + + # associated quantity, for root operation + quantity = serializers.PrimaryKeyRelatedField( + queryset=models.Quantity, required=False, allow_null=True + ) + # operation this operation is a parameter for, for non-root operations + child_operation = serializers.PrimaryKeyRelatedField( + queryset=models.OperationTree, required=False, allow_null=True + ) + # parameter label, for non-root operations + label = serializers.CharField(max_length=10, required=False) + + class Meta: + model = models.OperationTree + fields = ["operation", "parameters", "quantity", "label", "child_operation"] + + # Validate parent operations + def validate_parameters(self, value): + if "a" in value: + serializer = OperationTreeSerializer(data=value["a"]) + serializer.is_valid(raise_exception=True) + if "b" in value: + serializer = OperationTreeSerializer(data=value["b"]) + serializer.is_valid(raise_exception=True) + return value + + # Check that the operation has the correct parameters + def validate(self, data): + expected_parameters = { + "zero": [], + "one": [], + "constant": ["value"], + "variable": ["hash_value", "name"], + "neg": ["a"], + "reciprocal": ["a"], + "add": ["a", "b"], + "sub": ["a", "b"], + "mul": ["a", "b"], + "div": ["a", "b"], + "pow": ["a", "power"], + "transpose": ["a", "axes"], + "dot": ["a", "b"], + "matmul": ["a", "b"], + "tensor_product": ["a", "b", "a_index", "b_index"], + } + + for parameter in expected_parameters[data["operation"]]: + if parameter not in data["parameters"]: + raise serializers.ValidationError( + data["operation"] + " requires parameter " + parameter + ) + + return data + + # Serialize an OperationTree instance + def to_representation(self, instance): + data = {"operation": instance.operation, "parameters": instance.parameters} + for parent_operation in instance.parent_operations.all(): + data["parameters"][parent_operation.label] = self.to_representation( + parent_operation + ) + return data + + # Create an OperationTree instance + def create(self, validated_data): + parent_operation1 = None + parent_operation2 = None + if not constant_or_variable(validated_data["operation"]): + parent_operation1 = validated_data["parameters"].pop("a") + parent_operation1["label"] = "a" + if binary(validated_data["operation"]): + parent_operation2 = validated_data["parameters"].pop("b") + parent_operation2["label"] = "b" + operation_tree = models.OperationTree.objects.create(**validated_data) + if parent_operation1: + parent_operation1["child_operation"] = operation_tree + OperationTreeSerializer.create( + OperationTreeSerializer(), validated_data=parent_operation1 + ) + if parent_operation2: + parent_operation2["child_operation"] = operation_tree + OperationTreeSerializer.create( + OperationTreeSerializer(), validated_data=parent_operation2 + ) + return operation_tree + + +class ReferenceQuantitySerializer(serializers.ModelSerializer): + """Serialization, deserialization, and validation for the ReferenceQuantity model.""" + + # quantity whose operation tree this is a reference for + derived_quantity = serializers.PrimaryKeyRelatedField( + queryset=models.Quantity, required=False + ) + + class Meta: + model = models.ReferenceQuantity + fields = ["value", "variance", "units", "hash", "derived_quantity"] + + # serialize a ReferenceQuantity instance + def to_representation(self, instance): + data = super().to_representation(instance) + if "derived_quantity" in data: + data.pop("derived_quantity") + return data + + # create a ReferenceQuantity instance + def create(self, validated_data): + if "label" in validated_data: + validated_data.pop("label") + if "history" in validated_data: + validated_data.pop("history") + return models.ReferenceQuantity.objects.create(**validated_data) + + +class QuantitySerializer(serializers.ModelSerializer): + """Serialization, deserialization, and validation for the Quantity model.""" + + # associated operation tree + operation_tree = OperationTreeSerializer(read_only=False, required=False) + # references for the operation tree + references = ReferenceQuantitySerializer(many=True, read_only=False, required=False) + # quantity label + label = serializers.CharField(max_length=20) + # dataset this is a part of + dataset = serializers.PrimaryKeyRelatedField( + queryset=models.DataSet, required=False, allow_null=True + ) + # serialized JSON form of operation tree and references + history = serializers.JSONField(required=False, allow_null=True) + + class Meta: + model = models.Quantity + fields = [ + "value", + "variance", + "units", + "hash", + "operation_tree", + "references", + "label", + "dataset", + "history", + ] + + # validate references + def validate_history(self, value): + if "references" in value: + for ref in value["references"]: + serializer = ReferenceQuantitySerializer(data=ref) + serializer.is_valid(raise_exception=True) + + # TODO: should variable-only history be assumed to refer to the same Quantity and ignored? + # Extract operation tree from history + def to_internal_value(self, data): + if "history" in data: + data_copy = data.copy() + if "operation_tree" in data["history"]: + operations = data["history"]["operation_tree"] + if ( + "operation" in operations + and not operations["operation"] == "variable" + ): + data_copy["operation_tree"] = operations + return_data = super().to_internal_value(data_copy) + return_data["history"] = data["history"] + return return_data + else: + return super().to_internal_value(data_copy) + return super().to_internal_value(data) + + # Serialize a Quantity instance + def to_representation(self, instance): + data = super().to_representation(instance) + if "dataset" in data: + data.pop("dataset") + if "derived_quantity" in data: + data.pop("derived_quantity") + data["history"] = {} + data["history"]["operation_tree"] = data.pop("operation_tree") + data["history"]["references"] = data.pop("references") + return data + + # Create a Quantity instance + def create(self, validated_data): + operations_tree = None + references = None + if "operation_tree" in validated_data: + operations_tree = validated_data.pop("operation_tree") + if "history" in validated_data: + history = validated_data.pop("history") + if history and "references" in history: + references = history.pop("references") + quantity = models.Quantity.objects.create(**validated_data) + if operations_tree: + operations_tree["quantity"] = quantity + OperationTreeSerializer.create( + OperationTreeSerializer(), validated_data=operations_tree + ) + if references: + for ref in references: + ref["derived_quantity"] = quantity + ReferenceQuantitySerializer.create( + ReferenceQuantitySerializer(), validated_data=ref + ) + return quantity + + +class DataSetSerializer(serializers.ModelSerializer): + """Serialization, deserialization, and validation for the DataSet model.""" + + # associated metadata + metadata = MetaDataSerializer(read_only=False) + # associated files + files = serializers.PrimaryKeyRelatedField( + required=False, many=True, allow_null=True, queryset=models.DataFile.objects + ) + # quantities that make up the dataset + data_contents = QuantitySerializer(many=True, read_only=False) + # session the dataset is a part of, if any + session = serializers.PrimaryKeyRelatedField( + queryset=models.Session, required=False, allow_null=True + ) + # TODO: handle files better + + class Meta: + model = models.DataSet + fields = [ + "id", + "name", + "files", + "metadata", + "data_contents", + "is_public", + "current_user", + "users", + "session", + ] + + # Serialize a DataSet instance + def to_representation(self, instance): + data = super().to_representation(instance) + if "request" in self.context: + files = [ + file.id + for file in instance.files.all() + if ( + file.is_public + or permissions.has_access(self.context["request"], file) + ) + ] + data["files"] = files + return data + + # Check that files exist and user has access to them + def validate_files(self, value): + for file in value: + if not file.is_public and not permissions.has_access( + self.context["request"], file + ): + raise serializers.ValidationError( + "You do not have access to file " + str(file.id) + ) + return value + + # Check that private data has an owner + def validate(self, data): + if ( + not self.context["request"].user.is_authenticated + and "is_public" in data + and not data["is_public"] + ): + raise serializers.ValidationError("private data must have an owner") + if "current_user" in data and ( + data["current_user"] == "" or data["current_user"] is None + ): + if "is_public" in data: + if not data["is_public"]: + raise serializers.ValidationError("private data must have an owner") + else: + if not self.instance.is_public: + raise serializers.ValidationError("private data must have an owner") + return data + + # Create a DataSet instance + def create(self, validated_data): + files = [] + if self.context["request"].user.is_authenticated: + validated_data["current_user"] = self.context["request"].user + metadata_raw = validated_data.pop("metadata") + data_contents = validated_data.pop("data_contents") + if "files" in validated_data: + files = validated_data.pop("files") + dataset = models.DataSet.objects.create(**validated_data) + dataset.files.set(files) + metadata_raw["dataset"] = dataset + MetaDataSerializer.create(MetaDataSerializer(), validated_data=metadata_raw) + for d in data_contents: + d["dataset"] = dataset + QuantitySerializer.create(QuantitySerializer(), validated_data=d) + return dataset + + # TODO: account for updating other attributes + # Update a DataSet instance + def update(self, instance, validated_data): + if "metadata" in validated_data: + metadata_raw = validated_data.pop("metadata") + new_metadata = MetaDataSerializer.update( + MetaDataSerializer(), instance.metadata, validated_data=metadata_raw + ) + instance.metadata = new_metadata + instance.save() + return super().update(instance, validated_data) + + +class PublishedStateSerializer(serializers.ModelSerializer): + """Serialization, deserialization, and validation for the PublishedState model.""" + + # associated session + session = serializers.PrimaryKeyRelatedField( + queryset=models.Session.objects, required=False, allow_null=True + ) + + class Meta: + model = models.PublishedState + fields = "__all__" + + # check that session does not already have a published state + def validate_session(self, value): + try: + published = value.published_state + if published is not None: + raise serializers.ValidationError( + "Only one published state per session" + ) + except models.Session.published_state.RelatedObjectDoesNotExist: + return value + + # set a placeholder DOI + def to_internal_value(self, data): + data_copy = data.copy() + data_copy["doi"] = "http://127.0.0.1:8000/v1/data/session/" + return super().to_internal_value(data_copy) + + # create a PublishedState instance + def create(self, validated_data): + # TODO: generate DOI + validated_data["doi"] = ( + "http://127.0.0.1:8000/v1/data/session/" + + str(validated_data["session"].id) + + "/" + ) + return models.PublishedState.objects.create(**validated_data) + + +class PublishedStateUpdateSerializer(serializers.ModelSerializer): + """Serialization for PublishedState updates. Restricts changes to published field.""" + + class Meta: + model = models.PublishedState + fields = ["published"] + + +class SessionSerializer(serializers.ModelSerializer): + """Serialization, deserialization, and validation for the Session model.""" + + # datasets that make up the session + datasets = DataSetSerializer(read_only=False, many=True) + # associated published state, if any + published_state = PublishedStateSerializer(read_only=False, required=False) + + class Meta: + model = models.Session + fields = [ + "id", + "title", + "published_state", + "datasets", + "current_user", + "is_public", + "users", + ] + + # disallow private unowned sessions + def validate(self, data): + if ( + not self.context["request"].user.is_authenticated + and "is_public" in data + and not data["is_public"] + ): + raise serializers.ValidationError("private sessions must have an owner") + if "current_user" in data and data["current_user"] == "": + if "is_public" in data: + if not "is_public": + raise serializers.ValidationError( + "private sessions must have an owner" + ) + else: + if not self.instance.is_public: + raise serializers.ValidationError( + "private sessions must have an owner" + ) + return data + + # propagate is_public to datasets + def to_internal_value(self, data): + data_copy = data.copy() + if "is_public" in data: + if "datasets" in data: + for dataset in data_copy["datasets"]: + dataset["is_public"] = data["is_public"] + return super().to_internal_value(data_copy) + + # serialize a session instance + def to_representation(self, instance): + session = super().to_representation(instance) + for dataset in session["datasets"]: + dataset.pop("session") + return session + + # Create a Session instance + def create(self, validated_data): + published_state = None + if self.context["request"].user.is_authenticated: + validated_data["current_user"] = self.context["request"].user + if "published_state" in validated_data: + published_state = validated_data.pop("published_state") + datasets = validated_data.pop("datasets") + session = models.Session.objects.create(**validated_data) + if published_state: + published_state["session"] = session + PublishedStateSerializer.create( + PublishedStateSerializer(), validated_data=published_state + ) + for dataset in datasets: + dataset["session"] = session + DataSetSerializer.create( + DataSetSerializer(context=self.context), validated_data=dataset + ) + return session + + # update a session instance + def update(self, instance, validated_data): + if "is_public" in validated_data: + for dataset in instance.datasets.all(): + dataset.is_public = validated_data["is_public"] + dataset.save() + if "published_state" in validated_data: + pb_raw = validated_data.pop("published_state") + try: + PublishedStateUpdateSerializer.update( + PublishedStateUpdateSerializer(), + instance.published_state, + validated_data=pb_raw, + ) + except ObjectDoesNotExist: + pb_raw["session"] = instance + PublishedStateSerializer.create( + PublishedStateSerializer(), validated_data=pb_raw + ) + return super().update(instance, validated_data) diff --git a/sasdata/fair_database/data/test/__init__.py b/sasdata/fair_database/data/test/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/sasdata/fair_database/data/test/test_datafile.py b/sasdata/fair_database/data/test/test_datafile.py new file mode 100644 index 000000000..ac8dbb480 --- /dev/null +++ b/sasdata/fair_database/data/test/test_datafile.py @@ -0,0 +1,443 @@ +import os +import shutil + +from data.models import DataFile +from django.conf import settings +from django.contrib.auth.models import User +from django.db.models import Max +from django.test import TestCase +from rest_framework import status +from rest_framework.test import APIClient, APITestCase + + +# path to a file in example_data/1d_data +def find(filename): + return os.path.join( + os.path.dirname(__file__), "../../../example_data/1d_data", filename + ) + + +class TestLists(TestCase): + """Test get methods for DataFile.""" + + @classmethod + def setUpTestData(cls): + cls.public_test_data = DataFile.objects.create( + id=1, file_name="cyl_400_40.txt", is_public=True + ) + cls.public_test_data.file.save( + "cyl_400_40.txt", open(find("cyl_400_40.txt"), "rb") + ) + cls.user = User.objects.create_user( + username="testUser", password="secret", id=2 + ) + cls.private_test_data = DataFile.objects.create( + id=3, current_user=cls.user, file_name="cyl_400_20.txt", is_public=False + ) + cls.private_test_data.file.save( + "cyl_400_20.txt", open(find("cyl_400_20.txt"), "rb") + ) + cls.client_authenticated = APIClient() + cls.client_authenticated.force_authenticate(user=cls.user) + + # Test list public data + def test_does_list_public(self): + request = self.client_authenticated.get("/v1/data/file/") + self.assertEqual( + request.data, + {"public_data_ids": {1: "cyl_400_40.txt", 3: "cyl_400_20.txt"}}, + ) + + # Test list a user's private data + def test_does_list_user(self): + request = self.client_authenticated.get( + "/v1/data/file/", data={"username": "testUser"}, user=self.user + ) + self.assertEqual(request.data, {"user_data_ids": {3: "cyl_400_20.txt"}}) + + # Test list another user's public data + def test_list_other_user(self): + client_unauthenticated = APIClient() + request = client_unauthenticated.get( + "/v1/data/file/", data={"username": "testUser"}, user=self.user + ) + self.assertEqual(request.data, {"user_data_ids": {}}) + + # Test list a nonexistent user's data + def test_list_nonexistent_user(self): + request = self.client_authenticated.get( + "/v1/data/file/", data={"username": "fakeUser"} + ) + self.assertEqual(request.status_code, status.HTTP_404_NOT_FOUND) + + # Test loading a public data file + def test_does_load_data_info_public(self): + request = self.client_authenticated.get("/v1/data/file/1/") + self.assertEqual(request.status_code, status.HTTP_200_OK) + + # Test loading private data with authorization + def test_does_load_data_info_private(self): + request = self.client_authenticated.get("/v1/data/file/3/") + self.assertEqual(request.status_code, status.HTTP_200_OK) + + # Test loading data that does not exist + def test_load_data_info_nonexistent(self): + request = self.client_authenticated.get("/v1/data/file/5/") + self.assertEqual(request.status_code, status.HTTP_404_NOT_FOUND) + + @classmethod + def tearDownClass(cls): + cls.public_test_data.delete() + cls.private_test_data.delete() + cls.user.delete() + shutil.rmtree(settings.MEDIA_ROOT) + + +class TestingDatabase(APITestCase): + """Test non-get methods for DataFile.""" + + @classmethod + def setUpTestData(cls): + cls.user = User.objects.create_user( + username="testUser", password="secret", id=1 + ) + cls.data = DataFile.objects.create( + id=1, current_user=cls.user, file_name="cyl_400_20.txt", is_public=False + ) + cls.data.file.save("cyl_400_20.txt", open(find("cyl_400_20.txt"), "rb")) + cls.client_authenticated = APIClient() + cls.client_authenticated.force_authenticate(user=cls.user) + cls.client_unauthenticated = APIClient() + + # Test data upload creates data in database + def test_is_data_being_created(self): + file = open(find("cyl_400_40.txt"), "rb") + data = {"is_public": False, "file": file} + request = self.client_authenticated.post("/v1/data/file/", data=data) + max_id = DataFile.objects.aggregate(Max("id"))["id__max"] + self.assertEqual(request.status_code, status.HTTP_201_CREATED) + self.assertEqual( + request.data, + { + "current_user": "testUser", + "authenticated": True, + "file_id": max_id, + "file_alternative_name": "cyl_400_40.txt", + "is_public": False, + }, + ) + DataFile.objects.get(id=max_id).delete() + + # Test data upload w/out authenticated user + def test_is_data_being_created_no_user(self): + file = open(find("cyl_testdata.txt"), "rb") + data = {"is_public": True, "file": file} + request = self.client_unauthenticated.post("/v1/data/file/", data=data) + max_id = DataFile.objects.aggregate(Max("id"))["id__max"] + self.assertEqual(request.status_code, status.HTTP_201_CREATED) + self.assertEqual( + request.data, + { + "current_user": "", + "authenticated": False, + "file_id": max_id, + "file_alternative_name": "cyl_testdata.txt", + "is_public": True, + }, + ) + DataFile.objects.get(id=max_id).delete() + + # Test whether a user can overwrite data by specifying an in-use id + def test_no_data_overwrite(self): + file = open(find("apoferritin.txt")) + data = {"is_public": True, "file": file, id: 1} + request = self.client_authenticated.post("/v1/data/file/", data=data) + self.assertEqual(request.status_code, status.HTTP_201_CREATED) + self.assertEqual(DataFile.objects.get(id=1).file_name, "cyl_400_20.txt") + max_id = DataFile.objects.aggregate(Max("id"))["id__max"] + self.assertEqual( + request.data, + { + "current_user": "testUser", + "authenticated": True, + "file_id": max_id, + "file_alternative_name": "apoferritin.txt", + "is_public": True, + }, + ) + DataFile.objects.get(id=max_id).delete() + + # Test updating file + def test_does_file_upload_update(self): + file = open(find("cyl_testdata1.txt")) + data = {"file": file, "is_public": False} + request = self.client_authenticated.put("/v1/data/file/1/", data=data) + self.assertEqual( + request.data, + { + "current_user": "testUser", + "authenticated": True, + "file_id": 1, + "file_alternative_name": "cyl_testdata1.txt", + "is_public": False, + }, + ) + self.data.file.save("cyl_400_20.txt", open(find("cyl_400_20.txt"), "rb")) + self.data.file_name = "cyl_400_20.txt" + + # Test updating a public file + def test_public_file_upload_update(self): + data_object = DataFile.objects.create( + id=3, current_user=self.user, file_name="cyl_testdata2.txt", is_public=True + ) + data_object.file.save( + "cyl_testdata2.txt", open(find("cyl_testdata2.txt"), "rb") + ) + file = open(find("conalbumin.txt")) + data = {"file": file, "is_public": True} + request = self.client_authenticated.put("/v1/data/file/3/", data=data) + self.assertEqual( + request.data, + { + "current_user": "testUser", + "authenticated": True, + "file_id": 3, + "file_alternative_name": "conalbumin.txt", + "is_public": True, + }, + ) + data_object.delete() + + # Test file upload update fails when unauthorized + def test_unauthorized_file_upload_update(self): + file = open(find("cyl_400_40.txt")) + data = {"file": file, "is_public": False} + request = self.client_unauthenticated.put("/v1/data/file/1/", data=data) + self.assertEqual(request.status_code, status.HTTP_401_UNAUTHORIZED) + + # Test update nonexistent file fails + def test_file_upload_update_not_found(self): + file = open(find("cyl_400_40.txt")) + data = {"file": file, "is_public": False} + request = self.client_unauthenticated.put("/v1/data/file/5/", data=data) + self.assertEqual(request.status_code, status.HTTP_404_NOT_FOUND) + + # Test file download + def test_does_download(self): + request = self.client_authenticated.get( + "/v1/data/file/1/", data={"download": True} + ) + file_contents = b"".join(request.streaming_content) + test_file = open(find("cyl_400_20.txt"), "rb") + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual(file_contents, test_file.read()) + + # Test file download fails when unauthorized + def test_unauthorized_download(self): + request2 = self.client_unauthenticated.get( + "/v1/data/file/1/", data={"download": True} + ) + self.assertEqual(request2.status_code, status.HTTP_401_UNAUTHORIZED) + + # Test download nonexistent file + def test_download_nonexistent(self): + request = self.client_authenticated.get( + "/v1/data/file/5/", data={"download": True} + ) + self.assertEqual(request.status_code, status.HTTP_404_NOT_FOUND) + + # Test deleting a file + def test_delete(self): + DataFile.objects.create( + id=6, current_user=self.user, file_name="test.txt", is_public=False + ) + request = self.client_authenticated.delete("/v1/data/file/6/") + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertFalse(DataFile.objects.filter(pk=6).exists()) + + # Test deleting a file fails when unauthorized + def test_delete_unauthorized(self): + request = self.client_unauthenticated.delete("/v1/data/file/1/") + self.assertEqual(request.status_code, status.HTTP_401_UNAUTHORIZED) + + @classmethod + def tearDownClass(cls): + cls.user.delete() + cls.data.delete() + shutil.rmtree(settings.MEDIA_ROOT) + + +class TestAccessManagement(TestCase): + """Test viewing and managing access for a file.""" + + @classmethod + def setUpTestData(cls): + cls.user1 = User.objects.create_user(username="testUser", password="secret") + cls.user2 = User.objects.create_user(username="testUser2", password="secret2") + cls.private_test_data = DataFile.objects.create( + id=1, current_user=cls.user1, file_name="cyl_400_40.txt", is_public=False + ) + cls.private_test_data.file.save( + "cyl_400_40.txt", open(find("cyl_400_40.txt"), "rb") + ) + cls.shared_test_data = DataFile.objects.create( + id=2, current_user=cls.user1, file_name="cyl_400_20.txt", is_public=False + ) + cls.shared_test_data.file.save( + "cyl_400_20.txt", open(find("cyl_400_20.txt"), "rb") + ) + cls.shared_test_data.users.add(cls.user2) + cls.client_owner = APIClient() + cls.client_owner.force_authenticate(cls.user1) + cls.client_other = APIClient() + cls.client_other.force_authenticate(cls.user2) + + # test viewing no one with access + def test_view_no_access(self): + request = self.client_owner.get("/v1/data/file/1/users/") + data = { + "file": 1, + "file_name": "cyl_400_40.txt", + "is_public": False, + "users": [], + } + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual(request.data, data) + + # test viewing list of users with access + def test_view_access(self): + request = self.client_owner.get("/v1/data/file/2/users/") + data = { + "file": 2, + "file_name": "cyl_400_20.txt", + "is_public": False, + "users": ["testUser2"], + } + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual(request.data, data) + + # test granting another user access to private data + def test_grant_access(self): + data = {"username": "testUser2", "access": True} + request1 = self.client_owner.put("/v1/data/file/1/users/", data=data) + request2 = self.client_other.get("/v1/data/file/1/") + self.assertEqual(request1.status_code, status.HTTP_200_OK) + self.assertEqual(request2.status_code, status.HTTP_200_OK) + self.assertEqual( + request1.data, + { + "username": "testUser2", + "file": 1, + "file_name": "cyl_400_40.txt", + "access": True, + }, + ) + + # test removing another user's access to private data + def test_remove_access(self): + data = {"username": "testUser2", "access": False} + request1 = self.client_other.get("/v1/data/file/2/") + request2 = self.client_owner.put("/v1/data/file/2/users/", data=data) + request3 = self.client_other.get("/v1/data/file/2/") + self.assertEqual(request1.status_code, status.HTTP_200_OK) + self.assertEqual(request2.status_code, status.HTTP_200_OK) + self.assertEqual(request3.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual( + request2.data, + { + "username": "testUser2", + "file": 2, + "file_name": "cyl_400_20.txt", + "access": False, + }, + ) + + # test removing access from a user that already lacks access + def test_remove_no_access(self): + data = {"username": "testUser2", "access": False} + request1 = self.client_other.get("/v1/data/file/1/") + request2 = self.client_owner.put("/v1/data/file/1/users/", data=data) + request3 = self.client_other.get("/v1/data/file/1/") + self.assertEqual(request1.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(request2.status_code, status.HTTP_200_OK) + self.assertEqual(request3.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual( + request2.data, + { + "username": "testUser2", + "file": 1, + "file_name": "cyl_400_40.txt", + "access": False, + }, + ) + + # test owner's access cannot be removed + def test_cant_revoke_own_access(self): + data = {"username": "testUser", "access": False} + request1 = self.client_owner.put("/v1/data/file/1/users/", data=data) + request2 = self.client_owner.get("/v1/data/file/1/") + self.assertEqual(request1.status_code, status.HTTP_200_OK) + self.assertEqual(request2.status_code, status.HTTP_200_OK) + self.assertEqual( + request1.data, + { + "username": "testUser", + "file": 1, + "file_name": "cyl_400_40.txt", + "access": True, + }, + ) + + # test giving access to a user that already has access + def test_grant_existing_access(self): + data = {"username": "testUser2", "access": True} + request1 = self.client_other.get("/v1/data/file/2/") + request2 = self.client_owner.put("/v1/data/file/2/users/", data=data) + request3 = self.client_other.get("/v1/data/file/2/") + self.assertEqual(request1.status_code, status.HTTP_200_OK) + self.assertEqual(request2.status_code, status.HTTP_200_OK) + self.assertEqual(request3.status_code, status.HTTP_200_OK) + self.assertEqual( + request2.data, + { + "username": "testUser2", + "file": 2, + "file_name": "cyl_400_20.txt", + "access": True, + }, + ) + + # test that access is read-only for the file + def test_no_edit_access(self): + data = {"is_public": True} + request = self.client_other.put("/v1/data/file/2/", data=data) + self.assertEqual(request.status_code, status.HTTP_403_FORBIDDEN) + self.assertFalse(self.shared_test_data.is_public) + + # test that only the owner can view who has access + def test_only_view_access_to_owned_file(self): + request1 = self.client_other.get("/v1/data/file/1/users/") + request2 = self.client_other.get("/v1/data/file/2/users/") + self.assertEqual(request1.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(request2.status_code, status.HTTP_403_FORBIDDEN) + + # test that only the owner can change access + def test_only_edit_access_to_owned_file(self): + data1 = {"username": "testUser2", "access": True} + data2 = {"username": "testUser1", "access": False} + request1 = self.client_other.put("/v1/data/file/1/users/", data=data1) + request2 = self.client_other.put("/v1/data/file/2/users/", data=data2) + request3 = self.client_other.get("/v1/data/file/1/") + request4 = self.client_owner.get("/v1/data/file/2/") + self.assertEqual(request1.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(request2.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(request3.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(request4.status_code, status.HTTP_200_OK) + + @classmethod + def tearDownClass(cls): + cls.user1.delete() + cls.user2.delete() + cls.private_test_data.delete() + cls.shared_test_data.delete() + shutil.rmtree(settings.MEDIA_ROOT) diff --git a/sasdata/fair_database/data/test/test_dataset.py b/sasdata/fair_database/data/test/test_dataset.py new file mode 100644 index 000000000..aecfb7247 --- /dev/null +++ b/sasdata/fair_database/data/test/test_dataset.py @@ -0,0 +1,736 @@ +import os +import shutil + +from data.models import DataFile, DataSet, MetaData, OperationTree, Quantity +from django.conf import settings +from django.contrib.auth.models import User +from django.db.models import Max +from rest_framework import status +from rest_framework.test import APIClient, APITestCase + + +# path to a file in example_data/1d_data +def find(filename): + return os.path.join( + os.path.dirname(__file__), "../../../example_data/1d_data", filename + ) + + +class TestDataSet(APITestCase): + """Test HTTP methods of DataSetView.""" + + @classmethod + def setUpTestData(cls): + cls.empty_metadata = { + "title": "New Metadata", + "run": ["X"], + "description": "test", + "instrument": {}, + "process": {}, + "sample": {}, + } + cls.empty_data = [ + { + "value": 0, + "variance": 0, + "units": "no", + "hash": 0, + "label": "test", + "history": {"operation_tree": {}, "references": []}, + } + ] + cls.user1 = User.objects.create_user( + id=1, username="testUser1", password="secret" + ) + cls.user2 = User.objects.create_user( + id=2, username="testUser2", password="secret" + ) + cls.user3 = User.objects.create_user( + id=3, username="testUser3", password="secret" + ) + cls.public_dataset = DataSet.objects.create( + id=1, + current_user=cls.user1, + is_public=True, + name="Dataset 1", + ) + cls.private_dataset = DataSet.objects.create( + id=2, current_user=cls.user1, name="Dataset 2" + ) + cls.unowned_dataset = DataSet.objects.create( + id=3, is_public=True, name="Dataset 3" + ) + cls.private_dataset.users.add(cls.user3) + cls.auth_client1 = APIClient() + cls.auth_client2 = APIClient() + cls.auth_client3 = APIClient() + cls.auth_client1.force_authenticate(cls.user1) + cls.auth_client2.force_authenticate(cls.user2) + cls.auth_client3.force_authenticate(cls.user3) + + # Test a user can list their own private data + def test_list_private(self): + request = self.auth_client1.get("/v1/data/set/") + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data, + {"dataset_ids": {1: "Dataset 1", 2: "Dataset 2", 3: "Dataset 3"}}, + ) + + # Test a user can see others' public but not private data in list + def test_list_public(self): + request = self.auth_client2.get("/v1/data/set/") + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data, {"dataset_ids": {1: "Dataset 1", 3: "Dataset 3"}} + ) + + # Test a user can see private data they have been granted access to + def test_list_granted_access(self): + request = self.auth_client3.get("/v1/data/set/") + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data, + {"dataset_ids": {1: "Dataset 1", 2: "Dataset 2", 3: "Dataset 3"}}, + ) + + # Test an unauthenticated user can list public data + def test_list_unauthenticated(self): + request = self.client.get("/v1/data/set/") + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data, {"dataset_ids": {1: "Dataset 1", 3: "Dataset 3"}} + ) + + # Test a user can see all data listed by their username + def test_list_username(self): + request = self.auth_client1.get("/v1/data/set/", data={"username": "testUser1"}) + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data, {"dataset_ids": {1: "Dataset 1", 2: "Dataset 2"}} + ) + + # Test a user can list public data by another user's username + def test_list_username_2(self): + request = self.auth_client1.get("/v1/data/set/", {"username": "testUser2"}) + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual(request.data, {"dataset_ids": {}}) + + # Test an unauthenticated user can list public data by a username + def test_list_username_unauthenticated(self): + request = self.client.get("/v1/data/set/", {"username": "testUser1"}) + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual(request.data, {"dataset_ids": {1: "Dataset 1"}}) + + # Test listing by a username that doesn't exist + def test_list_wrong_username(self): + request = self.auth_client1.get("/v1/data/set/", {"username": "fakeUser1"}) + self.assertEqual(request.status_code, status.HTTP_404_NOT_FOUND) + + # TODO: test listing by other parameters if functionality is added for that + + # Test creating a dataset with associated metadata + def test_dataset_created(self): + dataset = { + "name": "New Dataset", + "metadata": self.empty_metadata, + "data_contents": self.empty_data, + } + request = self.auth_client1.post("/v1/data/set/", data=dataset, format="json") + max_id = DataSet.objects.aggregate(Max("id"))["id__max"] + new_dataset = DataSet.objects.get(id=max_id) + new_metadata = new_dataset.metadata + self.assertEqual(request.status_code, status.HTTP_201_CREATED) + self.assertEqual( + request.data, + { + "dataset_id": max_id, + "name": "New Dataset", + "authenticated": True, + "current_user": "testUser1", + "is_public": False, + }, + ) + self.assertEqual(new_dataset.name, "New Dataset") + self.assertEqual(new_metadata.title, "New Metadata") + self.assertEqual(new_dataset.current_user.username, "testUser1") + new_dataset.delete() + new_metadata.delete() + + # Test creating a dataset while unauthenticated + def test_dataset_created_unauthenticated(self): + dataset = { + "name": "New Dataset", + "metadata": self.empty_metadata, + "is_public": True, + "data_contents": self.empty_data, + } + request = self.client.post("/v1/data/set/", data=dataset, format="json") + max_id = DataSet.objects.aggregate(Max("id"))["id__max"] + new_dataset = DataSet.objects.get(id=max_id) + new_metadata = new_dataset.metadata + self.assertEqual(request.status_code, status.HTTP_201_CREATED) + self.assertEqual( + request.data, + { + "dataset_id": max_id, + "name": "New Dataset", + "authenticated": False, + "current_user": "", + "is_public": True, + }, + ) + self.assertEqual(new_dataset.name, "New Dataset") + self.assertIsNone(new_dataset.current_user) + new_dataset.delete() + new_metadata.delete() + + # Test creating a database with associated files + def test_dataset_created_with_files(self): + file = DataFile.objects.create( + id=1, file_name="cyl_testdata.txt", is_public=True + ) + file.file.save("cyl_testdata.txt", open(find("cyl_testdata.txt"))) + dataset = { + "name": "Dataset with file", + "metadata": self.empty_metadata, + "is_public": True, + "data_contents": self.empty_data, + "files": [1], + } + request = self.client.post("/v1/data/set/", data=dataset, format="json") + max_id = DataSet.objects.aggregate(Max("id"))["id__max"] + new_dataset = DataSet.objects.get(id=max_id) + self.assertEqual(request.status_code, status.HTTP_201_CREATED) + self.assertEqual( + request.data, + { + "dataset_id": max_id, + "name": "Dataset with file", + "authenticated": False, + "current_user": "", + "is_public": True, + }, + ) + self.assertTrue(file in new_dataset.files.all()) + new_dataset.delete() + file.delete() + + # Test that a dataset cannot be associated with inaccessible files + def test_no_dataset_with_private_files(self): + file = DataFile.objects.create( + id=1, file_name="cyl_testdata.txt", is_public=False, current_user=self.user2 + ) + file.file.save("cyl_testdata.txt", open(find("cyl_testdata.txt"))) + dataset = { + "name": "Dataset with file", + "metadata": self.empty_metadata, + "is_public": True, + "data_contents": self.empty_data, + "files": [1], + } + request = self.client.post("/v1/data/set/", data=dataset, format="json") + file.delete() + self.assertEqual(request.status_code, status.HTTP_400_BAD_REQUEST) + + # Test that a dataset cannot be associated with nonexistent files + def test_no_dataset_with_nonexistent_files(self): + dataset = { + "name": "Dataset with file", + "metadata": self.empty_metadata, + "is_public": True, + "data_contents": self.empty_data, + "files": [2], + } + request = self.client.post("/v1/data/set/", data=dataset, format="json") + self.assertEqual(request.status_code, status.HTTP_400_BAD_REQUEST) + + # Test that a dataset cannot be created without metadata + def test_metadata_required(self): + dataset = { + "name": "No metadata", + "is_public": True, + "data_contents": self.empty_data, + } + request = self.auth_client1.post("/v1/data/set/", data=dataset, format="json") + self.assertEqual(request.status_code, status.HTTP_400_BAD_REQUEST) + + # Test that a private dataset cannot be created without an owner + def test_no_private_unowned_dataset(self): + dataset = { + "name": "Disallowed Dataset", + "metadata": self.empty_metadata, + "is_public": False, + "data_contents": self.empty_data, + } + request = self.client.post("/v1/data/set/", data=dataset, format="json") + self.assertEqual(request.status_code, status.HTTP_400_BAD_REQUEST) + + # Test whether a user can overwrite data by specifying an in-use id + def test_no_data_overwrite(self): + dataset = { + "id": 2, + "name": "Overwrite Dataset", + "metadata": self.empty_metadata, + "is_public": True, + "data_contents": self.empty_data, + } + request = self.auth_client2.post("/v1/data/set/", data=dataset, format="json") + max_id = DataSet.objects.aggregate(Max("id"))["id__max"] + self.assertEqual(request.status_code, status.HTTP_201_CREATED) + self.assertEqual(DataSet.objects.get(id=2).name, "Dataset 2") + self.assertEqual( + request.data, + { + "dataset_id": max_id, + "name": "Overwrite Dataset", + "authenticated": True, + "current_user": "testUser2", + "is_public": True, + }, + ) + DataSet.objects.get(id=max_id).delete() + + @classmethod + def tearDownClass(cls): + cls.public_dataset.delete() + cls.private_dataset.delete() + cls.unowned_dataset.delete() + cls.user1.delete() + cls.user2.delete() + cls.user3.delete() + shutil.rmtree(settings.MEDIA_ROOT) + + +class TestSingleDataSet(APITestCase): + """Tests for HTTP methods of SingleDataSetView.""" + + @classmethod + def setUpTestData(cls): + cls.user1 = User.objects.create_user( + id=1, username="testUser1", password="secret" + ) + cls.user2 = User.objects.create_user( + id=2, username="testUser2", password="secret" + ) + cls.user3 = User.objects.create_user( + id=3, username="testUser3", password="secret" + ) + cls.public_dataset = DataSet.objects.create( + id=1, + current_user=cls.user1, + is_public=True, + name="Dataset 1", + ) + cls.private_dataset = DataSet.objects.create( + id=2, current_user=cls.user1, name="Dataset 2" + ) + cls.unowned_dataset = DataSet.objects.create( + id=3, is_public=True, name="Dataset 3" + ) + cls.metadata = MetaData.objects.create( + id=1, + title="Metadata", + run=0, + definition="test", + instrument="none", + process="none", + sample="none", + dataset=cls.public_dataset, + ) + cls.file = DataFile.objects.create( + id=1, file_name="cyl_testdata.txt", is_public=False, current_user=cls.user1 + ) + cls.file.file.save("cyl_testdata.txt", open(find("cyl_testdata.txt"))) + cls.private_dataset.users.add(cls.user3) + cls.public_dataset.files.add(cls.file) + cls.auth_client1 = APIClient() + cls.auth_client2 = APIClient() + cls.auth_client3 = APIClient() + cls.auth_client1.force_authenticate(cls.user1) + cls.auth_client2.force_authenticate(cls.user2) + cls.auth_client3.force_authenticate(cls.user3) + + # TODO: change load return data + # Test successfully accessing a private dataset + def test_load_private_dataset(self): + request1 = self.auth_client1.get("/v1/data/set/2/") + request2 = self.auth_client3.get("/v1/data/set/2/") + self.assertEqual(request1.status_code, status.HTTP_200_OK) + self.assertEqual(request2.status_code, status.HTTP_200_OK) + self.assertEqual( + request1.data, + { + "id": 2, + "current_user": "testUser1", + "users": [3], + "is_public": False, + "name": "Dataset 2", + "files": [], + "metadata": None, + "data_contents": [], + "session": None, + }, + ) + + # Test successfully accessing a public dataset + def test_load_public_dataset(self): + request1 = self.client.get("/v1/data/set/1/") + request2 = self.auth_client2.get("/v1/data/set/1/") + request3 = self.auth_client1.get("/v1/data/set/1/") + self.assertEqual(request1.status_code, status.HTTP_200_OK) + self.assertEqual(request2.status_code, status.HTTP_200_OK) + self.assertEqual(request3.status_code, status.HTTP_200_OK) + self.assertDictEqual( + request1.data, + { + "id": 1, + "current_user": "testUser1", + "users": [], + "is_public": True, + "name": "Dataset 1", + "files": [], + "metadata": { + "id": 1, + "title": "Metadata", + "run": 0, + "definition": "test", + "instrument": "none", + "process": "none", + "sample": "none", + }, + "data_contents": [], + "session": None, + }, + ) + self.assertEqual(request1.data, request2.data) + self.assertEqual( + request3.data, + { + "id": 1, + "current_user": "testUser1", + "users": [], + "is_public": True, + "name": "Dataset 1", + "files": [1], + "metadata": { + "id": 1, + "title": "Metadata", + "run": 0, + "definition": "test", + "instrument": "none", + "process": "none", + "sample": "none", + }, + "data_contents": [], + "session": None, + }, + ) + + # Test successfully accessing an unowned public dataset + def test_load_unowned_dataset(self): + request1 = self.auth_client1.get("/v1/data/set/3/") + request2 = self.client.get("/v1/data/set/3/") + self.assertEqual(request1.status_code, status.HTTP_200_OK) + self.assertEqual(request2.status_code, status.HTTP_200_OK) + self.assertDictEqual( + request1.data, + { + "id": 3, + "current_user": None, + "users": [], + "is_public": True, + "name": "Dataset 3", + "files": [], + "metadata": None, + "data_contents": [], + "session": None, + }, + ) + + # Test unsuccessfully accessing a private dataset + def test_load_private_dataset_unauthorized(self): + request1 = self.auth_client2.get("/v1/data/set/2/") + request2 = self.client.get("/v1/data/set/2/") + self.assertEqual(request1.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(request2.status_code, status.HTTP_401_UNAUTHORIZED) + + # Test only owner can change a private dataset + def test_update_private_dataset(self): + request1 = self.auth_client1.put("/v1/data/set/2/", data={"is_public": True}) + request2 = self.auth_client3.put("/v1/data/set/2/", data={"is_public": False}) + self.assertEqual(request1.status_code, status.HTTP_200_OK) + self.assertEqual(request2.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual( + request1.data, {"data_id": 2, "name": "Dataset 2", "is_public": True} + ) + self.assertTrue(DataSet.objects.get(id=2).is_public) + self.private_dataset.save() + self.assertFalse(DataSet.objects.get(id=2).is_public) + + # Test changing a public dataset + def test_update_public_dataset(self): + request1 = self.auth_client1.put( + "/v1/data/set/1/", data={"name": "Different name"} + ) + request2 = self.auth_client2.put("/v1/data/set/1/", data={"is_public": False}) + self.assertEqual(request1.status_code, status.HTTP_200_OK) + self.assertEqual(request2.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual( + request1.data, {"data_id": 1, "name": "Different name", "is_public": True} + ) + self.assertEqual(DataSet.objects.get(id=1).name, "Different name") + self.public_dataset.save() + + # TODO: test invalid updates if and when those are figured out + + # Test changing an unowned dataset + def test_update_unowned_dataset(self): + request1 = self.auth_client1.put("/v1/data/set/3/", data={"current_user": 1}) + request2 = self.client.put("/v1/data/set/3/", data={"name": "Different name"}) + self.assertEqual(request1.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(request2.status_code, status.HTTP_401_UNAUTHORIZED) + + # Test updating metadata + def test_update_dataset_metadata(self): + new_metadata = { + "title": "Updated Metadata", + "run": ["X"], + "definition": "update test", + "instrument": "none", + "process": "none", + "sample": "none", + } + request = self.auth_client1.put( + "/v1/data/set/1/", data={"metadata": new_metadata}, format="json" + ) + dataset = DataSet.objects.get(id=1) + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual(dataset.metadata.title, "Updated Metadata") + self.assertEqual(dataset.metadata.id, 1) + self.assertEqual(len(MetaData.objects.all()), 1) + dataset.metadata.delete() + self.metadata = MetaData.objects.create( + id=1, + title="Metadata", + run=0, + definition="test", + instrument="none", + process="none", + sample="none", + dataset=self.public_dataset, + ) + + # Test partially updating metadata + def test_update_dataset_partial_metadata(self): + request = self.auth_client1.put( + "/v1/data/set/1/", + data={"metadata": {"title": "Different Title"}}, + format="json", + ) + dataset = DataSet.objects.get(id=1) + metadata = dataset.metadata + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual(metadata.title, "Different Title") + self.assertEqual(metadata.definition, "test") + self.assertEqual(metadata.id, 1) + metadata.title = "Metadata" + metadata.save() + + # Test updating a dataset's files + def test_update_dataset_files(self): + request = self.auth_client1.put("/v1/data/set/2/", data={"files": [1]}) + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual(len(DataSet.objects.get(id=2).files.all()), 1) + self.private_dataset.files.remove(self.file) + + # Test replacing a dataset's files + def test_update_dataset_replace_files(self): + file = DataFile.objects.create( + id=2, file_name="cyl_testdata1.txt", is_public=True, current_user=self.user1 + ) + file.file.save("cyl_testdata1.txt", open(find("cyl_testdata1.txt"))) + request = self.auth_client1.put("/v1/data/set/1/", data={"files": [2]}) + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual(len(DataSet.objects.get(id=1).files.all()), 1) + self.assertTrue(file in DataSet.objects.get(id=1).files.all()) + self.public_dataset.files.add(self.file) + self.public_dataset.files.remove(file) + + # Test updating a dataset to have no files + def test_update_dataset_clear_files(self): + request = self.auth_client1.put("/v1/data/set/1/", data={"files": [""]}) + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual(len(DataSet.objects.get(id=1).files.all()), 0) + self.public_dataset.files.add(self.file) + + # Test that a dataset cannot be updated to be private and unowned + def test_update_dataset_no_private_unowned(self): + request1 = self.auth_client1.put("/v1/data/set/2/", data={"current_user": ""}) + request2 = self.auth_client1.put( + "/v1/data/set/1/", data={"current_user": "", "is_public": False} + ) + public_dataset = DataSet.objects.get(id=1) + self.assertEqual(request1.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(request2.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(DataSet.objects.get(id=2).current_user, self.user1) + self.assertEqual(public_dataset.current_user, self.user1) + self.assertTrue(public_dataset.is_public) + + # Test deleting a dataset + def test_delete_dataset(self): + quantity = Quantity.objects.create( + id=1, + value=0, + variance=0, + units="none", + hash=0, + label="test", + dataset=self.private_dataset, + ) + neg = OperationTree.objects.create(id=1, operation="neg", quantity=quantity) + OperationTree.objects.create( + id=2, operation="zero", parameters={}, child_operation=neg + ) + request = self.auth_client1.delete("/v1/data/set/2/") + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual(request.data, {"success": True}) + self.assertRaises(DataSet.DoesNotExist, DataSet.objects.get, id=2) + self.assertRaises(Quantity.DoesNotExist, Quantity.objects.get, id=1) + self.assertRaises(OperationTree.DoesNotExist, OperationTree.objects.get, id=1) + self.assertRaises(OperationTree.DoesNotExist, OperationTree.objects.get, id=2) + self.private_dataset = DataSet.objects.create( + id=2, current_user=self.user1, name="Dataset 2" + ) + + # Test cannot delete a public dataset + def test_delete_public_dataset(self): + request = self.auth_client1.delete("/v1/data/set/1/") + self.assertEqual(request.status_code, status.HTTP_403_FORBIDDEN) + + # Test cannot delete an unowned dataset + def test_delete_unowned_dataset(self): + request = self.auth_client1.delete("/v1/data/set/3/") + self.assertEqual(request.status_code, status.HTTP_403_FORBIDDEN) + + # Test cannot delete another user's dataset + def test_delete_dataset_unauthorized(self): + request1 = self.auth_client2.delete("/v1/data/set/1/") + request2 = self.auth_client3.delete("/v1/data/set/2/") + self.assertEqual(request1.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(request2.status_code, status.HTTP_403_FORBIDDEN) + + @classmethod + def tearDownClass(cls): + cls.public_dataset.delete() + cls.private_dataset.delete() + cls.unowned_dataset.delete() + cls.user1.delete() + cls.user2.delete() + cls.user3.delete() + cls.file.delete() + shutil.rmtree(settings.MEDIA_ROOT) + + +class TestDataSetAccessManagement(APITestCase): + """Tests for HTTP methods of DataSetUsersView.""" + + @classmethod + def setUpTestData(cls): + cls.user1 = User.objects.create_user(username="testUser1", password="secret") + cls.user2 = User.objects.create_user(username="testUser2", password="secret") + cls.private_dataset = DataSet.objects.create( + id=1, current_user=cls.user1, name="Dataset 1" + ) + cls.shared_dataset = DataSet.objects.create( + id=2, current_user=cls.user1, name="Dataset 2" + ) + cls.shared_dataset.users.add(cls.user2) + cls.client_owner = APIClient() + cls.client_other = APIClient() + cls.client_owner.force_authenticate(cls.user1) + cls.client_other.force_authenticate(cls.user2) + + # Test listing no users with access + def test_list_access_private(self): + request1 = self.client_owner.get("/v1/data/set/1/users/") + self.assertEqual(request1.status_code, status.HTTP_200_OK) + self.assertEqual( + request1.data, + {"data_id": 1, "name": "Dataset 1", "is_public": False, "users": []}, + ) + + # Test listing users with access + def test_list_access_shared(self): + request1 = self.client_owner.get("/v1/data/set/2/users/") + self.assertEqual(request1.status_code, status.HTTP_200_OK) + self.assertEqual( + request1.data, + { + "data_id": 2, + "name": "Dataset 2", + "is_public": False, + "users": ["testUser2"], + }, + ) + + # Test only owner can view access + def test_list_access_unauthorized(self): + request = self.client_other.get("/v1/data/set/2/users/") + self.assertEqual(request.status_code, status.HTTP_403_FORBIDDEN) + + # Test granting access to a dataset + def test_grant_access(self): + request1 = self.client_owner.put( + "/v1/data/set/1/users/", data={"username": "testUser2", "access": True} + ) + request2 = self.client_other.get("/v1/data/set/1/") + self.assertEqual(request1.status_code, status.HTTP_200_OK) + self.assertEqual(request2.status_code, status.HTTP_200_OK) + self.assertIn( # codespell:ignore + self.user2, DataSet.objects.get(id=1).users.all() + ) + self.assertEqual( + request1.data, + { + "username": "testUser2", + "data_id": 1, + "name": "Dataset 1", + "access": True, + }, + ) + self.private_dataset.users.remove(self.user2) + + # Test revoking access to a dataset + def test_revoke_access(self): + request1 = self.client_owner.put( + "/v1/data/set/2/users/", data={"username": "testUser2", "access": False} + ) + request2 = self.client_other.get("/v1/data/set/2/") + self.assertEqual(request1.status_code, status.HTTP_200_OK) + self.assertEqual(request2.status_code, status.HTTP_403_FORBIDDEN) + self.assertNotIn(self.user2, DataSet.objects.get(id=2).users.all()) + self.assertEqual( + request1.data, + { + "username": "testUser2", + "data_id": 2, + "name": "Dataset 2", + "access": False, + }, + ) + self.shared_dataset.users.add(self.user2) + + # Test only the owner can change access + def test_revoke_access_unauthorized(self): + request1 = self.client_other.put( + "/v1/data/set/2/users/", data={"username": "testUser2", "access": False} + ) + self.assertEqual(request1.status_code, status.HTTP_403_FORBIDDEN) + + @classmethod + def tearDownClass(cls): + cls.private_dataset.delete() + cls.shared_dataset.delete() + cls.user1.delete() + cls.user2.delete() diff --git a/sasdata/fair_database/data/test/test_operation_tree.py b/sasdata/fair_database/data/test/test_operation_tree.py new file mode 100644 index 000000000..90a26d81b --- /dev/null +++ b/sasdata/fair_database/data/test/test_operation_tree.py @@ -0,0 +1,798 @@ +from data.models import DataSet, MetaData, OperationTree, Quantity, ReferenceQuantity +from django.contrib.auth.models import User +from django.db.models import Max +from rest_framework import status +from rest_framework.test import APIClient, APITestCase + + +class TestCreateOperationTree(APITestCase): + """Tests for creating datasets with operation trees.""" + + @classmethod + def setUpTestData(cls): + cls.dataset = { + "name": "Test Dataset", + "metadata": { + "title": "test metadata", + "run": 1, + "definition": "test", + "instrument": {"source": {}, "collimation": {}, "detectors": {}}, + }, + "data_contents": [ + { + "label": "test", + "value": {"array_contents": [0, 0, 0, 0], "shape": (2, 2)}, + "variance": {"array_contents": [0, 0, 0, 0], "shape": (2, 2)}, + "units": "none", + "hash": 0, + } + ], + "is_public": True, + } + cls.user = User.objects.create_user( + id=1, username="testUser", password="sasview!" + ) + cls.client = APIClient() + cls.client.force_authenticate(cls.user) + + @staticmethod + def get_operation_tree(quantity): + return quantity.operation_tree + + # Test creating quantity with no operations performed (variable-only history) + def test_operation_tree_created_variable(self): + self.dataset["data_contents"][0]["history"] = { + "operation_tree": { + "operation": "variable", + "parameters": {"hash_value": 0, "name": "test"}, + }, + "references": [ + { + "label": "test", + "value": {"array_contents": [0, 0, 0, 0], "shape": (2, 2)}, + "variance": {"array_contents": [0, 0, 0, 0], "shape": (2, 2)}, + "units": "none", + "hash": 0, + "history": {}, + } + ], + } + request = self.client.post("/v1/data/set/", data=self.dataset, format="json") + max_id = DataSet.objects.aggregate(Max("id"))["id__max"] + new_dataset = DataSet.objects.get(id=max_id) + new_quantity = new_dataset.data_contents.get(hash=0) + self.assertEqual(request.status_code, status.HTTP_201_CREATED) + self.assertRaises( + Quantity.operation_tree.RelatedObjectDoesNotExist, + self.get_operation_tree, + quantity=new_quantity, + ) + self.assertEqual(len(new_quantity.references.all()), 0) + + # Test creating quantity with unary operation + def test_operation_tree_created_unary(self): + self.dataset["data_contents"][0]["history"] = { + "operation_tree": { + "operation": "reciprocal", + "parameters": { + "a": { + "operation": "variable", + "parameters": {"hash_value": 111, "name": "x"}, + } + }, + }, + "references": [ + {"value": 5, "variance": 0, "units": "none", "hash": 111, "history": {}} + ], + } + request = self.client.post("/v1/data/set/", data=self.dataset, format="json") + max_id = DataSet.objects.aggregate(Max("id"))["id__max"] + new_dataset = DataSet.objects.get(id=max_id) + new_quantity = new_dataset.data_contents.get(hash=0) + reciprocal = new_quantity.operation_tree + variable = reciprocal.parent_operations.all().get(label="a") + self.assertEqual(request.status_code, status.HTTP_201_CREATED) + self.assertEqual( + new_quantity.value, {"array_contents": [0, 0, 0, 0], "shape": [2, 2]} + ) + self.assertEqual(reciprocal.operation, "reciprocal") + self.assertEqual(variable.operation, "variable") + self.assertEqual(len(reciprocal.parent_operations.all()), 1) + self.assertEqual(reciprocal.parameters, {}) + self.assertEqual(len(ReferenceQuantity.objects.all()), 1) + self.assertEqual(len(new_quantity.references.all()), 1) + self.assertEqual(new_quantity.references.get(hash=111).value, 5) + + # Test creating quantity with binary operation + def test_operation_tree_created_binary(self): + self.dataset["data_contents"][0]["history"] = { + "operation_tree": { + "operation": "add", + "parameters": { + "a": { + "operation": "variable", + "parameters": {"hash_value": 111, "name": "x"}, + }, + "b": {"operation": "constant", "parameters": {"value": 5}}, + }, + }, + "references": [ + {"value": 5, "variance": 0, "units": "none", "hash": 111, "history": {}} + ], + } + request = self.client.post("/v1/data/set/", data=self.dataset, format="json") + max_id = DataSet.objects.aggregate(Max("id"))["id__max"] + new_dataset = DataSet.objects.get(id=max_id) + new_quantity = new_dataset.data_contents.get(hash=0) + add = new_quantity.operation_tree + variable = add.parent_operations.get(label="a") + constant = add.parent_operations.get(label="b") + self.assertEqual(request.status_code, status.HTTP_201_CREATED) + self.assertEqual(add.operation, "add") + self.assertEqual(add.parameters, {}) + self.assertEqual(variable.operation, "variable") + self.assertEqual(variable.parameters, {"hash_value": 111, "name": "x"}) + self.assertEqual(constant.operation, "constant") + self.assertEqual(constant.parameters, {"value": 5}) + self.assertEqual(len(add.parent_operations.all()), 2) + self.assertEqual(len(new_quantity.references.all()), 1) + self.assertEqual(new_quantity.references.get(hash=111).value, 5) + + # Test creating quantity with exponent + def test_operation_tree_created_pow(self): + self.dataset["data_contents"][0]["history"] = { + "operation_tree": { + "operation": "pow", + "parameters": { + "a": { + "operation": "variable", + "parameters": {"hash_value": 111, "name": "x"}, + }, + "power": 2, + }, + }, + "references": [ + {"value": 5, "variance": 0, "units": "none", "hash": 111, "history": {}} + ], + } + request = self.client.post("/v1/data/set/", data=self.dataset, format="json") + max_id = DataSet.objects.aggregate(Max("id"))["id__max"] + new_dataset = DataSet.objects.get(id=max_id) + new_quantity = new_dataset.data_contents.get(hash=0) + pow = new_quantity.operation_tree + self.assertEqual(request.status_code, status.HTTP_201_CREATED) + self.assertEqual(pow.operation, "pow") + self.assertEqual(pow.parameters, {"power": 2}) + self.assertEqual(len(new_quantity.references.all()), 1) + self.assertEqual(new_quantity.references.get(hash=111).value, 5) + + # Test creating a transposed quantity + def test_operation_tree_created_transpose(self): + self.dataset["data_contents"][0]["history"] = { + "operation_tree": { + "operation": "transpose", + "parameters": { + "a": { + "operation": "variable", + "parameters": {"hash_value": 111, "name": "x"}, + }, + "axes": [1, 0], + }, + }, + "references": [ + {"value": 5, "variance": 0, "units": "none", "hash": 111, "history": {}} + ], + } + request = self.client.post("/v1/data/set/", data=self.dataset, format="json") + max_id = DataSet.objects.aggregate(Max("id"))["id__max"] + new_dataset = DataSet.objects.get(id=max_id) + new_quantity = new_dataset.data_contents.get(hash=0) + transpose = new_quantity.operation_tree + variable = transpose.parent_operations.get() + self.assertEqual(request.status_code, status.HTTP_201_CREATED) + self.assertEqual(transpose.operation, "transpose") + self.assertEqual(transpose.parameters, {"axes": [1, 0]}) + self.assertEqual(variable.operation, "variable") + self.assertEqual(variable.parameters, {"hash_value": 111, "name": "x"}) + self.assertEqual(len(new_quantity.references.all()), 1) + self.assertEqual(new_quantity.references.get(hash=111).value, 5) + + # Test creating a quantity with multiple operations + def test_operation_tree_created_nested(self): + self.dataset["data_contents"][0]["history"] = { + "operation_tree": { + "operation": "neg", + "parameters": { + "a": { + "operation": "mul", + "parameters": { + "a": { + "operation": "constant", + "parameters": {"value": {"type": "int", "value": 7}}, + }, + "b": { + "operation": "variable", + "parameters": {"hash_value": 111, "name": "x"}, + }, + }, + }, + }, + }, + "references": [ + {"value": 5, "variance": 0, "units": "none", "hash": 111, "history": {}} + ], + } + request = self.client.post("/v1/data/set/", data=self.dataset, format="json") + max_id = DataSet.objects.aggregate(Max("id"))["id__max"] + new_dataset = DataSet.objects.get(id=max_id) + new_quantity = new_dataset.data_contents.get(hash=0) + negate = new_quantity.operation_tree + multiply = negate.parent_operations.get() + constant = multiply.parent_operations.get(label="a") + variable = multiply.parent_operations.get(label="b") + self.assertEqual(request.status_code, status.HTTP_201_CREATED) + self.assertEqual(negate.operation, "neg") + self.assertEqual(negate.parameters, {}) + self.assertEqual(multiply.operation, "mul") + self.assertEqual(multiply.parameters, {}) + self.assertEqual(constant.operation, "constant") + self.assertEqual(constant.parameters, {"value": {"type": "int", "value": 7}}) + self.assertEqual(variable.operation, "variable") + self.assertEqual(variable.parameters, {"hash_value": 111, "name": "x"}) + self.assertEqual(len(new_quantity.references.all()), 1) + self.assertEqual(new_quantity.references.get(hash=111).value, 5) + + # Test creating a quantity with tensordot + def test_operation_tree_created_tensor(self): + self.dataset["data_contents"][0]["history"] = { + "operation_tree": { + "operation": "tensor_product", + "parameters": { + "a": { + "operation": "variable", + "parameters": {"hash_value": 111, "name": "x"}, + }, + "b": {"operation": "constant", "parameters": {"value": 5}}, + "a_index": 1, + "b_index": 1, + }, + }, + "references": [ + {"value": 5, "variance": 0, "units": "none", "hash": 111, "history": {}} + ], + } + request = self.client.post("/v1/data/set/", data=self.dataset, format="json") + max_id = DataSet.objects.aggregate(Max("id"))["id__max"] + new_dataset = DataSet.objects.get(id=max_id) + new_quantity = new_dataset.data_contents.get(hash=0) + tensor = new_quantity.operation_tree + self.assertEqual(request.status_code, status.HTTP_201_CREATED) + self.assertEqual(tensor.operation, "tensor_product") + self.assertEqual(tensor.parameters, {"a_index": 1, "b_index": 1}) + self.assertEqual(len(new_quantity.references.all()), 1) + self.assertEqual(new_quantity.references.get(hash=111).value, 5) + + # Test creating a quantity with no history + def test_operation_tree_created_no_history(self): + if "history" in self.dataset["data_contents"][0]: + self.dataset["data_contents"][0].pop("history") + request = self.client.post( + "/v1/data/set/", data=self.dataset, format="json" + ) + max_id = DataSet.objects.aggregate(Max("id"))["id__max"] + new_dataset = DataSet.objects.get(id=max_id) + new_quantity = new_dataset.data_contents.get(hash=0) + self.assertEqual(request.status_code, status.HTTP_201_CREATED) + self.assertIsNone(new_quantity.operation_tree) + self.assertEqual(len(new_quantity.references.all()), 0) + + def tearDown(self): + DataSet.objects.all().delete() + MetaData.objects.all().delete() + Quantity.objects.all().delete() + OperationTree.objects.all().delete() + + @classmethod + def tearDownClass(cls): + cls.user.delete() + + +class TestCreateInvalidOperationTree(APITestCase): + """Tests for creating datasets with invalid operation trees.""" + + @classmethod + def setUpTestData(cls): + cls.dataset = { + "name": "Test Dataset", + "metadata": { + "title": "test metadata", + "run": 1, + "definition": "test", + "instrument": {"source": {}, "collimation": {}, "detectors": {}}, + }, + "data_contents": [ + { + "label": "test", + "value": {"array_contents": [0, 0, 0, 0], "shape": (2, 2)}, + "variance": {"array_contents": [0, 0, 0, 0], "shape": (2, 2)}, + "units": "none", + "hash": 0, + } + ], + "is_public": True, + } + cls.user = User.objects.create_user( + id=1, username="testUser", password="sasview!" + ) + cls.client = APIClient() + cls.client.force_authenticate(cls.user) + + # Test creating a quantity with an invalid operation + def test_create_operation_tree_invalid(self): + self.dataset["data_contents"][0]["history"] = { + "operation_tree": {"operation": "fix", "parameters": {}}, + "references": [], + } + request = self.client.post("/v1/data/set/", data=self.dataset, format="json") + self.assertEqual(request.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(len(DataSet.objects.all()), 0) + self.assertEqual(len(Quantity.objects.all()), 0) + self.assertEqual(len(OperationTree.objects.all()), 0) + + # Test creating a quantity with a nested invalid operation + def test_create_operation_tree_invalid_nested(self): + self.dataset["data_contents"][0]["history"] = { + "operation_tree": { + "operation": "reciprocal", + "parameters": { + "a": { + "operation": "fix", + "parameters": {}, + } + }, + }, + "references": [], + } + request = self.client.post("/v1/data/set/", data=self.dataset, format="json") + self.assertEqual(request.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(len(DataSet.objects.all()), 0) + self.assertEqual(len(Quantity.objects.all()), 0) + self.assertEqual(len(OperationTree.objects.all()), 0) + + # Test creating a unary operation with a missing parameter fails + def test_create_missing_parameter_unary(self): + self.dataset["data_contents"][0]["history"] = { + "operation_tree": {"operation": "neg", "parameters": {}}, + "references": {}, + } + request = self.client.post("/v1/data/set/", data=self.dataset, format="json") + self.assertEqual(request.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(len(DataSet.objects.all()), 0) + self.assertEqual(len(Quantity.objects.all()), 0) + self.assertEqual(len(OperationTree.objects.all()), 0) + + # Test creating a binary operation with a missing parameter fails + def test_create_missing_parameter_binary(self): + self.dataset["data_contents"][0]["history"] = { + "operation_tree": { + "operation": "add", + "parameters": { + "a": { + "operation": "variable", + "parameters": {"hash_value": 111, "name": "x"}, + } + }, + }, + "references": [], + } + request = self.client.post("/v1/data/set/", data=self.dataset, format="json") + self.assertEqual(request.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(len(DataSet.objects.all()), 0) + self.assertEqual(len(Quantity.objects.all()), 0) + self.assertEqual(len(OperationTree.objects.all()), 0) + + # TODO: should variable-only history be ignored? + # Test creating a variable with a missing parameter fails + def test_create_missing_parameter_variable(self): + self.dataset["data_contents"][0]["history"] = { + "operation_tree": { + "operation": "neg", + "parameters": { + "a": {"operation": "variable", "parameters": {"name": "x"}} + }, + }, + "references": [], + } + request = self.client.post("/v1/data/set/", data=self.dataset, format="json") + self.assertEqual(request.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(len(DataSet.objects.all()), 0) + self.assertEqual(len(Quantity.objects.all()), 0) + self.assertEqual(len(OperationTree.objects.all()), 0) + + # Test creating a constant with a missing parameter fails + def test_create_missing_parameter_constant(self): + self.dataset["data_contents"][0]["history"] = { + "operation_tree": { + "operation": "neg", + "parameters": {"a": {"operation": "constant", "parameters": {}}}, + }, + "references": [], + } + request = self.client.post("/v1/data/set/", data=self.dataset, format="json") + self.assertEqual(request.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(len(DataSet.objects.all()), 0) + self.assertEqual(len(Quantity.objects.all()), 0) + self.assertEqual(len(OperationTree.objects.all()), 0) + + # Test creating an exponent with a missing parameter fails + def test_create_missing_parameter_pow(self): + self.dataset["data_contents"][0]["history"] = { + "operation_tree": { + "operation": "pow", + "parameters": { + "a": { + "operation": "variable", + "parameters": {"hash_value": 111, "name": "x"}, + }, + }, + }, + "references": [], + } + request = self.client.post("/v1/data/set/", data=self.dataset, format="json") + self.assertEqual(request.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(len(DataSet.objects.all()), 0) + self.assertEqual(len(Quantity.objects.all()), 0) + self.assertEqual(len(OperationTree.objects.all()), 0) + + # Test creating a transpose with a missing parameter fails + def test_create_missing_parameter_transpose(self): + self.dataset["data_contents"][0]["history"] = { + "operation_tree": { + "operation": "transpose", + "parameters": { + "a": { + "operation": "variable", + "parameters": {"hash_value": 111, "name": "x"}, + }, + }, + }, + "references": [], + } + request = self.client.post("/v1/data/set/", data=self.dataset, format="json") + self.assertEqual(request.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(len(DataSet.objects.all()), 0) + self.assertEqual(len(Quantity.objects.all()), 0) + self.assertEqual(len(OperationTree.objects.all()), 0) + + # Test creating a tensor with a missing parameter fails + def test_create_missing_parameter_tensor(self): + self.dataset["data_contents"][0]["history"] = { + "operation_tree": { + "operation": "tensor_product", + "parameters": { + "a": { + "operation": "variable", + "parameters": {"hash_value": 111, "name": "x"}, + }, + "b": {"operation": "constant", "parameters": {"value": 5}}, + "b_index": 1, + }, + }, + "references": [], + } + request = self.client.post("/v1/data/set/", data=self.dataset, format="json") + self.assertEqual(request.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(len(DataSet.objects.all()), 0) + self.assertEqual(len(Quantity.objects.all()), 0) + self.assertEqual(len(OperationTree.objects.all()), 0) + + # TODO: Test variables have corresponding reference quantities + + @classmethod + def tearDownClass(cls): + cls.user.delete() + + +class TestGetOperationTree(APITestCase): + """Tests for retrieving datasets with operation trees.""" + + @classmethod + def setUpTestData(cls): + cls.user = User.objects.create_user( + id=1, username="testUser", password="sasview!" + ) + cls.dataset = DataSet.objects.create( + id=1, + current_user=cls.user, + name="Test Dataset", + is_public=True, + ) + cls.quantity = Quantity.objects.create( + id=1, + value=0, + variance=0, + label="test", + units="none", + hash=1, + dataset=cls.dataset, + ) + cls.variable = OperationTree.objects.create( + id=1, operation="variable", parameters={"hash_value": 111, "name": "x"} + ) + cls.constant = OperationTree.objects.create( + id=2, operation="constant", parameters={"value": 1} + ) + cls.ref_quantity = ReferenceQuantity.objects.create( + id=1, + value=5, + variance=0, + units="none", + hash=111, + derived_quantity=cls.quantity, + ) + cls.client = APIClient() + cls.client.force_authenticate(cls.user) + + # Test accessing a quantity with no operations performed + def test_get_operation_tree_none(self): + self.ref_quantity.delete() + request = self.client.get("/v1/data/set/1/") + self.ref_quantity = ReferenceQuantity.objects.create( + id=1, + value=5, + variance=0, + units="none", + hash=111, + derived_quantity=self.quantity, + ) + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data["data_contents"][0], + { + "label": "test", + "value": 0, + "variance": 0, + "units": "none", + "hash": 1, + "history": { + "operation_tree": None, + "references": [], + }, + }, + ) + + # Test accessing quantity with unary operation + def test_get_operation_tree_unary(self): + inv = OperationTree.objects.create( + id=3, + operation="reciprocal", + quantity=self.quantity, + ) + self.variable.label = "a" + self.variable.child_operation = inv + self.variable.save() + request = self.client.get("/v1/data/set/1/") + self.variable.child_operation = None + self.variable.save() + inv.delete() + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data["data_contents"][0], + { + "label": "test", + "value": 0, + "variance": 0, + "units": "none", + "hash": 1, + "history": { + "operation_tree": { + "operation": "reciprocal", + "parameters": { + "a": { + "operation": "variable", + "parameters": {"hash_value": 111, "name": "x"}, + } + }, + }, + "references": [ + { + "value": 5, + "variance": 0, + "units": "none", + "hash": 111, + } + ], + }, + }, + ) + + # Test accessing quantity with binary operation + def test_get_operation_tree_binary(self): + add = OperationTree.objects.create( + id=3, + operation="add", + quantity=self.quantity, + ) + self.variable.label = "a" + self.variable.child_operation = add + self.variable.save() + self.constant.label = "b" + self.constant.child_operation = add + self.constant.save() + request = self.client.get("/v1/data/set/1/") + self.variable.child_operation = None + self.constant.child_operation = None + self.variable.save() + self.constant.save() + add.delete() + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data["data_contents"][0]["history"]["operation_tree"], + { + "operation": "add", + "parameters": { + "a": { + "operation": "variable", + "parameters": {"hash_value": 111, "name": "x"}, + }, + "b": { + "operation": "constant", + "parameters": {"value": 1}, + }, + }, + }, + ) + + # Test accessing a quantity with exponent + def test_get_operation_tree_pow(self): + power = OperationTree.objects.create( + id=3, + operation="pow", + parameters={"power": 2}, + quantity=self.quantity, + ) + self.variable.label = "a" + self.variable.child_operation = power + self.variable.save() + request = self.client.get("/v1/data/set/1/") + self.variable.child_operation = None + self.variable.save() + power.delete() + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data["data_contents"][0]["history"]["operation_tree"], + { + "operation": "pow", + "parameters": { + "a": { + "operation": "variable", + "parameters": {"hash_value": 111, "name": "x"}, + }, + "power": 2, + }, + }, + ) + + # Test accessing a quantity with multiple operations + def test_get_operation_tree_nested(self): + neg = OperationTree.objects.create( + id=4, operation="neg", quantity=self.quantity + ) + multiply = OperationTree.objects.create( + id=3, operation="mul", child_operation=neg, label="a" + ) + self.constant.label = "a" + self.constant.child_operation = multiply + self.constant.save() + self.variable.label = "b" + self.variable.child_operation = multiply + self.variable.save() + request = self.client.get("/v1/data/set/1/") + self.constant.child_operation = None + self.variable.child_operation = None + self.constant.save() + self.variable.save() + neg.delete() + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data["data_contents"][0]["history"]["operation_tree"], + { + "operation": "neg", + "parameters": { + "a": { + "operation": "mul", + "parameters": { + "a": { + "operation": "constant", + "parameters": {"value": 1}, + }, + "b": { + "operation": "variable", + "parameters": { + "hash_value": 111, + "name": "x", + }, + }, + }, + } + }, + }, + ) + + # Test accessing a transposed quantity + def test_get_operation_tree_transpose(self): + trans = OperationTree.objects.create( + id=3, + operation="transpose", + parameters={"axes": (1, 0)}, + quantity=self.quantity, + ) + self.variable.label = "a" + self.variable.child_operation = trans + self.variable.save() + request = self.client.get("/v1/data/set/1/") + self.variable.child_operation = None + self.variable.save() + trans.delete() + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data["data_contents"][0]["history"]["operation_tree"], + { + "operation": "transpose", + "parameters": { + "a": { + "operation": "variable", + "parameters": {"hash_value": 111, "name": "x"}, + }, + "axes": [1, 0], + }, + }, + ) + + # Test accessing a quantity with tensordot + def test_get_operation_tree_tensordot(self): + tensor = OperationTree.objects.create( + id=3, + operation="tensor_product", + parameters={"a_index": 1, "b_index": 1}, + quantity=self.quantity, + ) + self.variable.label = "a" + self.variable.child_operation = tensor + self.variable.save() + self.constant.label = "b" + self.constant.child_operation = tensor + self.constant.save() + request = self.client.get("/v1/data/set/1/") + self.variable.child_operation = None + self.constant.child_operation = None + self.variable.save() + self.constant.save() + tensor.delete() + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data["data_contents"][0]["history"]["operation_tree"], + { + "operation": "tensor_product", + "parameters": { + "a": { + "operation": "variable", + "parameters": {"hash_value": 111, "name": "x"}, + }, + "b": { + "operation": "constant", + "parameters": {"value": 1}, + }, + "a_index": 1, + "b_index": 1, + }, + }, + ) + + @classmethod + def tearDownClass(cls): + cls.user.delete() + cls.quantity.delete() + cls.dataset.delete() + cls.variable.delete() + cls.constant.delete() diff --git a/sasdata/fair_database/data/test/test_published_state.py b/sasdata/fair_database/data/test/test_published_state.py new file mode 100644 index 000000000..20072f3b1 --- /dev/null +++ b/sasdata/fair_database/data/test/test_published_state.py @@ -0,0 +1,582 @@ +from data.models import PublishedState, Session +from django.contrib.auth.models import User +from django.db.models import Max +from rest_framework import status +from rest_framework.test import APIClient, APITestCase + + +# TODO: account for non-placeholder doi +# Get the placeholder DOI for a session based on id +def doi_generator(id: int): + return "http://127.0.0.1:8000/v1/data/session/" + str(id) + "/" + + +class TestPublishedState(APITestCase): + """Test HTTP methods of PublishedStateView.""" + + @classmethod + def setUpTestData(cls): + cls.user1 = User.objects.create_user( + id=1, username="testUser1", password="secret" + ) + cls.user2 = User.objects.create_user( + id=2, username="testUser2", password="secret" + ) + cls.public_session = Session.objects.create( + id=1, current_user=cls.user1, title="Public Session", is_public=True + ) + cls.private_session = Session.objects.create( + id=2, current_user=cls.user1, title="Private Session", is_public=False + ) + cls.unowned_session = Session.objects.create( + id=3, title="Unowned Session", is_public=True + ) + cls.unpublished_session = Session.objects.create( + id=4, current_user=cls.user1, title="Publishable Session", is_public=True + ) + cls.public_ps = PublishedState.objects.create( + id=1, + doi=doi_generator(1), + published=True, + session=cls.public_session, + ) + cls.private_ps = PublishedState.objects.create( + id=2, + doi=doi_generator(2), + published=False, + session=cls.private_session, + ) + cls.unowned_ps = PublishedState.objects.create( + id=3, + doi=doi_generator(3), + published=True, + session=cls.unowned_session, + ) + cls.auth_client1 = APIClient() + cls.auth_client2 = APIClient() + cls.auth_client1.force_authenticate(cls.user1) + cls.auth_client2.force_authenticate(cls.user2) + + # Test listing published states including those of owned private sessions + def test_list_published_states_private(self): + request = self.auth_client1.get("/v1/data/published/") + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data, + { + "published_state_ids": { + 1: { + "title": "Public Session", + "published": True, + "doi": doi_generator(1), + }, + 2: { + "title": "Private Session", + "published": False, + "doi": doi_generator(2), + }, + 3: { + "title": "Unowned Session", + "published": True, + "doi": doi_generator(3), + }, + } + }, + ) + + # Test listing published states of public sessions + def test_list_published_states_public(self): + request = self.auth_client2.get("/v1/data/published/") + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data, + { + "published_state_ids": { + 1: { + "title": "Public Session", + "published": True, + "doi": doi_generator(1), + }, + 3: { + "title": "Unowned Session", + "published": True, + "doi": doi_generator(3), + }, + } + }, + ) + + # Test listing published states including sessions with access granted + def test_list_published_states_shared(self): + self.private_session.users.add(self.user2) + request = self.auth_client2.get("/v1/data/published/") + self.private_session.users.remove(self.user2) + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data, + { + "published_state_ids": { + 1: { + "title": "Public Session", + "published": True, + "doi": doi_generator(1), + }, + 2: { + "title": "Private Session", + "published": False, + "doi": doi_generator(2), + }, + 3: { + "title": "Unowned Session", + "published": True, + "doi": doi_generator(3), + }, + } + }, + ) + + # Test listing published states while unauthenticated + def test_list_published_states_unauthenticated(self): + request = self.client.get("/v1/data/published/") + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data, + { + "published_state_ids": { + 1: { + "title": "Public Session", + "published": True, + "doi": doi_generator(1), + }, + 3: { + "title": "Unowned Session", + "published": True, + "doi": doi_generator(3), + }, + } + }, + ) + + # Test listing a user's own published states + def test_list_user_published_states_private(self): + request = self.auth_client1.get( + "/v1/data/published/", data={"username": "testUser1"} + ) + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data, + { + "published_state_ids": { + 1: { + "title": "Public Session", + "published": True, + "doi": doi_generator(1), + }, + 2: { + "title": "Private Session", + "published": False, + "doi": doi_generator(2), + }, + } + }, + ) + + # Test listing another user's published states + def test_list_user_published_states_public(self): + request = self.auth_client2.get( + "/v1/data/published/", data={"username": "testUser1"} + ) + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data, + { + "published_state_ids": { + 1: { + "title": "Public Session", + "published": True, + "doi": doi_generator(1), + } + } + }, + ) + + # Test listing another user's published states with access granted + def test_list_user_published_states_shared(self): + self.private_session.users.add(self.user2) + request = self.auth_client2.get( + "/v1/data/published/", data={"username": "testUser1"} + ) + self.private_session.users.remove(self.user2) + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data, + { + "published_state_ids": { + 1: { + "title": "Public Session", + "published": True, + "doi": doi_generator(1), + }, + 2: { + "title": "Private Session", + "published": False, + "doi": doi_generator(2), + }, + } + }, + ) + + # Test listing a user's published states while unauthenticated + def test_list_user_published_states_unauthenticated(self): + request = self.client.get("/v1/data/published/", data={"username": "testUser1"}) + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data, + { + "published_state_ids": { + 1: { + "title": "Public Session", + "published": True, + "doi": doi_generator(1), + } + } + }, + ) + + # Test creating a published state for a private session + def test_published_state_created_private(self): + self.unpublished_session.is_public = False + self.unpublished_session.save() + published_state = {"published": True, "session": 4} + request = self.auth_client1.post("/v1/data/published/", data=published_state) + max_id = PublishedState.objects.aggregate(Max("id"))["id__max"] + new_ps = PublishedState.objects.get(id=max_id) + self.publishable_session = Session.objects.get(id=4) + self.assertEqual(request.status_code, status.HTTP_201_CREATED) + self.assertEqual( + request.data, + { + "published_state_id": max_id, + "session_id": 4, + "title": "Publishable Session", + "doi": doi_generator(4), + "published": True, + "current_user": "testUser1", + "is_public": False, + }, + ) + self.assertEqual(self.publishable_session.published_state, new_ps) + self.assertEqual(new_ps.session, self.publishable_session) + new_ps.delete() + self.unpublished_session.is_public = True + self.unpublished_session.save() + + # Test creating a published state for a public session + def test_published_state_created_public(self): + published_state = {"published": False, "session": 4} + request = self.auth_client1.post("/v1/data/published/", data=published_state) + max_id = PublishedState.objects.aggregate(Max("id"))["id__max"] + new_ps = PublishedState.objects.get(id=max_id) + self.publishable_session = Session.objects.get(id=4) + self.assertEqual(request.status_code, status.HTTP_201_CREATED) + self.assertEqual( + request.data, + { + "published_state_id": max_id, + "session_id": 4, + "title": "Publishable Session", + "doi": doi_generator(4), + "published": False, + "current_user": "testUser1", + "is_public": True, + }, + ) + self.assertEqual(self.publishable_session.published_state, new_ps) + self.assertEqual(new_ps.session, self.publishable_session) + new_ps.delete() + + # Test that you can't create a published state for an unowned session + def test_published_state_created_unowned(self): + self.unpublished_session.current_user = None + self.unpublished_session.save() + published_state = {"published": True, "session": 4} + request = self.auth_client1.post("/v1/data/published/", data=published_state) + self.assertEqual(request.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(len(PublishedState.objects.all()), 3) + self.unpublished_session.current_user = self.user1 + self.unpublished_session.save() + + # Test that an unauthenticated user cannot create a published state + def test_published_state_created_unauthenticated(self): + published_state = {"published": True, "session": 4} + request = self.client.post("/v1/data/published/", data=published_state) + self.assertEqual(request.status_code, status.HTTP_401_UNAUTHORIZED) + self.assertEqual(len(PublishedState.objects.all()), 3) + + # Test that a user cannot create a published state for a session they don't own + def test_published_state_created_unauthorized(self): + published_state = {"published": True, "session": 4} + request = self.auth_client2.post("/v1/data/published/", data=published_state) + self.assertEqual(request.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(len(PublishedState.objects.all()), 3) + + # Test that only one published state can be created per session + def test_no_duplicate_published_states(self): + published_state = {"published": True, "session": 1} + request = self.auth_client1.post("/v1/data/published/", data=published_state) + self.assertEqual(request.status_code, status.HTTP_400_BAD_REQUEST) + + @classmethod + def tearDownClass(cls): + cls.public_session.delete() + cls.private_session.delete() + cls.unowned_session.delete() + cls.user1.delete() + cls.user2.delete() + + +class TestSinglePublishedState(APITestCase): + """Test HTTP methods of SinglePublishedStateView.""" + + @classmethod + def setUpTestData(cls): + cls.user1 = User.objects.create_user( + id=1, username="testUser1", password="secret" + ) + cls.user2 = User.objects.create_user( + id=2, username="testUser2", password="secret" + ) + cls.public_session = Session.objects.create( + id=1, current_user=cls.user1, title="Public Session", is_public=True + ) + cls.private_session = Session.objects.create( + id=2, current_user=cls.user1, title="Private Session", is_public=False + ) + cls.unowned_session = Session.objects.create( + id=3, title="Unowned Session", is_public=True + ) + cls.public_ps = PublishedState.objects.create( + id=1, + doi=doi_generator(1), + published=True, + session=cls.public_session, + ) + cls.private_ps = PublishedState.objects.create( + id=2, + doi=doi_generator(2), + published=False, + session=cls.private_session, + ) + cls.unowned_ps = PublishedState.objects.create( + id=3, + doi=doi_generator(3), + published=True, + session=cls.unowned_session, + ) + cls.auth_client1 = APIClient() + cls.auth_client2 = APIClient() + cls.auth_client1.force_authenticate(cls.user1) + cls.auth_client2.force_authenticate(cls.user2) + + # Test viewing a published state of a public session + def test_get_public_published_state(self): + request1 = self.auth_client2.get("/v1/data/published/1/") + request2 = self.client.get("/v1/data/published/1/") + self.assertEqual(request1.status_code, status.HTTP_200_OK) + self.assertEqual(request2.status_code, status.HTTP_200_OK) + self.assertEqual( + request1.data, + { + "id": 1, + "doi": doi_generator(1), + "published": True, + "session": 1, + "title": "Public Session", + "current_user": "testUser1", + "is_public": True, + }, + ) + self.assertEqual(request1.data, request2.data) + + # Test viewing a published state of a private session + def test_get_private_published_state(self): + request = self.auth_client1.get("/v1/data/published/2/") + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data, + { + "id": 2, + "doi": doi_generator(2), + "published": False, + "session": 2, + "title": "Private Session", + "current_user": "testUser1", + "is_public": False, + }, + ) + + # Test viewing a published state of an unowned session + def test_get_unowned_published_state(self): + request = self.auth_client1.get("/v1/data/published/3/") + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data, + { + "id": 3, + "doi": doi_generator(3), + "published": True, + "session": 3, + "title": "Unowned Session", + "current_user": "", + "is_public": True, + }, + ) + + # Test viewing a published state of a session with access granted + def test_get_shared_published_state(self): + self.private_session.users.add(self.user2) + request = self.auth_client2.get("/v1/data/published/2/") + self.private_session.users.remove(self.user2) + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data, + { + "id": 2, + "doi": doi_generator(2), + "published": False, + "session": 2, + "title": "Private Session", + "current_user": "testUser1", + "is_public": False, + }, + ) + + # Test a user can't view a published state of a private session they don't own + def test_get_private_published_state_unauthorized(self): + request1 = self.client.get("/v1/data/published/2/") + request2 = self.auth_client2.get("/v1/data/published/2/") + self.assertEqual(request1.status_code, status.HTTP_401_UNAUTHORIZED) + self.assertEqual(request2.status_code, status.HTTP_403_FORBIDDEN) + + # Test updating a published state of a public session + def test_update_public_published_state(self): + request = self.auth_client1.put( + "/v1/data/published/1/", data={"published": False} + ) + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data, + { + "published_state_id": 1, + "session_id": 1, + "title": "Public Session", + "published": False, + "is_public": True, + }, + ) + self.assertFalse(PublishedState.objects.get(id=1).published) + self.public_ps.save() + + # Test updating a published state of a private session + def test_update_private_published_state(self): + request = self.auth_client1.put( + "/v1/data/published/2/", data={"published": True} + ) + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data, + { + "published_state_id": 2, + "session_id": 2, + "title": "Private Session", + "published": True, + "is_public": False, + }, + ) + self.assertTrue(PublishedState.objects.get(id=2).published) + self.private_ps.save() + + # Test a user can't update the published state of an unowned session + def test_update_unowned_published_state(self): + request1 = self.auth_client1.put( + "/v1/data/published/3/", data={"published": False} + ) + request2 = self.client.put("/v1/data/published/3/", data={"published": False}) + self.assertEqual(request1.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(request2.status_code, status.HTTP_401_UNAUTHORIZED) + self.assertTrue(PublishedState.objects.get(id=3).published) + + # Test a user can't update a public published state unauthorized + def test_update_public_published_state_unauthorized(self): + request1 = self.auth_client2.put( + "/v1/data/published/1/", data={"published": False} + ) + self.public_session.users.add(self.user2) + request2 = self.auth_client2.put( + "/v1/data/published/1/", data={"published": False} + ) + self.public_session.users.remove(self.user2) + request3 = self.client.put("/v1/data/published/1/", data={"published": False}) + self.assertEqual(request1.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(request2.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(request3.status_code, status.HTTP_401_UNAUTHORIZED) + self.assertTrue(PublishedState.objects.get(id=1).published) + + # Test a user can't update a private published state unauthorized + def test_update_private_published_state_unauthorized(self): + request1 = self.auth_client2.put( + "/v1/data/published/2/", data={"published": True} + ) + self.public_session.users.add(self.user2) + request2 = self.auth_client2.put( + "/v1/data/published/2/", data={"published": True} + ) + self.public_session.users.remove(self.user2) + request3 = self.client.put("/v1/data/published/2/", data={"published": True}) + self.assertEqual(request1.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(request2.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(request3.status_code, status.HTTP_401_UNAUTHORIZED) + self.assertFalse(PublishedState.objects.get(id=2).published) + + # Test deleting a published state of a private session + def test_delete_private_published_state(self): + request = self.auth_client1.delete("/v1/data/published/2/") + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual(len(PublishedState.objects.all()), 2) + self.assertEqual(len(Session.objects.all()), 3) + self.assertRaises(PublishedState.DoesNotExist, PublishedState.objects.get, id=2) + self.private_ps = PublishedState.objects.create( + id=2, + doi=doi_generator(2), + published=False, + session=self.private_session, + ) + + # Test a user can't delete a private published state unauthorized + def test_delete_private_published_state_unauthorized(self): + request1 = self.auth_client2.delete("/v1/data/published/2/") + self.private_session.users.add(self.user2) + request2 = self.auth_client2.delete("/v1/data/published/2/") + self.private_session.users.remove(self.user2) + request3 = self.client.delete("/v1/data/published/2/") + self.assertEqual(request1.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(request2.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(request3.status_code, status.HTTP_401_UNAUTHORIZED) + + # Test a user can't delete a published state of a public + def test_cant_delete_public_published_state(self): + request = self.auth_client1.delete("/v1/data/published/1/") + self.assertEqual(request.status_code, status.HTTP_403_FORBIDDEN) + + # Test a user can't delete an unowned published state + def test_delete_unowned_published_state(self): + request = self.auth_client1.delete("/v1/data/published/3/") + self.assertEqual(request.status_code, status.HTTP_403_FORBIDDEN) + + @classmethod + def tearDownClass(cls): + cls.public_session.delete() + cls.private_session.delete() + cls.unowned_session.delete() + cls.user1.delete() + cls.user2.delete() diff --git a/sasdata/fair_database/data/test/test_session.py b/sasdata/fair_database/data/test/test_session.py new file mode 100644 index 000000000..fc185f8fd --- /dev/null +++ b/sasdata/fair_database/data/test/test_session.py @@ -0,0 +1,700 @@ +from data.models import DataSet, PublishedState, Session +from django.contrib.auth.models import User +from django.db.models import Max +from rest_framework import status +from rest_framework.test import APIClient, APITestCase + + +class TestSession(APITestCase): + """Test HTTP methods of SessionView.""" + + @classmethod + def setUpTestData(cls): + cls.user1 = User.objects.create_user( + id=1, username="testUser1", password="secret" + ) + cls.user2 = User.objects.create_user( + id=2, username="testUser2", password="secret" + ) + cls.public_session = Session.objects.create( + id=1, current_user=cls.user1, title="Public Session", is_public=True + ) + cls.private_session = Session.objects.create( + id=2, current_user=cls.user1, title="Private Session", is_public=False + ) + cls.unowned_session = Session.objects.create( + id=3, title="Unowned Session", is_public=True + ) + cls.auth_client1 = APIClient() + cls.auth_client2 = APIClient() + cls.auth_client1.force_authenticate(cls.user1) + cls.auth_client2.force_authenticate(cls.user2) + + # Test listing sessions + def test_list_private(self): + request = self.auth_client1.get("/v1/data/session/") + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data, + { + "session_ids": { + 1: "Public Session", + 2: "Private Session", + 3: "Unowned Session", + } + }, + ) + + # Test listing public sessions + def test_list_public(self): + request = self.auth_client2.get("/v1/data/session/") + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data, {"session_ids": {1: "Public Session", 3: "Unowned Session"}} + ) + + # Test listing sessions while unauthenticated + def test_list_unauthenticated(self): + request = self.client.get("/v1/data/session/") + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data, {"session_ids": {1: "Public Session", 3: "Unowned Session"}} + ) + + # Test listing a session with access granted + def test_list_granted_access(self): + self.private_session.users.add(self.user2) + request = self.auth_client2.get("/v1/data/session/") + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data, + { + "session_ids": { + 1: "Public Session", + 2: "Private Session", + 3: "Unowned Session", + } + }, + ) + self.private_session.users.remove(self.user2) + + # Test listing by username + def test_list_username(self): + request = self.auth_client1.get( + "/v1/data/session/", data={"username": "testUser1"} + ) + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data, {"session_ids": {1: "Public Session", 2: "Private Session"}} + ) + + # Test listing by another user's username + def test_list_other_username(self): + request = self.auth_client2.get( + "/v1/data/session/", data={"username": "testUser1"} + ) + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual(request.data, {"session_ids": {1: "Public Session"}}) + + # Test creating a public session + def test_session_created(self): + session = { + "title": "New session", + "datasets": [ + { + "name": "New dataset", + "metadata": { + "title": "New metadata", + "run": 0, + "description": "test", + "instrument": {}, + "process": {}, + "sample": {}, + }, + "data_contents": [], + } + ], + "is_public": True, + "published_state": {"published": False}, + } + request = self.auth_client1.post( + "/v1/data/session/", data=session, format="json" + ) + max_id = Session.objects.aggregate(Max("id"))["id__max"] + new_session = Session.objects.get(id=max_id) + new_dataset = new_session.datasets.get() + new_metadata = new_dataset.metadata + new_published_state = new_session.published_state + self.assertEqual(request.status_code, status.HTTP_201_CREATED) + self.assertEqual( + request.data, + { + "session_id": max_id, + "title": "New session", + "authenticated": True, + "current_user": "testUser1", + "is_public": True, + }, + ) + self.assertEqual(new_session.title, "New session") + self.assertEqual(new_dataset.name, "New dataset") + self.assertEqual(new_metadata.title, "New metadata") + self.assertEqual(new_session.current_user, self.user1) + self.assertEqual(new_dataset.current_user, self.user1) + self.assertTrue(all([new_session.is_public, new_dataset.is_public])) + self.assertFalse(new_published_state.published) + new_session.delete() + + # Test creating a private session + def test_session_created_private(self): + session = { + "title": "New session", + "datasets": [ + { + "name": "New dataset", + "metadata": { + "title": "New metadata", + "run": 0, + "description": "test", + "instrument": {}, + "process": {}, + "sample": {}, + }, + "data_contents": [], + } + ], + "is_public": False, + } + request = self.auth_client1.post( + "/v1/data/session/", data=session, format="json" + ) + max_id = Session.objects.aggregate(Max("id"))["id__max"] + new_session = Session.objects.get(id=max_id) + new_dataset = new_session.datasets.get() + self.assertEqual(request.status_code, status.HTTP_201_CREATED) + self.assertEqual( + request.data, + { + "session_id": max_id, + "title": "New session", + "authenticated": True, + "current_user": "testUser1", + "is_public": False, + }, + ) + self.assertEqual(new_session.current_user, self.user1) + self.assertEqual(new_dataset.current_user, self.user1) + self.assertFalse(any([new_session.is_public, new_dataset.is_public])) + new_session.delete() + + # Test creating a session while unauthenticated + def test_session_created_unauthenticated(self): + session = { + "title": "New session", + "datasets": [ + { + "name": "New dataset", + "metadata": { + "title": "New metadata", + "run": 0, + "description": "test", + "instrument": {}, + "process": {}, + "sample": {}, + }, + "data_contents": [], + } + ], + "is_public": True, + } + request = self.client.post("/v1/data/session/", data=session, format="json") + max_id = Session.objects.aggregate(Max("id"))["id__max"] + new_session = Session.objects.get(id=max_id) + new_dataset = new_session.datasets.get() + self.assertEqual(request.status_code, status.HTTP_201_CREATED) + self.assertEqual( + request.data, + { + "session_id": max_id, + "title": "New session", + "authenticated": False, + "current_user": "", + "is_public": True, + }, + ) + self.assertIsNone(new_session.current_user) + self.assertIsNone(new_dataset.current_user) + self.assertTrue(all([new_session.is_public, new_dataset.is_public])) + new_session.delete() + + # Test that a private session must have an owner + def test_no_private_unowned_session(self): + session = {"title": "New session", "datasets": [], "is_public": False} + request = self.client.post("/v1/data/session/", data=session, format="json") + self.assertEqual(request.status_code, status.HTTP_400_BAD_REQUEST) + + # Test post fails with dataset validation issue + def test_no_session_invalid_dataset(self): + session = { + "title": "New session", + "datasets": [ + { + "metadata": { + "title": "New metadata", + "run": 0, + "description": "test", + "instrument": {}, + "process": {}, + "sample": {}, + }, + "data_contents": [], + } + ], + "is_public": True, + } + request = self.auth_client1.post( + "/v1/data/session/", data=session, format="json" + ) + self.assertEqual(request.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(len(Session.objects.all()), 3) + self.assertEqual(len(DataSet.objects.all()), 0) + + @classmethod + def tearDownClass(cls): + cls.public_session.delete() + cls.private_session.delete() + cls.unowned_session.delete() + cls.user1.delete() + cls.user2.delete() + + +class TestSingleSession(APITestCase): + """Test HTTP methods of SingleSessionView.""" + + @classmethod + def setUpTestData(cls): + cls.user1 = User.objects.create_user( + id=1, username="testUser1", password="secret" + ) + cls.user2 = User.objects.create_user( + id=2, username="testUser2", password="secret" + ) + cls.public_session = Session.objects.create( + id=1, current_user=cls.user1, title="Public Session", is_public=True + ) + cls.private_session = Session.objects.create( + id=2, current_user=cls.user1, title="Private Session", is_public=False + ) + cls.unowned_session = Session.objects.create( + id=3, title="Unowned Session", is_public=True + ) + cls.public_dataset = DataSet.objects.create( + id=1, + current_user=cls.user1, + is_public=True, + name="Public Dataset", + session=cls.public_session, + ) + cls.private_dataset = DataSet.objects.create( + id=2, + current_user=cls.user1, + name="Private Dataset", + session=cls.private_session, + ) + cls.unowned_dataset = DataSet.objects.create( + id=3, is_public=True, name="Unowned Dataset", session=cls.unowned_session + ) + cls.private_published_state = PublishedState.objects.create( + id=2, + session=cls.private_session, + published=False, + doi="http://localhost:8000/v1/data/session/2/", + ) + cls.auth_client1 = APIClient() + cls.auth_client2 = APIClient() + cls.auth_client1.force_authenticate(cls.user1) + cls.auth_client2.force_authenticate(cls.user2) + + # Test loading another user's public session + def test_get_public_session(self): + request = self.auth_client2.get("/v1/data/session/1/") + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data, + { + "id": 1, + "current_user": "testUser1", + "users": [], + "is_public": True, + "title": "Public Session", + "datasets": [ + { + "id": 1, + "current_user": 1, + "users": [], + "is_public": True, + "name": "Public Dataset", + "files": [], + "metadata": None, + "data_contents": [], + } + ], + "published_state": None, + }, + ) + + # Test loading a private session as the owner + def test_get_private_session(self): + request = self.auth_client1.get("/v1/data/session/2/") + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data, + { + "id": 2, + "current_user": "testUser1", + "users": [], + "is_public": False, + "title": "Private Session", + "published_state": { + "id": 2, + "published": False, + "doi": "http://localhost:8000/v1/data/session/2/", + "session": 2, + }, + "datasets": [ + { + "id": 2, + "current_user": 1, + "users": [], + "is_public": False, + "name": "Private Dataset", + "files": [], + "metadata": None, + "data_contents": [], + } + ], + }, + ) + + # Test loading a private session as a user with granted access + def test_get_private_session_access_granted(self): + self.private_session.users.add(self.user2) + request = self.auth_client2.get("/v1/data/session/2/") + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.private_session.users.remove(self.user2) + + # Test loading an unowned session + def test_get_unowned_session(self): + request = self.auth_client1.get("/v1/data/session/3/") + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data, + { + "id": 3, + "current_user": None, + "users": [], + "is_public": True, + "title": "Unowned Session", + "published_state": None, + "datasets": [ + { + "id": 3, + "current_user": None, + "users": [], + "is_public": True, + "name": "Unowned Dataset", + "files": [], + "metadata": None, + "data_contents": [], + } + ], + }, + ) + + # Test loading another user's private session + def test_get_private_session_unauthorized(self): + request1 = self.auth_client2.get("/v1/data/session/2/") + request2 = self.client.get("/v1/data/session/2/") + self.assertEqual(request1.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(request2.status_code, status.HTTP_401_UNAUTHORIZED) + + # Test updating a public session + def test_update_public_session(self): + request = self.auth_client1.put( + "/v1/data/session/1/", data={"is_public": False} + ) + session = Session.objects.get(id=1) + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data, + {"session_id": 1, "title": "Public Session", "is_public": False}, + ) + self.assertFalse(session.is_public) + session.is_public = False + session.save() + + # Test creating a published state by updating a session + def test_update_session_new_published_state(self): + request = self.auth_client1.put( + "/v1/data/session/1/", + data={"published_state": {"published": False}}, + format="json", + ) + new_published_state = Session.objects.get(id=1).published_state + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertFalse(new_published_state.published) + new_published_state.delete() + + # Test that another user's public session cannot be updated + def test_update_public_session_unauthorized(self): + request1 = self.auth_client2.put( + "/v1/data/session/1/", data={"is_public": False} + ) + request2 = self.client.put("/v1/data/session/1/", data={"is_public": False}) + session = Session.objects.get(id=1) + session.users.add(self.user2) + request3 = self.auth_client2.put( + "/v1/data/session/1/", data={"is_public": False} + ) + session.users.remove(self.user2) + self.assertEqual(request1.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(request2.status_code, status.HTTP_401_UNAUTHORIZED) + self.assertEqual(request3.status_code, status.HTTP_403_FORBIDDEN) + self.assertTrue(Session.objects.get(id=1).is_public) + + # Test updating a private session + def test_update_private_session(self): + request1 = self.auth_client1.put( + "/v1/data/session/2/", data={"is_public": True} + ) + session = Session.objects.get(id=2) + self.assertEqual(request1.status_code, status.HTTP_200_OK) + self.assertEqual( + request1.data, + {"session_id": 2, "title": "Private Session", "is_public": True}, + ) + self.assertTrue(session.is_public) + self.assertTrue(session.datasets.get().is_public) + session.is_public = False + session.save() + + # Test updating a published state through its session + def test_update_session_published_state(self): + request = self.auth_client1.put( + "/v1/data/session/2/", + data={"published_state": {"published": True}}, + format="json", + ) + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertTrue(PublishedState.objects.get(id=2).published) + self.private_published_state.save() + + # Test that another user's private session cannot be updated + def test_update_private_session_unauthorized(self): + request1 = self.auth_client2.put( + "/v1/data/session/2/", data={"is_public": True} + ) + request2 = self.client.put("/v1/data/session/2/", data={"is_public": True}) + session = Session.objects.get(id=2) + session.users.add(self.user2) + request3 = self.auth_client2.put( + "/v1/data/session/2/", data={"is_public": True} + ) + session.users.remove(self.user2) + self.assertEqual(request1.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(request2.status_code, status.HTTP_401_UNAUTHORIZED) + self.assertEqual(request3.status_code, status.HTTP_403_FORBIDDEN) + self.assertFalse(Session.objects.get(id=2).is_public) + + # Test that an unowned session cannot be updated + def test_update_unowned_session(self): + request = self.auth_client1.put( + "/v1/data/session/3/", data={"is_public": False} + ) + self.assertEqual(request.status_code, status.HTTP_403_FORBIDDEN) + self.assertTrue(Session.objects.get(id=3).is_public) + + # Test deleting a private session + def test_delete_private_session(self): + request = self.auth_client1.delete("/v1/data/session/2/") + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertRaises(Session.DoesNotExist, Session.objects.get, id=2) + self.assertRaises(DataSet.DoesNotExist, DataSet.objects.get, id=2) + self.assertRaises(PublishedState.DoesNotExist, PublishedState.objects.get, id=2) + self.private_session = Session.objects.create( + id=2, current_user=self.user1, title="Private Session", is_public=False + ) + self.private_dataset = DataSet.objects.create( + id=2, + current_user=self.user1, + name="Private Dataset", + session=self.private_session, + ) + self.private_published_state = PublishedState.objects.create( + id=2, + session=self.private_session, + published=False, + doi="http://localhost:8000/v1/data/session/2/", + ) + + # Test that another user's private session cannot be deleted + def test_delete_private_session_unauthorized(self): + request1 = self.auth_client2.delete("/v1/data/session/2/") + request2 = self.client.delete("/v1/data/session/2/") + self.private_session.users.add(self.user2) + request3 = self.auth_client2.delete("/v1/data/session/2/") + self.private_session.users.remove(self.user2) + self.assertEqual(request1.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(request2.status_code, status.HTTP_401_UNAUTHORIZED) + self.assertEqual(request3.status_code, status.HTTP_403_FORBIDDEN) + + # Test that a public session cannot be deleted + def test_delete_public_session(self): + request = self.auth_client1.delete("/v1/data/session/1/") + self.assertEqual(request.status_code, status.HTTP_403_FORBIDDEN) + + # Test that an unowned session cannot be deleted + def test_delete_unowned_session(self): + request = self.auth_client1.delete("/v1/data/session/3/") + self.assertEqual(request.status_code, status.HTTP_403_FORBIDDEN) + + @classmethod + def tearDownClass(cls): + cls.public_session.delete() + cls.private_session.delete() + cls.unowned_session.delete() + cls.user1.delete() + cls.user2.delete() + + +class TestSessionAccessManagement(APITestCase): + """Test HTTP methods of SessionUsersView.""" + + @classmethod + def setUpTestData(cls): + cls.user1 = User.objects.create_user(username="testUser1", password="secret") + cls.user2 = User.objects.create_user(username="testUser2", password="secret") + cls.private_session = Session.objects.create( + id=1, current_user=cls.user1, title="Private Session", is_public=False + ) + cls.shared_session = Session.objects.create( + id=2, current_user=cls.user1, title="Shared Session", is_public=False + ) + cls.private_dataset = DataSet.objects.create( + id=1, + current_user=cls.user1, + name="Private Dataset", + session=cls.private_session, + ) + cls.shared_dataset = DataSet.objects.create( + id=2, + current_user=cls.user1, + name="Shared Dataset", + session=cls.shared_session, + ) + cls.shared_session.users.add(cls.user2) + cls.shared_dataset.users.add(cls.user2) + cls.client_owner = APIClient() + cls.client_other = APIClient() + cls.client_owner.force_authenticate(cls.user1) + cls.client_other.force_authenticate(cls.user2) + + # Test listing access to an unshared session + def test_list_access_private(self): + request = self.client_owner.get("/v1/data/session/1/users/") + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data, + { + "session_id": 1, + "title": "Private Session", + "is_public": False, + "users": [], + }, + ) + + # Test listing access to a shared session + def test_list_access_shared(self): + request = self.client_owner.get("/v1/data/session/2/users/") + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data, + { + "session_id": 2, + "title": "Shared Session", + "is_public": False, + "users": ["testUser2"], + }, + ) + + # Test that only the owner can view access + def test_list_access_unauthorized(self): + request1 = self.client_other.get("/v1/data/session/1/users/") + request2 = self.client_other.get("/v1/data/session/2/users/") + self.assertEqual(request1.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(request2.status_code, status.HTTP_403_FORBIDDEN) + + # Test granting access to a session + def test_grant_access(self): + request1 = self.client_owner.put( + "/v1/data/session/1/users/", {"username": "testUser2", "access": True} + ) + request2 = self.client_other.get("/v1/data/session/1/") + request3 = self.client_other.get("/v1/data/set/1/") + self.assertEqual(request1.status_code, status.HTTP_200_OK) + self.assertEqual( + request1.data, + { + "username": "testUser2", + "session_id": 1, + "title": "Private Session", + "access": True, + }, + ) + self.assertEqual(request2.status_code, status.HTTP_200_OK) + self.assertEqual(request3.status_code, status.HTTP_200_OK) + self.assertIn(self.user2, self.private_session.users.all()) # codespell:ignore + self.assertIn(self.user2, self.private_dataset.users.all()) # codespell:ignore + self.private_session.users.remove(self.user2) + self.private_dataset.users.remove(self.user2) + + # Test revoking access to a session + def test_revoke_access(self): + request1 = self.client_owner.put( + "/v1/data/session/2/users/", {"username": "testUser2", "access": False} + ) + request2 = self.client_other.get("/v1/data/session/2/") + request3 = self.client_other.get("/v1/data/session/2/") + self.assertEqual(request1.status_code, status.HTTP_200_OK) + self.assertEqual(request2.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(request3.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual( + request1.data, + { + "username": "testUser2", + "session_id": 2, + "title": "Shared Session", + "access": False, + }, + ) + self.assertNotIn(self.user2, self.shared_session.users.all()) + self.assertNotIn(self.user2, self.shared_dataset.users.all()) + self.shared_session.users.add(self.user2) + self.shared_dataset.users.add(self.user2) + + # Test that only the owner can change access + def test_revoke_access_unauthorized(self): + request1 = self.client_other.put( + "/v1/data/session/2/users/", {"username": "testUser2", "access": False} + ) + request2 = self.client_other.get("/v1/data/session/2/") + self.assertEqual(request1.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(request2.status_code, status.HTTP_200_OK) + self.assertIn(self.user2, self.shared_session.users.all()) # codespell:ignore + + @classmethod + def tearDownClass(cls): + cls.private_session.delete() + cls.shared_session.delete() + cls.user1.delete() + cls.user2.delete() diff --git a/sasdata/fair_database/data/urls.py b/sasdata/fair_database/data/urls.py new file mode 100644 index 000000000..0e94f60c7 --- /dev/null +++ b/sasdata/fair_database/data/urls.py @@ -0,0 +1,49 @@ +from django.urls import path + +from . import views + +urlpatterns = [ + path("file/", views.DataFileView.as_view(), name="view and create files"), + path( + "file//", + views.SingleDataFileView.as_view(), + name="view, download, modify, delete files", + ), + path( + "file//users/", + views.DataFileUsersView.as_view(), + name="manage access to files", + ), + path("set/", views.DataSetView.as_view(), name="view and create datasets"), + path( + "set//", + views.SingleDataSetView.as_view(), + name="load, modify, delete datasets", + ), + path( + "set//users/", + views.DataSetUsersView.as_view(), + name="manage access to datasets", + ), + path("session/", views.SessionView.as_view(), name="view and create sessions"), + path( + "session//", + views.SingleSessionView.as_view(), + name="load, modify, delete sessions", + ), + path( + "session//users/", + views.SessionUsersView.as_view(), + name="manage access to sessions", + ), + path( + "published/", + views.PublishedStateView.as_view(), + name="view and create published states", + ), + path( + "published//", + views.SinglePublishedStateView.as_view(), + name="load, modify, delete published states", + ), +] diff --git a/sasdata/fair_database/data/views.py b/sasdata/fair_database/data/views.py new file mode 100644 index 000000000..fc1c547c1 --- /dev/null +++ b/sasdata/fair_database/data/views.py @@ -0,0 +1,707 @@ +import json +import os + +from data.forms import DataFileForm +from data.models import DataFile, DataSet, PublishedState, Session +from data.serializers import ( + AccessManagementSerializer, + DataFileSerializer, + DataSetSerializer, + PublishedStateSerializer, + PublishedStateUpdateSerializer, + SessionSerializer, +) +from django.contrib.auth.models import User +from django.http import ( + FileResponse, + Http404, + HttpResponse, + HttpResponseBadRequest, + HttpResponseForbidden, +) +from django.shortcuts import get_object_or_404 +from drf_spectacular.utils import extend_schema +from fair_database import permissions +from fair_database.permissions import DataPermission +from rest_framework import status +from rest_framework.response import Response +from rest_framework.views import APIView + +from sasdata.dataloader.loader import Loader + + +class DataFileView(APIView): + """ + View associated with the DataFile model. + + Functionality for viewing a list of files and uploading a new file. + """ + + # List of datafiles + @extend_schema( + description="Retrieve a list of accessible data files by id and filename." + ) + def get(self, request, version=None): + if "username" in request.GET: + search_user = get_object_or_404(User, username=request.GET["username"]) + data_list = {"user_data_ids": {}} + private_data = DataFile.objects.filter(current_user=search_user) + for x in private_data: + if permissions.check_permissions(request, x): + data_list["user_data_ids"][x.id] = x.file_name + else: + public_data = DataFile.objects.all() + data_list = {"public_data_ids": {}} + for x in public_data: + if permissions.check_permissions(request, x): + data_list["public_data_ids"][x.id] = x.file_name + return Response(data_list) + + # Create a datafile + @extend_schema(description="Upload a data file.") + def post(self, request, version=None): + form = DataFileForm(request.data, request.FILES) + if form.is_valid(): + form.save() + db = DataFile.objects.get(pk=form.instance.pk) + serializer = DataFileSerializer( + db, + data={ + "file_name": os.path.basename(form.instance.file.path), + "current_user": None, + "users": [], + }, + context={"is_public": db.is_public}, + ) + if request.user.is_authenticated: + serializer.initial_data["current_user"] = request.user.id + + if serializer.is_valid(raise_exception=True): + serializer.save() + return_data = { + "current_user": request.user.username, + "authenticated": request.user.is_authenticated, + "file_id": db.id, + "file_alternative_name": serializer.data["file_name"], + "is_public": serializer.data["is_public"], + } + return Response(return_data, status=status.HTTP_201_CREATED) + + # Create a datafile + @extend_schema(description="Upload a data file.") + def put(self, request, version=None): + return self.post(request, version) + + +class SingleDataFileView(APIView): + """ + View associated with a single DataFile. + + Functionality for viewing, modifying, or deleting a DataFile. + """ + + # Load the contents of a datafile or download the file to a device + @extend_schema( + description="Retrieve the contents of a data file or download a file." + ) + def get(self, request, data_id, version=None): + data = get_object_or_404(DataFile, id=data_id) + if "download" in request.GET and request.GET["download"]: + if not permissions.check_permissions(request, data): + if not request.user.is_authenticated: + return HttpResponse("Must be authenticated to download", status=401) + return HttpResponseForbidden("data is private") + try: + file = open(data.file.path, "rb") + except Exception as e: + return HttpResponseBadRequest(str(e)) + if file is None: + raise Http404("File not found.") + return FileResponse(file, as_attachment=True) + else: + loader = Loader() + if not permissions.check_permissions(request, data): + if not request.user.is_authenticated: + return HttpResponse("Must be authenticated to view", status=401) + return HttpResponseForbidden( + "Data is either not public or wrong auth token" + ) + data_list = loader.load(data.file.path) + contents = [str(data) for data in data_list] + return_data = {data.file_name: contents} + return Response(return_data) + + # Modify a datafile + @extend_schema(description="Make changes to a data file that you own.") + def put(self, request, data_id, version=None): + db = get_object_or_404(DataFile, id=data_id) + if not permissions.check_permissions(request, db): + if not request.user.is_authenticated: + return HttpResponse("must be authenticated to modify", status=401) + return HttpResponseForbidden("must be the data owner to modify") + form = DataFileForm(request.data, request.FILES, instance=db) + if form.is_valid(): + form.save() + serializer = DataFileSerializer( + db, + data={ + "file_name": os.path.basename(form.instance.file.path), + "current_user": request.user.id, + }, + context={"is_public": db.is_public}, + partial=True, + ) + if serializer.is_valid(raise_exception=True): + serializer.save() + return_data = { + "current_user": request.user.username, + "authenticated": request.user.is_authenticated, + "file_id": db.id, + "file_alternative_name": serializer.data["file_name"], + "is_public": serializer.data["is_public"], + } + return Response(return_data) + + # Delete a datafile + @extend_schema(description="Delete a data file that you own.") + def delete(self, request, data_id, version=None): + db = get_object_or_404(DataFile, id=data_id) + if not permissions.is_owner(request, db): + if not request.user.is_authenticated: + return HttpResponse("Must be authenticated to delete", status=401) + return HttpResponseForbidden("Must be the data owner to delete") + db.delete() + return Response(data={"success": True}) + + +class DataFileUsersView(APIView): + """ + View for the users that have access to a datafile. + + Functionality for accessing a list of users with access and granting or + revoking access. + """ + + # View users with access to a datafile + @extend_schema( + description="Retrieve a list of users that have been granted access to" + " a data file and the file's publicity status." + ) + def get(self, request, data_id, version=None): + db = get_object_or_404(DataFile, id=data_id) + if not permissions.is_owner(request, db): + if not request.user.is_authenticated: + return HttpResponse( + "Must be authenticated to manage access", status=401 + ) + return HttpResponseForbidden("Must be the data owner to manage access") + response_data = { + "file": db.pk, + "file_name": db.file_name, + "is_public": db.is_public, + "users": [user.username for user in db.users.all()], + } + return Response(response_data) + + # Grant or revoke access to a datafile + @extend_schema(description="Grant or revoke a user's access to a data file.") + def put(self, request, data_id, version=None): + db = get_object_or_404(DataFile, id=data_id) + if not permissions.is_owner(request, db): + if not request.user.is_authenticated: + return HttpResponse( + "Must be authenticated to manage access", status=401 + ) + return HttpResponseForbidden("Must be the data owner to manage access") + serializer = AccessManagementSerializer(data=request.data) + serializer.is_valid() + user = get_object_or_404(User, username=serializer.data["username"]) + if serializer.data["access"]: + db.users.add(user) + else: + db.users.remove(user) + response_data = { + "username": user.username, + "file": db.pk, + "file_name": db.file_name, + "access": (serializer.data["access"] or user == db.current_user), + } + return Response(response_data) + + +class DataSetView(APIView): + """ + View associated with the DataSet model. + + Functionality for viewing a list of datasets and creating a dataset. + """ + + permission_classes = [DataPermission] + + # get a list of accessible datasets + @extend_schema(description="Retrieve a list of accessible datasets by id and name.") + def get(self, request, version=None): + data_list = {"dataset_ids": {}} + data = DataSet.objects.all() + if "username" in request.GET: + user = get_object_or_404(User, username=request.GET["username"]) + data = DataSet.objects.filter(current_user=user) + for dataset in data: + if permissions.check_permissions(request, dataset): + data_list["dataset_ids"][dataset.id] = dataset.name + return Response(data=data_list) + + # TODO: enable uploading files as part of dataset creation, not just associating dataset with existing files + # create a dataset + @extend_schema(description="Upload a dataset.") + def post(self, request, version=None): + # TODO: revisit request data format + if isinstance(request.data, str): + serializer = DataSetSerializer( + data=json.loads(request.data), context={"request": request} + ) + else: + serializer = DataSetSerializer( + data=request.data, context={"request": request} + ) + if serializer.is_valid(raise_exception=True): + serializer.save() + db = serializer.instance + response = { + "dataset_id": db.id, + "name": db.name, + "authenticated": request.user.is_authenticated, + "current_user": request.user.username, + "is_public": db.is_public, + } + return Response(data=response, status=status.HTTP_201_CREATED) + + # create a dataset + @extend_schema(description="Upload a dataset.") + def put(self, request, version=None): + return self.post(request, version) + + +class SingleDataSetView(APIView): + """ + View associated with single datasets. + + Functionality for accessing a dataset in a format intended to be loaded + into SasView, modifying a dataset, or deleting a dataset. + """ + + permission_classes = [DataPermission] + + # get a specific dataset + @extend_schema(description="Retrieve a dataset.") + def get(self, request, data_id, version=None): + db = get_object_or_404(DataSet, id=data_id) + if not permissions.check_permissions(request, db): + if not request.user.is_authenticated: + return HttpResponse("Must be authenticated to view dataset", status=401) + return HttpResponseForbidden( + "You do not have permission to view this dataset." + ) + serializer = DataSetSerializer(db, context={"request": request}) + response_data = serializer.data + if db.current_user: + response_data["current_user"] = db.current_user.username + return Response(response_data) + + # edit a specific dataset + @extend_schema(description="Make changes to a dataset that you own.") + def put(self, request, data_id, version=None): + db = get_object_or_404(DataSet, id=data_id) + if not permissions.check_permissions(request, db): + if not request.user.is_authenticated: + return HttpResponse( + "Must be authenticated to modify dataset", status=401 + ) + return HttpResponseForbidden("Cannot modify a dataset you do not own") + serializer = DataSetSerializer( + db, request.data, context={"request": request}, partial=True + ) + clear_files = "files" in request.data and not request.data["files"] + if clear_files: + data_copy = request.data.copy() + data_copy.pop("files") + serializer = DataSetSerializer( + db, data_copy, context={"request": request}, partial=True + ) + if serializer.is_valid(raise_exception=True): + serializer.save() + if clear_files: + db.files.clear() + db.save() + data = {"data_id": db.id, "name": db.name, "is_public": db.is_public} + return Response(data) + + # delete a dataset + @extend_schema(description="Delete a dataset that you own.") + def delete(self, request, data_id, version=None): + db = get_object_or_404(DataSet, id=data_id) + if not permissions.check_permissions(request, db): + if not request.user.is_authenticated: + return HttpResponse( + "Must be authenticated to delete a dataset", status=401 + ) + return HttpResponseForbidden("Not authorized to delete") + db.delete() + return Response({"success": True}) + + +class DataSetUsersView(APIView): + """ + View for the users that have access to a dataset. + + Functionality for accessing a list of users with access and granting or + revoking access. + """ + + permission_classes = [DataPermission] + + # get a list of users with access to dataset data_id + @extend_schema( + description="Retrieve a list of users that have been granted access to" + " a dataset and the dataset's publicity status." + ) + def get(self, request, data_id, version=None): + db = get_object_or_404(DataSet, id=data_id) + if not permissions.is_owner(request, db): + if not request.user.is_authenticated: + return HttpResponse("Must be authenticated to view access", status=401) + return HttpResponseForbidden("Must be the dataset owner to view access") + response_data = { + "data_id": db.id, + "name": db.name, + "is_public": db.is_public, + "users": [user.username for user in db.users.all()], + } + return Response(response_data) + + # grant or revoke a user's access to dataset data_id + @extend_schema(description="Grant or revoke a user's access to a dataset.") + def put(self, request, data_id, version=None): + db = get_object_or_404(DataSet, id=data_id) + if not permissions.is_owner(request, db): + if not request.user.is_authenticated: + return HttpResponse( + "Must be authenticated to manage access", status=401 + ) + return HttpResponseForbidden("Must be the dataset owner to manage access") + serializer = AccessManagementSerializer(data=request.data) + serializer.is_valid(raise_exception=True) + user = get_object_or_404(User, username=serializer.data["username"]) + if serializer.data["access"]: + db.users.add(user) + else: + db.users.remove(user) + response_data = { + "username": user.username, + "data_id": db.id, + "name": db.name, + "access": serializer.data["access"], + } + return Response(response_data) + + +class SessionView(APIView): + """ + View associated with the Session model. + + Functionality for viewing a list of sessions and for creating a session. + """ + + # View a list of accessible sessions + @extend_schema( + description="Retrieve a list of accessible sessions by name and title." + ) + def get(self, request, version=None): + session_list = {"session_ids": {}} + sessions = Session.objects.all() + if "username" in request.GET: + user = get_object_or_404(User, username=request.GET["username"]) + sessions = Session.objects.filter(current_user=user) + for session in sessions: + if permissions.check_permissions(request, session): + session_list["session_ids"][session.id] = session.title + return Response(data=session_list) + + # Create a session + # TODO: revisit response data + @extend_schema(description="Upload a session.") + def post(self, request, version=None): + if isinstance(request.data, str): + serializer = SessionSerializer( + data=json.loads(request.data), context={"request": request} + ) + else: + serializer = SessionSerializer( + data=request.data, context={"request": request} + ) + if serializer.is_valid(raise_exception=True): + serializer.save() + db = serializer.instance + response = { + "session_id": db.id, + "title": db.title, + "authenticated": request.user.is_authenticated, + "current_user": request.user.username, + "is_public": db.is_public, + } + return Response(data=response, status=status.HTTP_201_CREATED) + + # Create a session + @extend_schema(description="Upload a session.") + def put(self, request, version=None): + return self.post(request, version) + + +class SingleSessionView(APIView): + """ + View associated with single sessions. + + Functionality for viewing, modifying, and deleting individual sessions. + """ + + # get a specific session + @extend_schema(description="Retrieve a session.") + def get(self, request, data_id, version=None): + db = get_object_or_404(Session, id=data_id) + if not permissions.check_permissions(request, db): + if not request.user.is_authenticated: + return HttpResponse("Must be authenticated to view session", status=401) + return HttpResponseForbidden( + "You do not have permission to view this session." + ) + serializer = SessionSerializer(db) + response_data = serializer.data + if db.current_user: + response_data["current_user"] = db.current_user.username + return Response(response_data) + + # modify a session + @extend_schema(description="Make changes to a session that you own.") + def put(self, request, data_id, version=None): + db = get_object_or_404(Session, id=data_id) + if not permissions.check_permissions(request, db): + if not request.user.is_authenticated: + return HttpResponse( + "Must be authenticated to modify session", status=401 + ) + return HttpResponseForbidden("Cannot modify a session you do not own") + serializer = SessionSerializer( + db, request.data, context={"request": request}, partial=True + ) + if serializer.is_valid(raise_exception=True): + serializer.save() + data = {"session_id": db.id, "title": db.title, "is_public": db.is_public} + return Response(data) + + # delete a session + @extend_schema(description="Delete a session that you own.") + def delete(self, request, data_id, version=None): + db = get_object_or_404(Session, id=data_id) + if not permissions.check_permissions(request, db): + if not request.user.is_authenticated: + return HttpResponse( + "Must be authenticated to delete a session", status=401 + ) + return HttpResponseForbidden("Not authorized to delete") + db.delete() + return Response({"success": True}) + + +class SessionUsersView(APIView): + """ + View for the users that have access to a session. + + Functionality for accessing a list of users with access and granting or + revoking access. + """ + + # view the users that have access to a specific session + @extend_schema( + description="Retrieve a list of users that have been granted access to" + " a session and the session's publicity status." + ) + def get(self, request, data_id, version=None): + db = get_object_or_404(Session, id=data_id) + if not permissions.is_owner(request, db): + if not request.user.is_authenticated: + return HttpResponse("Must be authenticated to view access", status=401) + return HttpResponseForbidden("Must be the session owner to view access") + response_data = { + "session_id": db.id, + "title": db.title, + "is_public": db.is_public, + "users": [user.username for user in db.users.all()], + } + return Response(response_data) + + # grant or revoke access to a session + @extend_schema(description="Grant or revoke a user's access to a data file.") + def put(self, request, data_id, version=None): + db = get_object_or_404(Session, id=data_id) + if not permissions.is_owner(request, db): + if not request.user.is_authenticated: + return HttpResponse( + "Must be authenticated to manage access", status=401 + ) + return HttpResponseForbidden("Must be the dataset owner to manage access") + serializer = AccessManagementSerializer(data=request.data) + serializer.is_valid(raise_exception=True) + user = get_object_or_404(User, username=serializer.data["username"]) + if serializer.data["access"]: + db.users.add(user) + for dataset in db.datasets.all(): + dataset.users.add(user) + else: + db.users.remove(user) + for dataset in db.datasets.all(): + dataset.users.remove(user) + response_data = { + "username": user.username, + "session_id": db.id, + "title": db.title, + "access": serializer.data["access"], + } + return Response(response_data) + + +class PublishedStateView(APIView): + """ + View associated with the PublishedState model. + + Functionality for viewing a list of session published states and for + creating a published state. + """ + + # View a list of accessible sessions' published states + @extend_schema( + description="Retrieve a list of published states of accessible sessions." + ) + def get(self, request, version=None): + ps_list = {"published_state_ids": {}} + published_states = PublishedState.objects.all() + if "username" in request.GET: + user = get_object_or_404(User, username=request.GET["username"]) + published_states = PublishedState.objects.filter(session__current_user=user) + for ps in published_states: + if permissions.check_permissions(request, ps.session): + ps_list["published_state_ids"][ps.id] = { + "title": ps.session.title, + "published": ps.published, + "doi": ps.doi, + } + return Response(data=ps_list) + + # Create a published state for an existing session + @extend_schema(description="Create a published state for an existing session.") + def post(self, request, version=None): + if isinstance(request.data, str): + serializer = PublishedStateSerializer( + data=json.loads(request.data), context={"request": request} + ) + else: + serializer = PublishedStateSerializer( + data=request.data, context={"request": request} + ) + if serializer.is_valid(raise_exception=True): + if not permissions.is_owner(request, serializer.validated_data["session"]): + if not request.user.is_authenticated: + return HttpResponse( + "Must be authenticated to create a published state for a session", + status=401, + ) + return HttpResponseForbidden( + "Must be the session owner to create a published state for a session" + ) + serializer.save() + db = serializer.instance + response = { + "published_state_id": db.id, + "session_id": db.session.id, + "title": db.session.title, + "doi": db.doi, + "published": db.published, + "current_user": request.user.username, + "is_public": db.session.is_public, + } + return Response(data=response, status=status.HTTP_201_CREATED) + + # Create a published state for an existing session + @extend_schema(description="Create a published state for an existing session.") + def put(self, request, version=None): + return self.post(request, version) + + +class SinglePublishedStateView(APIView): + """ + View associated with specific session published states. + + Functionality for viewing, modifying, and deleting individual published states. + """ + + # View a specific published state + @extend_schema(description="Retrieve a published state.") + def get(self, request, ps_id, version=None): + db = get_object_or_404(PublishedState, id=ps_id) + if not permissions.check_permissions(request, db.session): + if not request.user.is_authenticated: + return HttpResponse( + "Must be authenticated to view published state", status=401 + ) + return HttpResponseForbidden( + "You do not have permission to view this published state." + ) + serializer = PublishedStateSerializer(db) + response_data = serializer.data + response_data["title"] = db.session.title + if db.session.current_user: + response_data["current_user"] = db.session.current_user.username + else: + response_data["current_user"] = "" + response_data["is_public"] = db.session.is_public + return Response(response_data) + + # Modify a published state + @extend_schema( + description="Make changes to the published state of a session that you own." + ) + def put(self, request, ps_id, version=None): + db = get_object_or_404(PublishedState, id=ps_id) + if not permissions.check_permissions(request, db.session): + if not request.user.is_authenticated: + return HttpResponse( + "Must be authenticated to modify published state", status=401 + ) + return HttpResponseForbidden( + "Cannot modify a published state you do not own" + ) + serializer = PublishedStateUpdateSerializer( + db, request.data, context={"request": request}, partial=True + ) + if serializer.is_valid(raise_exception=True): + serializer.save() + data = { + "published_state_id": db.id, + "session_id": db.session.id, + "title": db.session.title, + "published": db.published, + "is_public": db.session.is_public, + } + return Response(data) + + # Delete a published state + @extend_schema(description="Delete the published state of a session that you own.") + def delete(self, request, ps_id, version=None): + db = get_object_or_404(PublishedState, id=ps_id) + if not permissions.check_permissions(request, db.session): + if not request.user.is_authenticated: + return HttpResponse( + "Must be authenticated to delete a published state", status=401 + ) + return HttpResponseForbidden("Not authorized to delete") + db.delete() + return Response({"success": True}) diff --git a/sasdata/fair_database/documentation.yaml b/sasdata/fair_database/documentation.yaml new file mode 100644 index 000000000..22a0488fb --- /dev/null +++ b/sasdata/fair_database/documentation.yaml @@ -0,0 +1,1172 @@ +openapi: 3.0.3 +info: + title: SasView Database + version: 0.1.0 + description: A database following the FAIR data principles for SasView, a small + angle scattering analysis application. +paths: + /{version}/data/file/: + get: + operationId: data_file_list + description: Retrieve a list of accessible data files by id and filename. + parameters: + - in: path + name: version + schema: + type: string + pattern: ^(v1)$ + required: true + tags: + - data + security: + - knoxApiToken: [] + - cookieAuth: [] + - {} + responses: + '200': + description: OK + content: + application/json: + schema: + $ref: "#components/schemas/DataFileList" + post: + operationId: data_file_create + description: Upload a data file. + parameters: + - in: path + name: version + schema: + type: string + pattern: ^(v1)$ + required: true + tags: + - data + security: + - knoxApiToken: [] + - cookieAuth: [] + - {} + requestBody: + required: true + content: + multipart/form-data: + schema: + $ref: "#components/schemas/DataFileCreate" + responses: + '201': + description: CREATED + content: + application/json: + schema: + $ref: "#components/schemas/DataFileCreated" + put: + operationId: data_file_create_2 + description: Upload a data file. + parameters: + - in: path + name: version + schema: + type: string + pattern: ^(v1)$ + required: true + tags: + - data + security: + - knoxApiToken: [] + - cookieAuth: [] + - {} + requestBody: + required: true + content: + multipart/form-data: + schema: + $ref: "components/schemas/DataFileCreate" + responses: + '201': + description: CREATED + content: + application/json: + schema: + $ref: "#components/schemas/DataFileCreated" + /{version}/data/file/{data_id}/: + get: + operationId: data_file_retrieve + description: Retrieve the contents of a data file or download a file. + parameters: + - in: path + name: data_id + schema: + type: integer + required: true + - in: path + name: version + schema: + type: string + pattern: ^(v1)$ + required: true + tags: + - data + security: + - knoxApiToken: [] + - cookieAuth: [] + - {} + requestBlock: + schema: + $ref: "#components/schemas/DataFileGet" + responses: + '200': + description: OK + content: + application/json: + schema: + oneOf: + - $ref: "#components/schemas/DataFile" + - $ref: "components/schemas/DataFileDownload" + put: + operationId: data_file_update_2 + description: Make changes to a data file that you own. + parameters: + - in: path + name: data_id + schema: + type: integer + required: true + - in: path + name: version + schema: + type: string + pattern: ^(v1)$ + required: true + tags: + - data + security: + - knoxApiToken: [] + - cookieAuth: [] + - {} + requestBody: + required: true + content: + multipart/form-data: + schema: + $ref: "components/schemas/DataFileCreate" + responses: + '200': + description: OK + content: + application/json: + schema: + $ref: "#components/schemas/DataFileCreated" + delete: + operationId: data_file_destroy + description: Delete a data file that you own. + parameters: + - in: path + name: data_id + schema: + type: integer + required: true + - in: path + name: version + schema: + type: string + pattern: ^(v1)$ + required: true + tags: + - data + security: + - knoxApiToken: [] + - cookieAuth: [] + - {} + responses: + '200': + description: OK + content: + application/json: + schema: + $ref: "#components/schemas/Delete" + /{version}/data/file/{data_id}/users/: + get: + operationId: data_file_users_retrieve + description: Retrieve a list of users that have been granted access to a data + file and the file's publicity status. + parameters: + - in: path + name: data_id + schema: + type: integer + required: true + - in: path + name: version + schema: + type: string + pattern: ^(v1)$ + required: true + tags: + - data + security: + - knoxApiToken: [] + - cookieAuth: [] + - {} + responses: + '200': + description: OK + content: + application/json: + schema: + allOf: + - $ref: "#components/schemas/UsersList" + - type: object + properties: + file: + type: integer + file_name: + type: string + put: + operationId: data_file_users_update + description: Grant or revoke a user's access to a data file. + parameters: + - in: path + name: data_id + schema: + type: integer + required: true + - in: path + name: version + schema: + type: string + pattern: ^(v1)$ + required: true + tags: + - data + security: + - knoxApiToken: [] + - cookieAuth: [] + - {} + requestBody: + required: true + content: + application/json: + schema: + $ref: "#components/schemas/ManageAccess" + responses: + '200': + description: OK + content: + application/json: + schema: + allOf: + - $ref: "#components/schemas/ManageAccess" + - type: object + properties: + file: + type: integer + file_name: + type: string + /{version}/data/published/: + get: + operationId: data_published_retrieve + description: Retrieve a list of published states of accessible sessions. + parameters: + - in: path + name: version + schema: + type: string + pattern: ^(v1)$ + required: true + tags: + - data + security: + - knoxApiToken: [] + - cookieAuth: [] + - {} + responses: + '200': + description: OK + post: + operationId: data_published_create + description: Create a published state for an existing session. + parameters: + - in: path + name: version + schema: + type: string + pattern: ^(v1)$ + required: true + tags: + - data + security: + - knoxApiToken: [] + - cookieAuth: [] + - {} + responses: + '201': + description: CREATED + put: + operationId: data_published_update + description: Create a published state for an existing session. + parameters: + - in: path + name: version + schema: + type: string + pattern: ^(v1)$ + required: true + tags: + - data + security: + - knoxApiToken: [] + - cookieAuth: [] + - {} + responses: + '201': + description: CREATED + /{version}/data/published/{ps_id}/: + get: + operationId: data_published_retrieve_2 + description: Retrieve a published state. + parameters: + - in: path + name: ps_id + schema: + type: integer + required: true + - in: path + name: version + schema: + type: string + pattern: ^(v1)$ + required: true + tags: + - data + security: + - knoxApiToken: [] + - cookieAuth: [] + - {} + responses: + '200': + description: OK + put: + operationId: data_published_update_2 + description: Make changes to the published state of a session that you own. + parameters: + - in: path + name: ps_id + schema: + type: integer + required: true + - in: path + name: version + schema: + type: string + pattern: ^(v1)$ + required: true + tags: + - data + security: + - knoxApiToken: [] + - cookieAuth: [] + - {} + responses: + '200': + description: OK + delete: + operationId: data_published_destroy + description: Delete the published state of a session that you own. + parameters: + - in: path + name: ps_id + schema: + type: integer + required: true + - in: path + name: version + schema: + type: string + pattern: ^(v1)$ + required: true + tags: + - data + security: + - knoxApiToken: [] + - cookieAuth: [] + - {} + responses: + '200': + description: OK + /{version}/data/session/: + get: + operationId: data_session_retrieve + description: Retrieve a list of accessible sessions by name and title. + parameters: + - in: path + name: version + schema: + type: string + pattern: ^(v1)$ + required: true + tags: + - data + security: + - knoxApiToken: [] + - cookieAuth: [] + - {} + responses: + '200': + description: OK + post: + operationId: data_session_create + description: Upload a session. + parameters: + - in: path + name: version + schema: + type: string + pattern: ^(v1)$ + required: true + tags: + - data + security: + - knoxApiToken: [] + - cookieAuth: [] + - {} + responses: + '201': + description: CREATED + put: + operationId: data_session_update + description: Upload a session. + parameters: + - in: path + name: version + schema: + type: string + pattern: ^(v1)$ + required: true + tags: + - data + security: + - knoxApiToken: [] + - cookieAuth: [] + - {} + responses: + '201': + description: CREATED + /{version}/data/session/{data_id}/: + get: + operationId: data_session_retrieve_2 + description: Retrieve a session. + parameters: + - in: path + name: data_id + schema: + type: integer + required: true + - in: path + name: version + schema: + type: string + pattern: ^(v1)$ + required: true + tags: + - data + security: + - knoxApiToken: [] + - cookieAuth: [] + - {} + responses: + '200': + description: OK + put: + operationId: data_session_update_2 + description: Make changes to a session that you own. + parameters: + - in: path + name: data_id + schema: + type: integer + required: true + - in: path + name: version + schema: + type: string + pattern: ^(v1)$ + required: true + tags: + - data + security: + - knoxApiToken: [] + - cookieAuth: [] + - {} + responses: + '200': + description: OK + delete: + operationId: data_session_destroy + description: Delete a session that you own. + parameters: + - in: path + name: data_id + schema: + type: integer + required: true + - in: path + name: version + schema: + type: string + pattern: ^(v1)$ + required: true + tags: + - data + security: + - knoxApiToken: [] + - cookieAuth: [] + - {} + responses: + '200': + description: OK + /{version}/data/session/{data_id}/users/: + get: + operationId: data_session_users_retrieve + description: Retrieve a list of users that have been granted access to a session + and the session's publicity status. + parameters: + - in: path + name: data_id + schema: + type: integer + required: true + - in: path + name: version + schema: + type: string + pattern: ^(v1)$ + required: true + tags: + - data + security: + - knoxApiToken: [] + - cookieAuth: [] + - {} + responses: + '200': + description: OK + put: + operationId: data_session_users_update + description: Grant or revoke a user's access to a data file. + parameters: + - in: path + name: data_id + schema: + type: integer + required: true + - in: path + name: version + schema: + type: string + pattern: ^(v1)$ + required: true + tags: + - data + security: + - knoxApiToken: [] + - cookieAuth: [] + - {} + responses: + '200': + description: OK + /{version}/data/set/: + get: + operationId: data_set_retrieve + description: Retrieve a list of accessible datasets by id and name. + parameters: + - in: path + name: version + schema: + type: string + pattern: ^(v1)$ + required: true + tags: + - data + security: + - knoxApiToken: [] + - cookieAuth: [] + responses: + '200': + description: OK + post: + operationId: data_set_create + description: Upload a dataset. + parameters: + - in: path + name: version + schema: + type: string + pattern: ^(v1)$ + required: true + tags: + - data + security: + - knoxApiToken: [] + - cookieAuth: [] + responses: + '201': + description: CREATED + put: + operationId: data_set_update + description: Upload a dataset. + parameters: + - in: path + name: version + schema: + type: string + pattern: ^(v1)$ + required: true + tags: + - data + security: + - knoxApiToken: [] + - cookieAuth: [] + responses: + '201': + description: CREATED + /{version}/data/set/{data_id}/: + get: + operationId: data_set_retrieve_2 + description: Retrieve a dataset. + parameters: + - in: path + name: data_id + schema: + type: integer + required: true + - in: path + name: version + schema: + type: string + pattern: ^(v1)$ + required: true + tags: + - data + security: + - knoxApiToken: [] + - cookieAuth: [] + responses: + '200': + description: OK + put: + operationId: data_set_update_2 + description: Make changes to a dataset that you own. + parameters: + - in: path + name: data_id + schema: + type: integer + required: true + - in: path + name: version + schema: + type: string + pattern: ^(v1)$ + required: true + tags: + - data + security: + - knoxApiToken: [] + - cookieAuth: [] + responses: + '200': + description: OK + delete: + operationId: data_set_destroy + description: Delete a dataset that you own. + parameters: + - in: path + name: data_id + schema: + type: integer + required: true + - in: path + name: version + schema: + type: string + pattern: ^(v1)$ + required: true + tags: + - data + security: + - knoxApiToken: [] + - cookieAuth: [] + responses: + '200': + description: OK + /{version}/data/set/{data_id}/users/: + get: + operationId: data_set_users_retrieve + description: Retrieve a list of users that have been granted access to a dataset + and the dataset's publicity status. + parameters: + - in: path + name: data_id + schema: + type: integer + required: true + - in: path + name: version + schema: + type: string + pattern: ^(v1)$ + required: true + tags: + - data + security: + - knoxApiToken: [] + - cookieAuth: [] + responses: + '200': + description: OK + put: + operationId: data_set_users_update + description: Grant or revoke a user's access to a dataset. + parameters: + - in: path + name: data_id + schema: + type: integer + required: true + - in: path + name: version + schema: + type: string + pattern: ^(v1)$ + required: true + tags: + - data + security: + - knoxApiToken: [] + - cookieAuth: [] + responses: + '200': + description: OK + /auth/login/: + post: + operationId: auth_login_create + description: |- + Check the credentials and return the REST Token + if the credentials are valid and authenticated. + Calls Django Auth login method to register User ID + in Django session framework + + Accept the following POST parameters: username, password + Return the REST Framework Token Object's key. + tags: + - auth + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/Login' + application/x-www-form-urlencoded: + schema: + $ref: '#/components/schemas/Login' + multipart/form-data: + schema: + $ref: '#/components/schemas/Login' + required: true + security: + - knoxApiToken: [] + - cookieAuth: [] + - {} + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/Login' + description: '' + /auth/logout/: + post: + operationId: auth_logout_create + description: |- + Calls Django logout method and delete the Token object + assigned to the current User object. + + Accepts/Returns nothing. + tags: + - auth + security: + - knoxApiToken: [] + - cookieAuth: [] + - {} + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/RestAuthDetail' + description: '' + /auth/password/change/: + post: + operationId: auth_password_change_create + description: |- + Calls Django Auth SetPasswordForm save method. + + Accepts the following POST parameters: new_password1, new_password2 + Returns the success/fail message. + tags: + - auth + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/PasswordChange' + application/x-www-form-urlencoded: + schema: + $ref: '#/components/schemas/PasswordChange' + multipart/form-data: + schema: + $ref: '#/components/schemas/PasswordChange' + required: true + security: + - knoxApiToken: [] + - cookieAuth: [] + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/RestAuthDetail' + description: '' + /auth/register/: + post: + operationId: auth_register_create + description: |- + Registers a new user. + + Accepts the following POST parameters: username, email, password1, password2. + tags: + - auth + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/Register' + application/x-www-form-urlencoded: + schema: + $ref: '#/components/schemas/Register' + multipart/form-data: + schema: + $ref: '#/components/schemas/Register' + required: true + security: + - knoxApiToken: [] + - cookieAuth: [] + - {} + responses: + '201': + content: + application/json: + schema: + $ref: '#/components/schemas/Register' + description: '' + /auth/user/: + get: + operationId: auth_user_retrieve + description: |- + Reads and updates UserModel fields + Accepts GET, PUT, PATCH methods. + + Default accepted fields: username, first_name, last_name + Default display fields: pk, username, email, first_name, last_name + Read-only fields: pk, email + + Returns UserModel fields. + tags: + - auth + security: + - knoxApiToken: [] + - cookieAuth: [] + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/UserDetails' + description: '' + put: + operationId: auth_user_update + description: |- + Reads and updates UserModel fields + Accepts GET, PUT, PATCH methods. + + Default accepted fields: username, first_name, last_name + Default display fields: pk, username, email, first_name, last_name + Read-only fields: pk, email + + Returns UserModel fields. + tags: + - auth + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/UserDetails' + application/x-www-form-urlencoded: + schema: + $ref: '#/components/schemas/UserDetails' + multipart/form-data: + schema: + $ref: '#/components/schemas/UserDetails' + required: true + security: + - knoxApiToken: [] + - cookieAuth: [] + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/UserDetails' + description: '' + patch: + operationId: auth_user_partial_update + description: |- + Reads and updates UserModel fields + Accepts GET, PUT, PATCH methods. + + Default accepted fields: username, first_name, last_name + Default display fields: pk, username, email, first_name, last_name + Read-only fields: pk, email + + Returns UserModel fields. + tags: + - auth + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/PatchedUserDetails' + application/x-www-form-urlencoded: + schema: + $ref: '#/components/schemas/PatchedUserDetails' + multipart/form-data: + schema: + $ref: '#/components/schemas/PatchedUserDetails' + security: + - knoxApiToken: [] + - cookieAuth: [] + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/UserDetails' + description: '' +components: + schemas: + Delete: + type: object + properties: + success: + type: boolean + UsersList: + type: object + properties: + is_public: + type: boolean + users: + type: array + items: + type: string + ManageAccess: + type: object + properties: + username: + type: string + access: + type: boolean + DataFileList: + type: object + properties: + data_ids: + type: object + additionalProperties: + filename: + type: string + DataFileCreate: + type: object + properties: + filename: + type: string + file: + type: string + format: binary + DataFileCreated: + type: object + properties: + current_user: + type: string + authenticated: + type: boolean + file_id: + type: integer + file_alternative_name: + type: string + is_public: + type: boolean + DataFileGet: + type: object + properties: + download: + type: boolean + DataFile: + type: object + properties: + filename: + type: object + additionalProperties: + type: array + items: + type: string + DataFileDownload: + type: string + format: binary + Login: + type: object + properties: + username: + type: string + email: + type: string + format: email + password: + type: string + required: + - password + DataSetList: + type: object + properties: + dataset_ids: + type: object + additionalProperties: + name: + type: string + DataSetCreate: + type: object + properties: + name: + type: string + is_public: + type: boolean + data_contents: + type: array + items: + type: object + properties: + label: + type: string + value: + type: object + additionalProperties: true + PasswordChange: + type: object + properties: + new_password1: + type: string + maxLength: 128 + new_password2: + type: string + maxLength: 128 + required: + - new_password1 + - new_password2 + PatchedUserDetails: + type: object + description: User model w/o password + properties: + pk: + type: integer + readOnly: true + title: ID + username: + type: string + description: Required. 150 characters or fewer. Letters, digits and @/./+/-/_ + only. + pattern: ^[\w.@+-]+$ + maxLength: 150 + email: + type: string + format: email + readOnly: true + title: Email address + first_name: + type: string + maxLength: 150 + last_name: + type: string + maxLength: 150 + Register: + type: object + properties: + username: + type: string + maxLength: 150 + minLength: 1 + email: + type: string + format: email + password1: + type: string + writeOnly: true + password2: + type: string + writeOnly: true + required: + - password1 + - password2 + - username + RestAuthDetail: + type: object + properties: + detail: + type: string + readOnly: true + required: + - detail + UserDetails: + type: object + description: User model w/o password + properties: + pk: + type: integer + readOnly: true + title: ID + username: + type: string + description: Required. 150 characters or fewer. Letters, digits and @/./+/-/_ + only. + pattern: ^[\w.@+-]+$ + maxLength: 150 + email: + type: string + format: email + readOnly: true + title: Email address + first_name: + type: string + maxLength: 150 + last_name: + type: string + maxLength: 150 + required: + - email + - pk + - username + securitySchemes: + cookieAuth: + type: apiKey + in: cookie + name: sessionid + knoxApiToken: + type: apiKey + in: header + name: Authorization + description: Token-based authentication with required prefix "Token" diff --git a/sasdata/fair_database/fair_database/__init__.py b/sasdata/fair_database/fair_database/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/sasdata/fair_database/fair_database/asgi.py b/sasdata/fair_database/fair_database/asgi.py new file mode 100644 index 000000000..a10c9b212 --- /dev/null +++ b/sasdata/fair_database/fair_database/asgi.py @@ -0,0 +1,16 @@ +""" +ASGI config for fair_database project. + +It exposes the ASGI callable as a module-level variable named ``application``. + +For more information on this file, see +https://docs.djangoproject.com/en/5.1/howto/deployment/asgi/ +""" + +import os + +from django.core.asgi import get_asgi_application + +os.environ.setdefault("DJANGO_SETTINGS_MODULE", "fair_database.settings") + +application = get_asgi_application() diff --git a/sasdata/fair_database/fair_database/create_example_session.py b/sasdata/fair_database/fair_database/create_example_session.py new file mode 100644 index 000000000..6c10d9905 --- /dev/null +++ b/sasdata/fair_database/fair_database/create_example_session.py @@ -0,0 +1,97 @@ +import requests + +session = { + "title": "Example Session", + "datasets": [ + { + "name": "Dataset 1", + "metadata": { + "title": "Metadata 1", + "run": 1, + "description": "test", + "instrument": {}, + "process": {}, + "sample": {}, + }, + "data_contents": [ + { + "value": 0, + "variance": 0, + "units": "no", + "hash": 0, + "label": "Quantity 1", + "history": {"operation_tree": {}, "references": []}, + } + ], + }, + { + "name": "Dataset 2", + "metadata": { + "title": "Metadata 2", + "run": 2, + "description": "test", + "instrument": {}, + "process": {}, + "sample": {}, + }, + "data_contents": [ + { + "label": "Quantity 2", + "value": {"array_contents": [0, 0, 0, 0], "shape": (2, 2)}, + "variance": {"array_contents": [0, 0, 0, 0], "shape": (2, 2)}, + "units": "none", + "hash": 0, + "history": { + "operation_tree": { + "operation": "neg", + "parameters": { + "a": { + "operation": "mul", + "parameters": { + "a": { + "operation": "constant", + "parameters": { + "value": {"type": "int", "value": 7} + }, + }, + "b": { + "operation": "variable", + "parameters": { + "hash_value": 111, + "name": "x", + }, + }, + }, + }, + }, + }, + "references": [ + { + "value": 5, + "variance": 0, + "units": "none", + "hash": 111, + "history": {}, + } + ], + }, + } + ], + }, + ], + "is_public": False, +} + +url = "http://127.0.0.1:8000/v1/data/session/" +login_data = {"email": "test@test.org", "username": "testUser", "password": "sasview!"} +response = requests.post("http://127.0.0.1:8000/auth/login/", data=login_data) +if response.status_code != 200: + register_data = { + "email": "test@test.org", + "username": "testUser", + "password1": "sasview!", + "password2": "sasview!", + } + response = requests.post("http://127.0.0.1:8000/auth/register/", data=register_data) +token = response.json()["token"] +requests.request("POST", url, json=session, headers={"Authorization": "Token " + token}) diff --git a/sasdata/fair_database/fair_database/permissions.py b/sasdata/fair_database/fair_database/permissions.py new file mode 100644 index 000000000..74be5f33a --- /dev/null +++ b/sasdata/fair_database/fair_database/permissions.py @@ -0,0 +1,29 @@ +from rest_framework.permissions import BasePermission + + +# check if a request is made by an object's owner +def is_owner(request, obj): + return request.user.is_authenticated and request.user == obj.current_user + + +# check if a request is made by a user with read access +def has_access(request, obj): + return is_owner(request, obj) or ( + request.user.is_authenticated and request.user in obj.users.all() + ) + + +class DataPermission(BasePermission): + # check if a request has the correct permissions for a specific object + def has_object_permission(self, request, view, obj): + if request.method == "GET": + return obj.is_public or has_access(request, obj) + elif request.method == "DELETE": + return not obj.is_public and is_owner(request, obj) + else: + return is_owner(request, obj) + + +# check if a request has the correct permissions for a specific object +def check_permissions(request, obj): + return DataPermission().has_object_permission(request, None, obj) diff --git a/sasdata/fair_database/fair_database/settings.py b/sasdata/fair_database/fair_database/settings.py new file mode 100644 index 000000000..f3ec69ed9 --- /dev/null +++ b/sasdata/fair_database/fair_database/settings.py @@ -0,0 +1,202 @@ +""" +Django settings for fair_database project. + +Generated by 'django-admin startproject' using Django 5.1.5. + +For more information on this file, see +https://docs.djangoproject.com/en/5.1/topics/settings/ + +For the full list of settings and their values, see +https://docs.djangoproject.com/en/5.1/ref/settings/ +""" + +import os +from pathlib import Path + +# Build paths inside the project like this: BASE_DIR / 'subdir'. +BASE_DIR = Path(__file__).resolve().parent.parent + + +# Quick-start development settings - unsuitable for production +# See https://docs.djangoproject.com/en/5.1/howto/deployment/checklist/ + +# SECURITY WARNING: keep the secret key used in production secret! +SECRET_KEY = "django-insecure--f-t5!pdhq&4)^&xenr^k0e8n%-h06jx9d0&2kft(!+1$xzig)" + +# SECURITY WARNING: don't run with debug turned on in production! +DEBUG = True + +ALLOWED_HOSTS = [] + + +# Application definition + +INSTALLED_APPS = [ + "data.apps.DataConfig", + "django.contrib.admin", + "django.contrib.auth", + "django.contrib.contenttypes", + "django.contrib.sessions", + "django.contrib.messages", + "django.contrib.staticfiles", + "django.contrib.sites", + "rest_framework", + "rest_framework.authtoken", + "allauth", + "allauth.account", + "allauth.socialaccount", + "allauth.socialaccount.providers.orcid", + "dj_rest_auth", + "dj_rest_auth.registration", + "knox", + "user_app.apps.UserAppConfig", + "drf_spectacular", +] + +SITE_ID = 1 + +MIDDLEWARE = [ + "django.middleware.security.SecurityMiddleware", + "django.contrib.sessions.middleware.SessionMiddleware", + "django.middleware.common.CommonMiddleware", + "django.middleware.csrf.CsrfViewMiddleware", + "django.contrib.auth.middleware.AuthenticationMiddleware", + "django.contrib.messages.middleware.MessageMiddleware", + "django.middleware.clickjacking.XFrameOptionsMiddleware", + "allauth.account.middleware.AccountMiddleware", +] + +ROOT_URLCONF = "fair_database.urls" + +TEMPLATES = [ + { + "BACKEND": "django.template.backends.django.DjangoTemplates", + "DIRS": [], + "APP_DIRS": True, + "OPTIONS": { + "context_processors": [ + "django.template.context_processors.debug", + "django.template.context_processors.request", + "django.contrib.auth.context_processors.auth", + "django.contrib.messages.context_processors.messages", + ], + }, + }, +] + +WSGI_APPLICATION = "fair_database.wsgi.application" + +# Authentication +AUTHENTICATION_BACKENDS = ( + "django.contrib.auth.backends.ModelBackend", + "allauth.account.auth_backends.AuthenticationBackend", +) + +REST_FRAMEWORK = { + "DEFAULT_AUTHENTICATION_CLASSES": [ + "knox.auth.TokenAuthentication", + "rest_framework.authentication.SessionAuthentication", + ], + "DEFAULT_SCHEMA_CLASS": "drf_spectacular.openapi.AutoSchema", +} + +REST_AUTH = { + "TOKEN_SERIALIZER": "user_app.serializers.KnoxSerializer", + "USER_DETAILS_SERIALIZER": "dj_rest_auth.serializers.UserDetailsSerializer", + "TOKEN_MODEL": "knox.models.AuthToken", + "TOKEN_CREATOR": "user_app.util.create_knox_token", +} + +SPECTACULAR_SETTINGS = { + "TITLE": "SasView Database", + "DESCRIPTION": "A database following the FAIR data principles for SasView," + " a small angle scattering analysis application.", + "VERSION": "0.1.0", + "SERVE_INCLUDE_SCHEMA": False, +} + +# allauth settings +HEADLESS_ONLY = True +ACCOUNT_EMAIL_VERIFICATION = "none" + +# to enable ORCID, register for credentials through ORCID and fill out client_id and secret +# https://info.orcid.org/documentation/integration-guide/ +# https://docs.allauth.org/en/latest/socialaccount/index.html +SOCIALACCOUNT_PROVIDERS = { + "orcid": { + "APPS": [ + { + "client_id": "", + "secret": "", + "key": "", + } + ], + "SCOPE": [ + "profile", + "email", + ], + "AUTH_PARAMETERS": {"access_type": "online"}, + # Base domain of the API. Default value: 'orcid.org', for the production API + "BASE_DOMAIN": "sandbox.orcid.org", # for the sandbox API + # Member API or Public API? Default: False (for the public API) + "MEMBER_API": False, + } +} + +# Database +# https://docs.djangoproject.com/en/5.1/ref/settings/#databases + +DATABASES = { + "default": { + "ENGINE": "django.db.backends.sqlite3", + "NAME": BASE_DIR / "db.sqlite3", + } +} + + +# Password validation +# https://docs.djangoproject.com/en/5.1/ref/settings/#auth-password-validators + +AUTH_PASSWORD_VALIDATORS = [ + { + "NAME": "django.contrib.auth.password_validation.UserAttributeSimilarityValidator", + }, + { + "NAME": "django.contrib.auth.password_validation.MinimumLengthValidator", + }, + { + "NAME": "django.contrib.auth.password_validation.CommonPasswordValidator", + }, + { + "NAME": "django.contrib.auth.password_validation.NumericPasswordValidator", + }, +] + + +# Internationalization +# https://docs.djangoproject.com/en/5.1/topics/i18n/ + +LANGUAGE_CODE = "en-us" + +TIME_ZONE = "UTC" + +USE_I18N = True + +USE_TZ = True + + +# Static files (CSS, JavaScript, Images) +# https://docs.djangoproject.com/en/4.2/howto/static-files/ + + +STATIC_ROOT = os.path.join(BASE_DIR, "static") +STATIC_URL = "/static/" + +# instead of doing this, create a create a new media_root +MEDIA_ROOT = os.path.join(BASE_DIR, "media") +MEDIA_URL = "/media/" + +# Default primary key field type +# https://docs.djangoproject.com/en/5.1/ref/settings/#default-auto-field + +DEFAULT_AUTO_FIELD = "django.db.models.BigAutoField" diff --git a/sasdata/fair_database/fair_database/test_permissions.py b/sasdata/fair_database/fair_database/test_permissions.py new file mode 100644 index 000000000..ffb55dbd7 --- /dev/null +++ b/sasdata/fair_database/fair_database/test_permissions.py @@ -0,0 +1,292 @@ +import os +import shutil + +from data.models import DataFile +from django.conf import settings +from django.contrib.auth.models import User +from rest_framework import status +from rest_framework.test import APITestCase + + +def find(filename): + return os.path.join( + os.path.dirname(__file__), "../../example_data/1d_data", filename + ) + + +def auth_header(response): + return {"Authorization": "Token " + response.data["token"]} + + +class DataListPermissionsTests(APITestCase): + """Test permissions of data views using user_app for authentication.""" + + @classmethod + def setUpTestData(cls): + cls.user = User.objects.create_user( + username="testUser", password="secret", id=1, email="email@domain.com" + ) + cls.user2 = User.objects.create_user( + username="testUser2", password="secret", id=2, email="email2@domain.com" + ) + cls.unowned_test_data = DataFile.objects.create( + id=1, file_name="cyl_400_40.txt", is_public=True + ) + cls.unowned_test_data.file.save( + "cyl_400_40.txt", open(find("cyl_400_40.txt"), "rb") + ) + cls.private_test_data = DataFile.objects.create( + id=2, current_user=cls.user, file_name="cyl_400_20.txt", is_public=False + ) + cls.private_test_data.file.save( + "cyl_400_20.txt", open(find("cyl_400_20.txt"), "rb") + ) + cls.public_test_data = DataFile.objects.create( + id=3, current_user=cls.user, file_name="cyl_testdata.txt", is_public=True + ) + cls.public_test_data.file.save( + "cyl_testdata.txt", open(find("cyl_testdata.txt"), "rb") + ) + cls.login_data_1 = { + "username": "testUser", + "password": "secret", + "email": "email@domain.com", + } + cls.login_data_2 = { + "username": "testUser2", + "password": "secret", + "email": "email2@domain.com", + } + + # Authenticated user can view list of data + def test_list_authenticated(self): + token = self.client.post("/auth/login/", data=self.login_data_1) + response = self.client.get("/v1/data/file/", headers=auth_header(token)) + response2 = self.client.get( + "/v1/data/file/", data={"username": "testUser"}, headers=auth_header(token) + ) + self.assertEqual( + response.data, + { + "public_data_ids": { + 1: "cyl_400_40.txt", + 2: "cyl_400_20.txt", + 3: "cyl_testdata.txt", + } + }, + ) + self.assertEqual( + response2.data, + {"user_data_ids": {2: "cyl_400_20.txt", 3: "cyl_testdata.txt"}}, + ) + + # Authenticated user cannot view other users' private data on list + def test_list_authenticated_2(self): + token = self.client.post("/auth/login/", data=self.login_data_2) + response = self.client.get("/v1/data/file/", headers=auth_header(token)) + response2 = self.client.get( + "/v1/data/file/", data={"username": "testUser"}, headers=auth_header(token) + ) + response3 = self.client.get( + "/v1/data/file/", data={"username": "testUser2"}, headers=auth_header(token) + ) + self.assertEqual( + response.data, + {"public_data_ids": {1: "cyl_400_40.txt", 3: "cyl_testdata.txt"}}, + ) + self.assertEqual(response2.status_code, status.HTTP_200_OK) + self.assertEqual(response2.data, {"user_data_ids": {3: "cyl_testdata.txt"}}) + self.assertEqual(response3.data, {"user_data_ids": {}}) + + # Unauthenticated user can view list of public data + def test_list_unauthenticated(self): + response = self.client.get("/v1/data/file/") + response2 = self.client.get("/v1/data/file/", data={"username": "testUser"}) + self.assertEqual( + response.data, + {"public_data_ids": {1: "cyl_400_40.txt", 3: "cyl_testdata.txt"}}, + ) + self.assertEqual(response2.status_code, status.HTTP_200_OK) + self.assertEqual(response2.data, {"user_data_ids": {3: "cyl_testdata.txt"}}) + + # Authenticated user can load public data and owned private data + def test_load_authenticated(self): + token = self.client.post("/auth/login/", data=self.login_data_1) + response = self.client.get("/v1/data/file/1/", headers=auth_header(token)) + response2 = self.client.get("/v1/data/file/2/", headers=auth_header(token)) + response3 = self.client.get("/v1/data/file/3/", headers=auth_header(token)) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response2.status_code, status.HTTP_200_OK) + self.assertEqual(response3.status_code, status.HTTP_200_OK) + + # Authenticated user cannot load others' private data + def test_load_unauthorized(self): + token = self.client.post("/auth/login/", data=self.login_data_2) + response = self.client.get("/v1/data/file/2/", headers=auth_header(token)) + response2 = self.client.get("/v1/data/file/3/", headers=auth_header(token)) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(response2.status_code, status.HTTP_200_OK) + + # Unauthenticated user can load public data only + def test_load_unauthenticated(self): + response = self.client.get("/v1/data/file/1/") + response2 = self.client.get("/v1/data/file/2/") + response3 = self.client.get("/v1/data/file/3/") + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response2.status_code, status.HTTP_401_UNAUTHORIZED) + self.assertEqual(response3.status_code, status.HTTP_200_OK) + + # Authenticated user can upload data + def test_upload_authenticated(self): + token = self.client.post("/auth/login/", data=self.login_data_1) + file = open(find("cyl_testdata1.txt"), "rb") + data = {"file": file, "is_public": False} + response = self.client.post( + "/v1/data/file/", data=data, headers=auth_header(token) + ) + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual( + response.data, + { + "current_user": "testUser", + "authenticated": True, + "file_id": 4, + "file_alternative_name": "cyl_testdata1.txt", + "is_public": False, + }, + ) + DataFile.objects.get(id=4).delete() + + # Unauthenticated user can upload public data only + def test_upload_unauthenticated(self): + file = open(find("cyl_testdata2.txt"), "rb") + file2 = open(find("cyl_testdata2.txt"), "rb") + data = {"file": file, "is_public": True} + data2 = {"file": file2, "is_public": False} + response = self.client.post("/v1/data/file/", data=data) + response2 = self.client.post("/v1/data/file/", data=data2) + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual( + response.data, + { + "current_user": "", + "authenticated": False, + "file_id": 4, + "file_alternative_name": "cyl_testdata2.txt", + "is_public": True, + }, + ) + self.assertEqual(response2.status_code, status.HTTP_400_BAD_REQUEST) + + # Authenticated user can update own data + def test_upload_put_authenticated(self): + token = self.client.post("/auth/login/", data=self.login_data_1) + data = {"is_public": False} + response = self.client.put( + "/v1/data/file/2/", data=data, headers=auth_header(token) + ) + response2 = self.client.put( + "/v1/data/file/3/", data=data, headers=auth_header(token) + ) + self.assertEqual( + response.data, + { + "current_user": "testUser", + "authenticated": True, + "file_id": 2, + "file_alternative_name": "cyl_400_20.txt", + "is_public": False, + }, + ) + self.assertEqual( + response2.data, + { + "current_user": "testUser", + "authenticated": True, + "file_id": 3, + "file_alternative_name": "cyl_testdata.txt", + "is_public": False, + }, + ) + DataFile.objects.get(id=3).is_public = True + + # Authenticated user cannot update unowned data + def test_upload_put_unauthorized(self): + token = self.client.post("/auth/login/", data=self.login_data_2) + file = open(find("cyl_400_40.txt")) + data = {"file": file, "is_public": False} + response = self.client.put( + "/v1/data/file/1/", data=data, headers=auth_header(token) + ) + response2 = self.client.put( + "/v1/data/file/2/", data=data, headers=auth_header(token) + ) + response3 = self.client.put( + "/v1/data/file/3/", data=data, headers=auth_header(token) + ) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(response2.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(response3.status_code, status.HTTP_403_FORBIDDEN) + + # Unauthenticated user cannot update data + def test_upload_put_unauthenticated(self): + file = open(find("cyl_400_40.txt")) + data = {"file": file, "is_public": False} + response = self.client.put("/v1/data/file/1/", data=data) + response2 = self.client.put("/v1/data/file/2/", data=data) + response3 = self.client.put("/v1/data/file/3/", data=data) + self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) + self.assertEqual(response2.status_code, status.HTTP_401_UNAUTHORIZED) + self.assertEqual(response3.status_code, status.HTTP_401_UNAUTHORIZED) + + # Authenticated user can download public and own data + def test_download_authenticated(self): + token = self.client.post("/auth/login/", data=self.login_data_1) + response = self.client.get( + "/v1/data/file/1/", data={"download": True}, headers=auth_header(token) + ) + response2 = self.client.get( + "/v1/data/file/2/", data={"download": True}, headers=auth_header(token) + ) + response3 = self.client.get( + "/v1/data/file/3/", data={"download": True}, headers=auth_header(token) + ) + b"".join(response.streaming_content) + b"".join(response2.streaming_content) + b"".join(response3.streaming_content) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response2.status_code, status.HTTP_200_OK) + self.assertEqual(response3.status_code, status.HTTP_200_OK) + + # Authenticated user cannot download others' data + def test_download_unauthorized(self): + token = self.client.post("/auth/login/", data=self.login_data_2) + response = self.client.get( + "/v1/data/file/2/", data={"download": True}, headers=auth_header(token) + ) + response2 = self.client.get( + "/v1/data/file/3/", data={"download": True}, headers=auth_header(token) + ) + b"".join(response2.streaming_content) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(response2.status_code, status.HTTP_200_OK) + + # Unauthenticated user cannot download private data + def test_download_unauthenticated(self): + response = self.client.get("/v1/data/file/1/", data={"download": True}) + response2 = self.client.get("/v1/data/file/2/", data={"download": True}) + response3 = self.client.get("/v1/data/file/3/", data={"download": True}) + b"".join(response.streaming_content) + b"".join(response3.streaming_content) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response2.status_code, status.HTTP_401_UNAUTHORIZED) + self.assertEqual(response3.status_code, status.HTTP_200_OK) + + @classmethod + def tearDownClass(cls): + cls.user.delete() + cls.user2.delete() + cls.public_test_data.delete() + cls.private_test_data.delete() + cls.unowned_test_data.delete() + shutil.rmtree(settings.MEDIA_ROOT) diff --git a/sasdata/fair_database/fair_database/upload_example_data.py b/sasdata/fair_database/fair_database/upload_example_data.py new file mode 100644 index 000000000..79de203d4 --- /dev/null +++ b/sasdata/fair_database/fair_database/upload_example_data.py @@ -0,0 +1,46 @@ +import logging +import os +from glob import glob + +import requests + +EXAMPLE_DATA_DIR = os.environ.get("EXAMPLE_DATA_DIR", "../../example_data") + + +def parse_1D(): + dir_1d = os.path.join(EXAMPLE_DATA_DIR, "1d_data") + if not os.path.isdir(dir_1d): + logging.error(f"1D Data directory not found at: {dir_1d}") + return + for file_path in glob(os.path.join(dir_1d, "*")): + upload_file(file_path) + + +def parse_2D(): + dir_2d = os.path.join(EXAMPLE_DATA_DIR, "2d_data") + if not os.path.isdir(dir_2d): + logging.error(f"2D Data directory not found at: {dir_2d}") + return + for file_path in glob(os.path.join(dir_2d, "*")): + upload_file(file_path) + + +def parse_sesans(): + sesans_dir = os.path.join(EXAMPLE_DATA_DIR, "sesans_data") + if not os.path.isdir(sesans_dir): + logging.error(f"Sesans Data directory not found at: {sesans_dir}") + return + for file_path in glob(os.path.join(sesans_dir, "*")): + upload_file(file_path) + + +def upload_file(file_path): + url = "http://localhost:8000/v1/data/file/" + file = open(file_path, "rb") + requests.request("POST", url, data={"is_public": True}, files={"file": file}) + + +if __name__ == "__main__": + parse_1D() + parse_2D() + parse_sesans() diff --git a/sasdata/fair_database/fair_database/urls.py b/sasdata/fair_database/fair_database/urls.py new file mode 100644 index 000000000..56c88ce21 --- /dev/null +++ b/sasdata/fair_database/fair_database/urls.py @@ -0,0 +1,42 @@ +""" +URL configuration for fair_database project. + +The `urlpatterns` list routes URLs to views. For more information please see: + https://docs.djangoproject.com/en/5.1/topics/http/urls/ +Examples: +Function views + 1. Add an import: from my_app import views + 2. Add a URL to urlpatterns: path('', views.home, name='home') +Class-based views + 1. Add an import: from other_app.views import Home + 2. Add a URL to urlpatterns: path('', Home.as_view(), name='home') +Including another URLconf + 1. Import the include() function: from django.urls import include, path + 2. Add a URL to urlpatterns: path('blog/', include('blog.urls')) +""" + +from django.contrib import admin +from django.urls import include, path, re_path +from drf_spectacular.views import ( + SpectacularAPIView, + SpectacularRedocView, + SpectacularSwaggerView, +) + +urlpatterns = [ + re_path(r"^(?P(v1))/data/", include("data.urls")), + path("admin/", admin.site.urls), + path("accounts/", include("allauth.urls")), # needed for social auth + path("auth/", include("user_app.urls")), + path("api/schema/", SpectacularAPIView.as_view(), name="schema"), + path( + "api/schema/swagger-ui/", + SpectacularSwaggerView.as_view(url_name="schema"), + name="swagger-ui", + ), + path( + "api/schema/redoc/", + SpectacularRedocView.as_view(url_name="schema"), + name="redoc", + ), +] diff --git a/sasdata/fair_database/fair_database/wsgi.py b/sasdata/fair_database/fair_database/wsgi.py new file mode 100644 index 000000000..5dfc4819c --- /dev/null +++ b/sasdata/fair_database/fair_database/wsgi.py @@ -0,0 +1,16 @@ +""" +WSGI config for fair_database project. + +It exposes the WSGI callable as a module-level variable named ``application``. + +For more information on this file, see +https://docs.djangoproject.com/en/5.1/howto/deployment/wsgi/ +""" + +import os + +from django.core.wsgi import get_wsgi_application + +os.environ.setdefault("DJANGO_SETTINGS_MODULE", "fair_database.settings") + +application = get_wsgi_application() diff --git a/sasdata/fair_database/manage.py b/sasdata/fair_database/manage.py new file mode 100755 index 000000000..7d7e97246 --- /dev/null +++ b/sasdata/fair_database/manage.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python +"""Django's command-line utility for administrative tasks.""" + +import os +import sys + + +def main(): + """Run administrative tasks.""" + os.environ.setdefault("DJANGO_SETTINGS_MODULE", "fair_database.settings") + try: + from django.core.management import execute_from_command_line + except ImportError as exc: + raise ImportError( + "Couldn't import Django. Are you sure it's installed and " + "available on your PYTHONPATH environment variable? Did you " + "forget to activate a virtual environment?" + ) from exc + execute_from_command_line(sys.argv) + + +if __name__ == "__main__": + main() diff --git a/sasdata/fair_database/requirements.txt b/sasdata/fair_database/requirements.txt new file mode 100644 index 000000000..22b32934b --- /dev/null +++ b/sasdata/fair_database/requirements.txt @@ -0,0 +1,8 @@ +#this requirements extends the base sasview requirements files +#to get both you will need to run this after base requirements files +django +djangorestframework +dj-rest-auth +django-allauth +django-rest-knox +drf-spectacular diff --git a/sasdata/fair_database/user_app/__init__.py b/sasdata/fair_database/user_app/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/sasdata/fair_database/user_app/admin.py b/sasdata/fair_database/user_app/admin.py new file mode 100644 index 000000000..846f6b406 --- /dev/null +++ b/sasdata/fair_database/user_app/admin.py @@ -0,0 +1 @@ +# Register your models here. diff --git a/sasdata/fair_database/user_app/apps.py b/sasdata/fair_database/user_app/apps.py new file mode 100644 index 000000000..83a29decf --- /dev/null +++ b/sasdata/fair_database/user_app/apps.py @@ -0,0 +1,6 @@ +from django.apps import AppConfig + + +class UserAppConfig(AppConfig): + default_auto_field = "django.db.models.BigAutoField" + name = "user_app" diff --git a/sasdata/fair_database/user_app/migrations/__init__.py b/sasdata/fair_database/user_app/migrations/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/sasdata/fair_database/user_app/models.py b/sasdata/fair_database/user_app/models.py new file mode 100644 index 000000000..6b2021999 --- /dev/null +++ b/sasdata/fair_database/user_app/models.py @@ -0,0 +1 @@ +# Create your models here. diff --git a/sasdata/fair_database/user_app/serializers.py b/sasdata/fair_database/user_app/serializers.py new file mode 100644 index 000000000..4993a7ab3 --- /dev/null +++ b/sasdata/fair_database/user_app/serializers.py @@ -0,0 +1,14 @@ +from dj_rest_auth.serializers import UserDetailsSerializer +from rest_framework import serializers + + +class KnoxSerializer(serializers.Serializer): + """ + Serializer for Knox authentication. + """ + + token = serializers.SerializerMethodField() + user = UserDetailsSerializer() + + def get_token(self, obj): + return obj["token"][1] diff --git a/sasdata/fair_database/user_app/tests.py b/sasdata/fair_database/user_app/tests.py new file mode 100644 index 000000000..8943ab40e --- /dev/null +++ b/sasdata/fair_database/user_app/tests.py @@ -0,0 +1,169 @@ +from django.contrib.auth.models import User +from django.test import TestCase +from rest_framework import status +from rest_framework.test import APIClient + + +# Create your tests here. +class AuthTests(TestCase): + """Tests for authentication endpoints.""" + + @classmethod + def setUpTestData(cls): + cls.client1 = APIClient() + cls.client2 = APIClient() + cls.register_data = { + "email": "email@domain.org", + "username": "testUser", + "password1": "sasview!", + "password2": "sasview!", + } + cls.login_data = { + "username": "testUser", + "email": "email@domain.org", + "password": "sasview!", + } + cls.login_data_2 = { + "username": "testUser2", + "email": "email2@domain.org", + "password": "sasview!", + } + cls.user = User.objects.create_user( + id=1, username="testUser2", password="sasview!", email="email2@domain.org" + ) + cls.client_authenticated = APIClient() + cls.client_authenticated.force_authenticate(user=cls.user) + + # Create an authentication header for a given token + def auth_header(self, response): + return {"Authorization": "Token " + response.data["token"]} + + # Test if registration successfully creates a new user and logs in + def test_register(self): + response = self.client1.post("/auth/register/", data=self.register_data) + user = User.objects.get(username="testUser") + response2 = self.client1.get("/auth/user/", headers=self.auth_header(response)) + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(user.email, self.register_data["email"]) + self.assertEqual(response2.status_code, status.HTTP_200_OK) + user.delete() + + # Test if login successful + def test_login(self): + response = self.client1.post("/auth/login/", data=self.login_data_2) + response2 = self.client1.get("/auth/user/", headers=self.auth_header(response)) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response2.status_code, status.HTTP_200_OK) + + # Test simultaneous login by multiple clients + def test_multiple_login(self): + response = self.client1.post("/auth/login/", data=self.login_data_2) + response2 = self.client2.post("/auth/login/", data=self.login_data_2) + response3 = self.client1.get("/auth/user/", headers=self.auth_header(response)) + response4 = self.client2.get("/auth/user/", headers=self.auth_header(response2)) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response2.status_code, status.HTTP_200_OK) + self.assertEqual(response3.status_code, status.HTTP_200_OK) + self.assertEqual(response4.status_code, status.HTTP_200_OK) + self.assertNotEqual(response.content, response2.content) + + # Test get user information + def test_user_get(self): + response = self.client_authenticated.get("/auth/user/") + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual( + response.content, + b'{"pk":1,"username":"testUser2","email":"email2@domain.org","first_name":"","last_name":""}', + ) + + # Test changing username + def test_user_put_username(self): + data = {"username": "newName"} + response = self.client_authenticated.put("/auth/user/", data=data) + self.user.username = "testUser2" + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual( + response.content, + b'{"pk":1,"username":"newName","email":"email2@domain.org","first_name":"","last_name":""}', + ) + + # Test changing username and first and last name + def test_user_put_name(self): + data = {"username": "newName", "first_name": "Clark", "last_name": "Kent"} + response = self.client_authenticated.put("/auth/user/", data=data) + self.user.first_name = "" + self.user.last_name = "" + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual( + response.content, + b'{"pk":1,"username":"newName","email":"email2@domain.org","first_name":"Clark","last_name":"Kent"}', + ) + + # Test user info inaccessible when unauthenticated + def test_user_unauthenticated(self): + response = self.client1.get("/auth/user/") + self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) + self.assertEqual( + response.content, + b'{"detail":"Authentication credentials were not provided."}', + ) + + # Test logout is successful after login + def test_login_logout(self): + self.client1.post("/auth/login/", data=self.login_data_2) + response = self.client1.post("/auth/logout/") + response2 = self.client1.get("/auth/user/") + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.content, b'{"detail":"Successfully logged out."}') + self.assertEqual(response2.status_code, status.HTTP_401_UNAUTHORIZED) + + # Test logout is successful after registration + def test_register_logout(self): + self.client1.post("/auth/register/", data=self.register_data) + response = self.client1.post("/auth/logout/") + response2 = self.client1.get("/auth/user/") + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.content, b'{"detail":"Successfully logged out."}') + self.assertEqual(response2.status_code, status.HTTP_401_UNAUTHORIZED) + User.objects.get(username="testUser").delete() + + # Test multiple logins for the same account log out independently + def test_multiple_logout(self): + self.client1.post("/auth/login/", data=self.login_data_2) + token = self.client2.post("/auth/login/", data=self.login_data_2) + response = self.client1.post("/auth/logout/") + response2 = self.client2.get("/auth/user/", headers=self.auth_header(token)) + response3 = self.client2.post("/auth/logout/") + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response2.status_code, status.HTTP_200_OK) + self.assertEqual(response3.status_code, status.HTTP_200_OK) + + # Test login is successful after registering then logging out + def test_register_login(self): + register_response = self.client1.post( + "/auth/register/", data=self.register_data + ) + logout_response = self.client1.post("/auth/logout/") + login_response = self.client1.post("/auth/login/", data=self.login_data) + self.assertEqual(register_response.status_code, status.HTTP_201_CREATED) + self.assertEqual(logout_response.status_code, status.HTTP_200_OK) + self.assertEqual(login_response.status_code, status.HTTP_200_OK) + User.objects.get(username="testUser").delete() + + # Test password is successfully changed + def test_password_change(self): + data = { + "new_password1": "sasview?", + "new_password2": "sasview?", + "old_password": "sasview!", + } + self.login_data_2["password"] = "sasview?" + response = self.client_authenticated.post("/auth/password/change/", data=data) + login_response = self.client1.post("/auth/login/", data=self.login_data_2) + self.login_data_2["password"] = "sasview!" + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(login_response.status_code, status.HTTP_200_OK) + + @classmethod + def tearDownClass(cls): + cls.user.delete() diff --git a/sasdata/fair_database/user_app/urls.py b/sasdata/fair_database/user_app/urls.py new file mode 100644 index 000000000..e393cb4b6 --- /dev/null +++ b/sasdata/fair_database/user_app/urls.py @@ -0,0 +1,15 @@ +from dj_rest_auth.views import LogoutView, PasswordChangeView, UserDetailsView +from django.urls import path + +from .views import KnoxLoginView, KnoxRegisterView + +"""Urls for authentication. Orcid login not functional. See settings.py for ORCID activation.""" + +urlpatterns = [ + path("register/", KnoxRegisterView.as_view(), name="register"), + path("login/", KnoxLoginView.as_view(), name="login"), + path("logout/", LogoutView.as_view(), name="logout"), + path("user/", UserDetailsView.as_view(), name="view user information"), + path("password/change/", PasswordChangeView.as_view(), name="change password"), + # path("login/orcid/", OrcidLoginView.as_view(), name="orcid login"), +] diff --git a/sasdata/fair_database/user_app/util.py b/sasdata/fair_database/user_app/util.py new file mode 100644 index 000000000..dc7b35026 --- /dev/null +++ b/sasdata/fair_database/user_app/util.py @@ -0,0 +1,7 @@ +from knox.models import AuthToken + + +# create an authentication token +def create_knox_token(token_model, user, serializer): + token = AuthToken.objects.create(user=user) + return token diff --git a/sasdata/fair_database/user_app/views.py b/sasdata/fair_database/user_app/views.py new file mode 100644 index 000000000..3a033add4 --- /dev/null +++ b/sasdata/fair_database/user_app/views.py @@ -0,0 +1,39 @@ +from allauth.account import app_settings as allauth_settings +from allauth.account.utils import complete_signup +from allauth.socialaccount.providers.orcid.views import OrcidOAuth2Adapter +from dj_rest_auth.registration.views import RegisterView, SocialLoginView +from dj_rest_auth.views import LoginView +from rest_framework.response import Response +from user_app.serializers import KnoxSerializer +from user_app.util import create_knox_token + +# Login using knox tokens rather than django-rest-framework tokens. + + +class KnoxLoginView(LoginView): + def get_response(self): + serializer_class = self.get_response_serializer() + + data = {"user": self.user, "token": self.token} + serializer = serializer_class(instance=data, context={"request": self.request}) + + return Response(serializer.data, status=200) + + +# Registration using knox tokens rather than django-rest-framework tokens. +class KnoxRegisterView(RegisterView): + def get_response_data(self, user): + return KnoxSerializer({"user": user, "token": self.token}).data + + def perform_create(self, serializer): + user = serializer.save(self.request) + self.token = create_knox_token(None, user, None) + complete_signup( + self.request._request, user, allauth_settings.EMAIL_VERIFICATION, None + ) + return user + + +# For ORCID login +class OrcidLoginView(SocialLoginView): + adapter_class = OrcidOAuth2Adapter diff --git a/sasdata/metadata.py b/sasdata/metadata.py index d53c3102c..8c665d267 100644 --- a/sasdata/metadata.py +++ b/sasdata/metadata.py @@ -67,6 +67,33 @@ class Rot3: pitch: Quantity[float] | None yaw: Quantity[float] | None + @staticmethod + def deserialise_json(json_data: dict): + roll = None + pitch = None + yaw = None + if "roll" in json_data: + roll = Quantity.deserialise_json(json_data["roll"]) + if "pitch" in json_data: + pitch = Quantity.deserialise_json(json_data["pitch"]) + if "yaw" in json_data: + yaw = Quantity.deserialise_json(json_data["yaw"]) + return Rot3(roll=roll, pitch=pitch, yaw=yaw) + + def serialise_json(self): + data = { + "roll": None, + "pitch": None, + "yaw": None + } + if self.roll is not None: + data["roll"] = self.roll.serialise_json() + if self.pitch is not None: + data["pitch"] = self.pitch.serialise_json() + if self.yaw is not None: + data["yaw"] = self.yaw.serialise_json() + return data + @staticmethod def from_json(obj: dict) -> Quantity | None: if obj is None: @@ -567,6 +594,22 @@ def from_json(obj): raw=MetaNode.from_json(obj["raw"]), ) + def serialise_json(self): + serialized = { + "instrument": None, + "process": [p.serialise_json() for p in self.process], + "sample": None, + "title": self.title, + "run": self.run, + "definition": self.definition + } + if self.sample is not None: + serialized["sample"] = self.sample.serialise_json() + if self.instrument is not None: + serialized["instrument"] = self.instrument.serialise_json() + + return serialized + @property def id_header(self): """Generate a header for used in the unique_id for datasets""" diff --git a/sasdata/quantities/_units_base.py b/sasdata/quantities/_units_base.py index d030fe299..29aa26319 100644 --- a/sasdata/quantities/_units_base.py +++ b/sasdata/quantities/_units_base.py @@ -21,11 +21,11 @@ class DimensionError(Exception): class Dimensions: """ - Note that some SI Base units are not useful from the perspecive of the sasview project, and make things + Note that some SI Base units are not useful from the perspective of the sasview project, and make things behave badly. In particular: moles and angular measures are dimensionless, and candelas are really a weighted measure of power. - We do however track angle and amount, because its really useful for formatting units + We do however track angle and amount, because it's really useful for formatting units """ def __init__(self, diff --git a/sasdata/quantities/quantity.py b/sasdata/quantities/quantity.py index 213c7aadd..270d6d15b 100644 --- a/sasdata/quantities/quantity.py +++ b/sasdata/quantities/quantity.py @@ -1457,6 +1457,20 @@ def hash_data_via_numpy(*data: ArrayLike): QuantityType = TypeVar("QuantityType") +# TODO: change QuantityType serialisation for greater efficiency +def quantity_type_serialisation(var): + if isinstance(var, np.ndarray): + return {"array_contents": var.tobytes(), "shape": var.shape} + else: + return var + +def quantity_type_deserialisation(var): + if isinstance(var, dict): + array = np.frombuffer(var["array_contents"]) + return np.reshape(array, shape=var["shape"]) + else: + return var + class QuantityHistory: """Class that holds the information for keeping track of operations done on quantities""" @@ -1547,6 +1561,24 @@ def summary(self): return s + @staticmethod + def deserialise_json(json_data: dict) -> "QuantityHistory": + operation_tree = Operation.deserialise(json_data["operation_tree"]) + references = { + key: Quantity.deserialise_json(json_data["references"][key]) + for key in json_data["references"] + } + return QuantityHistory(operation_tree, references) + + def serialise_json(self): + return { + "operation_tree": self.operation_tree.serialise(), + "references": [ + ref.serialise_json_no_history() for ref in self.references.values() + ] + + } + class Quantity[QuantityType]: def __init__( @@ -1683,6 +1715,37 @@ def in_si_with_standard_error(self): else: return self.in_si(), None + @staticmethod + def deserialise_json(json_data: dict) -> "Quantity": + value = numerical_decode(json_data["value"]) + units_ = Unit.parse(json_data["units"]) + standard_error = numerical_decode(json_data["variance"]) ** 0.5 + hash_seed = json_data["hash_seed"] + history = QuantityHistory.deserialise_json(json_data["history"]) + quantity = Quantity(value, units_, standard_error, hash_seed) + quantity.history = history + return quantity + + def serialise_json(self): + return { + "value": numerical_encode(self.value), + "units": str(self.units), # Unit serialisation + "variance": numerical_encode(self._variance), + "hash_seed": self._hash_seed, # is this just a string? + "hash_value": self.hash_value, + "history": self.history.serialise_json() + } + + def serialise_json_no_history(self): + return { + "value": numerical_encode(self.value), + "units": str(self.units), # Unit serialisation + "variance": numerical_encode(self._variance), + "hash_seed": self._hash_seed, # is this just a string? + "hash_value": self.hash_value, + "history": {} + } + def explicitly_formatted(self, unit_string: str) -> str: """Returns quantity as a string with specific unit formatting @@ -1901,6 +1964,21 @@ def with_standard_error(self, standard_error: Quantity): f"Standard error units ({standard_error.units}) are not compatible with value units ({self.units})" ) + @staticmethod + def deserialise_json(json_data: dict) -> "NamedQuantity": + name = json_data["name"] + value = numerical_decode(json_data["value"]) + units_ = Unit.parse(json_data["units"]) + standard_error = numerical_decode(json_data["variance"]) ** 0.5 + history = QuantityHistory.deserialise_json(json_data["history"]) + quantity = NamedQuantity(name, value, units_, standard_error) + quantity.history = history + return quantity + + def serialise_json(self): + quantity = super().serialise_json() + quantity["name"] = self.name + return quantity @property def string_repr(self): return self.name @@ -1928,3 +2006,12 @@ def variance(self) -> Quantity: self._variance_cache = self.history.variance_propagate(self.units) return self._variance_cache + + + @staticmethod + def deserialise_json(json_data: dict) -> "DerivedQuantity": + value = numerical_decode(json_data["value"]) + units_ = Unit.parse(json_data["units"]) + history = QuantityHistory.deserialise_json(json_data["history"]) + quantity = DerivedQuantity(value, units_, history) + return quantity diff --git a/sasdata/quantities/unit_formatting.py b/sasdata/quantities/unit_formatting.py index e63921329..adcc7e6ba 100644 --- a/sasdata/quantities/unit_formatting.py +++ b/sasdata/quantities/unit_formatting.py @@ -1,4 +1,5 @@ + import numpy as np diff --git a/test/sasdataloader/reference/14250.txt b/test/sasdataloader/reference/14250.txt new file mode 100644 index 000000000..6f11aba78 --- /dev/null +++ b/test/sasdataloader/reference/14250.txt @@ -0,0 +1,51 @@ +sasentry01 + [FILE_ID_HERE/sasentry01/sasdata/I] [[0.0 ... 0.0]] ± [[0.0 ... 0.0]] cm⁻¹ + [FILE_ID_HERE/sasentry01/sasdata/Qx] [[-0.11925 ... 0.11925000000000001]] Å⁻¹ + [FILE_ID_HERE/sasentry01/sasdata/Qy] [[-0.11925 ... 0.11925000000000001]] Å⁻¹ +Metadata: + + High C high V 63, 900oC, 10h 1.65T_SANS, Run: 14250 + =================================================== + +Definition: High C high V 63, 900oC, 10h 1.65T_SANS +Process: + Name: Mantid_generated_NXcanSAS + Date: 2016-12-06T17:15:48 + Description: None + Term: None + Notes: None +Sample: + ID: None + Transmission: None + Thickness: None + Temperature: None + Position: None + Orientation: None +Aperture: + Name: None + Aperture size: None + Aperture distance: None +Collimation: + Length: None +Detector: + Name: None + Distance: None + Offset: None + Orientation: None + Beam center: None + Pixel size: None + Slit length: None +Source: + Radiation: Spallation Neutron Source + Shape: None + Wavelength: None + Min. Wavelength: None + Max. Wavelength: None + Wavelength Spread: None + Beam Size: None +Transmission Spectrum: + Name: None + Timestamp: None + Wavelengths: None + Transmission: None + diff --git a/test/sasdataloader/reference/33837.txt b/test/sasdataloader/reference/33837.txt new file mode 100644 index 000000000..26f36b96b --- /dev/null +++ b/test/sasdataloader/reference/33837.txt @@ -0,0 +1,50 @@ +sasentry01 + [FILE_ID_HERE/sasentry01/sasdata/I] [5.416094671273121 ... 0.33697913143947616] ± [0.6152247543248875 ... 0.19365125082205084] m + [FILE_ID_HERE/sasentry01/sasdata/Q] [0.0041600000000000005 ... 0.6189241619415587] m +Metadata: + + MH4_5deg_16T_SLOW, Run: 33837 + ============================= + +Definition: MH4_5deg_16T_SLOW +Process: + Name: Mantid_generated_NXcanSAS + Date: 11-May-2016 12:20:43 + Description: None + Term: None + Notes: None +Sample: + ID: None + Transmission: None + Thickness: None + Temperature: None + Position: None + Orientation: None +Aperture: + Name: None + Aperture size: None + Aperture distance: None +Collimation: + Length: None +Detector: + Name: None + Distance: None + Offset: None + Orientation: None + Beam center: None + Pixel size: None + Slit length: None +Source: + Radiation: Spallation Neutron Source + Shape: None + Wavelength: None + Min. Wavelength: None + Max. Wavelength: None + Wavelength Spread: None + Beam Size: None +Transmission Spectrum: + Name: None + Timestamp: None + Wavelengths: None + Transmission: None + diff --git a/test/sasdataloader/reference/33837_v3.txt b/test/sasdataloader/reference/33837_v3.txt new file mode 100644 index 000000000..f4d205b2a --- /dev/null +++ b/test/sasdataloader/reference/33837_v3.txt @@ -0,0 +1,50 @@ +sasentry01 + [FILE_ID_HERE/sasentry01/sasdata/I] [5.416094671273121 ... 0.33697913143947616] ± [0.6152247543248875 ... 0.19365125082205084] none + [FILE_ID_HERE/sasentry01/sasdata/Q] [0.0041600000000000005 ... 0.6189241619415587] Å⁻¹ +Metadata: + + MH4_5deg_16T_SLOW, Run: 33837 + ============================= + +Definition: MH4_5deg_16T_SLOW +Process: + Name: Mantid_generated_NXcanSAS + Date: 2016-07-04T10:34:34 + Description: None + Term: None + Notes: None +Sample: + ID: None + Transmission: None + Thickness: None + Temperature: None + Position: None + Orientation: None +Aperture: + Name: None + Aperture size: None + Aperture distance: None +Collimation: + Length: None +Detector: + Name: None + Distance: None + Offset: None + Orientation: None + Beam center: None + Pixel size: None + Slit length: None +Source: + Radiation: Spallation Neutron Source + Shape: None + Wavelength: None + Min. Wavelength: None + Max. Wavelength: None + Wavelength Spread: None + Beam Size: None +Transmission Spectrum: + Name: None + Timestamp: None + Wavelengths: None + Transmission: None + diff --git a/test/sasdataloader/reference/BAM.txt b/test/sasdataloader/reference/BAM.txt new file mode 100644 index 000000000..bf6328b1c --- /dev/null +++ b/test/sasdataloader/reference/BAM.txt @@ -0,0 +1,51 @@ +sasentry01 + [FILE_ID_HERE/sasentry01/data/I] [[-4132.585671758142 ... -4139.954861346877]] ± [[1000000000.0 ... 1000000000.0]] m⁻¹ + [FILE_ID_HERE/sasentry01/data/Imask] [[True ... True]] m⁻¹ + [FILE_ID_HERE/sasentry01/data/Q] [[[-0.10733919639695527 ... 0.0]]] Å⁻¹ +Metadata: + + Qais oriented iron test 1, Run: 12345 + ===================================== + +Definition: Qais oriented iron test 1 +Process: + Name: None + Date: None + Description: None + Term: None + Notes: None +Sample: + ID: None + Transmission: None + Thickness: None + Temperature: None + Position: None + Orientation: None +Aperture: + Name: None + Aperture size: None + Aperture distance: None +Collimation: + Length: None +Detector: + Name: None + Distance: None + Offset: None + Orientation: None + Beam center: None + Pixel size: None + Slit length: None +Source: + Radiation: None + Shape: None + Wavelength: None + Min. Wavelength: None + Max. Wavelength: None + Wavelength Spread: None + Beam Size: None +Transmission Spectrum: + Name: None + Timestamp: None + Wavelengths: None + Transmission: None + diff --git a/test/sasdataloader/reference/x25000_no_di.txt b/test/sasdataloader/reference/x25000_no_di.txt index b0cad2ee3..00d3c9dfd 100644 --- a/test/sasdataloader/reference/x25000_no_di.txt +++ b/test/sasdataloader/reference/x25000_no_di.txt @@ -5,12 +5,12 @@ sasentry01 I Metadata: - , Run: + , Run: ======= -Definition: +Definition: Sample: - ID: + ID: Transmission: None Thickness: None Temperature: None @@ -19,7 +19,7 @@ Sample: Collimation: Length: None Detector: - Name: + Name: Distance: None Offset: None Orientation: None