In [1]:
import re
import json
from fernet2 import Fernet2 
from base64 import urlsafe_b64encode
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.backends import default_backend
import cryptography.hazmat.primitives.asymmetric as asym
from cryptography.hazmat.primitives.serialization import load_pem_private_key

In [2]:
private_keyring = {
    "ecc.sec241.1.enc.priv": "-----BEGIN EC PRIVATE KEY-----\nMHcCAQEEIBYSJJDwaERSb8pvpUSmyzwokTT6bhomM1uX2T2+qQhToAoGCCqGSM49\nAwEHoUQDQgAEw6pMRon2aMn9oNsPjcOfnRf/uEm7Ed64SIG+zSqvkdxPQxewBVLF\nO+iP8UJGsm0rEx29wrCnaFCUpOxeLGQ0bA==\n-----END EC PRIVATE KEY-----\n",
    "ecc.sec241.1.sig.priv": "-----BEGIN EC PRIVATE KEY-----\nMHcCAQEEIC4cxNRgrZOm8lPyDKUbn5Q1LbfK6lkGt+mmwgvyednUoAoGCCqGSM49\nAwEHoUQDQgAEl1+7XxgkH0FwR/1WSnPFSPasQL3kj/XLxpbaCVCPhkQuRl7gddep\nfjSVf4w/eh52HKHDoJVPVnwaXcpUeaixlQ==\n-----END EC PRIVATE KEY-----\n",
    "rsa.2048.1.enc.priv":"<PEM encoded RSA PRIVATE KEY>",
    "rsa.2048.1.sig.priv":"<PEM encoded RSA PRIVATE KEY>"
}

In [3]:
public_keyrings = {
    "<group1-name>": {
        "ecc.sec224.1.enc.pub": "<ECC public key with sec224 curve>",
        "ecc.sec224.1.sig.pub": "ECC public key with sec224 curve",
        "rsa.2048.1.enc.pub": "<PEM encoded RSA PUBLIC KEY>",
        "Dsa.2048.1.sig.pub": "<PEM encoded DSA key (only good for sigining)>",
        "rsa.2048.1.sig.pub": "<PEM encoded RSA PUBLIC KEY>"
    },
    "<group2-name>": {
        "ecc.sec224.1.enc.pub": "<ECC public key with sec224 curve>",
        "rsa.2048.1.sig.pub": "<PEM encoded RSA PUBLIC KEY>",
        "dsa.1024.1.sig.pub" : "<PEM encoded DSA key (only good for signing)>"
    }
}

In [4]:
curve_switcher = {
    "sec192r1": asym.ec.SECP192R1,
    "sec224r1": asym.ec.SECP224R1,
    "sec256k1": asym.ec.SECP256K1,
    "sec256r1": asym.ec.SECP256R1,
    "sec384r1": asym.ec.SECP384R1,
    "sec521r1": asym.ec.SECP521R1,
    "sec163k1": asym.ec.SECT163K1,
    "sec163r2": asym.ec.SECT163R2,
    "sec233k1": asym.ec.SECT233K1,
}

def get_curve(curve_alias):
    return curve_switcher.get(curve_alias)

In [5]:
class PKFernet(object):
    def __init__(self, priv_keyring={}, public_keyrings={}):
        assert type(priv_keyring) is dict and type(public_keyrings) is dict, "Invalid parameter types, please pass JSON keyrings"
        self.priv_keyring = priv_keyring
        self.pub_keyrings = public_keyrings
    
    def url_safe_pem(pem_key):
        key = re.findall("KEY-----\n(.*)==\n-----END", pem_key, re.DOTALL)[0]
#         key = re.findall("-----BEGIN EC PRIVATE KEY-----\n(.*)==\n-----END EC PRIVATE KEY-----\n", pem_key, re.DOTALL)[0]
        key.replace("+", "-").replace("/", "_")
        return key
    
    def restore_pem(url_safe_pem):
        return url_safe_pem.replace("-", "+").replace("_", "/")
    
    def export_pub_keys(key_alias_list=[]):
        if not key_alias_list:
            return self.pub_keyrings
        res = {}
        for k in key_alias_list:
            res[k] = self.pub_keyrings[k]
        return res
        
    def import_pub_keys(self, receiver_name, receiver_public_keyring, overwrite=False):
        if overwrite is False:
            assert receiver_name not in self.pub_keyrings, "A public keyring already exists for this user, pass overwrite=True to update existing keyring"
        self.pub_keyrings[receiver_name] = receiver_public_keyring
        
    def encrypt(msg, receiver_name, receiver_enc_pub_key_alias, sender_sign_header, adata='', sign_also=True):
        # ensure keyrings are populated with sender and receiver
        assert receiver_name in self.pub_keyrings and sender_sign_header in self.priv_keyring, "Keys must exist for both the sender and receiver to encrypt."
        
        alg, key_param, ver, usage, key_type = receiver_enc_pub_key_alias.split(".")
#         alg_send, key_param_send, ver_send, usage_send, key_type_send = sender_sign_priv_key_alias.split(".")
#         TODO: ASSERT SENDER AND RECEIVER CAN AGREE UPON ALGORITHM
#         assert alg_rec == alg_send and key_param_rec == key_param_send and ver_rec == ver_send, "Sender and Receiver Exchange is not compatible."
        
        # generate ephemeral private key based on the given algorithm
        if alg == "ecc":
            curve = get_curve(key_param)
            ephem_priv_key = asym.ec.generate_private_key(curve(), default_backend())
        elif alg == "dsa":
            ephem_priv_key = asym.dsa.generate_private_key(key_size=key_param, backend=default_backend())
        elif alg == "rsa":
            ephem_priv_key = asym.rsa.generate_private_key(public_exponent=65537, key_size=key_param, backend=default_backend())

        rec_pub_key = self.pub_keyrings[receiver_name]
        send_priv_key = self.priv_keyring[sender_sign_header]
        ephem_public_key = ephem_priv_key.public_key()
        
        # PARSE HEADER TO DECIDE WHICH ALG TO USE HERE
        # exchange using ephemeral private key and the receiver's public key   
        shared_key = ephem_priv_key.exchange(ec.ECDH(), rec_pub_key)
        
        # sign if necessary using static sender's private key
        if sign_also:
            sender_private_key = load_pem_private_key(self.priv_keyring[sender_sign_header])
            # TODO: PARSE HEADER TO DECIDE WHICH ALG TO USE HERE
            signer = sender_private_key.signer(asym.ec.ECDSA(hashes.SHA256()))
            signer.update(msg)
            signature = signer.finalize()
            msg += "|" + signature

        fern = Fernet2(shared_key)
        f2_ctxt = fern.encrypt(msg, adata) # MUST DECODE ADATA?
        
        # TODO: MAKE URL SAFE 
        return adata + "|" + alg_ + "|" + ephem_public_key + "|" + f2_ctxt

    def decrypt(ctx, sender_name, verfiy_also=True):
        # use static private key for decryption
        adata, enc_alg, rpk, f2_ctxt = ctx.split('|')
        
        send_pub_key = self.pub_keyrings[sender_name]
        # TODO: parse header and decide which algorithm, double check shared key and fernet isntantiations
        shared_key = rpk.exchange(ec.ECDH(), send_pub_key)
        
        fern = Fernet2(shared_key)
        msg = fern.decrypt(ctx, associated_data=adata)
            
        '''
        enc_priv_key.decrypt(
            ctx,
            padding.OAEP(
                mgf=padding.MGF1(algorithm=hashes.SHA1()),
                algorithm=hashes.SHA1(),
                label=None
            )
         )
        
        public_key = private_key.public_key()
        verifier = public_key.verifier(
            signature,
            padding.PSS(
                mgf=padding.MGF1(hashes.SHA256()),
                salt_length=padding.PSS.MAX_LENGTH
            ),
            hashes.SHA256()
        )
        verifier.update(message)
        verifier.verify()
        '''
        
# TODO: HANDLE PADDING
#     def pad(self, enc_priv_key):
#         enc_priv_key.decrypt(
#             ctx,
#             padding.OAEP(
#                 mgf=padding.MGF1(algorithm=hashes.SHA1()),
#                 algorithm=hashes.SHA1(),
#                 label=None
#             )
#          )

In [6]:
pf = PKFernet(private_keyring, public_keyrings)

In [7]:
priv_key = load_pem_private_key(private_keyring["ecc.sec241.1.enc.priv"], password=None, backend=default_backend())

In [8]:
ephem_private_key = asym.dsa.generate_private_key(key_size=1024, backend=default_backend())

In [None]:
#             asym.rsa.RSAPrivateKeyWithSerialization.private_bytes(encoding, format, encryption_algorithm)
#             private_bytes(encoding, format, encryption_algorithm)
#             asymmetric.rsa.RSAPrivateKeyWithSerialization[source]