Skip to content

Commit

Permalink
Check DiscoveryOnly before the message is parsed. (#2112)
Browse files Browse the repository at this point in the history
A discovery only channel only accepts discovery service calls. The type of service call should be checked before the message is parsed. 
- Add a check for discovery message before the chunks are decoded.
- Add a check that discovery service call are below a restricted message size (e.g. 64k).
- Add a test case.
  • Loading branch information
mregen committed Apr 11, 2023
1 parent 8851b04 commit b5d4dd7
Show file tree
Hide file tree
Showing 9 changed files with 243 additions and 51 deletions.
53 changes: 37 additions & 16 deletions Stack/Opc.Ua.Bindings.Https/Stack/Https/HttpsTransportListener.cs
Original file line number Diff line number Diff line change
Expand Up @@ -361,14 +361,26 @@ public async Task SendAsync(HttpContext context)
}
}

if (endpoint == null &&
input.TypeId != DataTypeIds.GetEndpointsRequest &&
input.TypeId != DataTypeIds.FindServersRequest)
if (endpoint == null)
{
message = "Connection refused, invalid security policy.";
Utils.LogError(message);
await WriteResponseAsync(context.Response, message, HttpStatusCode.Unauthorized).ConfigureAwait(false);
return;
ServiceResultException serviceResultException = null;
if (input.TypeId != DataTypeIds.GetEndpointsRequest &&
input.TypeId != DataTypeIds.FindServersRequest &&
input.TypeId != DataTypeIds.FindServersOnNetworkRequest)
{
serviceResultException = new ServiceResultException(StatusCodes.BadSecurityPolicyRejected, "Channel can only be used for discovery.");
}
else if (length > TcpMessageLimits.DefaultDiscoveryMaxMessageSize)
{
serviceResultException = new ServiceResultException(StatusCodes.BadSecurityPolicyRejected, "Discovery Channel message size exceeded.");
}

if (serviceResultException != null)
{
IServiceResponse serviceResponse = EndpointBase.CreateFault(null, serviceResultException);
await WriteServiceResponseAsync(context, serviceResponse, ct).ConfigureAwait(false);
return;
}
}

// note: do not use Task.Factory.FromAsync here
Expand All @@ -381,15 +393,8 @@ public async Task SendAsync(HttpContext context)

IServiceResponse output = m_callback.EndProcessRequest(result);

byte[] response = BinaryEncoder.EncodeMessage(output, m_quotas.MessageContext);
context.Response.ContentLength = response.Length;
context.Response.ContentType = context.Request.ContentType;
context.Response.StatusCode = (int)HttpStatusCode.OK;
#if NETSTANDARD2_1 || NET5_0_OR_GREATER || NETCOREAPP3_1_OR_GREATER
await context.Response.Body.WriteAsync(response.AsMemory(0, response.Length), ct).ConfigureAwait(false);
#else
await context.Response.Body.WriteAsync(response, 0, response.Length, ct).ConfigureAwait(false);
#endif
await WriteServiceResponseAsync(context, output, ct).ConfigureAwait(false);

return;
}
catch (Exception e)
Expand Down Expand Up @@ -438,6 +443,22 @@ public async Task SendAsync(HttpContext context)
Start();
}

/// <summary>
/// Encodes a service response and writes it back.
/// </summary>
private async Task WriteServiceResponseAsync(HttpContext context, IServiceResponse response, CancellationToken ct)
{
byte[] encodedResponse = BinaryEncoder.EncodeMessage(response, m_quotas.MessageContext);
context.Response.ContentLength = encodedResponse.Length;
context.Response.ContentType = context.Request.ContentType;
context.Response.StatusCode = (int)HttpStatusCode.OK;
#if NETSTANDARD2_1 || NET5_0_OR_GREATER || NETCOREAPP3_1_OR_GREATER
await context.Response.Body.WriteAsync(encodedResponse.AsMemory(0, encodedResponse.Length), ct).ConfigureAwait(false);
#else
await context.Response.Body.WriteAsync(encodedResponse, 0, encodedResponse.Length, ct).ConfigureAwait(false);
#endif
}

private static Task WriteResponseAsync(HttpResponse response, string message, HttpStatusCode status)
{
response.ContentLength = message.Length;
Expand Down
4 changes: 2 additions & 2 deletions Stack/Opc.Ua.Core/Stack/Server/EndpointBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ protected ServiceDefinition FindService(ExpandedNodeId requestTypeId)
/// <param name="request">The request.</param>
/// <param name="exception">The exception.</param>
/// <returns>A fault message.</returns>
protected static ServiceFault CreateFault(IServiceRequest request, Exception exception)
public static ServiceFault CreateFault(IServiceRequest request, Exception exception)
{
DiagnosticsMasks diagnosticsMask = DiagnosticsMasks.ServiceNoInnerStatus;

Expand Down Expand Up @@ -452,7 +452,7 @@ protected static ServiceFault CreateFault(IServiceRequest request, Exception exc
/// <param name="request">The request.</param>
/// <param name="exception">The exception.</param>
/// <returns>A fault message.</returns>
protected static Exception CreateSoapFault(IServiceRequest request, Exception exception)
public static Exception CreateSoapFault(IServiceRequest request, Exception exception)
{
ServiceFault fault = CreateFault(request, exception);

Expand Down
5 changes: 5 additions & 0 deletions Stack/Opc.Ua.Core/Stack/Tcp/TcpMessageType.cs
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,11 @@ public static class TcpMessageLimits
/// </summary>
public const int DefaultMaxMessageSize = DefaultMaxChunkCount * DefaultMaxBufferSize;

/// <summary>
/// The default maximum message size for the discovery channel.
/// </summary>
public const int DefaultDiscoveryMaxMessageSize = DefaultMaxBufferSize;

/// <summary>
/// How long a connection will remain in the server after it goes into a faulted state.
/// </summary>
Expand Down
85 changes: 68 additions & 17 deletions Stack/Opc.Ua.Core/Stack/Tcp/TcpServerChannel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -433,24 +433,25 @@ private bool ProcessHelloMessage(ArraySegment<byte> messageChunk)

try
{
MemoryStream ostrm = new MemoryStream(buffer, 0, SendBufferSize);
BinaryEncoder encoder = new BinaryEncoder(ostrm, Quotas.MessageContext);

encoder.WriteUInt32(null, TcpMessageType.Acknowledge);
encoder.WriteUInt32(null, 0);
encoder.WriteUInt32(null, 0); // ProtocolVersion
encoder.WriteUInt32(null, (uint)ReceiveBufferSize);
encoder.WriteUInt32(null, (uint)SendBufferSize);
encoder.WriteUInt32(null, (uint)MaxRequestMessageSize);
encoder.WriteUInt32(null, (uint)MaxRequestChunkCount);
using (MemoryStream ostrm = new MemoryStream(buffer, 0, SendBufferSize))
using (BinaryEncoder encoder = new BinaryEncoder(ostrm, Quotas.MessageContext))
{
encoder.WriteUInt32(null, TcpMessageType.Acknowledge);
encoder.WriteUInt32(null, 0);
encoder.WriteUInt32(null, 0); // ProtocolVersion
encoder.WriteUInt32(null, (uint)ReceiveBufferSize);
encoder.WriteUInt32(null, (uint)SendBufferSize);
encoder.WriteUInt32(null, (uint)MaxRequestMessageSize);
encoder.WriteUInt32(null, (uint)MaxRequestChunkCount);

int size = encoder.Close();
UpdateMessageSize(buffer, 0, size);
int size = encoder.Close();
UpdateMessageSize(buffer, 0, size);

// now ready for the open or bind request.
State = TcpChannelState.Opening;
// now ready for the open or bind request.
State = TcpChannelState.Opening;

BeginWriteMessage(new ArraySegment<byte>(buffer, 0, size), null);
BeginWriteMessage(new ArraySegment<byte>(buffer, 0, size), null);
}
buffer = null;
}
finally
Expand Down Expand Up @@ -956,11 +957,37 @@ private bool ProcessRequestMessage(uint messageType, ArraySegment<byte> messageC
// check if it is necessary to wait for more chunks.
if (!TcpMessageType.IsFinal(messageType))
{
SaveIntermediateChunk(requestId, messageBody, true);
bool firstChunk = SaveIntermediateChunk(requestId, messageBody, true);

// validate the type is allowed with a discovery channel
if (DiscoveryOnly)
{
if (firstChunk)
{
if (!ValidateDiscoveryServiceCall(token, requestId, messageBody, out chunksToProcess))
{
ChannelClosed();
}
}
else if (GetSavedChunksTotalSize() > TcpMessageLimits.DefaultDiscoveryMaxMessageSize)
{
chunksToProcess = GetSavedChunks(0, messageBody, true);
SendServiceFault(token, requestId, ServiceResult.Create(StatusCodes.BadSecurityPolicyRejected, "Discovery Channel message size exceeded."));
ChannelClosed();
}
}

return true;
}

// Utils.LogTrace("ChannelId {0}: ProcessRequestMessage RequestId {1}", ChannelId, requestId);
if (DiscoveryOnly && GetSavedChunksTotalSize() == 0)
{
if (!ValidateDiscoveryServiceCall(token, requestId, messageBody, out chunksToProcess))
{
return true;
}
}

// get the chunks to process.
chunksToProcess = GetSavedChunks(requestId, messageBody, true);
Expand All @@ -977,7 +1004,7 @@ private bool ProcessRequestMessage(uint messageType, ArraySegment<byte> messageC
// ensure that only discovery requests come through unsecured.
if (DiscoveryOnly)
{
if (!(request is GetEndpointsRequest || request is FindServersRequest))
if (!(request is GetEndpointsRequest || request is FindServersRequest || request is FindServersOnNetworkRequest))
{
SendServiceFault(token, requestId, ServiceResult.Create(StatusCodes.BadSecurityPolicyRejected, "Channel can only be used for discovery."));
return true;
Expand Down Expand Up @@ -1083,6 +1110,30 @@ protected override void DoMessageLimitsExceeded()
base.DoMessageLimitsExceeded();
ChannelClosed();
}

/// <summary>
/// Validate the type of message before it is decoded.
/// </summary>
private bool ValidateDiscoveryServiceCall(ChannelToken token, uint requestId, ArraySegment<byte> messageBody, out BufferCollection chunksToProcess)
{
chunksToProcess = null;
using (var decoder = new BinaryDecoder(messageBody.AsMemory().ToArray(), Quotas.MessageContext))
{
// read the type of the message before more chunks are processed.
NodeId typeId = decoder.ReadNodeId(null);

if (typeId != ObjectIds.GetEndpointsRequest_Encoding_DefaultBinary &&
typeId != ObjectIds.FindServersRequest_Encoding_DefaultBinary &&
typeId != ObjectIds.FindServersOnNetworkRequest_Encoding_DefaultBinary)
{
chunksToProcess = GetSavedChunks(0, messageBody, true);
SendServiceFault(token, requestId, ServiceResult.Create(StatusCodes.BadSecurityPolicyRejected, "Channel can only be used for discovery."));
return false;
}
return true;
}
}

#endregion

#region Private Fields
Expand Down
22 changes: 19 additions & 3 deletions Stack/Opc.Ua.Core/Stack/Tcp/UaSCBinaryChannel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -258,10 +258,12 @@ protected bool VerifySequenceNumber(uint sequenceNumber, string context)
/// <summary>
/// Saves an intermediate chunk for an incoming message.
/// </summary>
protected void SaveIntermediateChunk(uint requestId, ArraySegment<byte> chunk, bool isServerContext)
protected bool SaveIntermediateChunk(uint requestId, ArraySegment<byte> chunk, bool isServerContext)
{
bool firstChunk = false;
if (m_partialMessageChunks == null)
{
firstChunk = true;
m_partialMessageChunks = new BufferCollection();
}

Expand All @@ -280,14 +282,16 @@ protected void SaveIntermediateChunk(uint requestId, ArraySegment<byte> chunk, b
if (chunkOrSizeLimitsExceeded)
{
DoMessageLimitsExceeded();
return;
return firstChunk;
}

if (requestId != 0)
{
m_partialRequestId = requestId;
m_partialMessageChunks.Add(chunk);
}

return firstChunk;
}

/// <summary>
Expand All @@ -301,12 +305,24 @@ protected BufferCollection GetSavedChunks(uint requestId, ArraySegment<byte> chu
return savedChunks;
}

/// <summary>
/// Returns total length of the chunks saved for message.
/// </summary>
protected int GetSavedChunksTotalSize()
{
if (m_partialMessageChunks != null)
{
return m_partialMessageChunks.TotalSize;
}
return 0;
}

/// <summary>
/// Code executed when the message limits are exceeded.
/// </summary>
protected virtual void DoMessageLimitsExceeded()
{
Utils.LogError("ChannelId {0}: - Message limits exceeded while building up message. Channel will be closed", ChannelId);
Utils.LogError("ChannelId {0}: - Message limits exceeded while building up message. Channel will be closed.", ChannelId);
}
#endregion

Expand Down
13 changes: 7 additions & 6 deletions Stack/Opc.Ua.Core/Types/Encoders/BinaryEncoder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -179,13 +179,14 @@ public static byte[] EncodeMessage(IEncodeable message, IServiceMessageContext c
if (context == null) throw new ArgumentNullException(nameof(context));

// create encoder.
BinaryEncoder encoder = new BinaryEncoder(context);

// encode message
encoder.EncodeMessage(message);
using (BinaryEncoder encoder = new BinaryEncoder(context))
{
// encode message
encoder.EncodeMessage(message);

// close encoder.
return encoder.CloseAndReturnBuffer();
// close encoder.
return encoder.CloseAndReturnBuffer();
}
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,7 @@ public async Task OneTimeSetUpAsync(TextWriter writer = null)

m_clientFixture = new ClientFixture();
await m_clientFixture.LoadClientConfiguration(m_pkiRoot).ConfigureAwait(false);
m_clientFixture.Config.TransportQuotas.MaxMessageSize =
m_clientFixture.Config.TransportQuotas.MaxBufferSize = 4 * 1024 * 1024;
m_clientFixture.Config.TransportQuotas.MaxMessageSize = 4 * 1024 * 1024;
m_url = new Uri(m_uriScheme + "://localhost:" + m_serverFixture.Port.ToString());
try
{
Expand Down
Loading

0 comments on commit b5d4dd7

Please sign in to comment.