Skip to content

Commit

Permalink
feat: Rewrite GraphQL subscriptions to Django Channels and use Load B…
Browse files Browse the repository at this point in the history
…alancer to handle websocket connections (#488)

* Use django channels for graphql subscriptions

* Configure redis as django channels layers backend

* Add redis configuration for channels layers and implement authentication middleware

* Show error when one tries to use permission_checker decorator on a Subscription class

* Change gunicorn worker class to uvicorn to support ASGI app.

* Delete wsgi.py

* Update backend and frontend tests

* Add missing redis cluster to AWS infra

* Make sure we use wss protocol in AWS

* Fix CF distribution to handle websockets

* Move scripts from package.json to project.json to avoid issues with env variables passing

* Fix docs deployment

* Fix failing tests and build steps

* Reconnect websockets after user logs into the app.

* Fix sonar-project properties
  • Loading branch information
pziemkowski committed Mar 4, 2024
1 parent 6067b41 commit e13baf4
Show file tree
Hide file tree
Showing 76 changed files with 1,316 additions and 2,089 deletions.
4 changes: 0 additions & 4 deletions .versionrc.js
Expand Up @@ -74,10 +74,6 @@ module.exports = {
filename: './packages/internal/docs/package.json',
type: 'json',
},
{
filename: './packages/internal/local-ws-server/package.json',
type: 'json',
},
{
filename: './packages/internal/status-dashboard/package.json',
type: 'json',
Expand Down
18 changes: 7 additions & 11 deletions docker-compose.local.yml
Expand Up @@ -7,6 +7,9 @@ volumes:

web_backend_staticfiles: {}

redis_cache:
driver: local

services:
db:
volumes:
Expand All @@ -28,7 +31,6 @@ services:
- localstack
- mailcatcher
- workers
- localwsserver

workers:
volumes:
Expand All @@ -49,6 +51,10 @@ services:
ports:
- "3005:3005"

redis:
volumes:
- redis_cache:/data

localstack:
image: localstack/localstack:2.3.0
ports:
Expand Down Expand Up @@ -78,13 +84,3 @@ services:
- "1080:1080"
- "1025:1025"
restart: always

localwsserver:
build: ./packages/internal/local-ws-server
ports:
- "8080:8080"
environment:
- BE_ENDPOINT_URL=http://backend:5001/
volumes:
- ./packages/internal/local-ws-server/:/app/
- /app/node_modules/
14 changes: 14 additions & 0 deletions docker-compose.yml
Expand Up @@ -29,9 +29,12 @@ services:
condition: service_started
db:
condition: service_healthy
redis:
condition: service_healthy
restart: unless-stopped
links:
- db
- redis
environment:
- AWS_ACCESS_KEY_ID=foo
- AWS_SECRET_ACCESS_KEY=bar
Expand All @@ -49,6 +52,17 @@ services:
links:
- db

redis:
image: redis:7.2.4-alpine
restart: always
ports:
- '6379:6379'
healthcheck:
test: [ "CMD", "redis-cli", "--raw", "incr", "ping" ]
interval: 10s
timeout: 5s
retries: 5

stripemock:
image: stripe/stripe-mock:v0.170.0
ports:
Expand Down
1 change: 1 addition & 0 deletions packages/backend/.env.shared
Expand Up @@ -15,6 +15,7 @@ TASKS_LOCAL_URL=http://workers:3005
TASKS_BASE_HANDLER=common.tasks.TaskLocalInvoke

DB_CONNECTION={"dbname":"backend","username":"backend","password":"backend","host":"db","port":5432}
REDIS_CONNECTION=redis://redis:6379

AWS_ENDPOINT_URL=http://localstack:4566
WORKERS_EVENT_BUS_NAME=local-workers
Expand Down
1 change: 1 addition & 0 deletions packages/backend/.test.env
Expand Up @@ -11,6 +11,7 @@ DJANGO_DEFAULT_FILE_STORAGE=common.tests.storages.MockS3Boto3Storage
HASHID_FIELD_SALT=9q#3t$5gs9ob682b@(6^fdv2kg*0ztr(3doa((w&kyq!d8rbt^y

DB_CONNECTION='{"dbname":"backend","username":"backend","password":"backend","host":"db","port":5432}'
REDIS_CONNECTION=redis://redis:6379

WORKERS_EVENT_BUS_NAME=local-workers

Expand Down
1 change: 1 addition & 0 deletions packages/backend/Dockerfile
Expand Up @@ -41,6 +41,7 @@ ENV HASHID_FIELD_SALT='' \
DJANGO_PARENT_HOST='' \
DJANGO_SECRET_KEY='build' \
DB_CONNECTION='{"dbname":"build","username":"build","password":"build","host":"db","port":5432}' \
REDIS_CONNECTION=redis://redis:6379 \
WORKERS_EVENT_BUS_NAME='' \
PYTHONPATH=/pkgs/__pypackages__/3.11/lib

Expand Down
69 changes: 0 additions & 69 deletions packages/backend/apps/demo/tests/test_schema.py
Expand Up @@ -167,38 +167,6 @@ def test_create_new_item_sends_notification(self, graphene_client, user_factory,
}
assert notification.issuer == user

def test_create_new_item_sends_notification_through_websockets(
self, mocker, graphene_client, user_factory, graph_ql_subscription_factory, input_data
):
post_to_connection = mocker.patch("apps.websockets.apigateway.post_to_connection")
user = user_factory()
admin = user_factory(admin=True)
graph_ql_subscription_factory(
connection__connection_id="conn-id",
connection__user=admin,
operation_name="notificationsListSubscription",
relay_id="1",
query=self.NOTIFICATIONS_SUBSCRIPTION,
)

graphene_client.force_authenticate(user)
graphene_client.mutate(self.CREATE_MUTATION, variable_values={"input": input_data})

assert Notification.objects.count() == 1
notification = Notification.objects.first()
notification_global_id = to_global_id("NotificationType", str(notification.id))
post_to_connection.assert_called_once_with(
{
"id": "1",
"type": "next",
"payload": {
"data": {"notificationCreated": {"edges": [{"node": {"id": notification_global_id}}]}},
"errors": None,
},
},
"conn-id",
)


class TestUpdateCrudDemoItemMutation:
UPDATE_MUTATION = """
Expand Down Expand Up @@ -296,43 +264,6 @@ def test_update_existing_item_sends_notification_to_admins_skipping_creator_if_h
assert Notification.objects.filter(user=admins[0], type=constants.Notification.CRUD_ITEM_UPDATED.value).exists()
assert Notification.objects.filter(user=admins[1], type=constants.Notification.CRUD_ITEM_UPDATED.value).exists()

def test_update_existing_item_sends_notification_through_websocket_to_admin_with_open_subscription(
self, mocker, graphene_client, crud_demo_item, user_factory, graph_ql_subscription_factory, input_data_factory
):
post_to_connection = mocker.patch("apps.websockets.apigateway.post_to_connection")
user = user_factory()
crud_demo_item.user = user
crud_demo_item.save()
admins = user_factory.create_batch(2, admin=True)
input_data = input_data_factory(crud_demo_item)
graph_ql_subscription_factory(
connection__connection_id="conn-id",
connection__user=admins[0],
operation_name="notificationsListSubscription",
relay_id="1",
query=self.NOTIFICATIONS_SUBSCRIPTION,
)

graphene_client.force_authenticate(user)
graphene_client.mutate(
self.UPDATE_MUTATION,
variable_values={"input": input_data},
)

notification = Notification.objects.get(user=admins[0], type=constants.Notification.CRUD_ITEM_UPDATED.value)
notification_global_id = to_global_id("NotificationType", str(notification.id))
post_to_connection.assert_called_once_with(
{
"id": "1",
"type": "next",
"payload": {
"data": {"notificationCreated": {"edges": [{"node": {"id": notification_global_id}}]}},
"errors": None,
},
},
"conn-id",
)


class TestDeleteCrudDemoItemMutation:
DELETE_MUTATION = """
Expand Down
44 changes: 33 additions & 11 deletions packages/backend/apps/notifications/schema.py
@@ -1,15 +1,17 @@
import channels_graphql_ws
import graphene
from apps.users.models import User
from apps.users.services.users import get_user_avatar_url
from channels.db import database_sync_to_async
from common.acl.policies import IsAuthenticatedFullAccess
from common.graphql import mutations
from common.graphql.acl import permission_classes
from graphene import relay
from graphene.types.generic import GenericScalar
from graphene_django import DjangoObjectType

from common.graphql import mutations
from . import models
from . import serializers
from . import services
from apps.users.models import User

from apps.users.services.users import get_user_avatar_url, get_user_from_resolver


class HasUnreadNotificationsMixin:
Expand All @@ -31,11 +33,11 @@ class Meta:

@staticmethod
def resolve_first_name(parent, info):
return get_user_from_resolver(info).profile.first_name
return parent.profile.first_name

@staticmethod
def resolve_last_name(parent, info):
return get_user_from_resolver(info).profile.last_name
return parent.profile.last_name

@staticmethod
def resolve_avatar(parent, info):
Expand Down Expand Up @@ -102,9 +104,29 @@ def resolve_has_unread_notifications(root, info):
return models.Notification.objects.filter(user=info.context.user, read_at=None).exists()


class Subscription(graphene.ObjectType):
notification_created = graphene.relay.ConnectionField(NotificationConnection)
class NotificationCreatedSubscription(channels_graphql_ws.Subscription):
"""Simple GraphQL subscription."""

# Leave only latest 64 messages in the server queue.
notification_queue_limit = 64

notification = graphene.Field(NotificationType)

@staticmethod
def subscribe(root, info):
return [str(info.context.channels_scope['user'].id)]

@staticmethod
@database_sync_to_async
def get_response(id: str):
notification = models.Notification.objects.prefetch_related('issuer', 'issuer__profile').get(id=id)
return NotificationCreatedSubscription(notification)

@staticmethod
def resolve_notification_created(root, info):
return root
async def publish(payload, info):
return await NotificationCreatedSubscription.get_response(id=payload['id'])


@permission_classes(IsAuthenticatedFullAccess)
class Subscription(graphene.ObjectType):
notification_created = NotificationCreatedSubscription.Field()
7 changes: 2 additions & 5 deletions packages/backend/apps/notifications/signals.py
@@ -1,13 +1,10 @@
from django.db.models.signals import post_save
from django.dispatch import receiver

from apps.websockets import utils
from . import models, constants
from . import models, schema


@receiver(post_save, sender=models.Notification)
def notify_about_entry(sender, instance: models.Notification, created, update_fields, **kwargs):
if created:
utils.send_subscriptions_messages(
instance.user, constants.Subscription.NOTIFICATIONS_LIST_SUBSCRIPTION.value, root_value=[instance]
)
schema.NotificationCreatedSubscription.broadcast(payload={'id': str(instance.id)}, group=str(instance.user.id))
47 changes: 47 additions & 0 deletions packages/backend/apps/users/authentication.py
@@ -1,4 +1,6 @@
from channels.db import database_sync_to_async
from django.conf import settings
from django.http import parse_cookie
from rest_framework import HTTP_HEADER_ENCODING
from rest_framework_simplejwt import authentication

Expand All @@ -19,3 +21,48 @@ def get_header(self, request):

def get_raw_token(self, header):
return header


class JSONWebTokenChannelsAuthentication(authentication.JWTAuthentication):
def get_header(self, scope):
for name, value in scope.get("headers", []):
if name == b"cookie":
cookies = parse_cookie(value.decode("latin1"))
break
else:
# No cookie header found - add an empty default.
cookies = {}

"""
Extracts the header containing the JSON web token from the given
request.
"""
header = cookies.get(settings.ACCESS_TOKEN_COOKIE)

if isinstance(header, str):
# Work around django test client oddness
header = header.encode(HTTP_HEADER_ENCODING)

return header

def get_raw_token(self, header):
return header


class JSONWebTokenCookieMiddleware:
def __init__(self, app):
self.app = app

@database_sync_to_async
def authenticate(self, scope):
auth_backend = JSONWebTokenChannelsAuthentication()
return auth_backend.authenticate(scope)

async def __call__(self, scope, receive, send):
scope = dict(scope)
result = await self.authenticate(scope)
if result is not None:
user, _ = result
scope["user"] = user

return await self.app(scope, receive, send)
13 changes: 0 additions & 13 deletions packages/backend/apps/websockets/admin.py

This file was deleted.

21 changes: 0 additions & 21 deletions packages/backend/apps/websockets/apigateway.py

This file was deleted.

9 changes: 9 additions & 0 deletions packages/backend/apps/websockets/consumers.py
@@ -0,0 +1,9 @@
import channels_graphql_ws

from config.schema import schema as graphql_schema


class DefaultGraphqlWsConsumer(channels_graphql_ws.GraphqlWsConsumer):
"""Channels WebSocket consumer which provides GraphQL API."""

schema = graphql_schema
18 changes: 0 additions & 18 deletions packages/backend/apps/websockets/graphql.py

This file was deleted.

0 comments on commit e13baf4

Please sign in to comment.