Skip to content

Commit

Permalink
Check limits of messages (CVE-2022-25304) (#1040)
Browse files Browse the repository at this point in the history
* check message limits on recv

* add ErrorMessage handling

* add to large chunk test

* client disconnect on ErrorMessage

* change default transport limits
  • Loading branch information
schroeder- committed Sep 20, 2022
1 parent c66f34c commit 01c7acf
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 17 deletions.
14 changes: 10 additions & 4 deletions asyncua/client/ua_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
Low level binary client
"""
import asyncio
import copy
import logging
from typing import Awaitable, Callable, Dict, List, Optional, Union

from asyncua import ua
from ..ua.ua_binary import struct_from_binary, uatcp_to_binary, struct_to_binary, nodeid_from_binary, header_from_binary
from ..ua.uaerrors import BadTimeout, BadNoSubscription, BadSessionClosed, BadUserAccessDenied, UaStructParsingError
from ..common.connection import SecureConnection
from ..common.connection import SecureConnection, TransportLimits


class UASocketProtocol(asyncio.Protocol):
Expand All @@ -20,7 +21,7 @@ class UASocketProtocol(asyncio.Protocol):
OPEN = 'open'
CLOSED = 'closed'

def __init__(self, timeout: float = 1, security_policy: ua.SecurityPolicy = ua.SecurityPolicy()):
def __init__(self, timeout: float = 1, security_policy: ua.SecurityPolicy = ua.SecurityPolicy(), limits: TransportLimits = None):
"""
:param timeout: Timeout in seconds
:param security_policy: Security policy (optional)
Expand All @@ -34,7 +35,12 @@ def __init__(self, timeout: float = 1, security_policy: ua.SecurityPolicy = ua.S
self._request_id = 0
self._request_handle = 0
self._callbackmap: Dict[int, asyncio.Future] = {}
self._connection = SecureConnection(security_policy)
if limits is None:
limits = TransportLimits(65535, 65535, 0, 0)
else:
limits = copy.deep_copy(limits) # Make a copy because the limits can change in the session
self._connection = SecureConnection(security_policy, limits)

self.state = self.INITIALIZED
self.closed: bool = False
# needed to pass params from asynchronous request to synchronous data receive callback, as well as
Expand Down Expand Up @@ -103,7 +109,7 @@ def _process_received_message(self, msg: Union[ua.Message, ua.Acknowledge, ua.Er
self._call_callback(0, msg)
elif isinstance(msg, ua.ErrorMessage):
self.logger.fatal("Received an error: %r", msg)
self._call_callback(0, ua.UaStatusCodeError(msg.Error.value))
self.disconnect_socket()
else:
raise ua.UaError(f"Unsupported message type: {msg}")

Expand Down
66 changes: 61 additions & 5 deletions asyncua/common/connection.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from dataclasses import dataclass
import hashlib
from datetime import datetime, timedelta
import logging
Expand All @@ -15,6 +16,56 @@ class InvalidSignature(Exception): # type: ignore
logger = logging.getLogger('asyncua.uaprotocol')


@dataclass
class TransportLimits:
'''
Limits of the tcp transport layer to prevent excessive resource usage
'''
max_recv_buffer: int = 65535
max_send_buffer: int = 65535
max_chunk_count: int = ((100 * 1024 * 1024) // 65535) + 1 # max_message_size / max_recv_buffer
max_message_size: int = 100 * 1024 * 1024 # 100mb

@staticmethod
def _select_limit(hint: ua.UInt32, limit: int) -> ua.UInt32:
if limit <= 0:
return hint
elif limit < hint:
return hint
return ua.UInt32(limit)

def check_max_msg_size(self, sz: int) -> bool:
if self.max_message_size == 0:
return True
return self.max_message_size <= sz

def check_max_chunk_count(self, sz: int) -> bool:
if self.max_chunk_count == 0:
return True
return self.max_chunk_count <= sz

def create_acknowledge_limits(self, msg: ua.Hello) -> ua.Acknowledge:
ack = ua.Acknowledge()
ack.ReceiveBufferSize = min(msg.ReceiveBufferSize, self.max_recv_buffer)
ack.SendBufferSize = min(msg.SendBufferSize, self.max_send_buffer)
ack.MaxChunkCount = self._select_limit(msg.MaxChunkCount, self.max_chunk_count)
ack.MaxMessageSize = self._select_limit(msg.MaxMessageSize, self.max_message_size)
self.update_limits(ack)
return ack

def create_hello_limits(self, msg: ua.Hello) -> ua.Hello:
msg.ReceiveBufferSize = self.max_recv_buffer
msg.SendBufferSize = self.max_send_buffer
msg.MaxChunkCount = self.max_chunk_count
msg.MaxMessageSize = self.max_chunk_count

def update_limits(self, msg: ua.Acknowledge) -> None:
self.max_chunk_count = msg.MaxChunkCount
self.max_recv_buffer = msg.ReceiveBufferSize
self.max_send_buffer = msg.SendBufferSize
self.max_message_size = msg.MaxMessageSize


class MessageChunk:
"""
Message Chunk, as described in OPC UA specs Part 6, 6.7.2.
Expand Down Expand Up @@ -139,7 +190,7 @@ class SecureConnection:
"""
Common logic for client and server
"""
def __init__(self, security_policy):
def __init__(self, security_policy, limits: TransportLimits):
self._sequence_number = 0
self._peer_sequence_number = None
self._incoming_parts = []
Expand All @@ -152,7 +203,7 @@ def __init__(self, security_policy):
self.local_nonce = 0
self.remote_nonce = 0
self._allow_prev_token = False
self._max_chunk_size = 65536
self._limits = limits

def set_channel(self, params, request_type, client_nonce):
"""
Expand Down Expand Up @@ -257,7 +308,7 @@ def message_to_binary(self, message, message_type=ua.MessageType.SecureMessage,
chunks = MessageChunk.message_to_chunks(
self.security_policy,
message,
self._max_chunk_size,
self._limits.max_send_buffer,
message_type=message_type,
channel_id=self.security_token.ChannelId,
request_id=request_id,
Expand Down Expand Up @@ -353,11 +404,10 @@ def receive_from_header_and_body(self, header, body):
return self._receive(chunk)
if header.MessageType == ua.MessageType.Hello:
msg = struct_from_binary(ua.Hello, body)
self._max_chunk_size = msg.ReceiveBufferSize
return msg
if header.MessageType == ua.MessageType.Acknowledge:
msg = struct_from_binary(ua.Acknowledge, body)
self._max_chunk_size = msg.SendBufferSize
self._limits.update_limits(msg)
return msg
if header.MessageType == ua.MessageType.Error:
msg = struct_from_binary(ua.ErrorMessage, body)
Expand All @@ -366,8 +416,14 @@ def receive_from_header_and_body(self, header, body):
raise ua.UaError(f"Unsupported message type {header.MessageType}")

def _receive(self, msg):
if msg.MessageHeader.packet_size > self._limits.max_recv_buffer:
self._incoming_parts = []
raise ua.UaStatusCodeError(ua.StatusCodes.BadRequestTooLarge)
self._check_incoming_chunk(msg)
self._incoming_parts.append(msg)
if not self._limits.check_max_chunk_count(len(self._incoming_parts)):
self._incoming_parts = []
raise ua.UaStatusCodeError(ua.StatusCodes.BadRequestTooLarge)
if msg.MessageHeader.ChunkType == ua.ChunkType.Intermediate:
return None
if msg.MessageHeader.ChunkType == ua.ChunkType.Abort:
Expand Down
17 changes: 15 additions & 2 deletions asyncua/server/binary_server_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
"""
import logging
import asyncio
import math
from typing import Optional

from ..common.connection import TransportLimits
from ..ua.ua_binary import header_from_binary
from ..common.utils import Buffer, NotEnoughData
from .uaprocessor import UaProcessor
Expand All @@ -18,7 +20,7 @@ class OPCUAProtocol(asyncio.Protocol):
Instantiated for every connection.
"""

def __init__(self, iserver: InternalServer, policies, clients, closing_tasks):
def __init__(self, iserver: InternalServer, policies, clients, closing_tasks, limits: TransportLimits):
self.peer_name = None
self.transport = None
self.processor = None
Expand All @@ -28,6 +30,7 @@ def __init__(self, iserver: InternalServer, policies, clients, closing_tasks):
self.clients = clients
self.closing_tasks = closing_tasks
self.messages = asyncio.Queue()
self.limits = limits
self._task = None

def __str__(self):
Expand All @@ -39,7 +42,7 @@ def connection_made(self, transport):
self.peer_name = transport.get_extra_info('peername')
logger.info('New connection from %s', self.peer_name)
self.transport = transport
self.processor = UaProcessor(self.iserver, self.transport)
self.processor = UaProcessor(self.iserver, self.transport, self.limits)
self.processor.set_policies(self.policies)
self.iserver.asyncio_transports.append(transport)
self.clients.append(self)
Expand Down Expand Up @@ -119,6 +122,15 @@ def __init__(self, internal_server: InternalServer, hostname, port):
self.clients = []
self.closing_tasks = []
self.cleanup_task = None
# Use accectable limits
buffer_sz = 65535
max_msg_sz = 16 * 1024 * 1024 # 16mb simular to the opc ua c stack so this is a good default
self.limits = TransportLimits(
max_recv_buffer=buffer_sz,
max_send_buffer=buffer_sz,
max_chunk_count=math.ceil(buffer_sz / max_msg_sz), # Round up to allow max msg size
max_message_size=max_msg_sz
)

def set_policies(self, policies):
self._policies = policies
Expand All @@ -130,6 +142,7 @@ def _make_protocol(self):
policies=self._policies,
clients=self.clients,
closing_tasks=self.closing_tasks,
limits=self.limits
)

async def start(self):
Expand Down
18 changes: 12 additions & 6 deletions asyncua/server/uaprocessor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import time
import logging
from typing import Deque, Optional
Expand All @@ -6,7 +7,7 @@
from asyncua import ua
from ..ua.ua_binary import nodeid_from_binary, struct_from_binary, struct_to_binary, uatcp_to_binary
from .internal_server import InternalServer, InternalSession
from ..common.connection import SecureConnection
from ..common.connection import SecureConnection, TransportLimits
from ..common.utils import ServiceError

_logger = logging.getLogger(__name__)
Expand All @@ -25,7 +26,7 @@ class UaProcessor:
Processor for OPC UA messages. Implements the OPC UA protocol for the server side.
"""

def __init__(self, internal_server: InternalServer, transport):
def __init__(self, internal_server: InternalServer, transport, limits: TransportLimits):
self.iserver: InternalServer = internal_server
self.name = transport.get_extra_info('peername')
self.sockname = transport.get_extra_info('sockname')
Expand All @@ -35,7 +36,8 @@ def __init__(self, internal_server: InternalServer, transport):
self._publish_requests: Deque[PublishRequestData] = deque()
# used when we need to wait for PublishRequest
self._publish_results: Deque[ua.PublishResult] = deque()
self._connection = SecureConnection(ua.SecurityPolicy())
self._limits = copy.deepcopy(limits) # Copy limits because they get overriden
self._connection = SecureConnection(ua.SecurityPolicy(), self._limits)

def set_policies(self, policies):
self._connection.set_policy_factories(policies)
Expand Down Expand Up @@ -89,6 +91,12 @@ async def forward_publish_response(self, result: ua.PublishResult):
async def process(self, header, body):
try:
msg = self._connection.receive_from_header_and_body(header, body)
except ua.uaerrors.BadRequestTooLarge as e:
_logger.warning("Recived request that exceed the transport limits")
err = ua.ErrorMessage(ua.StatusCode(e.code), str(e))
data = uatcp_to_binary(ua.MessageType.Error, err)
self._transport.write(data)
return True
except ua.uaerrors.BadUserAccessDenied:
_logger.warning("Unauthenticated user attempted to connect")
return False
Expand All @@ -101,9 +109,7 @@ async def process(self, header, body):
elif header.MessageType == ua.MessageType.SecureMessage:
return await self.process_message(msg.SequenceHeader(), msg.body())
elif isinstance(msg, ua.Hello):
ack = ua.Acknowledge()
ack.ReceiveBufferSize = msg.ReceiveBufferSize
ack.SendBufferSize = msg.SendBufferSize
ack = self._limits.create_acknowledge_limits(msg)
data = uatcp_to_binary(ua.MessageType.Acknowledge, ack)
self._transport.write(data)
elif isinstance(msg, ua.ErrorMessage):
Expand Down
24 changes: 24 additions & 0 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,30 @@ async def test_server_read_write_attribute_value(server: Server):
assert dv.Value.Value == 5
await server.delete_nodes([node])


@pytest.fixture(scope="function")
def restore_transport_limits_server(server: Server):
# Restore limits after test
max_recv = server.bserver.limits.max_recv_buffer
max_chunk_count = server.bserver.limits.max_chunk_count
yield server
server.bserver.limits.max_recv_buffer = max_recv
server.bserver.limits.max_chunk_count = max_chunk_count


async def test_message_limits(restore_transport_limits_server: Server):
server = restore_transport_limits_server
server.bserver.limits.max_recv_buffer = 1024
server.bserver.limits.max_chunk_count = 10
client = Client(server.endpoint.geturl())
# This should trigger a timeout error because the message is to large
with pytest.raises(asyncio.TimeoutError):
async with client:
test_string = 'a' * (1024 * 1024 * 1024)
n = client.get_node(ua.NodeId())
await n.write_value(test_string, ua.VariantType.String)


"""
class TestServerCaching(unittest.TestCase):
def runTest(self):
Expand Down

0 comments on commit 01c7acf

Please sign in to comment.