/
websocket.py
248 lines (176 loc) · 5.35 KB
/
websocket.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
from socket import *
from struct import *
import re
import base64
import hashlib
import copy
#
# TODO tls support
#
class WebSocketError(Exception) :
pass
class WebSocket() :
'BSD-style interface to websockets'
# constant used in handshake
WS_GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
# bufsize argument for recv, see python doc
RECV_SIZE = 4096
# states
INITIAL, LISTENING, OPEN, CLOSED = xrange(4)
# opcodes
CONT = 0x0
TEXT = 0x1
BIN = 0x2
# 0x3-7 reserved, non-control
CLOSE = 0x8
PING = 0x9
PONG = 0xa
# 0xb-f reserved, control
def __init__(self) :
self._socket = socket()
self._state = WebSocket.INITIAL
self._server = False
def _assert_state(self, state, message='') :
if self._state != state :
raise WebSocketError(message)
def listen(self, port=80) :
self._assert_state(WebSocket.INITIAL, 'Cannot listen while in state: ' + str(self._state))
self._socket.bind(('', port))
self._socket.listen(5)
self._server = True
self._state = WebSocket.LISTENING
def accept(self) :
'Loop until a connection is established'
self._assert_state(WebSocket.LISTENING, 'Cannot accept connections without calling listen()')
while (True) :
conn, addr = self._socket.accept()
ret = self._parse_handshake(conn.recv(WebSocket.RECV_SIZE))
if ret is None :
self._bad_handshake(conn)
conn.close()
continue
resource, hdr = ret
self._establish(conn, hdr)
socket = copy.deepcopy(self)
socket._state = WebSocket.OPEN
socket._socket = conn
return socket
# TODO fragmenting?
def send(self, frame, type_=None) :
'Send a frame to the remote end. Type must be WebSocket.TEXT or WebSocket.BIN, TEXT is default.'
self._assert_state(WebSocket.OPEN, 'Cannot send without establishing connection.')
if type_ is None or type_ == WebSocket.TEXT :
type_ = WebSocket.TEXT
frame = frame.encode('utf-8')
if type_ not in (WebSocket.TEXT, WebSocket.BIN) :
raise TypeError, 'Frame type must be text or binary.'
# TODO implement client send with masking
hdr = ''
size = len(frame)
fin_op = 1 << 7 | type_
if size > 125 :
if size > 65535 :
hdr = pack('>BBQ', fin_op, 127, size)
else :
hdr = pack('>BBH', fin_op, 126, size)
else :
hdr = pack('BB', fin_op, size)
self._socket.sendall(hdr + frame)
# XXX too much array slicing?
def recv(self) :
'Receive a frame from the remote end. returns (frame, type)'
self._assert_state(WebSocket.OPEN, 'Cannot receive without establishing connection.')
data = ''
type_ = None
fin = 0
# loop until complete frame has been received
while not fin :
buf = ''
while len(buf) < 2 :
buf += self._socket.recv(WebSocket.RECV_SIZE)
fin_op, mask_size = unpack('BB', buf[:2])
buf = buf[2:]
masked = mask_size & 1 << 7
fin = fin_op & 1 << 7
op = fin_op & 0xf
if op in (WebSocket.BIN, WebSocket.TEXT) :
type_ = op
if self._server and not masked :
pass
# TODO unmasked frame from client, close connection
if not self._server and masked :
pass
# TODO client and masked frame, close
size = 0xff >> 1 & mask_size
if size == 126 :
while len(buf) < 2 :
buf += self._socket.recv(WebSocket.RECV_SIZE)
size = unpack('>H', buf[:2])[0]
buf = buf[2:]
elif size == 127 :
while len(buf) < 8 :
buf += self._socket.recv(WebSocket.RECV_SIZE)
size = unpack('>Q', buf[:8])[0]
buf = buf[8:]
key = ()
if masked :
while len(buf) < 4 :
buf += self._socket.recv(WebSocket.RECV_SIZE)
key = unpack('BBBB', buf[:4])
buf = buf[4:]
while len(buf) < size :
buf += self._socket.recv(WebSocket.RECV_SIZE)
if self._server :
buf = self._mask(key, buf)
data += buf
return str(data), type_
def _mask(self, key, data) :
data = bytearray(data)
for i in xrange(len(data)) :
data[i] = data[i] ^ key[i % 4]
return data
def pong(self, data) :
pass
def _parse_handshake(self, handshake) :
'Returns a resource name and dictionary of header fields if handshake is valid'
req = handshake.split('\n')[0]
m = re.match(r'GET\s.*(/.*)\sHTTP/1\.[1-9]', req)
if not m :
return None
# check header for required values as per rfc
# using a liberal grammar here
hdr = dict(re.findall(r'([^:\s]+):\s([^\r\n]*)', handshake))
try :
hdr['Host']
if hdr['Upgrade'].lower() != 'websocket' :
return None
if hdr['Connection'].lower() != 'upgrade' :
return None
# length of base64 key must be 16 bytes
key = hdr['Sec-WebSocket-Key']
last_group = 0
if key.endswith('=') :
last_group = 3 - len(key[key.find('='):])
if len(key.strip('=')) / 4 * 3 + last_group != 16 :
return None
if hdr['Sec-WebSocket-Version'] != '13' :
return None
except KeyError, e :
return None
return (m.group(1), hdr)
def _establish(self, socket, hdr) :
'Respond to a handshake'
accept = base64.b64encode(hashlib.sha1(hdr['Sec-WebSocket-Key'] + WebSocket.WS_GUID).digest())
response = [
'HTTP/1.1 101 Switching Protocols',
'Upgrade: websocket',
'Connection: Upgrade',
'Sec-WebSocket-Accept: ' + accept,
]
socket.sendall('\r\n'.join(response) + '\r\n' * 2)
def _bad_handshake(self, socket) :
'Send 400 Bad Request in response to bogus handshake'
try :
socket.sendall('HTTP/1.1 400 Bad Request\r\n\r\n')
except error :
pass