<a href="https://colab.research.google.com/github/AltmannPeter/privacy-key-management/blob/main/ARKG_HDK.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%%capture
!pip install ecpy

In [2]:
from ecpy.curves    import Point, Curve
import secrets
from hashlib        import sha256
from math           import ceil

from cryptography.hazmat.primitives             import hashes
from cryptography.hazmat.primitives.asymmetric  import ec
from cryptography.hazmat.primitives.kdf.hkdf    import HKDF
from cryptography.hazmat.primitives             import serialization

In ARKG-based HDK, you perform a KEM between the WSCA and the WSCD to derive the shared secret. You therefore begin with two separate keys, one is the `ROOT` key protected by the WSCD, the other is the `KEM` key that is generated in the WSCA. These two keypairs are called the HDK `seed`.

The examples below use RFC 9180 for the key generation and the KEM.

In [3]:
# M = master, P = parent, R = recipient, E = ephemeral
ikmM = '24c431f05a924bb8872918d399cffc913599055fea3cb5c01ff833ed06ef7209'
ikmP = 'ee8d9c4b099ba2a8a0f9f9656c7404636fdc386bd4e9124c1c562675ee91cafb'
ikmR = '668b37171f1072f3cf12ea8a236a45df23fc13b82af3609ad1e354f6ef817550'
ikmE = '4270e54ffd08d79d5928020af4686d8f6b7d35dbe470265f1f5aa22816ce860e'
skEm = '4995788ef4b9d6132b249ce59a77281493eb39af373d236a1fe415cb0c2d7beb'
pkEm = '04a92719c6195d5085104f469a8b9814d5838ff72b60501e2c4466e5e67b325ac98536d7b61a1af4b78e5b7f951c0900be863c403ce65c9bfcb9382657222d18c4'
ss   = 'c0d26aeab536609a572b07695d933b589dcf363ff9d93c93adea537aeabb8cb8'

MODE = 'ADD'
DST_HDK_DERIVE_KEY = b'HDK-derive-blinding-factor'

# RFC 9180
# Function docstrings should be written based on RFC 9180
Nsecret = 32
Nsk = 32
crv = Curve.get_curve('secp256r1')
G = crv.generator
N = crv.order
salt = b""
kem_id = 0x0010
suite_id = b"KEM" + kem_id.to_bytes(2)
bitmask = 0xFF

SEC1 = serialization.Encoding.X962
UCOMP = serialization.PublicFormat.UncompressedPoint
COMP = serialization.PublicFormat.CompressedPoint

def derive_keypair(ikm):
  label_extract = b"dkp_prk"
  label_expand = b"candidate"
  sk = 0
  counter = 0

  while sk == 0 or sk >= crv.order:
    if counter > 255:
      raise DeriveKeyPairError
    sk = bytearray(extract_and_expand(ikm, label_extract, label_expand, counter.to_bytes(1), Nsk))
    sk[0] &= bitmask
    sk = int.from_bytes(sk)
    counter += 1

  prv = ec.derive_private_key(sk, ec.SECP256R1())
  pub = prv.public_key()
  return prv, pub


def test_derive_keypair():
  output = derive_keypair(bytes.fromhex(ikmE))[0].private_numbers().private_value
  expected = int(skEm, 16)
  assert output == expected


def ECDH(a, B):
  return a.exchange(ec.ECDH(), B)


def extract_and_expand(ikm, label_extract, label_expand, info, L):
  labeled_ikm = b"HPKE-v1" + suite_id + label_extract + ikm
  labeled_info = L.to_bytes(2) + b"HPKE-v1" + suite_id + label_expand + info
  return HKDF(
        algorithm=hashes.SHA256(),
        length=L,
        salt=salt,
        info=labeled_info
        ).derive(labeled_ikm)


def encap(pub_recipient, seed, label_extract=b'eae_prk'):
  prv_kem, pub_kem = derive_keypair(bytes.fromhex(seed))
  dh = ECDH(prv_kem, pub_recipient)
  enc = pub_kem.public_bytes(SEC1, UCOMP)

  pub_recipient_enc = pub_recipient.public_bytes(SEC1, UCOMP)
  kem_context = enc + pub_recipient_enc

  label_expand = b'shared_secret'
  shared_secret = extract_and_expand(dh, label_extract, label_expand, kem_context, Nsecret)
  return shared_secret, enc


def test_encap():
  prv, pub = derive_keypair(bytes.fromhex(ikmR))
  output = encap(pub, ikmE)
  assert output[0].hex() == ss
  assert output[1].hex() == pkEm

test_derive_keypair()
test_encap()

In [4]:
# From RFC 9380 using SHA-256
# Function docstrings should be written based on RFC 9380

def strxor(b1, b2):
    return bytes(a ^ b for a, b in zip(b1, b2))


def expand_message_xmd(msg, DST, len_in_bytes):
    """
    https://www.rfc-editor.org/rfc/rfc9380.html#name-expand_message_xmdsha-256
    """
    b_in_bytes = sha256().digest_size
    s_in_bytes = sha256().block_size
    ell = ceil(len_in_bytes / b_in_bytes)
    if any([ell > 255, len_in_bytes > 65535, len(DST) > 255]):
        raise ValueError("Input values out of range.")
    DST_len_I2OSP = int.to_bytes(len(DST), 1)
    DST_prime = DST + DST_len_I2OSP
    Z_pad = b"\x00" * s_in_bytes
    l_i_b_str = int.to_bytes(len_in_bytes, 2)
    msg_prime = Z_pad + msg + l_i_b_str + int.to_bytes(0, 1) + DST_prime

    b_0 = sha256()
    b_0.update(msg_prime)

    b_1 = sha256()
    b_1.update(b_0.digest() + int.to_bytes(1, 1) + DST_prime)

    b = [b_0.digest(), b_1.digest()]
    for i in range(2, ell + 1):
        hash_input_xor = strxor(b[0], b[i - 1])
        str_xor = sha256()
        str_xor.update(hash_input_xor + int.to_bytes(i, 1) + DST_prime)
        b.append(str_xor.digest())

    uniform_bytes = b[1]
    for i in b[2:]:
        uniform_bytes += i

    return uniform_bytes[:len_in_bytes]


def test_expand_msg_xmd():
  dst = b"QUUX-V01-CS02-with-expander-SHA256-128"

  out_1 = expand_message_xmd(b"", dst, 0x20).hex()
  out_2 = expand_message_xmd(b"abc", dst, 0x20).hex()

  expected_1 = "68a985b87eb6b46952128911f2a4412bbc302a9d759667f87f7a21d803f07235"
  expected_2 = "d8ccab23b5985ccea865c6c97b6e5b8350e794e603b4b97902f53a8a0d605615"

  assert out_1 == expected_1
  assert out_2 == expected_2

test_expand_msg_xmd()


def hash_to_field(msg, count, DST, order=0xFFFFFFFF00000001000000000000000000000000FFFFFFFFFFFFFFFFFFFFFFFF):
    """
    https://www.rfc-editor.org/rfc/rfc9380.html#name-p256_xmdsha-256_sswu_ro_
    """
    m, L, p = 1, 48, order

    len_in_bytes = count * m * L
    uniform_bytes = expand_message_xmd(
        msg=msg,
        DST=DST,
        len_in_bytes=len_in_bytes,
    )

    u = []
    for i in range(count):
        for j in range(m):
            elm_offset = L * (j + i * m)
            tv = uniform_bytes[elm_offset : elm_offset + L]
            e_j = int.from_bytes(tv) % p
            u.append(e_j)
    return u


def test_hash_to_field():
  DST = b"QUUX-V01-CS02-with-P256_XMD:SHA-256_SSWU_RO_"
  msg = b""

  out_1 = hash_to_field(msg, 2, DST)[0]
  out_2 = hash_to_field(msg, 2, DST)[1]

  expected_1 = 0xad5342c66a6dd0ff080df1da0ea1c04b96e0330dd89406465eeba11582515009
  expected_2 = 0x8c0f1d43204bd6f6ea70ae8013070a1518b43873bcd850aafa0a9e220e2eea5a

  assert out_1 == expected_1
  assert out_2 == expected_2

  msg = b"abc"

  out_1 = hash_to_field(msg, 2, DST)[0]
  out_2 = hash_to_field(msg, 2, DST)[1]

  expected_1 = 0xafe47f2ea2b10465cc26ac403194dfb68b7f5ee865cda61e9f3e07a537220af1
  expected_2 = 0x379a27833b0bfe6f7bdca08e1e83c760bf9a338ab335542704edcd69ce9e46e0

  assert out_1 == expected_1
  assert out_2 == expected_2

test_hash_to_field()

The rest is from HDK.

The derivation path for HDK supports both branching and nested shared secrets as follows:

```
- root_key/kem_i/derived_key_j/shared_secret_k/derived_key_l/...
```

Each KEM is a `shared_secret` between an issuer and the WSCA. Each derived key is the output of the `derive_key_child` functions below at index $j$.

This derivation path is intended to model the EUDIW context where the user's EUDIW will request attestations from one issuer, and possibly use that attestation to request batch attestations from another issuer. The WSCA needs to store information that enables recreating the following tree:

```
root_key
  │
  ├─ shared_secret_A       <-- KEM for Issuer A (from root)   
  │   ├─ derived_key_A1    <-- first child
  │   └─ derived_key_A2    <-- second child
  │       │
  │       └─ shared_secret_B       <-- KEM for Issuer B (from A2)
  │           ├─ derived_key_B1    <-- first nested child key
  │           └─ ⋮
  │
  ├─ shared_secret_C       <-- KEM for Issuer C (from root)
  │   ├─ derived_key_C1    <-- first child
  │   ├─ derived_key_C2    <-- second child
  │   ├─ ⋮
  ↓   ↓
```

The WSCA can recreate the entire structure locally as the KEM is between the WSCA and the WSCD. Each shared secret is then associated with an Issuer.

In [5]:
# Helper functions
def prf(msg, DST):
  """
  https://www.rfc-editor.org/rfc/rfc9380.html#name-prf_sha-256
  Returns a pseudorandom function output from msg and DST.
  """
  return hash_to_field(
      msg=msg,
      count=1,
      DST=DST
  )[0]


def combine_bf(bf1, bf2, mode):
  """
  Combine two blinding factors.
  mode determines whether to add or multiply the blinding factors.
  """
  return (bf1 + bf2) % N if mode == 'ADD' else (bf1 * bf2) % N


def blind_pub_key(pub_parent, bf, mode):
  """
  Blind a public key using the blinding factor bf.
  mode determines whether to add or multiply the blinding factor bf
  """
  pub_parent_point = crv.decode_point(pub_parent.public_bytes(SEC1, COMP))
  pub_child_point = (pub_parent_point + bf * G) if mode == 'ADD' else (pub_parent_point * bf)

  return ec.EllipticCurvePublicKeyWithSerialization.from_encoded_point(
      ec.SECP256R1(),
      bytes(crv.encode_point(pub_child_point, True))
  )


def blind_prv_key(prv_parent, bf, mode):
  """
  Blind a private key
  mode determines whether to add or multiply the blinding factor bf
  """
  prv = prv_parent.private_numbers().private_value
  prv_child = (prv + bf) % N if mode == 'ADD' else (prv * bf) % N
  return ec.derive_private_key(prv_child, ec.SECP256R1())


def blind_key(key_parent, bf, mode):
  """
  Blind a public or private key
  mode determines whether to add or multiply the blinding factor bf
  """
  if isinstance(key_parent, ec.EllipticCurvePrivateKey):
      return blind_prv_key(key_parent, bf, mode)
  elif isinstance(key_parent, ec.EllipticCurvePublicKey):
      return blind_pub_key(key_parent, bf, mode)
  else:
      raise TypeError(f"Unsupported key type: {type(key_parent)}")


def derive_pub_child(key, mode, shared_secret, index=0):
  """
  Derive pub key from shared secret.
  mode determines whether to add or multiply blinding factors.
  index is the index of the child key.
  Output is a public key.
  """
  dst_suffix = key.public_bytes(SEC1, COMP) + int.to_bytes(index, 4)
  bf = prf(shared_secret, DST_HDK_DERIVE_KEY + dst_suffix)
  return blind_key(key, bf, mode)


def derive_prv_child(key, mode, shared_secret=None, bf=None, index=0):
  """
  Derive prv key from either shared secret or blinding factor.
  Either shared_secret or bf must be provided.
  mode determines whether to add or multiply blinding factors.
  Index is the index of the child key.
  Output is a private key.
  """
  if shared_secret is None and bf is None:
    raise ValueError("Either shared_secret or bf must be provided.")
  if shared_secret is not None:
    pub = key.public_key()
    dst_suffix = pub.public_bytes(SEC1, COMP) + int.to_bytes(index, 4)
    bf = prf(shared_secret, DST_HDK_DERIVE_KEY + dst_suffix)
  return blind_key(key, bf, mode), bf

**Illustrative example**

1. Generate a root key pair (this would be the WSCD key).
2. Compute the shared secret between the WSCD and the WSCA generated ephemeral key.
3. Use the shared secret as input to the key derivation to:
  * Derive two public child keys with index 1 and 2. The issuer can do this step given the shared secret.
  * Derive two private child keys with index 1 and 2 and their corresponding blinding factors. This step can only be done by the user who controls the WSCD.
4. Use derived key 2 as a parent key for a new key derivation to demonstrate how key derivation can be nested.
5. Demonstrate how the WSCD only needs to protect the root key, and how you can go from the root to any offspring in a single step.

In [6]:
# Generate root key pair in WSCD
prv_root, pub_root  = derive_keypair(bytes.fromhex(ikmM))

# Parts done by WSCA
# RFC 9180 encap
ss_A, _ = encap(pub_root, ikmE)

# Compute
pub_A1 = derive_pub_child(
    key=pub_root,
    mode='ADD',
    shared_secret=ss_A,
    index=1,
    )
prv_A1, bf_prv_A1 = derive_prv_child(
    key=prv_root,
    mode='ADD',
    shared_secret=ss_A,
    index=1,
    )

pub_A2 = derive_pub_child(
    key=pub_root,
    mode='ADD',
    shared_secret=ss_A,
    index=2,
    )
prv_A2, bf_prv_A2 = derive_prv_child(
    key=prv_root,
    mode='ADD',
    shared_secret=ss_A,
    index=2,
    )

assert blind_prv_key(prv_root, bf_prv_A2, 'ADD').private_numbers() == prv_A2.private_numbers()
assert pub_A1.public_bytes(SEC1, COMP).hex() == prv_A1.public_key().public_bytes(SEC1, COMP).hex()
assert pub_A2.public_bytes(SEC1, COMP).hex() == prv_A2.public_key().public_bytes(SEC1, COMP).hex()

In [7]:
# Derive grandchild
# ikmE only reused for simplicity here
ss_B, _ = encap(pub_A2, ikmE)

pub_B1 = derive_pub_child(
    key=pub_A2,
    mode='ADD',
    shared_secret=ss_B,
    index=1,
    )
prv_B1, bf_prv_B1 = derive_prv_child(
    key=prv_A2,
    mode='ADD',
    shared_secret=ss_B,
    index=1,
    )

assert blind_prv_key(prv_A2, bf_prv_B1, 'ADD').private_numbers() == prv_B1.private_numbers()
assert pub_B1.public_bytes(SEC1, COMP).hex() == prv_B1.public_key().public_bytes(SEC1, COMP).hex()

# You can go to any path from root just by joining the blinding factors
bf_root_to_B1 = combine_bf(bf_prv_A2, bf_prv_B1, "ADD")
prv_B1_from_root, _ = derive_prv_child(
    key=prv_root,
    mode='ADD',
    bf=bf_root_to_B1,
    )
assert prv_B1.private_numbers() == prv_B1_from_root.private_numbers()