diff --git a/src/Microsoft.AspNetCore.Authentication.Cookies/CookieAuthenticationHandler.cs b/src/Microsoft.AspNetCore.Authentication.Cookies/CookieAuthenticationHandler.cs index 343cf1b3a..b77a51ef4 100644 --- a/src/Microsoft.AspNetCore.Authentication.Cookies/CookieAuthenticationHandler.cs +++ b/src/Microsoft.AspNetCore.Authentication.Cookies/CookieAuthenticationHandler.cs @@ -85,7 +85,7 @@ private void CheckForRefresh(AuthenticationTicket ticket) } } - private void RequestRefresh(AuthenticationTicket ticket) + private void RequestRefresh(AuthenticationTicket ticket, ClaimsPrincipal replacedPrincipal = null) { var issuedUtc = ticket.Properties.IssuedUtc; var expiresUtc = ticket.Properties.ExpiresUtc; @@ -97,14 +97,15 @@ private void RequestRefresh(AuthenticationTicket ticket) _refreshIssuedUtc = currentUtc; var timeSpan = expiresUtc.Value.Subtract(issuedUtc.Value); _refreshExpiresUtc = currentUtc.Add(timeSpan); - _refreshTicket = CloneTicket(ticket); + _refreshTicket = CloneTicket(ticket, replacedPrincipal); } } - private AuthenticationTicket CloneTicket(AuthenticationTicket ticket) + private AuthenticationTicket CloneTicket(AuthenticationTicket ticket, ClaimsPrincipal replacedPrincipal) { + var principal = replacedPrincipal ?? ticket.Principal; var newPrincipal = new ClaimsPrincipal(); - foreach (var identity in ticket.Principal.Identities) + foreach (var identity in principal.Identities) { newPrincipal.AddIdentity(identity.Clone()); } @@ -183,7 +184,7 @@ protected override async Task HandleAuthenticateAsync() if (context.ShouldRenew) { - RequestRefresh(result.Ticket); + RequestRefresh(result.Ticket, context.Principal); } return AuthenticateResult.Success(new AuthenticationTicket(context.Principal, context.Properties, Scheme.Name)); diff --git a/test/Microsoft.AspNetCore.Authentication.Test/CookieTests.cs b/test/Microsoft.AspNetCore.Authentication.Test/CookieTests.cs index 945ec82ee..766d1e2e5 100644 --- a/test/Microsoft.AspNetCore.Authentication.Test/CookieTests.cs +++ b/test/Microsoft.AspNetCore.Authentication.Test/CookieTests.cs @@ -900,6 +900,80 @@ public async Task CookieCanBeRenewedByValidator() Assert.Null(FindClaimValue(transaction5, ClaimTypes.Name)); } + [Fact] + public async Task CookieCanBeReplacedByValidator() + { + var server = CreateServer(o => + { + o.Events = new CookieAuthenticationEvents + { + OnValidatePrincipal = ctx => + { + ctx.ShouldRenew = true; + ctx.ReplacePrincipal(new ClaimsPrincipal(new ClaimsIdentity(new GenericIdentity("Alice2", "Cookies2")))); + return Task.FromResult(0); + } + }; + }, + context => + context.SignInAsync("Cookies", + new ClaimsPrincipal(new ClaimsIdentity(new GenericIdentity("Alice", "Cookies"))))); + + var transaction1 = await SendAsync(server, "http://example.com/testpath"); + + var transaction2 = await SendAsync(server, "http://example.com/me/Cookies", transaction1.CookieNameValue); + Assert.NotNull(transaction2.SetCookie); + Assert.Equal("Alice2", FindClaimValue(transaction2, ClaimTypes.Name)); + } + + [Fact] + public async Task CookieCanBeUpdatedByValidatorDuringRefresh() + { + var replace = false; + var server = CreateServer(o => + { + o.ExpireTimeSpan = TimeSpan.FromMinutes(10); + o.Events = new CookieAuthenticationEvents + { + OnValidatePrincipal = ctx => + { + if (replace) + { + ctx.ShouldRenew = true; + ctx.ReplacePrincipal(new ClaimsPrincipal(new ClaimsIdentity(new GenericIdentity("Alice2", "Cookies2")))); + ctx.Properties.Items["updated"] = "yes"; + } + return Task.FromResult(0); + } + }; + }, + context => + context.SignInAsync("Cookies", + new ClaimsPrincipal(new ClaimsIdentity(new GenericIdentity("Alice", "Cookies"))))); + + var transaction1 = await SendAsync(server, "http://example.com/testpath"); + + var transaction2 = await SendAsync(server, "http://example.com/me/Cookies", transaction1.CookieNameValue); + Assert.Equal("Alice", FindClaimValue(transaction2, ClaimTypes.Name)); + + var transaction3 = await SendAsync(server, "http://example.com/me/Cookies", transaction1.CookieNameValue); + Assert.Equal("Alice", FindClaimValue(transaction2, ClaimTypes.Name)); + Assert.Null(FindPropertiesValue(transaction3, "updated")); + + replace = true; + + var transaction4 = await SendAsync(server, "http://example.com/me/Cookies", transaction1.CookieNameValue); + Assert.NotNull(transaction4.SetCookie); + Assert.Equal("Alice2", FindClaimValue(transaction4, ClaimTypes.Name)); + Assert.Equal("yes", FindPropertiesValue(transaction4, "updated")); + + replace = false; + + var transaction5 = await SendAsync(server, "http://example.com/me/Cookies", transaction4.CookieNameValue); + Assert.Equal("Alice2", FindClaimValue(transaction5, ClaimTypes.Name)); + Assert.Equal("yes", FindPropertiesValue(transaction4, "updated")); + } + [Fact] public async Task CookieCanBeRenewedByValidatorWithSlidingExpiry() { @@ -1730,6 +1804,16 @@ private static string FindClaimValue(Transaction transaction, string claimType) return claim.Attribute("value").Value; } + private static string FindPropertiesValue(Transaction transaction, string key) + { + var property = transaction.ResponseElement.Elements("extra").SingleOrDefault(elt => elt.Attribute("type").Value == key); + if (property == null) + { + return null; + } + return property.Attribute("value").Value; + } + private static async Task GetAuthData(TestServer server, string url, string cookie) { var request = new HttpRequestMessage(HttpMethod.Get, url);