Skip to content

Commit

Permalink
Websocket Queue #64
Browse files Browse the repository at this point in the history
  • Loading branch information
AliRn76 committed Jan 24, 2024
2 parents 9c4ba74 + 192f8a9 commit 8779ee5
Show file tree
Hide file tree
Showing 10 changed files with 147 additions and 102 deletions.
3 changes: 3 additions & 0 deletions docs/docs/release_notes.md
@@ -1,3 +1,6 @@
### 3.8.0
- Handle WebSocket connections when we have multiple workers with `multiprocessing.Manager`

### 3.7.0
- Add `ModelSerializer`

Expand Down
16 changes: 11 additions & 5 deletions docs/docs/websocket.md
Expand Up @@ -50,9 +50,15 @@ urls = {
from panther.websocket import send_message_to_websocket
await send_message_to_websocket(connection_id='7e82d57c9ec0478787b01916910a9f45', data='New Message From WS')
```
8. If you want to use `webscoket` in `multi-tread` or `multi-instance` backend, you should add `RedisMiddleware` in your `configs` or it won't work well.
8. If you want to use `webscoket` in a backend with `multiple workers`, we recommend you to add `RedisMiddleware` in your `configs`
[[Adding Redis Middleware]](https://pantherpy.github.io/middlewares/#redis-middleware)
9. If you want to close a connection:
9. If you don't want to add `RedisMiddleware` and you still want to use `websocket` in `multi-thread`,
you have to use `--preload` option while running the project like below:
```bash
gunicorn -w 10 -k uvicorn.workers.UvicornWorker main:app --preload
```

10. If you want to close a connection:
- In websocket class scope: You can close connection with `self.close()` method which takes 2 args, `code` and `reason`:
```python
from panther import status
Expand All @@ -65,7 +71,7 @@ urls = {
await close_websocket_connection(connection_id='7e82d57c9ec0478787b01916910a9f45', code=status.WS_1008_POLICY_VIOLATION, reason='')
```

10. `Path Variables` will be passed to `connect()`:
11. `Path Variables` will be passed to `connect()`:
```python
from panther.websocket import GenericWebsocket

Expand All @@ -77,6 +83,6 @@ urls = {
'/ws/<user_id>/<room_id>/': UserWebsocket
}
```
11. WebSocket Echo Example -> [Https://GitHub.com/PantherPy/echo_websocket](https://github.com/PantherPy/echo_websocket)
12. Enjoy.
12. WebSocket Echo Example -> [Https://GitHub.com/PantherPy/echo_websocket](https://github.com/PantherPy/echo_websocket)
13. Enjoy.

2 changes: 1 addition & 1 deletion example/model_serializer_example.py
@@ -1,6 +1,6 @@
from pydantic import Field

from panther import status, Panther
from panther import Panther, status
from panther.app import API
from panther.db import Model
from panther.request import Request
Expand Down
2 changes: 1 addition & 1 deletion panther/__init__.py
@@ -1,6 +1,6 @@
from panther.main import Panther # noqa: F401

__version__ = '3.7.0'
__version__ = '3.8.0'


def version():
Expand Down
101 changes: 71 additions & 30 deletions panther/base_websocket.py
Expand Up @@ -3,67 +3,106 @@
import asyncio
import contextlib
import logging
from typing import TYPE_CHECKING
from multiprocessing import Manager
from typing import TYPE_CHECKING, Literal

import orjson as json

from panther import status
from panther._utils import generate_ws_connection_id
from panther.base_request import BaseRequest
from panther.configs import config
from panther.db.connection import redis
from panther.utils import Singleton

if TYPE_CHECKING:
from redis import Redis


logger = logging.getLogger('panther')


class PubSub:
def __init__(self, manager):
self._manager = manager
self._subscribers = self._manager.list()

def subscribe(self):
queue = self._manager.Queue()
self._subscribers.append(queue)
return queue

def publish(self, msg):
for queue in self._subscribers:
queue.put(msg)


class WebsocketConnections(Singleton):
def __init__(self):
def __init__(self, manager: Manager = None):
self.connections = {}
self.connections_count = 0
self.manager = manager

def __call__(self, r: Redis | None):
if r:
subscriber = r.pubsub()
subscriber.subscribe('websocket_connections')
logger.info("Subscribed to 'websocket_connections' channel")
for channel_data in subscriber.listen():
# Check Type of PubSub Message
match channel_data['type']:
# Subscribed
case 'subscribe':
continue

# Message Received
case 'message':
loaded_data = json.loads(channel_data['data'].decode())
if (
isinstance(loaded_data, dict)
and (connection_id := loaded_data.get('connection_id'))
and (data := loaded_data.get('data'))
and (action := loaded_data.get('action'))
and (connection := self.connections.get(connection_id))
):
# Check Action of WS
match action:
case 'send':
logger.debug(f'Sending Message to {connection_id}')
asyncio.run(connection.send(data=data))
case 'close':
with contextlib.suppress(RuntimeError):
asyncio.run(connection.close(code=data['code'], reason=data['reason']))
# We are trying to disconnect the connection between a thread and a user
# from another thread, it's working, but we have to find another solution it
#
# Error:
# Task <Task pending coro=<Websocket.close()>> got Future
# <Task pending coro=<WebSocketCommonProtocol.transfer_data()>>
# attached to a different loop
case _:
logger.debug(f'Unknown Message Action: {action}')
case _:
logger.debug(f'Unknown Channel Type: {channel_data["type"]}')
self._handle_received_message(received_message=loaded_data)

case unknown_type:
logger.debug(f'Unknown Channel Type: {unknown_type}')
else:
self.pubsub = PubSub(manager=self.manager)
queue = self.pubsub.subscribe()
logger.info("Subscribed to 'websocket_connections' queue")
while True:
received_message = queue.get()
self._handle_received_message(received_message=received_message)

def _handle_received_message(self, received_message):
if (
isinstance(received_message, dict)
and (connection_id := received_message.get('connection_id'))
and connection_id in self.connections
and 'action' in received_message
and 'data' in received_message
):
# Check Action of WS
match received_message['action']:
case 'send':
asyncio.run(self.connections[connection_id].send(data=received_message['data']))
case 'close':
with contextlib.suppress(RuntimeError):
asyncio.run(self.connections[connection_id].close(
code=received_message['data']['code'],
reason=received_message['data']['reason']
))
# We are trying to disconnect the connection between a thread and a user
# from another thread, it's working, but we have to find another solution for it
#
# Error:
# Task <Task pending coro=<Websocket.close()>> got Future
# <Task pending coro=<WebSocketCommonProtocol.transfer_data()>>
# attached to a different loop
case unknown_action:
logger.debug(f'Unknown Message Action: {unknown_action}')

def publish(self, connection_id: str, action: Literal['send', 'close'], data: any):
publish_data = {'connection_id': connection_id, 'action': action, 'data': data}

if redis.is_connected:
redis.publish('websocket_connections', json.dumps(publish_data))
else:
self.pubsub.publish(publish_data)

async def new_connection(self, connection: Websocket) -> None:
await connection.connect(**connection.path_variables)
Expand Down Expand Up @@ -106,6 +145,7 @@ async def receive(self, data: str | bytes) -> None:
pass

async def send(self, data: any = None) -> None:
logger.debug(f'Sending WS Message to {self.connection_id}')
if data:
if isinstance(data, bytes):
await self.send_bytes(bytes_data=data)
Expand All @@ -121,6 +161,7 @@ async def send_bytes(self, bytes_data: bytes) -> None:
await self.asgi_send({'type': 'websocket.send', 'bytes': bytes_data})

async def close(self, code: int = status.WS_1000_NORMAL_CLOSURE, reason: str = '') -> None:
logger.debug(f'Closing WS Connection {self.connection_id}')
self.is_connected = False
config['websocket_connections'].remove_connection(self)
await self.asgi_send({'type': 'websocket.close', 'code': code, 'reason': reason})
Expand Down
35 changes: 23 additions & 12 deletions panther/main.py
Expand Up @@ -5,6 +5,7 @@
import types
from collections.abc import Callable
from logging.config import dictConfig
from multiprocessing import Manager
from pathlib import Path
from threading import Thread

Expand Down Expand Up @@ -52,14 +53,6 @@ def __init__(self, name: str, configs=None, urls: dict | None = None, startup: C
# Print Info
print_info(config)

# Start Websocket Listener (Redis Required)
if config['has_ws']:
Thread(
target=config['websocket_connections'],
daemon=True,
args=(self.ws_redis_connection,),
).start()

def load_configs(self) -> None:

# Check & Read The Configs File
Expand Down Expand Up @@ -98,8 +91,7 @@ def load_configs(self) -> None:
self._create_ws_connections_instance()

def _create_ws_connections_instance(self):
from panther.base_websocket import Websocket
from panther.websocket import WebsocketConnections
from panther.base_websocket import Websocket, WebsocketConnections

# Check do we have ws endpoint
for endpoint in config['flat_urls'].values():
Expand All @@ -111,7 +103,6 @@ def _create_ws_connections_instance(self):

# Create websocket connections instance
if config['has_ws']:
config['websocket_connections'] = WebsocketConnections()
# Websocket Redis Connection
for middleware in config['http_middlewares']:
if middleware.__class__.__name__ == 'RedisMiddleware':
Expand All @@ -120,6 +111,10 @@ def _create_ws_connections_instance(self):
else:
self.ws_redis_connection = None

# Don't create Manager() if we are going to use Redis for PubSub
manager = None if self.ws_redis_connection else Manager()
config['websocket_connections'] = WebsocketConnections(manager=manager)

async def __call__(self, scope: dict, receive: Callable, send: Callable) -> None:
"""
1.
Expand All @@ -138,6 +133,7 @@ async def __call__(self, scope: dict, receive: Callable, send: Callable) -> None
if scope['type'] == 'lifespan':
message = await receive()
if message["type"] == "lifespan.startup":
await self.handle_ws_listener()
await self.handle_startup()
return

Expand Down Expand Up @@ -262,6 +258,15 @@ async def handle_http(self, scope: dict, receive: Callable, send: Callable) -> N
body=response.body,
)

async def handle_ws_listener(self):
# Start Websocket Listener (Redis/ Queue)
if config['has_ws']:
Thread(
target=config['websocket_connections'],
daemon=True,
args=(self.ws_redis_connection,),
).start()

async def handle_startup(self):
if startup := config['startup'] or self._startup:
if is_function_async(startup):
Expand All @@ -272,7 +277,13 @@ async def handle_startup(self):
def handle_shutdown(self):
if shutdown := config['shutdown'] or self._shutdown:
if is_function_async(shutdown):
asyncio.run(shutdown())
try:
asyncio.run(shutdown())
except ModuleNotFoundError:
# Error: import of asyncio halted; None in sys.modules
# And as I figured it out, it only happens when we running with
# gunicorn and Uvicorn workers (-k uvicorn.workers.UvicornWorker)
pass
else:
shutdown()

Expand Down
3 changes: 2 additions & 1 deletion panther/response.py
Expand Up @@ -35,9 +35,10 @@ def body(self) -> bytes:

@property
def headers(self) -> dict:
content_length = 0 if self.body == b'null' else len(self.body)
return {
'content-type': self.content_type,
'content-length': len(self.body),
'content-length': content_length,
'access-control-allow-origin': '*',
} | (self._headers or {})

Expand Down
39 changes: 26 additions & 13 deletions panther/serializer.py
Expand Up @@ -3,31 +3,44 @@


class ModelSerializer:
def __new__(cls, *args, **kwargs):
def __new__(cls, *args, model=None, **kwargs):
# Check `metaclass`
if len(args) == 0:
msg = f"you should not inherit the 'ModelSerializer', you should use it as 'metaclass' -> {cls.__name__}"
address = f'{cls.__module__}.{cls.__name__}'
msg = f"you should not inherit the 'ModelSerializer', you should use it as 'metaclass' -> {address}"
raise TypeError(msg)

model_name = args[0]
if 'model' not in kwargs:
msg = f"'model' required while using 'ModelSerializer' metaclass -> {model_name}"
data = args[2]
address = f'{data["__module__"]}.{model_name}'

# Check `model`
if model is None:
msg = f"'model' required while using 'ModelSerializer' metaclass -> {address}"
raise AttributeError(msg)
# Check `fields`
if 'fields' not in data:
msg = f"'fields' required while using 'ModelSerializer' metaclass. -> {address}"
raise AttributeError(msg) from None

model_fields = kwargs['model'].model_fields
model_fields = model.model_fields
field_definitions = {}
if 'fields' not in args[2]:
msg = f"'fields' required while using 'ModelSerializer' metaclass. -> {model_name}"
raise AttributeError(msg) from None
for field_name in args[2]['fields']:

# Collect `fields`
for field_name in data['fields']:
if field_name not in model_fields:
msg = f"'{field_name}' is not in '{kwargs['model'].__name__}' -> {model_name}"
msg = f"'{field_name}' is not in '{model.__name__}' -> {address}"
raise AttributeError(msg) from None

field_definitions[field_name] = (model_fields[field_name].annotation, model_fields[field_name])
for required in args[2].get('required_fields', []):

# Change `required_fields
for required in data.get('required_fields', []):
if required not in field_definitions:
msg = f"'{required}' is in 'required_fields' but not in 'fields' -> {model_name}"
msg = f"'{required}' is in 'required_fields' but not in 'fields' -> {address}"
raise AttributeError(msg) from None
field_definitions[required][1].default = PydanticUndefined

# Create Model
return create_model(
__model_name=model_name,
**field_definitions
Expand Down

0 comments on commit 8779ee5

Please sign in to comment.