diff --git a/Sustainsys.Saml2.Owin/Saml2AuthenticationHandler.cs b/Sustainsys.Saml2.Owin/Saml2AuthenticationHandler.cs index e68879b79..a65a61282 100644 --- a/Sustainsys.Saml2.Owin/Saml2AuthenticationHandler.cs +++ b/Sustainsys.Saml2.Owin/Saml2AuthenticationHandler.cs @@ -60,12 +60,10 @@ protected async override Task AuthenticateCoreAsync() [System.Diagnostics.CodeAnalysis.SuppressMessage("Microsoft.Naming", "CA2204:Literals should be spelled correctly", MessageId = "ReturnUrl")] private AuthenticationTicket CreateErrorAuthenticationTicket(HttpRequestData httpRequestData, Exception ex) { - AuthenticationProperties authProperties = null; - if (httpRequestData.StoredRequestState != null) - { - authProperties = new AuthenticationProperties( - httpRequestData.StoredRequestState.RelayData); + var authProperties = new AuthenticationProperties(); + if (httpRequestData.StoredRequestState?.ReturnUrl != null) + { // ReturnUrl is removed from AuthProps dictionary to save space, need to put it back. authProperties.RedirectUri = httpRequestData.StoredRequestState.ReturnUrl.OriginalString; } @@ -83,10 +81,7 @@ private AuthenticationTicket CreateErrorAuthenticationTicket(HttpRequestData htt redirectUrl = httpRequestData.ApplicationUrl; } - authProperties = new AuthenticationProperties - { - RedirectUri = redirectUrl.OriginalString - }; + authProperties.RedirectUri = redirectUrl.OriginalString; } // The Google middleware adds this, so let's follow that example. diff --git a/Tests/Owin.Tests/Saml2AuthenticationMiddlewareTests.cs b/Tests/Owin.Tests/Saml2AuthenticationMiddlewareTests.cs index 9c8952cce..82ac9a6e7 100644 --- a/Tests/Owin.Tests/Saml2AuthenticationMiddlewareTests.cs +++ b/Tests/Owin.Tests/Saml2AuthenticationMiddlewareTests.cs @@ -978,6 +978,78 @@ public async Task Saml2AuthenticationMiddleware_AcsRedirectsToAuthPropsReturnUri context.Authentication.AuthenticationResponseGrant.Should().BeNull(); } + [TestMethod] + public async Task Saml2AuthenticationMiddleware_AcsRedirectsToAuthProps_StoredRequestStateWithNoReturnUrl() + { + var context = OwinTestHelpers.CreateOwinContext(); + context.Request.Method = "POST"; + + var authProps = new AuthenticationProperties(); + + var state = new StoredRequestState(new EntityId("https://idp.example.com"), + null, + new Saml2Id("InResponseToId"), + authProps.Dictionary); + + var relayState = SecureKeyGenerator.CreateRelayState(); + + var cookieData = HttpRequestData.ConvertBinaryData( + CreateAppBuilder().CreateDataProtector( + typeof(Saml2AuthenticationMiddleware).FullName) + .Protect(state.Serialize())); + + context.Request.Headers["Cookie"] = $"{StoredRequestState.CookieNameBase}{relayState}={cookieData}"; + + var response = + @" + + https://idp.example.com + + + + + + https://idp.example.com + + SomeUser + + + + + "; + + // No signature, that's an error. + var bodyData = new KeyValuePair[] { + new KeyValuePair("SAMLResponse", + Convert.ToBase64String(Encoding.UTF8.GetBytes(response))), + new KeyValuePair("RelayState",relayState) + }; + + var encodedBodyData = new FormUrlEncodedContent(bodyData); + + context.Request.Body = encodedBodyData.ReadAsStreamAsync().Result; + context.Request.ContentType = encodedBodyData.Headers.ContentType.ToString(); + context.Request.Host = new HostString("localhost"); + context.Request.Path = new PathString("/Saml2/Acs"); + + var middleware = new Saml2AuthenticationMiddleware(null, CreateAppBuilder(), + new Saml2AuthenticationOptions(true) + { + SignInAsAuthenticationType = "AuthType" + }); + + await middleware.Invoke(context); + + context.Response.StatusCode.Should().Be(302); + context.Response.Headers["Location"].Should().Be("http://localhost/LoggedIn?error=access_denied"); + context.Authentication.AuthenticationResponseGrant.Should().BeNull(); + } + [TestMethod] public async Task Saml2AuthenticationMiddleware_AcsWorks() {