diff --git a/api/segments/serializers.py b/api/segments/serializers.py index 2bf44a4778e1..bd9397a402a1 100644 --- a/api/segments/serializers.py +++ b/api/segments/serializers.py @@ -115,12 +115,14 @@ class Meta: def validate(self, attrs: dict[str, Any]) -> dict[str, Any]: attrs = super().validate(attrs) - project = self.instance.project if self.instance else attrs["project"] # type: ignore[union-attr] + metadata = attrs.get("metadata", []) + + # TODO: Make "project" read-only — https://github.com/Flagsmith/flagsmith-workflows/issues/102 + project_pk = self.context["view"].kwargs["project_pk"] + project = attrs["project"] = Project.objects.get(pk=project_pk) organisation = project.organisation - self._validate_required_metadata( - organisation, attrs.get("metadata", []), project=project - ) + self._validate_required_metadata(organisation, metadata, project) self._validate_segment_rules_conditions_limit(attrs["rules"]) self._validate_project_segment_limit(project) return attrs diff --git a/api/segments/views.py b/api/segments/views.py index 91328d8b9aa8..a6016abb4c96 100644 --- a/api/segments/views.py +++ b/api/segments/views.py @@ -1,5 +1,5 @@ import logging -from typing import Any +from typing import TYPE_CHECKING, Any from common.projects.permissions import VIEW_PROJECT from django.utils.decorators import method_decorator @@ -21,6 +21,7 @@ SegmentAssociatedFeatureStateSerializer, ) from features.versioning.models import EnvironmentFeatureVersion +from projects.models import Project from .models import Segment from .permissions import SegmentPermissions @@ -31,6 +32,9 @@ ) from .services import delete_segment +if TYPE_CHECKING: + from users.models import FFAdminUser + logger = logging.getLogger() @@ -88,15 +92,16 @@ class SegmentViewSet(viewsets.ModelViewSet): # type: ignore[type-arg] permission_classes = [SegmentPermissions] pagination_class = CustomPagination + def get_project(self) -> Project: + user: "FFAdminUser" = self.request.user # type: ignore[assignment] + projects = user.get_permitted_projects(permission_key=VIEW_PROJECT) + return get_object_or_404(projects, pk=self.kwargs["project_pk"]) + def get_queryset(self): # type: ignore[no-untyped-def] if getattr(self, "swagger_fake_view", False): return Segment.objects.none() - permitted_projects = self.request.user.get_permitted_projects( # type: ignore[union-attr] - permission_key=VIEW_PROJECT - ) - project = get_object_or_404(permitted_projects, pk=self.kwargs["project_pk"]) - + project = self.get_project() queryset = Segment.live_objects.filter(project=project, is_system_segment=False) if self.action == "list": diff --git a/api/tests/unit/segments/test_unit_segments_views.py b/api/tests/unit/segments/test_unit_segments_views.py index 5557498e14df..a9a13ca53582 100644 --- a/api/tests/unit/segments/test_unit_segments_views.py +++ b/api/tests/unit/segments/test_unit_segments_views.py @@ -205,7 +205,7 @@ def test_segments_limit_ignores_old_segment_versions( with_project_permissions: WithProjectPermissionsCallable, ) -> None: # Given - with_project_permissions([MANAGE_SEGMENTS]) # type: ignore[call-arg] + with_project_permissions([MANAGE_SEGMENTS, VIEW_PROJECT]) # type: ignore[call-arg] # let's reduce the max segments allowed to 2 project.max_segments_allowed = 2 @@ -1884,3 +1884,28 @@ def test_create_segment__required_metadata_on_other_project__returns_201( # Then assert response.status_code == status.HTTP_201_CREATED + + +def test_create_segment__body_project_differs_from_url__does_not_create_in_other_project( + admin_client: APIClient, + project: Project, +) -> None: + # Given + other_org = Organisation.objects.create(name="Other Org") + other_project = Project.objects.create(name="Other Project", organisation=other_org) + + # When + response = admin_client.post( + f"/api/v1/projects/{project.id}/segments/", + data={ + "name": "a_wild_pokemon", + "project": other_project.id, + "rules": [{"type": "ALL", "rules": [], "conditions": []}], + }, + format="json", + ) + + # Then + assert response.status_code == status.HTTP_201_CREATED + assert response.json()["project"] == project.id + assert not Segment.objects.filter(project=other_project).exists()