diff --git a/employees/tests/test_unit_report_viewset.py b/employees/tests/test_unit_report_viewset.py index a4b2d2ba9..8b8ffb6b8 100644 --- a/employees/tests/test_unit_report_viewset.py +++ b/employees/tests/test_unit_report_viewset.py @@ -3,6 +3,7 @@ from django.contrib.auth.models import AnonymousUser from django.test import TestCase +from freezegun import freeze_time from rest_framework.reverse import reverse from rest_framework.test import APIRequestFactory @@ -317,6 +318,17 @@ def test_custom_report_list_create_serializer_method_should_return_serializer_wi self.assertTrue(new_project in serializer.fields["project"].queryset) self.assertTrue(self.project not in serializer.fields["project"].queryset) + def test_custom_report_list_create_serializer_method_should_return_serializer_with_date_field_containing_current_date( + self + ): + request = APIRequestFactory().get(path=self.url) + request.user = self.user + view = ReportList() + view.request = request + with freeze_time("2010-01-21"): + serializer = view._create_serializer() + self.assertEqual(serializer.fields["date"].initial, "2010-01-21") + def test_custom_report_list_view_should_add_user_to_project_selected_in_project_join_form_on_join(self): new_project = Project(name="New Project", start_date=datetime.datetime.now()) new_project.full_clean() diff --git a/employees/views.py b/employees/views.py index 117e7e2d1..7bf6c1903 100644 --- a/employees/views.py +++ b/employees/views.py @@ -1,3 +1,4 @@ +import datetime from typing import Any from typing import Dict from typing import Union @@ -43,6 +44,7 @@ def query_as_dict(query_set: QuerySet) -> Dict[str, Any]: class ReportList(APIView): + serializer_class = ReportSerializer renderer_classes = [renderers.TemplateHTMLRenderer] template_name = "employees/report_list.html" reports_dict = {} # type: Dict[str, Any] @@ -64,6 +66,7 @@ def _create_serializer(self) -> ReportSerializer: members__id=self.request.user.id ).order_by("name") reports_serializer.fields["task_activities"].queryset = TaskActivityType.objects.order_by("name") + reports_serializer.fields["date"].initial = str(datetime.datetime.now().date()) return reports_serializer def initial(self, request: HttpRequest, *args: Any, **kwargs: Any) -> None: @@ -126,28 +129,30 @@ def post(self, request: HttpRequest) -> Response: class ReportDetail(APIView): + serializer_class = ReportSerializer + model_class = Report renderer_classes = [renderers.TemplateHTMLRenderer] template_name = "employees/report_detail.html" permission_classes = (permissions.IsAuthenticated,) def _create_serializer(self, report: Report, data: Any = None) -> ReportSerializer: if data is None: - reports_serializer = ReportSerializer(report, context={"request": self.request}) + reports_serializer = self.serializer_class(report, context={"request": self.request}) else: - reports_serializer = ReportSerializer(report, data=data, context={"request": self.request}) + reports_serializer = self.serializer_class(report, data=data, context={"request": self.request}) reports_serializer.fields["project"].queryset = Project.objects.filter(members__id=report.author.pk).order_by( "name" ) return reports_serializer def get(self, _request: HttpRequest, pk: int) -> Response: - report = get_object_or_404(Report, pk=pk) + report = get_object_or_404(self.model_class, pk=pk) serializer = self._create_serializer(report) return Response({"serializer": serializer, "report": report, "UI_text": ReportDetailStrings}) def post(self, request: HttpRequest, pk: int) -> Union[Response, HttpResponseRedirectBase]: if "discard" not in request.POST: - report = get_object_or_404(Report, pk=pk) + report = get_object_or_404(self.model_class, pk=pk) serializer = self._create_serializer(report, request.data) if not serializer.is_valid(): return Response(