-
Notifications
You must be signed in to change notification settings - Fork 0
/
mutations.py
151 lines (132 loc) · 4.92 KB
/
mutations.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import typing
import uuid
import strawberry
import strawberry_django
from asgiref.sync import async_to_sync
from channels.layers import get_channel_layer
from django.conf import settings
from django.core.exceptions import ValidationError
from django.utils import timezone
from strawberry.file_uploads import Upload
from strawberry.permission import BasePermission
from strawberry.types import Info
from strawberry.utils.str_converters import to_camel_case
from strawberry_django.permissions import (
HasPerm,
IsStaff,
)
from server.app.blog import forms as blog_forms
from server.app.blog import models as blog_models
from server.app.blog.graph import types as blog_types
class IsAuthor(BasePermission):
message = "You must be the author of this post to perform this action."
def has_permission(self, source: typing.Any, info: Info, **kwargs) -> bool:
data = kwargs["data"]
post = data["id"].resolve_node_sync(info, ensure_type=blog_models.Post)
user = info.context.request.user
if post.author == user:
return True
return False
def _handle_form_errors(
errors: dict[str, list[ValidationError]],
) -> typing.Iterator[
blog_types.ValidationError
| blog_types.InvalidChoiceError
| blog_types.DuplicateError
]:
for field, field_errors in errors.items():
for err in field_errors:
code = getattr(err, "code", "invalid")
if code == "unique":
yield blog_types.DuplicateError(
field=to_camel_case(field),
message=err.message % err.params if err.params else err.message,
)
elif code == "invalid_choice":
yield blog_types.InvalidChoiceError(
field=to_camel_case(field),
message=err.message % err.params if err.params else err.message,
value=err.params["value"],
)
else:
yield blog_types.ValidationError(
field=to_camel_case(field),
message=err.message % err.params if err.params else err.message,
)
def notify_new_post(post: blog_models.Post) -> None:
channel_layer = get_channel_layer()
group_send = async_to_sync(channel_layer.group_send) # type: ignore
group_send(
settings.POSTS_CHANNEL,
{
"type": "chat.message",
"post_id": post.pk,
},
)
@strawberry.type
class Mutation:
@strawberry_django.mutation(
handle_django_errors=True,
permission_classes=[IsAuthor],
)
def update_post(
self,
data: blog_types.PostInputPartial,
info: Info,
) -> blog_types.Post:
post = data.id.resolve_node_sync(info, ensure_type=blog_models.Post)
input_data = vars(data)
for field, value in input_data.items():
if field in ("id", "tags", "categories"):
continue
if value and hasattr(post, field):
setattr(post, field, value)
post.save()
if data.tags and isinstance(data.tags, list):
tags = [
tag_id.resolve_node_sync(info, ensure_type=blog_models.Tag)
for tag_id in data.tags
]
post.tags.set(tags)
if data.categories and isinstance(data.categories, list):
categories = [
category_id.resolve_node_sync(info, ensure_type=blog_models.Category)
for category_id in data.categories
]
post.categories.set(categories)
return typing.cast(blog_types.Post, post)
@strawberry_django.mutation
def create_post(self, data: blog_types.PostInput) -> blog_types.CreatePostResult:
form = blog_forms.PostForm(vars(data))
if not form.is_valid():
return blog_types.CreatePostResult(
post=None,
errors=list(_handle_form_errors(form.errors.as_data())),
)
post = form.save()
return blog_types.CreatePostResult(post=post)
@strawberry_django.input_mutation(
handle_django_errors=True,
extensions=[
HasPerm(["blog.publish_post", "blog.view_post"], any_perm=False),
IsStaff(),
],
)
def publish_post(self, id: uuid.UUID) -> blog_types.Post: # noqa: A002
post = blog_models.Post.objects.get(pk=id)
if not post.published:
post.published = True
post.published_at = timezone.now()
post.save()
notify_new_post(post)
return typing.cast(blog_types.Post, post)
@strawberry_django.mutation(handle_django_errors=True)
def upload_post_cover_image(
self,
post_id: uuid.UUID,
file: Upload,
) -> blog_types.Post:
post = blog_models.Post.objects.get(pk=post_id)
post.cover_image = file # type: ignore
post.save()
return typing.cast(blog_types.Post, post)