Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions ws4py/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

class WebSocketBaseClient(WebSocket):
def __init__(self, url, protocols=None, extensions=None,
heartbeat_freq=None, ssl_options=None, headers=None):
heartbeat_freq=None, ssl_options=None, headers=None, exclude_headers=None):
"""
A websocket client that implements :rfc:`6455` and provides a simple
interface to communicate with a websocket server.
Expand Down Expand Up @@ -78,6 +78,8 @@ def __init__(self, url, protocols=None, extensions=None,
self.resource = None
self.ssl_options = ssl_options or {}
self.extra_headers = headers or []
self.exclude_headers = exclude_headers or []
self.exclude_headers = [x.lower() for x in self.exclude_headers]

if self.scheme == "wss":
# Prevent check_hostname requires server_hostname (ref #187)
Expand Down Expand Up @@ -211,7 +213,7 @@ def connect(self):
# default port is now 443; upgrade self.sender to send ssl
self.sock = ssl.wrap_socket(self.sock, **self.ssl_options)
self._is_secure = True

self.sock.connect(self.bind_addr)

self._write(self.handshake_request)
Expand Down Expand Up @@ -257,14 +259,15 @@ def handshake_headers(self):
('Sec-WebSocket-Key', self.key.decode('utf-8')),
('Sec-WebSocket-Version', str(max(WS_VERSION)))
]

if self.protocols:
headers.append(('Sec-WebSocket-Protocol', ','.join(self.protocols)))

if self.extra_headers:
headers.extend(self.extra_headers)

if not any(x for x in headers if x[0].lower() == 'origin'):
if not any(x for x in headers if x[0].lower() == 'origin') and \
'origin' not in self.exclude_headers:

scheme, url = self.url.split(":", 1)
parsed = urlsplit(url, scheme="http")
Expand All @@ -277,6 +280,8 @@ def handshake_headers(self):
origin = origin + ':' + str(parsed.port)
headers.append(('Origin', origin))

headers = [x for x in headers if x[0].lower() not in self.exclude_headers]

return headers

@property
Expand Down
4 changes: 2 additions & 2 deletions ws4py/client/geventclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
__all__ = ['WebSocketClient']

class WebSocketClient(WebSocketBaseClient):
def __init__(self, url, protocols=None, extensions=None, heartbeat_freq=None, ssl_options=None, headers=None):
def __init__(self, url, protocols=None, extensions=None, heartbeat_freq=None, ssl_options=None, headers=None, exclude_headers=None):
"""
WebSocket client that executes the
:meth:`run() <ws4py.websocket.WebSocket.run>` into a gevent greenlet.
Expand Down Expand Up @@ -41,7 +41,7 @@ def outgoing():
gevent.joinall(greenlets)
"""
WebSocketBaseClient.__init__(self, url, protocols, extensions, heartbeat_freq,
ssl_options=ssl_options, headers=headers)
ssl_options=ssl_options, headers=headers, exclude_headers=exclude_headers)
self._th = Greenlet(self.run)

self.messages = Queue()
Expand Down
4 changes: 2 additions & 2 deletions ws4py/client/threadedclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

class WebSocketClient(WebSocketBaseClient):
def __init__(self, url, protocols=None, extensions=None, heartbeat_freq=None,
ssl_options=None, headers=None):
ssl_options=None, headers=None, exclude_headers=None):
"""
.. code-block:: python

Expand All @@ -32,7 +32,7 @@ def received_message(self, m):

"""
WebSocketBaseClient.__init__(self, url, protocols, extensions, heartbeat_freq,
ssl_options, headers=headers)
ssl_options, headers=headers, exclude_headers=exclude_headers)
self._th = threading.Thread(target=self.run, name='WebSocketClient')
self._th.daemon = True

Expand Down
4 changes: 2 additions & 2 deletions ws4py/client/tornadoclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

class TornadoWebSocketClient(WebSocketBaseClient):
def __init__(self, url, protocols=None, extensions=None,
io_loop=None, ssl_options=None, headers=None):
io_loop=None, ssl_options=None, headers=None, exclude_headers=None):
"""
.. code-block:: python

Expand All @@ -32,7 +32,7 @@ def closed(self, code, reason=None):
ioloop.IOLoop.instance().start()
"""
WebSocketBaseClient.__init__(self, url, protocols, extensions,
ssl_options=ssl_options, headers=headers)
ssl_options=ssl_options, headers=headers, exclude_headers=exclude_headers)
if self.scheme == "wss":
self.sock = ssl.wrap_socket(self.sock, do_handshake_on_connect=False, **self.ssl_options)
self._is_secure = True
Expand Down