diff --git a/apps/common/migrations/0010_userprofile.py b/apps/common/migrations/0010_userprofile.py new file mode 100644 index 00000000..1901aa86 --- /dev/null +++ b/apps/common/migrations/0010_userprofile.py @@ -0,0 +1,49 @@ +# Generated by Django 5.0.2 on 2024-11-22 06:07 + +import django.db.models.deletion +import django.utils.timezone +import model_utils.fields +from django.conf import settings +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("auth", "0012_alter_user_first_name_max_length"), + ("common", "0009_countryclassifiedproductaliases_aliases"), + ] + + operations = [ + migrations.CreateModel( + name="UserProfile", + fields=[ + ( + "created", + model_utils.fields.AutoCreatedField( + default=django.utils.timezone.now, editable=False, verbose_name="created" + ), + ), + ( + "modified", + model_utils.fields.AutoLastModifiedField( + default=django.utils.timezone.now, editable=False, verbose_name="modified" + ), + ), + ( + "user", + models.OneToOneField( + on_delete=django.db.models.deletion.CASCADE, + primary_key=True, + serialize=False, + to=settings.AUTH_USER_MODEL, + ), + ), + ("profile_data", models.JSONField(blank=True, default=dict, null=True)), + ], + options={ + "verbose_name": "user profile", + "verbose_name_plural": "user profiles", + }, + ), + ] diff --git a/apps/common/models.py b/apps/common/models.py index e20c12db..c19e0afb 100644 --- a/apps/common/models.py +++ b/apps/common/models.py @@ -8,6 +8,7 @@ import operator from functools import reduce +from django.contrib.auth.models import User from django.core import validators from django.core.cache import cache from django.core.exceptions import ObjectDoesNotExist, ValidationError @@ -950,3 +951,19 @@ class Meta: fields=["country", "product"], name="common_countryclassified_country_code_product_code_uniq" ) ] + + +class UserProfile(Model): + """ + A profile to store data associated with a user to enable a customized user experience + """ + + user = models.OneToOneField(User, on_delete=CASCADE, primary_key=True, unique=True) + profile_data = models.JSONField(default=dict, null=True, blank=True) + + def __str__(self): + return f"user_profile: {str(self.user)}" + + class Meta: + verbose_name = _("user profile") + verbose_name_plural = _("user profiles") diff --git a/apps/common/serializers.py b/apps/common/serializers.py index c79b413a..7d13730e 100644 --- a/apps/common/serializers.py +++ b/apps/common/serializers.py @@ -1,6 +1,7 @@ +from django.contrib.auth.models import User from rest_framework import serializers -from .models import ClassifiedProduct, Country, Currency, UnitOfMeasure +from .models import ClassifiedProduct, Country, Currency, UnitOfMeasure, UserProfile class CountrySerializer(serializers.ModelSerializer): @@ -61,3 +62,37 @@ class Meta: "kcals_per_unit", "aliases", ] + + +class UserSerializer(serializers.ModelSerializer): + class Meta: + model = User + fields = ["id", "username", "first_name", "last_name"] + + +class CurrentUserSerializer(serializers.ModelSerializer): + permissions = serializers.ListField(source="get_all_permissions", read_only=True) + groups = serializers.SerializerMethodField() + + def get_groups(self, user): + return user.groups.values_list("name", flat=True) + + class Meta: + model = User + fields = [ + "id", + "username", + "first_name", + "last_name", + "email", + "permissions", + "groups", + "is_staff", + "is_superuser", + ] + + +class UserProfileSerializer(serializers.ModelSerializer): + class Meta: + model = UserProfile + fields = ("user", "profile_data") diff --git a/apps/common/tests/factories.py b/apps/common/tests/factories.py index a2ce0ec0..2d97f412 100644 --- a/apps/common/tests/factories.py +++ b/apps/common/tests/factories.py @@ -60,6 +60,14 @@ def groups(self, create, extracted, **kwargs): self.groups.add(group) +class UserProfileFactory(factory.django.DjangoModelFactory): + class Meta: + model = "common.UserProfile" + django_get_or_create = ("user",) + + user = factory.SubFactory(UserFactory) + + class GroupFactory(factory.django.DjangoModelFactory): class Meta: model = "auth.Group" diff --git a/apps/common/tests/test_viewsets.py b/apps/common/tests/test_viewsets.py index 4f51565d..285c3c13 100644 --- a/apps/common/tests/test_viewsets.py +++ b/apps/common/tests/test_viewsets.py @@ -11,6 +11,7 @@ CurrencyFactory, UnitOfMeasureFactory, UserFactory, + UserProfileFactory, ) @@ -190,3 +191,59 @@ def test_search_fields(self): self.assertEqual(response.status_code, 200) result = json.loads(response.content.decode("utf-8")) self.assertEqual(len(result), 1) + + +class UserViewSetTestCase(APITestCase): + def setUp(self): + self.user = UserFactory(username="testuser", password="password123", first_name="Test", last_name="User") + self.client.force_authenticate(user=self.user) + self.url = reverse("user-list") + + def test_get_current_user(self): + response = self.client.get(f"{self.url}current/") + self.assertEqual(response.status_code, 200) + self.assertEqual(response.data["username"], self.user.username) + + def test_search_users(self): + UserFactory(username="searchuser", password="password123", first_name="Search", last_name="User") + response = self.client.get(self.url, {"search": "Search"}) + self.assertEqual(response.status_code, 200) + self.assertEqual(len(response.data), 1) + self.assertEqual(response.data[0]["first_name"], "Search") + + +class UserProfileViewSetTestCase(APITestCase): + def setUp(self): + self.user = UserFactory(username="testuser", password="password123") + self.profile = UserProfileFactory(user=self.user) + self.client.force_authenticate(user=self.user) + self.url = reverse("userprofile-list") + + def test_get_current_profile(self): + response = self.client.get(f"{self.url}current/") + self.assertEqual(response.status_code, 200) + self.assertEqual(response.data["user"], self.user.id) + + def test_superuser_access_profiles(self): + superuser = UserFactory(username="admin", password="password123", is_superuser=True) + self.client.force_authenticate(user=superuser) + response = self.client.get(f"{self.url}{self.profile.user.id}/") + self.assertEqual(response.status_code, 200) + self.assertEqual(response.data["user"], self.user.id) + + def test_queryset_filters(self): + other_user = UserFactory(username="otheruser", password="password123") + UserProfileFactory(user=other_user) + + # Current user profile only + response = self.client.get(f"{self.url}?pk=current") + self.assertEqual(response.status_code, 200) + self.assertEqual(len(response.data), 1) + self.assertEqual(response.data[0]["user"], self.user.id) + + # Superuser access to all profiles + superuser = UserFactory(username="admin", password="password123", is_superuser=True) + self.client.force_authenticate(user=superuser) + response = self.client.get(self.url) + self.assertEqual(response.status_code, 200) + self.assertGreaterEqual(len(response.data), 2) diff --git a/apps/common/viewsets.py b/apps/common/viewsets.py index 5b32ffed..e36f4547 100644 --- a/apps/common/viewsets.py +++ b/apps/common/viewsets.py @@ -1,18 +1,24 @@ +from django.contrib.auth.models import User from django.utils.text import format_lazy from django.utils.translation import gettext_lazy as _ from django_filters import rest_framework as filters from rest_framework import viewsets +from rest_framework.decorators import action from rest_framework.exceptions import NotAcceptable from rest_framework.pagination import PageNumberPagination +from rest_framework.permissions import BasePermission, IsAuthenticated from .fields import translation_fields from .filters import MultiFieldFilter -from .models import ClassifiedProduct, Country, Currency, UnitOfMeasure +from .models import ClassifiedProduct, Country, Currency, UnitOfMeasure, UserProfile from .serializers import ( ClassifiedProductSerializer, CountrySerializer, CurrencySerializer, + CurrentUserSerializer, UnitOfMeasureSerializer, + UserProfileSerializer, + UserSerializer, ) @@ -323,3 +329,64 @@ class ClassifiedProductViewSet(BaseModelViewSet): *translation_fields("description"), *translation_fields("common_name"), ) + + +class CurrentUserOnly(BasePermission): + def has_permission(self, request, view): + if request.user.is_superuser: + return True + elif view.kwargs == {"pk": "current"}: + # Even anonymous users can see their current user record + return True + elif request.query_params.get("pk") == "current": + # List views seem to use query_params rather than kwargs + return True + return False + + +class UserViewSet(BaseModelViewSet): + """ + Allows users to be viewed or edited. + """ + + queryset = User.objects.all() + permission_classes = [CurrentUserOnly] + serializer_class = UserSerializer + search_fields = ["username", "first_name", "last_name"] + + def get_object(self): + pk = self.kwargs.get("pk") + + if pk == "current": + self.serializer_class = CurrentUserSerializer + return self.request.user if self.request.user.id else User.get_anonymous() + + return super().get_object() + + @action(detail=True, methods=["get"]) + def current(self, request, *args, **kwargs): + return self.retrieve(request, *args, **kwargs) + + +class UserProfileViewSet(BaseModelViewSet): + queryset = UserProfile.objects.all() + serializer_class = UserProfileSerializer + permission_classes = [CurrentUserOnly, IsAuthenticated] + + def get_object(self): + pk = self.kwargs.get("pk") + if pk == "current": + return self.request.user.userprofile if self.request.user.id else None + return super().get_object() + + def get_queryset(self): + queryset = super().get_queryset() + pk = self.request.query_params.get("pk") or self.kwargs.get("pk") + + if pk == "current": + return queryset.filter(user=self.request.user.id) + elif pk: + # Superusers can access profiles without using pk=current. + return queryset.filter(user=pk) + else: + return queryset diff --git a/hea/urls.py b/hea/urls.py index 1cc6733e..20c04098 100644 --- a/hea/urls.py +++ b/hea/urls.py @@ -50,6 +50,8 @@ CountryViewSet, CurrencyViewSet, UnitOfMeasureViewSet, + UserProfileViewSet, + UserViewSet, ) from metadata.viewsets import ( HazardCategoryViewSet, @@ -67,6 +69,8 @@ router.register(r"currency", CurrencyViewSet) router.register(r"unitofmeasure", UnitOfMeasureViewSet) router.register(r"classifiedproduct", ClassifiedProductViewSet) +router.register(r"user", UserViewSet) +router.register(r"userprofile", UserProfileViewSet) # Metadata router.register(r"livelihoodcategory", LivelihoodCategoryViewSet)