Skip to content

Commit

Permalink
feat: add post notify subscription
Browse files Browse the repository at this point in the history
  • Loading branch information
JiaWeiXie committed Oct 10, 2023
1 parent 48ce1a5 commit e8074a8
Show file tree
Hide file tree
Showing 8 changed files with 577 additions and 5 deletions.
487 changes: 483 additions & 4 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions pyproject.toml
Expand Up @@ -11,6 +11,8 @@ django = "^4.2.5"
strawberry-graphql-django = "^0.17.4"
django-extensions = "^3.2.3"
pillow = "^10.0.1"
strawberry-graphql = {extras = ["channels"], version = "^0.209.6"}
daphne = "^4.0.0"


[tool.poetry.group.dev.dependencies]
Expand Down
17 changes: 17 additions & 0 deletions server/app/blog/graph/mutations.py
Expand Up @@ -3,6 +3,9 @@

import strawberry
import strawberry_django
from asgiref.sync import async_to_sync
from channels.layers import get_channel_layer
from django.conf import settings
from django.core.exceptions import ValidationError
from django.utils import timezone
from strawberry.file_uploads import Upload
Expand Down Expand Up @@ -59,6 +62,18 @@ def _handle_form_errors(
)


def notify_new_post(post: blog_models.Post) -> None:
channel_layer = get_channel_layer()
group_send = async_to_sync(channel_layer.group_send) # type: ignore
group_send(
settings.POSTS_CHANNEL,
{
"type": "chat.message",
"post_id": post.pk,
},
)


@strawberry.type
class Mutation:
@strawberry_django.mutation(
Expand Down Expand Up @@ -120,6 +135,8 @@ def publish_post(self, id: uuid.UUID) -> blog_types.Post: # noqa: A002
post.published = True
post.published_at = timezone.now()
post.save()

notify_new_post(post)
return typing.cast(blog_types.Post, post)

@strawberry_django.mutation(handle_django_errors=True)
Expand Down
42 changes: 42 additions & 0 deletions server/app/blog/graph/subscriptions.py
@@ -0,0 +1,42 @@
import typing

import strawberry
from channels.db import database_sync_to_async
from django.conf import settings
from strawberry.types import Info

from server.app.blog import models as blog_models
from server.app.blog.graph import types as blog_types


@database_sync_to_async
def _get_post(id: str) -> blog_models.Post | None: # noqa: A002
try:
return blog_models.Post.objects.get(id=id)
except blog_models.Post.DoesNotExist:
return None


@strawberry.type
class Subscription:
@strawberry.subscription
async def post_notify(
self,
info: Info,
) -> typing.AsyncGenerator[blog_types.PostNotification, None]:
ws = info.context["ws"]
channel_layer = ws.channel_layer
await channel_layer.group_add(settings.POSTS_CHANNEL, ws.channel_name)

async with ws.listen_to_channel(
"chat.message",
groups=[settings.POSTS_CHANNEL],
) as cm:
async for message in cm:
post = await _get_post(message["post_id"])
if post:
yield blog_types.PostNotification(
id=post.pk,
title=post.title,
publish_at=post.published_at,
)
7 changes: 7 additions & 0 deletions server/app/blog/graph/types.py
Expand Up @@ -140,3 +140,10 @@ class CreatePostResult:
strawberry.union("FormValidationError"),
]
] | None = strawberry.field(default=None)


@strawberry.type
class PostNotification(relay.Node):
id: relay.NodeID[uuid.UUID] # noqa: A003
title: str
publish_at: datetime.datetime
11 changes: 10 additions & 1 deletion server/asgi.py
Expand Up @@ -10,7 +10,16 @@
import os

from django.core.asgi import get_asgi_application
from strawberry.channels import GraphQLProtocolTypeRouter

os.environ.setdefault("DJANGO_SETTINGS_MODULE", "server.settings")

application = get_asgi_application()
django_asgi_app = get_asgi_application()

from server.schema import ws_schema # noqa: E402

application = GraphQLProtocolTypeRouter(
ws_schema,
django_application=django_asgi_app, # type: ignore
url_pattern="^wsgraphql",
)
6 changes: 6 additions & 0 deletions server/schema.py
Expand Up @@ -5,6 +5,7 @@
from server.app.authentication.graph import queries as auth_queries
from server.app.blog.graph import mutations as blog_mutations
from server.app.blog.graph import queries as blog_queries
from server.app.blog.graph import subscriptions as blog_subscriptions

__all__ = ("schema",)

Expand All @@ -23,6 +24,11 @@
auth_mutations.Mutation,
),
)
subscription = strawberry.tools.merge_types(
"Subscription",
(blog_subscriptions.Subscription,),
)


schema = strawberry.Schema(query=query, mutation=mutation)
ws_schema = strawberry.Schema(query=query, subscription=subscription)
10 changes: 10 additions & 0 deletions server/settings.py
Expand Up @@ -31,6 +31,7 @@
# Application definition

INSTALLED_APPS = [
"daphne",
"django.contrib.admin",
"django.contrib.auth",
"django.contrib.contenttypes",
Expand Down Expand Up @@ -72,6 +73,7 @@
]

WSGI_APPLICATION = "server.wsgi.application"
ASGI_APPLICATION = "server.asgi.application"


# Database
Expand Down Expand Up @@ -129,3 +131,11 @@
# https://docs.djangoproject.com/en/4.2/ref/settings/#default-auto-field

DEFAULT_AUTO_FIELD = "django.db.models.BigAutoField"

CHANNEL_LAYERS = {
"default": {
"BACKEND": "channels.layers.InMemoryChannelLayer",
},
}

POSTS_CHANNEL = "posts_notifications"

0 comments on commit e8074a8

Please sign in to comment.