diff --git a/api/experimentation/views.py b/api/experimentation/views.py index b82c8df0f90d..c2e1eccf1eaf 100644 --- a/api/experimentation/views.py +++ b/api/experimentation/views.py @@ -1,14 +1,16 @@ import logging from typing import Any +from django.db import IntegrityError from django.db.models import Q, QuerySet -from rest_framework import mixins, status +from rest_framework import mixins, serializers, status from rest_framework.decorators import action from rest_framework.permissions import IsAuthenticated from rest_framework.request import Request from rest_framework.response import Response from rest_framework.serializers import BaseSerializer +from app.pagination import CustomPagination from environments.views import NestedEnvironmentViewSet from experimentation.models import ( Experiment, @@ -101,7 +103,7 @@ class ExperimentViewSet( mixins.DestroyModelMixin, ): serializer_class = ExperimentSerializer - pagination_class = None + pagination_class = CustomPagination permission_classes = [IsAuthenticated, ExperimentPermission] model_class = Experiment lookup_field = "id" @@ -125,6 +127,10 @@ def get_queryset(self) -> "QuerySet[Experiment]": ) status_filter = self.request.query_params.get("status") if status_filter: + if status_filter not in ExperimentStatus.values: + raise serializers.ValidationError( + {"status": f"Invalid status '{status_filter}'."} + ) qs = qs.filter(status=status_filter) q = self.request.query_params.get("q") @@ -152,7 +158,13 @@ def create(self, request: Request, *args: object, **kwargs: object) -> Response: status=status.HTTP_409_CONFLICT, ) - self.perform_create(serializer) + try: + self.perform_create(serializer) + except IntegrityError: + return Response( + {"detail": "An active experiment already exists for this feature."}, + status=status.HTTP_409_CONFLICT, + ) return Response(serializer.data, status=status.HTTP_201_CREATED) def perform_create(self, serializer: BaseSerializer[Experiment]) -> None: @@ -162,12 +174,28 @@ def perform_create(self, serializer: BaseSerializer[Experiment]) -> None: ) def perform_update(self, serializer: BaseSerializer[Experiment]) -> None: + changed_fields = { + field + for field, value in serializer.validated_data.items() + if getattr(serializer.instance, field, None) != value + } + if not changed_fields: + return experiment: Experiment = serializer.save() create_experiment_audit_log( experiment, self._get_user(self.request), action="updated" ) def perform_destroy(self, instance: Experiment) -> None: + if instance.status == ExperimentStatus.RUNNING: + raise serializers.ValidationError( + { + "detail": ( + "Cannot delete a running experiment. " + "Pause or complete it first." + ) + } + ) create_experiment_audit_log( instance, self._get_user(self.request), action="deleted" ) diff --git a/api/tests/unit/experimentation/test_experiment_views.py b/api/tests/unit/experimentation/test_experiment_views.py index 54e39f4f896c..89b053be0550 100644 --- a/api/tests/unit/experimentation/test_experiment_views.py +++ b/api/tests/unit/experimentation/test_experiment_views.py @@ -3,7 +3,9 @@ from typing import TYPE_CHECKING import pytest +from django.db import IntegrityError from django.urls import reverse +from pytest_mock import MockerFixture from rest_framework import status from rest_framework.test import APIClient @@ -265,8 +267,9 @@ def test_get_list__with_experiments__returns_all( # Then assert response.status_code == status.HTTP_200_OK - assert len(response.json()) == 1 - assert response.json()[0]["id"] == experiment.id + results = response.json()["results"] + assert len(results) == 1 + assert results[0]["id"] == experiment.id def test_get_list__with_experiments__returns_nested_feature( @@ -284,9 +287,9 @@ def test_get_list__with_experiments__returns_nested_feature( # Then assert response.status_code == status.HTTP_200_OK - data = response.json() - assert len(data) == 1 - feature_data = data[0]["feature"] + results = response.json()["results"] + assert len(results) == 1 + feature_data = results[0]["feature"] assert isinstance(feature_data, dict) assert feature_data["id"] == multivariate_feature.id assert feature_data["name"] == multivariate_feature.name @@ -329,7 +332,7 @@ def test_get_list__empty__returns_200( # Then assert response.status_code == status.HTTP_200_OK - assert response.json() == [] + assert response.json()["results"] == [] @pytest.mark.parametrize( @@ -357,7 +360,7 @@ def test_get_list__filter_by_status__returns_filtered( # Then assert response.status_code == status.HTTP_200_OK - assert len(response.json()) == expected_count + assert len(response.json()["results"]) == expected_count def test_get_list__search_by_experiment_name__returns_matching( @@ -374,8 +377,9 @@ def test_get_list__search_by_experiment_name__returns_matching( # Then assert response.status_code == status.HTTP_200_OK - assert len(response.json()) == 1 - assert response.json()[0]["id"] == experiment.id + results = response.json()["results"] + assert len(results) == 1 + assert results[0]["id"] == experiment.id def test_get_list__search_by_feature_name__returns_matching( @@ -395,8 +399,9 @@ def test_get_list__search_by_feature_name__returns_matching( # Then assert response.status_code == status.HTTP_200_OK - assert len(response.json()) == 1 - assert response.json()[0]["id"] == experiment.id + results = response.json()["results"] + assert len(results) == 1 + assert results[0]["id"] == experiment.id def test_get_list__search_no_match__returns_empty( @@ -413,7 +418,7 @@ def test_get_list__search_no_match__returns_empty( # Then assert response.status_code == status.HTTP_200_OK - assert len(response.json()) == 0 + assert len(response.json()["results"]) == 0 def test_get_detail__exists__returns_200( @@ -670,3 +675,93 @@ def test_delete__valid_delete__creates_audit_log( ).last() assert audit is not None assert "deleted" in audit.log + + +def test_get_list__invalid_status__returns_400( + admin_client_new: APIClient, + environment: Environment, + enable_features: EnableFeaturesFixture, +) -> None: + # Given + enable_features(EXPERIMENT_FLAG) + + # When + response = admin_client_new.get(_list_url(environment), {"status": "garbage"}) + + # Then + assert response.status_code == status.HTTP_400_BAD_REQUEST + + +def test_delete__running_experiment__returns_400( + admin_client_new: APIClient, + environment: Environment, + experiment: Experiment, + enable_features: EnableFeaturesFixture, +) -> None: + # Given + enable_features(EXPERIMENT_FLAG) + experiment.status = ExperimentStatus.RUNNING + experiment.save() + + # When + response = admin_client_new.delete(_detail_url(environment, experiment)) + + # Then + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert Experiment.objects.filter(id=experiment.id).exists() + + +def test_patch__no_change__skips_audit_log( + admin_client_new: APIClient, + environment: Environment, + experiment: Experiment, + enable_features: EnableFeaturesFixture, +) -> None: + # Given + enable_features(EXPERIMENT_FLAG) + audit_count_before = AuditLog.objects.filter( + related_object_type=RelatedObjectType.EXPERIMENT.name + ).count() + + # When + response = admin_client_new.patch( + _detail_url(environment, experiment), + data={"name": experiment.name}, + format="json", + ) + + # Then + assert response.status_code == status.HTTP_200_OK + audit_count_after = AuditLog.objects.filter( + related_object_type=RelatedObjectType.EXPERIMENT.name + ).count() + assert audit_count_after == audit_count_before + + +def test_post__concurrent_create_race__returns_409( + admin_client_new: APIClient, + environment: Environment, + multivariate_feature: Feature, + enable_features: EnableFeaturesFixture, + mocker: MockerFixture, +) -> None: + # Given + enable_features(EXPERIMENT_FLAG) + mocker.patch( + "experimentation.views.ExperimentViewSet.perform_create", + side_effect=IntegrityError(), + ) + + # When + response = admin_client_new.post( + _list_url(environment), + data={ + "feature": multivariate_feature.id, + "name": "Race", + "hypothesis": "Should 409", + }, + format="json", + ) + + # Then + assert response.status_code == status.HTTP_409_CONFLICT