Skip to content

Commit

Permalink
fix(docker-jans-certmanager): resolve key_ops_type created on key rot…
Browse files Browse the repository at this point in the history
…ation (#7727)
  • Loading branch information
iromli committed Feb 14, 2024
1 parent 8f5d6a9 commit 7129c62
Showing 1 changed file with 65 additions and 13 deletions.
78 changes: 65 additions & 13 deletions docker-jans-certmanager/scripts/auth_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,7 @@ def __init__(self, manager, dry_run, **opts):

@property
def allowed_key_algs(self):
algs = self.sig_keys.split() + self.enc_keys.split()
return algs
return self.sig_keys.split() + self.enc_keys.split()

def get_merged_keys(self, exp_hours):
# get previous JWKS
Expand All @@ -231,6 +230,7 @@ def get_merged_keys(self, exp_hours):

# a counter to determine value of `key_ops_type` to be passed to `KeyGenerator`
ops_type_cnt = Counter(key_ops_from_jwk(jwk) for jwk in old_jwks)
logger.info(f"Detected key_ops_type={dict(ops_type_cnt)}")

# if we have ssa key, use `connect` keys
if ops_type_cnt["ssa"]:
Expand All @@ -248,7 +248,7 @@ def get_merged_keys(self, exp_hours):

if retcode != 0:
logger.error(f"Unable to generate keys; reason={err.decode()}")
return
return "", ""

new_jwks = deque(json.loads(out).get("keys", []))

Expand All @@ -259,26 +259,45 @@ def get_merged_keys(self, exp_hours):

# counter for `connect` keys
cnt = Counter(
jwk["alg"] for jwk in new_jwks
jwk["alg"].upper() for jwk in new_jwks
if key_ops_from_jwk(jwk) == "connect"
)

# counter for `ssa` keys
cnt_ssa = Counter(
jwk["alg"].upper() for jwk in new_jwks
if key_ops_from_jwk(jwk) == "ssa"
)

for jwk in old_jwks:
alg = jwk.get("alg", "").upper()

# exclude alg if it's not allowed
if jwk["alg"] not in self.allowed_key_algs:
if alg not in self.allowed_key_algs:
continue

ops_type = key_ops_from_jwk(jwk)

# exclude unsupported key_ops_type
if ops_type not in ["ssa", "connect"]:
continue

# cannot have more than 2 connect keys for same algorithm in new JWKS
if ops_type == "connect" and cnt[jwk["alg"]] > 1:
if ops_type == "connect" and cnt[alg] >= 2:
continue

# cannot have more than 1 ssa keys for same algorithm in new JWKS
if ops_type == "ssa" and cnt_ssa[alg] >= 1:
continue

# insert old key to new keys
new_jwks.appendleft(jwk)

if ops_type == "connect":
cnt[jwk["alg"]] += 1
cnt[alg] += 1

if ops_type == "ssa":
cnt_ssa[alg] += 1

# import key to new JKS
keytool_import_key(old_jks_fn, jks_fn, jwk["kid"], jks_pass)
Expand Down Expand Up @@ -353,6 +372,10 @@ def patch(self):
web_keys = json.loads(config["jansConfWebKeys"])
except TypeError:
web_keys = config["jansConfWebKeys"]
except json.decoder.JSONDecodeError as exc:
# probably corrupted JSON
logger.warning(f"Unable to load existing JWKS; reason={exc}. New JWKS will be created.")
web_keys = {"keys": []}

with open("/etc/certs/auth-keys.old.json", "w") as f:
f.write(json.dumps(web_keys, indent=2))
Expand Down Expand Up @@ -504,6 +527,10 @@ def prune(self):
web_keys = json.loads(config["jansConfWebKeys"])
except TypeError:
web_keys = config["jansConfWebKeys"]
except json.decoder.JSONDecodeError as exc:
# probably corrupted JSON
logger.warning(f"Unable to load existing JWKS; reason={exc}. Existing JWKS will be deleted.")
web_keys = {"keys": []}

logger.info("Cleaning up keys (if any)")

Expand All @@ -519,28 +546,43 @@ def prune(self):
old_jwks = sorted(old_jwks, key=lambda k: k["exp"], reverse=True)

cnt = Counter(
jwk["alg"] for jwk in new_jwks
jwk["alg"].upper() for jwk in new_jwks
if key_ops_from_jwk(jwk) == "connect"
)

cnt_ssa = Counter(
jwk["alg"].upper() for jwk in new_jwks
if key_ops_from_jwk(jwk) == "ssa"
)

for jwk in old_jwks:
alg = jwk.get("alg", "").upper()

# exclude alg if it's not allowed
if jwk["alg"] not in self.allowed_key_algs:
if alg not in self.allowed_key_algs:
keytool_delete_key(jks_fn, jwk["kid"], jks_pass)
continue

ops_type = key_ops_from_jwk(jwk)

# cannot have more than 1 key for same algorithm in new JWKS
if ops_type == "connect" and cnt[jwk["alg"]]:
# cannot have more than 1 connect key for same algorithm in new JWKS
if ops_type == "connect" and cnt[alg] >= 1:
keytool_delete_key(jks_fn, jwk["kid"], jks_pass)
continue

# cannot have more than 1 ssa key for same algorithm in new JWKS
if ops_type == "ssa" and cnt_ssa[alg] >= 1:
keytool_delete_key(jks_fn, jwk["kid"], jks_pass)
continue

# preserve the key
new_jwks.append(jwk)

if ops_type == "connect":
cnt[jwk["alg"]] += 1
cnt[alg] += 1

if ops_type == "ssa":
cnt_ssa[alg] += 1

web_keys["keys"] = new_jwks

Expand Down Expand Up @@ -665,7 +707,17 @@ def resolve_enc_keys(keys: str) -> str:

def key_ops_from_jwk(jwk):
"""Resolve key_ops_type first value."""
return (jwk.get("key_ops_type") or ["connect"])[0]
types = jwk.get("key_ops_type") or []
try:
type_ = types[0]
except IndexError:
if jwk["kid"].startswith("connect_"):
type_ = "connect"
elif jwk["kid"].startswith("ssa_"):
type_ = "ssa"
else:
type_ = ""
return type_.lower()


def has_ext_jwks_uri(conf_dynamic, manager) -> bool:
Expand Down

0 comments on commit 7129c62

Please sign in to comment.