Skip to content

Commit

Permalink
[Fixes #74] This implements configurability for the json.loads/ json.…
Browse files Browse the repository at this point in the history
…dumps

used.

First draft.. not as granular as it could be.. but at least it shouldn't
slow down the current applications.

Tests pass, adds "json_loads" and "json_dumps" parameters to the
socketio_manage() call.
  • Loading branch information
Alexandre Bourget committed Aug 4, 2012
1 parent 79d188f commit a8b6d93
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 27 deletions.
14 changes: 13 additions & 1 deletion socketio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
log = logging.getLogger(__name__)


def socketio_manage(environ, namespaces, request=None, error_handler=None):
def socketio_manage(environ, namespaces, request=None, error_handler=None,
json_loads=None, json_dumps=None):
"""Main SocketIO management function, call from within your Framework of
choice's view.
Expand Down Expand Up @@ -35,6 +36,11 @@ def socketio_manage(environ, namespaces, request=None, error_handler=None):
The callable you pass in should have the same signature as the default
error handler.
The ``json_loads`` and ``json_dumps`` are overrides for the default
``json.loads`` and ``json.dumps`` function calls. Override these at
the top-most level here. This will affect all sockets created by this
socketio manager, and all namespaces inside.
This function will block the current "view" or "controller" in your
framework to do the recv/send on the socket, and dispatch incoming messages
to your namespaces.
Expand All @@ -45,6 +51,7 @@ def socketio_manage(environ, namespaces, request=None, error_handler=None):
def my_view(request):
socketio_manage(request.environ, {'': GlobalNamespace}, request)
NOTE: You must understand that this function is going to be called
*only once* per socket opening, *even though* you are using a long
polling mechanism. The subsequent calls (for long polling) will
Expand All @@ -67,6 +74,11 @@ def my_view(request):
if error_handler:
socket._set_error_handler(error_handler)

if json_loads:
socket._set_json_loads(json_loads)
if json_dumps:
socket._set_json_dumps(json_dumps)

receiver_loop = socket._spawn_receiver_loop()
watcher = socket._spawn_watcher()

Expand Down
21 changes: 21 additions & 0 deletions socketio/defaultjson.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
### default json loaders
try:
import simplejson as json
json_decimal_args = {"use_decimal": True} # pragma: no cover
except ImportError:
import json
import decimal

class DecimalEncoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, decimal.Decimal):
return float(o)
return super(DecimalEncoder, self).default(o)
json_decimal_args = {"cls": DecimalEncoder}

def default_json_dumps(data):
return json.dumps(data, separators=(',', ':'),
**json_decimal_args)

def default_json_loads(data):
return json.loads(data)
33 changes: 9 additions & 24 deletions socketio/packet.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,4 @@
try:
import simplejson as json
json_decimal_args = {"use_decimal": True} # pragma: no cover
except ImportError:
import json
import decimal

class DecimalEncoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, decimal.Decimal):
return float(o)
return super(DecimalEncoder, self).default(o)
json_decimal_args = {"cls": DecimalEncoder}

from socketio.defaultjson import default_json_dumps, default_json_loads

MSG_TYPES = {
'disconnect': 0,
Expand Down Expand Up @@ -45,7 +32,7 @@ def default(self, o):
'ackId', 'reason', 'advice', 'qs', 'id']


def encode(data):
def encode(data, json_dumps=default_json_dumps):
"""
Encode an attribute dict into a byte string.
"""
Expand All @@ -72,14 +59,13 @@ def encode(data):
if msg == '3':
payload = data['data']
if msg == '4':
payload = json.dumps(data['data'], separators=(',', ':'),
**json_decimal_args)
payload = json_dumps(data['data'])
if msg == '5':
d = {}
d['name'] = data['name']
if 'args' in data and data['args'] != []:
d['args'] = data['args']
payload = json.dumps(d, separators=(',', ':'), **json_decimal_args)
payload = json_dumps(d)
if 'id' in data:
msg += ':' + str(data['id'])
if data['ack'] == 'data':
Expand All @@ -98,8 +84,7 @@ def encode(data):
# '6:::' [id] '+' [data]
msg += '::' + data.get('endpoint', '') + ':' + str(data['ackId'])
if 'args' in data and data['args'] != []:
msg += '+' + json.dumps(data['args'], separators=(',', ':'),
**json_decimal_args)
msg += '+' + json_dumps(data['args'])

elif msg == '7':
# '7::' [endpoint] ':' [reason] '+' [advice]
Expand All @@ -117,7 +102,7 @@ def encode(data):
return msg


def decode(rawstr):
def decode(rawstr, json_loads=default_json_loads):
"""
Decode a rawstr packet arriving from the socket into a dict.
"""
Expand Down Expand Up @@ -163,11 +148,11 @@ def decode(rawstr):
decoded_msg['data'] = data

elif msg_type == "4": # json msg
decoded_msg['data'] = json.loads(data)
decoded_msg['data'] = json_loads(data)

elif msg_type == "5": # event
try:
data = json.loads(data)
data = json_loads(data)
except ValueError, e:
print("Invalid JSON event message", data)
decoded_msg['args'] = []
Expand All @@ -182,7 +167,7 @@ def decode(rawstr):
if '+' in data:
ackId, data = data.split('+')
decoded_msg['ackId'] = int(ackId)
decoded_msg['args'] = json.loads(data)
decoded_msg['args'] = json_loads(data)
else:
decoded_msg['ackId'] = int(data)
decoded_msg['args'] = []
Expand Down
24 changes: 22 additions & 2 deletions socketio/virtsocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from gevent.event import Event

from socketio import packet
from socketio.defaultjson import default_json_loads, default_json_dumps


def default_error_handler(socket, error_name, error_message, endpoint,
Expand Down Expand Up @@ -65,6 +66,9 @@ class Socket(object):
"""Use this to be explicit when specifying a Global Namespace (an endpoint
with no name, not '/chat' or anything."""

json_loads = staticmethod(default_json_loads)
json_dumps = staticmethod(default_json_dumps)

def __init__(self, server, error_handler=None):
self.server = weakref.proxy(server)
self.sessid = str(random.random())[2:]
Expand Down Expand Up @@ -116,6 +120,22 @@ def _set_error_handler(self, error_handler):
"""
self.error_handler = error_handler

def _set_json_loads(self, json_loads):
"""Change the default JSON decoder.
This should be a callable that accepts a single string, and returns
a well-formed object.
"""
self.json_loads = json_loads

def _set_json_dumps(self, json_dumps):
"""Change the default JSON decoder.
This should be a callable that accepts a single string, and returns
a well-formed object.
"""
self.json_dumps = json_dumps

def _get_next_msgid(self):
"""This retrieves the next value for the 'id' field when sending
an 'event' or 'message' or 'json' that asks the remote client
Expand Down Expand Up @@ -296,7 +316,7 @@ def remove_namespace(self, namespace):
def send_packet(self, pkt):
"""Low-level interface to queue a packet on the wire (encoded as wire
protocol"""
self.put_client_msg(packet.encode(pkt))
self.put_client_msg(packet.encode(pkt, self.json_dumps))

def spawn(self, fn, *args, **kwargs):
"""Spawn a new Greenlet, attached to this Socket instance.
Expand All @@ -320,7 +340,7 @@ def _receiver_loop(self):
if not rawdata:
continue # or close the connection ?
try:
pkt = packet.decode(rawdata)
pkt = packet.decode(rawdata, self.json_loads)
except (ValueError, KeyError, Exception), e:
self.error('invalid_packet',
"There was a decoding error when dealing with packet "
Expand Down

0 comments on commit a8b6d93

Please sign in to comment.