Skip to content
This repository has been archived by the owner on Aug 5, 2019. It is now read-only.

Commit

Permalink
SSL: Authenticate continuation should complete only if the actual aut…
Browse files Browse the repository at this point in the history
…henticate task completes. This prevents double disconnect coming when AuthenticateAsServerAsync fails.

Fix local IP address check to take into account IPv6 loopback addresses.
  • Loading branch information
marius-klimantavicius committed Nov 24, 2016
1 parent a196289 commit a38e4cc
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 16 deletions.
15 changes: 14 additions & 1 deletion Nowin/IpIsLocalChecker.cs
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -32,7 +32,20 @@ public IpIsLocalChecker()


public bool IsLocal(IPAddress address) public bool IsLocal(IPAddress address)
{ {
return _dict.ContainsKey(address); if (_dict.ContainsKey(address))
return true;

if (IPAddress.IsLoopback(address))
return true;

if (address.IsIPv4MappedToIPv6)
{
var ip4 = address.MapToIPv4();
if (_dict.ContainsKey(ip4))
return true;
}

return false;
} }
} }
} }
35 changes: 23 additions & 12 deletions Nowin/SslTransportHandler.cs
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -194,11 +194,8 @@ public void FinishAccept(byte[] buffer, int offset, int length, IPEndPoint remot
_authenticateTask = _ssl.AuthenticateAsServerAsync(_serverParameters.Certificate, _serverParameters.ClientCertificateRequired, _serverParameters.Protocols, false).ContinueWith((t, selfObject) => _authenticateTask = _ssl.AuthenticateAsServerAsync(_serverParameters.Certificate, _serverParameters.ClientCertificateRequired, _serverParameters.Protocols, false).ContinueWith((t, selfObject) =>
{ {
var self = (SslTransportHandler)selfObject; var self = (SslTransportHandler)selfObject;
if (t.IsFaulted || t.IsCanceled) self._next.SetRemoteCertificate(_ssl.RemoteCertificate);
self.Callback.StartDisconnect(); }, this, TaskContinuationOptions.OnlyOnRanToCompletion);
else
_next.SetRemoteCertificate(_ssl.RemoteCertificate);
}, this);
_next.FinishAccept(_recvBuffer, _recvOffset, 0, remoteEndPoint, localEndPoint); _next.FinishAccept(_recvBuffer, _recvOffset, 0, remoteEndPoint, localEndPoint);
} }
catch (Exception) catch (Exception)
Expand Down Expand Up @@ -248,18 +245,32 @@ public void StartReceive(byte[] buffer, int offset, int length)
} }
else else
{ {
_ssl.ReadAsync(self._recvBuffer, self._recvOffset, self._recvLength).ContinueWith((t2, selfObject2) => try
{ {
var self2 = (SslTransportHandler)selfObject2; self._ssl.ReadAsync(self._recvBuffer, self._recvOffset, self._recvLength).ContinueWith((t2, selfObject2) =>
if (t2.IsFaulted || t2.IsCanceled || t2.Result == 0) {
self._next.FinishReceive(null, 0, -1); var self2 = (SslTransportHandler)selfObject2;
else if (t2.IsFaulted || t2.IsCanceled || t2.Result == 0)
self._next.FinishReceive(self2._recvBuffer, self2._recvOffset, t2.Result); self2._next.FinishReceive(null, 0, -1);
}, self); else
self2._next.FinishReceive(self2._recvBuffer, self2._recvOffset, t2.Result);
}, self);
}
catch (Exception)
{
self._next.FinishReceive(null, 0, -1);
}
} }
}, this); }, this);
return; return;
} }

if (_authenticateTask.IsCanceled || _authenticateTask.IsFaulted)
{
_next.FinishReceive(null, 0, -1);
return;
}

_ssl.ReadAsync(buffer, offset, length).ContinueWith((t, selfObject) => _ssl.ReadAsync(buffer, offset, length).ContinueWith((t, selfObject) =>
{ {
var self = (SslTransportHandler)selfObject; var self = (SslTransportHandler)selfObject;
Expand Down
4 changes: 1 addition & 3 deletions NowinTests/NowinTestsBase.cs
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -419,11 +419,9 @@ public void EnvironmentEmptyGetRequest()
Assert.True(env.TryGetValue("server.IsLocal", out ignored)); Assert.True(env.TryGetValue("server.IsLocal", out ignored));
Assert.Equal(true, env["server.IsLocal"]); Assert.Equal(true, env["server.IsLocal"]);
// Don't check for actual IP address as it might be IPv6 local address
Assert.True(env.TryGetValue("server.RemoteIpAddress", out ignored)); Assert.True(env.TryGetValue("server.RemoteIpAddress", out ignored));
Assert.Equal("127.0.0.1", env["server.RemoteIpAddress"]);
Assert.True(env.TryGetValue("server.LocalIpAddress", out ignored)); Assert.True(env.TryGetValue("server.LocalIpAddress", out ignored));
Assert.Equal("127.0.0.1", env["server.LocalIpAddress"]);
Assert.True(env.TryGetValue("server.RemotePort", out ignored)); Assert.True(env.TryGetValue("server.RemotePort", out ignored));
Assert.True(env.TryGetValue("server.LocalPort", out ignored)); Assert.True(env.TryGetValue("server.LocalPort", out ignored));
Expand Down

0 comments on commit a38e4cc

Please sign in to comment.