Skip to content

Commit

Permalink
fix(security): publish messages only to authenticated clients
Browse files Browse the repository at this point in the history
  • Loading branch information
Eino Gourdin committed Oct 13, 2022
1 parent d320547 commit c630baa
Showing 1 changed file with 21 additions and 10 deletions.
31 changes: 21 additions & 10 deletions python/src/wslink/backends/aiohttp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,13 +218,13 @@ def create_webserver(server_config):

class WslinkHandler(object):
def __init__(self, protocol=None, web_app=None):
self._valid_token = False
self.serverProtocol = protocol
self.web_app = web_app
self.functionMap = {}
self.attachmentsReceived = {}
self.attachmentsRecvQueue = []
self.connections = {}
self.authentified_client_ids = set()
self.attachment_atomic = asyncio.Lock()

# Build the rpc method dictionary, assuming we were given a serverprotocol
Expand Down Expand Up @@ -291,6 +291,7 @@ async def handleWsRequest(self, request):
await self.onClose(client_id)

del self.connections[client_id]
self.authentified_client_ids.discard(client_id)

logging.info("client {0} disconnected".format(client_id))

Expand Down Expand Up @@ -354,7 +355,7 @@ async def handleSystemMessage(self, rpcid, methodName, args, client_id):
and ("secret" in args[0])
and await self.validateToken(args[0]["secret"], client_id)
):
self._valid_token = True
self.authentified_client_ids.add(client_id)
await self.sendWrappedMessage(
rpcid,
{"clientID": "c{0}".format(client_id)},
Expand Down Expand Up @@ -382,7 +383,7 @@ async def onMessage(self, msg, client_id):
payload = msg.data

if isBinary:
if self._valid_token:
if self.isClientAuthenticated(client_id):
# assume all binary messages are attachments
try:
key = self.attachmentsRecvQueue.pop(0)
Expand Down Expand Up @@ -426,7 +427,7 @@ async def onMessage(self, msg, client_id):
return

# Prevent any further processing if token is not valid
if not self._valid_token:
if not self.isClientAuthenticated(client_id):
await self.sendWrappedError(
rpcid,
pub.AUTHENTICATION_ERROR,
Expand Down Expand Up @@ -539,6 +540,19 @@ async def validateToken(self, token, client_id):
return True
return token == self.serverProtocol.secret

def isClientAuthenticated(self, client_id):
return client_id in self.authentified_client_ids

def getAuthenticatedWebsockets(self, client_id=None):
if client_id:
if self.isClientAuthenticated(client_id):
return [self.connections.get(client_id)]
else:
return []
else:
return [self.connections[c] for c in self.connections if self.isClientAuthenticated(c)]


async def sendWrappedMessage(self, rpcid, content, method="", client_id=None):
wrapper = {
"wslink": "1.0",
Expand All @@ -559,11 +573,7 @@ async def sendWrappedMessage(self, rpcid, content, method="", client_id=None):
)
return

websockets = (
[self.connections.get(client_id)]
if client_id
else [self.connections[c] for c in self.connections]
)
websockets = self.getAuthenticatedWebsockets(client_id)

# Check if any attachments in the map go with this message
attachments = pub.publishManager.getAttachmentMap()
Expand Down Expand Up @@ -631,7 +641,8 @@ async def sendWrappedError(self, rpcid, code, message, data=None, client_id=None
def publish(self, topic, data, client_id=None):
client_list = [client_id] if client_id else [c_id for c_id in self.connections]
for client in client_list:
pub.publishManager.publish(topic, data, client_id=client)
if self.isClientAuthenticated(client):
pub.publishManager.publish(topic, data, client_id=client)

def addAttachment(self, payload):
return pub.publishManager.addAttachment(payload)
Expand Down

0 comments on commit c630baa

Please sign in to comment.