From 7129c62f3706fa26f46d28f2425b9b1c7b09378f Mon Sep 17 00:00:00 2001 From: Isman Firmansyah Date: Wed, 14 Feb 2024 23:27:48 +0700 Subject: [PATCH] fix(docker-jans-certmanager): resolve key_ops_type created on key rotation (#7727) --- .../scripts/auth_handler.py | 78 +++++++++++++++---- 1 file changed, 65 insertions(+), 13 deletions(-) diff --git a/docker-jans-certmanager/scripts/auth_handler.py b/docker-jans-certmanager/scripts/auth_handler.py index df0c348f346..2e4c1e9bae3 100644 --- a/docker-jans-certmanager/scripts/auth_handler.py +++ b/docker-jans-certmanager/scripts/auth_handler.py @@ -210,8 +210,7 @@ def __init__(self, manager, dry_run, **opts): @property def allowed_key_algs(self): - algs = self.sig_keys.split() + self.enc_keys.split() - return algs + return self.sig_keys.split() + self.enc_keys.split() def get_merged_keys(self, exp_hours): # get previous JWKS @@ -231,6 +230,7 @@ def get_merged_keys(self, exp_hours): # a counter to determine value of `key_ops_type` to be passed to `KeyGenerator` ops_type_cnt = Counter(key_ops_from_jwk(jwk) for jwk in old_jwks) + logger.info(f"Detected key_ops_type={dict(ops_type_cnt)}") # if we have ssa key, use `connect` keys if ops_type_cnt["ssa"]: @@ -248,7 +248,7 @@ def get_merged_keys(self, exp_hours): if retcode != 0: logger.error(f"Unable to generate keys; reason={err.decode()}") - return + return "", "" new_jwks = deque(json.loads(out).get("keys", [])) @@ -259,26 +259,45 @@ def get_merged_keys(self, exp_hours): # counter for `connect` keys cnt = Counter( - jwk["alg"] for jwk in new_jwks + jwk["alg"].upper() for jwk in new_jwks if key_ops_from_jwk(jwk) == "connect" ) + # counter for `ssa` keys + cnt_ssa = Counter( + jwk["alg"].upper() for jwk in new_jwks + if key_ops_from_jwk(jwk) == "ssa" + ) + for jwk in old_jwks: + alg = jwk.get("alg", "").upper() + # exclude alg if it's not allowed - if jwk["alg"] not in self.allowed_key_algs: + if alg not in self.allowed_key_algs: continue ops_type = key_ops_from_jwk(jwk) + # exclude unsupported key_ops_type + if ops_type not in ["ssa", "connect"]: + continue + # cannot have more than 2 connect keys for same algorithm in new JWKS - if ops_type == "connect" and cnt[jwk["alg"]] > 1: + if ops_type == "connect" and cnt[alg] >= 2: + continue + + # cannot have more than 1 ssa keys for same algorithm in new JWKS + if ops_type == "ssa" and cnt_ssa[alg] >= 1: continue # insert old key to new keys new_jwks.appendleft(jwk) if ops_type == "connect": - cnt[jwk["alg"]] += 1 + cnt[alg] += 1 + + if ops_type == "ssa": + cnt_ssa[alg] += 1 # import key to new JKS keytool_import_key(old_jks_fn, jks_fn, jwk["kid"], jks_pass) @@ -353,6 +372,10 @@ def patch(self): web_keys = json.loads(config["jansConfWebKeys"]) except TypeError: web_keys = config["jansConfWebKeys"] + except json.decoder.JSONDecodeError as exc: + # probably corrupted JSON + logger.warning(f"Unable to load existing JWKS; reason={exc}. New JWKS will be created.") + web_keys = {"keys": []} with open("/etc/certs/auth-keys.old.json", "w") as f: f.write(json.dumps(web_keys, indent=2)) @@ -504,6 +527,10 @@ def prune(self): web_keys = json.loads(config["jansConfWebKeys"]) except TypeError: web_keys = config["jansConfWebKeys"] + except json.decoder.JSONDecodeError as exc: + # probably corrupted JSON + logger.warning(f"Unable to load existing JWKS; reason={exc}. Existing JWKS will be deleted.") + web_keys = {"keys": []} logger.info("Cleaning up keys (if any)") @@ -519,20 +546,32 @@ def prune(self): old_jwks = sorted(old_jwks, key=lambda k: k["exp"], reverse=True) cnt = Counter( - jwk["alg"] for jwk in new_jwks + jwk["alg"].upper() for jwk in new_jwks if key_ops_from_jwk(jwk) == "connect" ) + cnt_ssa = Counter( + jwk["alg"].upper() for jwk in new_jwks + if key_ops_from_jwk(jwk) == "ssa" + ) + for jwk in old_jwks: + alg = jwk.get("alg", "").upper() + # exclude alg if it's not allowed - if jwk["alg"] not in self.allowed_key_algs: + if alg not in self.allowed_key_algs: keytool_delete_key(jks_fn, jwk["kid"], jks_pass) continue ops_type = key_ops_from_jwk(jwk) - # cannot have more than 1 key for same algorithm in new JWKS - if ops_type == "connect" and cnt[jwk["alg"]]: + # cannot have more than 1 connect key for same algorithm in new JWKS + if ops_type == "connect" and cnt[alg] >= 1: + keytool_delete_key(jks_fn, jwk["kid"], jks_pass) + continue + + # cannot have more than 1 ssa key for same algorithm in new JWKS + if ops_type == "ssa" and cnt_ssa[alg] >= 1: keytool_delete_key(jks_fn, jwk["kid"], jks_pass) continue @@ -540,7 +579,10 @@ def prune(self): new_jwks.append(jwk) if ops_type == "connect": - cnt[jwk["alg"]] += 1 + cnt[alg] += 1 + + if ops_type == "ssa": + cnt_ssa[alg] += 1 web_keys["keys"] = new_jwks @@ -665,7 +707,17 @@ def resolve_enc_keys(keys: str) -> str: def key_ops_from_jwk(jwk): """Resolve key_ops_type first value.""" - return (jwk.get("key_ops_type") or ["connect"])[0] + types = jwk.get("key_ops_type") or [] + try: + type_ = types[0] + except IndexError: + if jwk["kid"].startswith("connect_"): + type_ = "connect" + elif jwk["kid"].startswith("ssa_"): + type_ = "ssa" + else: + type_ = "" + return type_.lower() def has_ext_jwks_uri(conf_dynamic, manager) -> bool: