Skip to content

Commit

Permalink
🐛 Fix JwtTokenManager.verify() after java-jwt upgrade
Browse files Browse the repository at this point in the history
  • Loading branch information
ujibang committed Apr 28, 2023
1 parent 6b16caf commit 58657ab
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
import com.auth0.jwt.JWTVerifier;
import com.auth0.jwt.JWTCreator.Builder;
import com.auth0.jwt.algorithms.Algorithm;
import com.auth0.jwt.impl.NullClaim;
import com.auth0.jwt.interfaces.DecodedJWT;
import com.google.common.collect.Sets;
import io.undertow.security.idm.Account;
Expand All @@ -53,24 +52,25 @@
import org.restheart.plugins.RegisterPlugin;
import org.restheart.plugins.security.TokenManager;
import org.restheart.security.BaseAccount;
import org.restheart.security.FileRealmAccount;
import org.restheart.security.JwtAccount;
import org.restheart.security.PwdCredentialAccount;
import org.restheart.security.WithProperties;
import org.restheart.utils.Pair;
import org.restheart.utils.URLUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.nio.charset.Charset;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Arrays;
import java.util.Base64;
import java.util.Date;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;

import static org.restheart.plugins.ConfigurablePlugin.argValue;

/**
*
* @author Andrea Di Cesare <andrea@softinstigate.com>
Expand Down Expand Up @@ -104,8 +104,8 @@ public class JwtTokenManager implements TokenManager {
public void init() throws ConfigurationException {
this.enabled = true;

this.srvURI = argValue(config, "srv-uri");
this.ttl = argValue(config, "ttl");
this.srvURI = arg(config, "srv-uri");
this.ttl = arg(config, "ttl");

if (ttl < 1) {
this.enabled = false;
Expand All @@ -118,7 +118,7 @@ public void init() throws ConfigurationException {
LOGGER.warn("You should really update the JWT key!");
}

this.algo = Algorithm.HMAC256((String) argValue(config, "key"));
this.algo = Algorithm.HMAC256((String) arg(config, "key"));
this.issuer = arg(config, "issuer");

jwtCache = CacheFactory.createLocalLoadingCache(MAX_CACHE_SIZE,
Expand Down Expand Up @@ -156,31 +156,32 @@ public Account verify(final String id, final Credential credential) {

var ca = new ComparableAccount(new BaseAccount(id, null));

var _cached = this.jwtCache.get(ca);

// first check if the very same token is in the cache
if (this.jwtCache.get(ca) != null && this.jwtCache.get(ca).isPresent() && Arrays.equals(rawToken, this.jwtCache.get(ca).get().raw())) {
if (_cached != null && _cached.isPresent() && Arrays.equals(rawToken, _cached.get().raw())) {
LOGGER.debug("jwt token in cache");
var cached = this.jwtCache.get(ca).get();
var roles = Sets.newHashSet(this.jwtCache.get(ca).get().roles());
if (cached.properties() == null) {
return new PwdCredentialAccount(id, rawToken, roles);
} else {
return new FileRealmAccount(id, rawToken, roles, cached.properties());
}
var cached = _cached.get();
var roles = Sets.newHashSet(cached.roles());

var jwtParts = new String(cached.raw()).split("\\.");

var jwtPayload = new String(Base64.getUrlDecoder().decode(jwtParts[1]), Charset.forName("UTF-8"));

return new JwtAccount(id, roles, jwtPayload);
} else {
LOGGER.trace("jwt token not in cache, let's verify it");
// if the token is not in the cache, verify it
try {
var decoded = this.verifier.verify(new String(rawToken));

if (id.equals(decoded.getSubject())) {
var roles = decoded.getClaim("roles").asArray(String.class);

var token = Token.fromJWT(decoded);
if (token.properties() == null) {
return new PwdCredentialAccount(id, rawToken, Sets.newHashSet(token.roles()));
} else {
return new FileRealmAccount(id, rawToken, Sets.newHashSet(roles), token.properties());
}
var _roles = decoded.getClaim("roles").asArray(String.class);
var roles = Sets.newHashSet(_roles);

var jwtPayload = new String(Base64.getUrlDecoder().decode(decoded.getPayload()), Charset.forName("UTF-8"));
this.jwtCache.put(ca, newToken(ca.wrapped, decoded.getExpiresAt()));
return new JwtAccount(id, roles, jwtPayload);
} else {
LOGGER.warn("invalid token from user {}, not matching id in token, was {}", id, decoded.getSubject());
return null;
Expand Down Expand Up @@ -212,17 +213,16 @@ public PasswordCredential get(Account account) {

var token = this.jwtCache.getLoading(new ComparableAccount(account)).get();

PwdCredentialAccount newTokenAccount = new PwdCredentialAccount(
account.getPrincipal().getName(),
token.raw(),
Sets.newTreeSet(account.getRoles()));
var newTokenAccount = new PwdCredentialAccount(account.getPrincipal().getName(), token.raw(), Sets.newTreeSet(account.getRoles()));

return newTokenAccount.getCredentials();
}

private Token newToken(Account account) {
var expires = Date.from(Instant.now().plus(ttl, ChronoUnit.MINUTES));
return newToken(account, Date.from(Instant.now().plus(ttl, ChronoUnit.MINUTES)));
}

private Token newToken(Account account, Date expires) {
var creator = audience != null
? JWT.create().withIssuer(issuer).withAudience(audience)
: JWT.create().withIssuer(issuer);
Expand Down Expand Up @@ -380,13 +380,16 @@ public static Token fromJWT(DecodedJWT jwt) {
var raw = jwt.getToken().toCharArray();
var expires = jwt.getExpiresAt();
var roles = jwt.getClaim("roles").asArray(String.class);
var _properties = jwt.getClaim("properties");

if (_properties instanceof NullClaim) {
return new Token(raw, expires, roles, null);
} else {
return new Token(raw, expires, roles, _properties.asMap());
}
var accountProperties = new HashMap<String, Object>();

jwt.getClaims().entrySet().stream()
.filter(e -> !e.getKey().equals("sub"))
.filter(e -> !e.getKey().equals("iss"))
.filter(e -> !e.getKey().equals("roles"))
.forEach(e -> accountProperties.put(e.getKey(), e.getValue()));

return new Token(raw, expires, roles, accountProperties);
}

public String getDateAsString() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@
@RegisterPlugin(
name = "extraJwtVerifier",
priority = 100,
description = "Adds an extra verifictation step "
+ "to the jwtAuthenticationMechanism")
description = "Adds an extra verifictation step to the jwtAuthenticationMechanism")
public class ExtraJwtVerifier implements Initializer {
private static final Logger LOGGER = LoggerFactory.getLogger(ExtraJwtVerifier.class);

Expand Down Expand Up @@ -80,14 +79,16 @@ public void init() {

var extra = extraClaim.asMap();

if (extra == null) {
throw new JWTVerificationException("extra claim is empty");
}

if (!extra.containsKey("a")) {
throw new JWTVerificationException("extra claim does not have "
+ "'a' property");
throw new JWTVerificationException("extra claim does not have 'a' property");
}

if (!extra.containsKey("b")) {
throw new JWTVerificationException("extra claim does not have "
+ "'b' property");
throw new JWTVerificationException("extra claim does not have 'b' property");
}
});
}
Expand Down

0 comments on commit 58657ab

Please sign in to comment.