Skip to content

Commit

Permalink
Support concurrent refresh of refresh tokens (#39631)
Browse files Browse the repository at this point in the history
Co-authored-by: Jay Modi jaymode@users.noreply.github.com

This change adds support for the concurrent refresh of access
tokens as described in #36872
In short it allows subsequent client requests to refresh the same token that
come within a predefined window of 60 seconds to be handled as duplicates
of the original one and thus receive the same response with the same newly
issued access token and refresh token.
In order to support that, two new fields are added in the token document. One
contains the instant (in epoqueMillis) when a given refresh token is refreshed
and one that contains a pointer to the token document that stores the new
refresh token and access token that was created by the original refresh.
A side effect of this change, that was however also a intended enhancement
for the token service, is that we needed to stop encrypting the string
representation of the UserToken while serializing. ( It was necessary as we
correctly used a new IV for every time we encrypted a token in serialization, so
subsequent serializations of the same exact UserToken would produce
different access token strings)

This change also handles the serialization/deserialization BWC logic:

    In mixed clusters we keep creating tokens in the old format and
    consume only old format tokens
    In upgraded clusters, we start creating tokens in the new format but
    still remain able to consume old format tokens (that could have been
    created during the rolling upgrade and are still valid)
    When reading/writing TokensInvalidationResult objects, we take into
    consideration that pre 7.1.0 these contained an integer field that carried
    the attempt count

Resolves #36872
  • Loading branch information
jkakavas committed Mar 4, 2019
1 parent daa86f5 commit 578c019
Show file tree
Hide file tree
Showing 12 changed files with 704 additions and 299 deletions.
Expand Up @@ -7,6 +7,7 @@
package org.elasticsearch.xpack.core.security.authc.support;

import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.Version;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
Expand All @@ -32,10 +33,9 @@ public class TokensInvalidationResult implements ToXContentObject, Writeable {
private final List<String> invalidatedTokens;
private final List<String> previouslyInvalidatedTokens;
private final List<ElasticsearchException> errors;
private final int attemptCount;

public TokensInvalidationResult(List<String> invalidatedTokens, List<String> previouslyInvalidatedTokens,
@Nullable List<ElasticsearchException> errors, int attemptCount) {
@Nullable List<ElasticsearchException> errors) {
Objects.requireNonNull(invalidatedTokens, "invalidated_tokens must be provided");
this.invalidatedTokens = invalidatedTokens;
Objects.requireNonNull(previouslyInvalidatedTokens, "previously_invalidated_tokens must be provided");
Expand All @@ -45,18 +45,19 @@ public TokensInvalidationResult(List<String> invalidatedTokens, List<String> pre
} else {
this.errors = Collections.emptyList();
}
this.attemptCount = attemptCount;
}

public TokensInvalidationResult(StreamInput in) throws IOException {
this.invalidatedTokens = in.readStringList();
this.previouslyInvalidatedTokens = in.readStringList();
this.errors = in.readList(StreamInput::readException);
this.attemptCount = in.readVInt();
if (in.getVersion().before(Version.V_8_0_0)) {
in.readVInt();
}
}

public static TokensInvalidationResult emptyResult() {
return new TokensInvalidationResult(Collections.emptyList(), Collections.emptyList(), Collections.emptyList(), 0);
return new TokensInvalidationResult(Collections.emptyList(), Collections.emptyList(), Collections.emptyList());
}


Expand All @@ -72,10 +73,6 @@ public List<ElasticsearchException> getErrors() {
return errors;
}

public int getAttemptCount() {
return attemptCount;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject()
Expand All @@ -100,6 +97,8 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeStringCollection(invalidatedTokens);
out.writeStringCollection(previouslyInvalidatedTokens);
out.writeCollection(errors, StreamOutput::writeException);
out.writeVInt(attemptCount);
if (out.getVersion().before(Version.V_8_0_0)) {
out.writeVInt(5);
}
}
}
Expand Up @@ -199,6 +199,13 @@
"refreshed" : {
"type" : "boolean"
},
"refresh_time": {
"type": "date",
"format": "epoch_millis"
},
"superseded_by": {
"type": "keyword"
},
"invalidated" : {
"type" : "boolean"
},
Expand Down
Expand Up @@ -29,8 +29,7 @@ public void testSerialization() throws IOException {
TokensInvalidationResult result = new TokensInvalidationResult(Arrays.asList(generateRandomStringArray(20, 15, false)),
Arrays.asList(generateRandomStringArray(20, 15, false)),
Arrays.asList(new ElasticsearchException("foo", new IllegalArgumentException("this is an error message")),
new ElasticsearchException("bar", new IllegalArgumentException("this is an error message2"))),
randomIntBetween(0, 5));
new ElasticsearchException("bar", new IllegalArgumentException("this is an error message2"))));
InvalidateTokenResponse response = new InvalidateTokenResponse(result);
try (BytesStreamOutput output = new BytesStreamOutput()) {
response.writeTo(output);
Expand All @@ -47,8 +46,7 @@ public void testSerialization() throws IOException {
}

result = new TokensInvalidationResult(Arrays.asList(generateRandomStringArray(20, 15, false)),
Arrays.asList(generateRandomStringArray(20, 15, false)),
Collections.emptyList(), randomIntBetween(0, 5));
Arrays.asList(generateRandomStringArray(20, 15, false)), Collections.emptyList());
response = new InvalidateTokenResponse(result);
try (BytesStreamOutput output = new BytesStreamOutput()) {
response.writeTo(output);
Expand All @@ -68,8 +66,7 @@ public void testToXContent() throws IOException {
List previouslyInvalidatedTokens = Arrays.asList(generateRandomStringArray(20, 15, false));
TokensInvalidationResult result = new TokensInvalidationResult(invalidatedTokens, previouslyInvalidatedTokens,
Arrays.asList(new ElasticsearchException("foo", new IllegalArgumentException("this is an error message")),
new ElasticsearchException("bar", new IllegalArgumentException("this is an error message2"))),
randomIntBetween(0, 5));
new ElasticsearchException("bar", new IllegalArgumentException("this is an error message2"))));
InvalidateTokenResponse response = new InvalidateTokenResponse(result);
XContentBuilder builder = XContentFactory.jsonBuilder();
response.toXContent(builder, ToXContent.EMPTY_PARAMS);
Expand Down
Expand Up @@ -63,7 +63,7 @@ protected void doExecute(Task task, SamlAuthenticateRequest request, ActionListe
final Map<String, Object> tokenMeta = (Map<String, Object>) result.getMetadata().get(SamlRealm.CONTEXT_TOKEN_DATA);
tokenService.createUserToken(authentication, originatingAuthentication,
ActionListener.wrap(tuple -> {
final String tokenString = tokenService.getUserTokenString(tuple.v1());
final String tokenString = tokenService.getAccessTokenAsString(tuple.v1());
final TimeValue expiresIn = tokenService.getExpirationDelay();
listener.onResponse(
new SamlAuthenticateResponse(authentication.getUser().principal(), tokenString, tuple.v2(), expiresIn));
Expand Down
Expand Up @@ -89,7 +89,7 @@ private void createToken(CreateTokenRequest request, Authentication authenticati
boolean includeRefreshToken, ActionListener<CreateTokenResponse> listener) {
try {
tokenService.createUserToken(authentication, originatingAuth, ActionListener.wrap(tuple -> {
final String tokenStr = tokenService.getUserTokenString(tuple.v1());
final String tokenStr = tokenService.getAccessTokenAsString(tuple.v1());
final String scope = getResponseScopeValue(request.getScope());

final CreateTokenResponse response =
Expand Down
Expand Up @@ -31,7 +31,7 @@ public TransportRefreshTokenAction(TransportService transportService, ActionFilt
@Override
protected void doExecute(Task task, CreateTokenRequest request, ActionListener<CreateTokenResponse> listener) {
tokenService.refreshToken(request.getRefreshToken(), ActionListener.wrap(tuple -> {
final String tokenStr = tokenService.getUserTokenString(tuple.v1());
final String tokenStr = tokenService.getAccessTokenAsString(tuple.v1());
final String scope = getResponseScopeValue(request.getScope());

final CreateTokenResponse response =
Expand Down

0 comments on commit 578c019

Please sign in to comment.