Skip to content

Commit

Permalink
Merge pull request #18 from gfoidl/allocation-tweaks
Browse files Browse the repository at this point in the history
Merge SocketAsyncEventArgs into SocketAwaitable and use cached instances
  • Loading branch information
mycroes committed Jun 18, 2023
2 parents bea5b4c + 0a6f371 commit 16abe22
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 43 deletions.
24 changes: 18 additions & 6 deletions src/AdsClient/AmsSocket.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ public AmsSocket(string host, int port = 48898)
public bool Connected => TcpClient.Connected;

private AmsSocketConnection connection;
private SocketAwaitable socketAwaitable;

public void Close()
{
Expand Down Expand Up @@ -56,22 +57,33 @@ public async Task ConnectAsync(IIncomingMessageHandler messageHandler, Cancellat

public async Task SendAsync(byte[] message)
{
using var args = new SocketAsyncEventArgs();
args.SetBuffer(message, 0, message.Length);
var sa = Interlocked.Exchange(ref socketAwaitable, null) ?? new SocketAwaitable();
sa.SetBuffer(message, 0, message.Length);

await Socket.SendAsync(new SocketAwaitable(args));
await Socket.SendAwaitable(sa);

if (Interlocked.CompareExchange(ref socketAwaitable, sa, null) != null)
{
sa.Dispose();
}
}

public async Task SendAsync(ArraySegment<byte> buffer)
{
using var args = new SocketAsyncEventArgs();
args.SetBuffer(buffer.Array, buffer.Offset, buffer.Count);
var sa = Interlocked.Exchange(ref socketAwaitable, null) ?? new SocketAwaitable();
sa.SetBuffer(buffer.Array, buffer.Offset, buffer.Count);

await Socket.SendAsync(new SocketAwaitable(args));
await Socket.SendAwaitable(sa);

if (Interlocked.CompareExchange(ref socketAwaitable, sa, null) != null)
{
sa.Dispose();
}
}

void IDisposable.Dispose()
{
Close();
TcpClient?.Dispose();
}
}
Expand Down
31 changes: 16 additions & 15 deletions src/AdsClient/AmsSocketConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,24 @@

namespace Viscon.Communication.Ads;

internal class AmsSocketConnection
internal sealed class AmsSocketConnection
{
private const int ReceiveTaskTimeout = 3000;

private readonly Socket socket;
private readonly IIncomingMessageHandler messageHandler;
private readonly Task receiveTask;
private readonly SocketAwaitable socketAwaitable = new SocketAwaitable();
private readonly byte[] headerBuffer = new byte[AmsHeaderHelper.AmsTcpHeaderSize];

public AmsSocketConnection(Socket socket, IIncomingMessageHandler messageHandler)
{
this.socket = socket;
this.messageHandler = messageHandler;
receiveTask = Task.Run(ReceiveLoop);
receiveTask = ReceiveLoop();
}

private bool closed;
private volatile bool closed;

public void Close()
{
Expand All @@ -30,6 +32,7 @@ public void Close()
socket.Close();

receiveTask.Wait(ReceiveTaskTimeout);
socketAwaitable.Dispose();
}

private async Task<byte[]> GetAmsMessage(byte[] tcpHeader)
Expand All @@ -49,7 +52,6 @@ private Task GetMessage(byte[] response)

private async Task Listen()
{

try
{
var buffer = await ListenForHeader();
Expand Down Expand Up @@ -78,33 +80,32 @@ private async Task Listen()

private async Task<byte[]> ListenForHeader()
{
var buffer = new byte[AmsHeaderHelper.AmsTcpHeaderSize];
await ReceiveAsync(buffer);

return buffer;
await ReceiveAsync(headerBuffer);
return headerBuffer;
}

private async Task ReceiveAsync(byte[] buffer)
{
using var args = new SocketAsyncEventArgs();
args.SetBuffer(buffer, 0, buffer.Length);
var awaitable = new SocketAwaitable(args);
var sa = socketAwaitable;
sa.SetBuffer(buffer, 0, buffer.Length);

do
{
args.SetBuffer(args.Offset + args.BytesTransferred, args.Count - args.BytesTransferred);
await socket.ReceiveAsync(awaitable);
sa.SetBuffer(sa.Offset + sa.BytesTransferred, sa.Count - sa.BytesTransferred);
await socket.ReceiveAwaitable(sa);

if (args.BytesTransferred == 0)
if (sa.BytesTransferred == 0)
{
messageHandler.HandleException(new Exception("Remote host closed the connection."));
Close();
}
} while (args.Count != args.BytesTransferred);
} while (socketAwaitable.BytesTransferred != buffer.Length);
}

private async Task ReceiveLoop()
{
await Task.Yield();

while (!closed)
{
await Listen();
Expand Down
7 changes: 3 additions & 4 deletions src/AdsClient/Helpers/IdGenerator.cs
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading;

namespace Viscon.Communication.Ads.Helpers
{
internal class IdGenerator
{
private uint id;
private int id;

public uint Next()
{
return (uint)Interlocked.Increment(ref Unsafe.As<uint, int>(ref id));
return (uint)Interlocked.Increment(ref id);
}
}
}
2 changes: 1 addition & 1 deletion src/AdsClient/Internal/Assertions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ internal class Assertions
{
public static void AssertDataLength(ReadOnlySpan<byte> buffer, int length, int offset)
{
if (length + offset != buffer.Length)
if (length != buffer.Length - offset)
{
throw new Exception(
$"Received {buffer.Length} bytes of data, but length indicates {length} bytes remaining at offset {offset}, resulting in a expected total of {length + offset} bytes.");
Expand Down
19 changes: 7 additions & 12 deletions src/AdsClient/Internal/SocketAwaitable.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,17 @@ namespace Viscon.Communication.Ads.Internal;
/// <remarks>
/// Based on https://devblogs.microsoft.com/pfxteam/awaiting-socket-operations/.
/// </remarks>
internal sealed class SocketAwaitable : INotifyCompletion
internal sealed class SocketAwaitable : SocketAsyncEventArgs, INotifyCompletion
{
private static readonly Action Sentinel = () => { };

public bool WasCompleted;
public volatile bool WasCompleted;
private Action? continuation;

Check warning on line 19 in src/AdsClient/Internal/SocketAwaitable.cs

View workflow job for this annotation

GitHub Actions / run_test

The annotation for nullable reference types should only be used in code within a '#nullable' annotations context.

Check warning on line 19 in src/AdsClient/Internal/SocketAwaitable.cs

View workflow job for this annotation

GitHub Actions / create_nuget

The annotation for nullable reference types should only be used in code within a '#nullable' annotations context.
public readonly SocketAsyncEventArgs EventArgs;

public SocketAwaitable(SocketAsyncEventArgs eventArgs)
protected override void OnCompleted(SocketAsyncEventArgs _)
{
EventArgs = eventArgs ?? throw new ArgumentNullException(nameof(eventArgs));
eventArgs.Completed += delegate
{
var prev = continuation ?? Interlocked.CompareExchange(ref continuation, Sentinel, null);
prev?.Invoke();
};
var prev = continuation ?? Interlocked.CompareExchange(ref continuation, Sentinel, null);
prev?.Invoke();
}

internal void Reset()
Expand All @@ -53,9 +48,9 @@ public void OnCompleted(Action continuation)

public void GetResult()
{
if (EventArgs.SocketError != SocketError.Success)
if (SocketError != SocketError.Success)
{
throw new SocketException((int)EventArgs.SocketError);
throw new SocketException((int)SocketError);
}
}
}
8 changes: 4 additions & 4 deletions src/AdsClient/Internal/SocketExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,18 @@ namespace Viscon.Communication.Ads.Internal;
/// </remarks>
internal static class SocketExtensions
{
public static SocketAwaitable ReceiveAsync(this Socket socket, SocketAwaitable awaitable)
public static SocketAwaitable ReceiveAwaitable(this Socket socket, SocketAwaitable awaitable)
{
awaitable.Reset();
if (!socket.ReceiveAsync(awaitable.EventArgs))
if (!socket.ReceiveAsync(awaitable))
awaitable.WasCompleted = true;
return awaitable;
}

public static SocketAwaitable SendAsync(this Socket socket, SocketAwaitable awaitable)
public static SocketAwaitable SendAwaitable(this Socket socket, SocketAwaitable awaitable)
{
awaitable.Reset();
if (!socket.SendAsync(awaitable.EventArgs))
if (!socket.SendAsync(awaitable))
awaitable.WasCompleted = true;
return awaitable;
}
Expand Down
4 changes: 3 additions & 1 deletion src/AdsClient/Internal/TypeExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ internal static class TypeExtensions
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static ref byte GetOffset(this ref byte source, int offset)
{
return ref Unsafe.Add(ref source, offset);
// The cast to uint is in order to avoid a sign-extending move
// in the machine code.
return ref Unsafe.Add(ref source, (uint)offset);
}
}

0 comments on commit 16abe22

Please sign in to comment.