Permalink
Browse files

SSL: Authenticate continuation should complete only if the actual aut…

…henticate task completes. This prevents double disconnect coming when AuthenticateAsServerAsync fails.

Fix local IP address check to take into account IPv6 loopback addresses.
  • Loading branch information...
1 parent a196289 commit a38e4cc921fd92df612e0dea0738c92e0f2ea1c3 @marius-klimantavicius marius-klimantavicius committed Nov 24, 2016
Showing with 38 additions and 16 deletions.
  1. +14 −1 Nowin/IpIsLocalChecker.cs
  2. +23 −12 Nowin/SslTransportHandler.cs
  3. +1 −3 NowinTests/NowinTestsBase.cs
@@ -32,7 +32,20 @@ public IpIsLocalChecker()
public bool IsLocal(IPAddress address)
{
- return _dict.ContainsKey(address);
+ if (_dict.ContainsKey(address))
+ return true;
+
+ if (IPAddress.IsLoopback(address))
+ return true;
+
+ if (address.IsIPv4MappedToIPv6)
+ {
+ var ip4 = address.MapToIPv4();
+ if (_dict.ContainsKey(ip4))
+ return true;
+ }
+
+ return false;
}
}
}
@@ -194,11 +194,8 @@ public void FinishAccept(byte[] buffer, int offset, int length, IPEndPoint remot
_authenticateTask = _ssl.AuthenticateAsServerAsync(_serverParameters.Certificate, _serverParameters.ClientCertificateRequired, _serverParameters.Protocols, false).ContinueWith((t, selfObject) =>
{
var self = (SslTransportHandler)selfObject;
- if (t.IsFaulted || t.IsCanceled)
- self.Callback.StartDisconnect();
- else
- _next.SetRemoteCertificate(_ssl.RemoteCertificate);
- }, this);
+ self._next.SetRemoteCertificate(_ssl.RemoteCertificate);
+ }, this, TaskContinuationOptions.OnlyOnRanToCompletion);
_next.FinishAccept(_recvBuffer, _recvOffset, 0, remoteEndPoint, localEndPoint);
}
catch (Exception)
@@ -248,18 +245,32 @@ public void StartReceive(byte[] buffer, int offset, int length)
}
else
{
- _ssl.ReadAsync(self._recvBuffer, self._recvOffset, self._recvLength).ContinueWith((t2, selfObject2) =>
+ try
{
- var self2 = (SslTransportHandler)selfObject2;
- if (t2.IsFaulted || t2.IsCanceled || t2.Result == 0)
- self._next.FinishReceive(null, 0, -1);
- else
- self._next.FinishReceive(self2._recvBuffer, self2._recvOffset, t2.Result);
- }, self);
+ self._ssl.ReadAsync(self._recvBuffer, self._recvOffset, self._recvLength).ContinueWith((t2, selfObject2) =>
+ {
+ var self2 = (SslTransportHandler)selfObject2;
+ if (t2.IsFaulted || t2.IsCanceled || t2.Result == 0)
+ self2._next.FinishReceive(null, 0, -1);
+ else
+ self2._next.FinishReceive(self2._recvBuffer, self2._recvOffset, t2.Result);
+ }, self);
+ }
+ catch (Exception)
+ {
+ self._next.FinishReceive(null, 0, -1);
+ }
}
}, this);
return;
}
+
+ if (_authenticateTask.IsCanceled || _authenticateTask.IsFaulted)
+ {
+ _next.FinishReceive(null, 0, -1);
+ return;
+ }
+
_ssl.ReadAsync(buffer, offset, length).ContinueWith((t, selfObject) =>
{
var self = (SslTransportHandler)selfObject;
@@ -419,11 +419,9 @@ public void EnvironmentEmptyGetRequest()
Assert.True(env.TryGetValue("server.IsLocal", out ignored));
Assert.Equal(true, env["server.IsLocal"]);
+ // Don't check for actual IP address as it might be IPv6 local address
Assert.True(env.TryGetValue("server.RemoteIpAddress", out ignored));
- Assert.Equal("127.0.0.1", env["server.RemoteIpAddress"]);
-
Assert.True(env.TryGetValue("server.LocalIpAddress", out ignored));
- Assert.Equal("127.0.0.1", env["server.LocalIpAddress"]);
Assert.True(env.TryGetValue("server.RemotePort", out ignored));
Assert.True(env.TryGetValue("server.LocalPort", out ignored));

0 comments on commit a38e4cc

Please sign in to comment.