In [1]:
import re
import json
from fernet2 import Fernet2 
from base64 import urlsafe_b64encode
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.backends import default_backend
import cryptography.hazmat.primitives.asymmetric as asym
from cryptography.hazmat.primitives.serialization import load_pem_private_key, load_pem_public_key
from cryptography.hazmat.primitives.serialization import Encoding as encoding, PublicFormat as public_format

In [2]:
private_keyring = {
    # TO
    "ecc.sec224r1.1.sig.priv": "-----BEGIN EC PRIVATE KEY-----\nMHcCAQEEIBYSJJDwaERSb8pvpUSmyzwokTT6bhomM1uX2T2+qQhToAoGCCqGSM49\nAwEHoUQDQgAEw6pMRon2aMn9oNsPjcOfnRf/uEm7Ed64SIG+zSqvkdxPQxewBVLF\nO+iP8UJGsm0rEx29wrCnaFCUpOxeLGQ0bA==\n-----END EC PRIVATE KEY-----\n",
    "ecc.sec241r1.1.enc.priv": "-----BEGIN EC PRIVATE KEY-----\nMHcCAQEEIBYSJJDwaERSb8pvpUSmyzwokTT6bhomM1uX2T2+qQhToAoGCCqGSM49\nAwEHoUQDQgAEw6pMRon2aMn9oNsPjcOfnRf/uEm7Ed64SIG+zSqvkdxPQxewBVLF\nO+iP8UJGsm0rEx29wrCnaFCUpOxeLGQ0bA==\n-----END EC PRIVATE KEY-----\n",
    "ecc.sec241r1.1.sig.priv": "-----BEGIN EC PRIVATE KEY-----\nMHcCAQEEIC4cxNRgrZOm8lPyDKUbn5Q1LbfK6lkGt+mmwgvyednUoAoGCCqGSM49\nAwEHoUQDQgAEl1+7XxgkH0FwR/1WSnPFSPasQL3kj/XLxpbaCVCPhkQuRl7gddep\nfjSVf4w/eh52HKHDoJVPVnwaXcpUeaixlQ==\n-----END EC PRIVATE KEY-----\n",
    "rsa.2048.1.enc.priv":"<PEM encoded RSA PRIVATE KEY>",
    "rsa.2048.1.sig.priv":"<PEM encoded RSA PRIVATE KEY>"
} 

In [3]:
public_keyrings = {
    "<group1-name>": {
        "ecc.sec224r1.1.enc.pub": "<ECC public key with sec224 curve>",
        "ecc.sec224r1.1.sig.pub": "ECC public key with sec224 curve",
        "rsa.2048.1.enc.pub": "<PEM encoded RSA PUBLIC KEY>",
        "Dsa.2048.1.sig.pub": "<PEM encoded DSA key (only good for sigining)>",
        "rsa.2048.1.sig.pub": "<PEM encoded RSA PUBLIC KEY>"
    },
    "<group2-name>": {
        "ecc.sec224r1.1.enc.pub": "<ECC public key with sec224 curve>",
        "rsa.2048.1.sig.pub": "<PEM encoded RSA PUBLIC KEY>",
        "dsa.1024.1.sig.pub" : "<PEM encoded DSA key (only good for signing)>"
    }, 
    "asheesh_teja": {
        "dsa.1024.1.sig.pub": "-----BEGIN PUBLIC KEY-----\nMIIBtzCCASsGByqGSM44BAEwggEeAoGBAK0wjowa0YhHl_wB8jgfN6sl4HeiCfSf\nE99DY7McBs7Op9L8qMTQt151fiIcaaQeAOpzUqI6ofk-1iaSK0iKTQ63t9QLl1mz\nbknp0vMny4IW2PqSkE14OZtqaDzRsWJqKAIb91BdKGyqdFNmRfHddpaPuDNawOge\ng8yIe6P4QBN1AhUAnbg5PlIuvavIX2g_YIC9uGbm1xkCgYBW93q4kiXqdPki7a5j\ngYSGD9uul58q6h361gl_BQcTwvf2VJioffL7HqfDS--jmS8_cZCJ3VPeXqUvDCOz\nLKnwl9Fc3s7xG8Ks0R0PyLp3RikUKWv1CtT6GmS81JrzWvPrgKBWbIrIruddLtvb\nFKX0l4BHV751QLLU7mmcPbcSFgOBhQACgYEAiKIEcb55nZF-_E38puoOeGqvv2sy\nTnE9Prek5kpAzqA9Q9VT4m4SmKlFAbE6qC_7IxgQTjoKs301EZSWeA15z6vcnOo-\nr-N5Z8Gn7qwJDzCL3NJpRhTQgBVL_Xh4xpJS-MM1EoEEqKBem8gCGFM-TGLdrx-K\nMDbcPF_UfT46rY0=\n-----END PUBLIC KEY-----\n", 
        "ecc.secp224r1.1.enc.pub": "-----BEGIN PUBLIC KEY-----\nME4wEAYHKoZIzj0CAQYFK4EEACEDOgAElFmWnRvIgK53WG088jVdDFhMgvfP7SHS\nNzPpLe4Y7NdvjawlwfTb5k3eIvT0hTRra431odw19fc=\n-----END PUBLIC KEY-----\n"
    }
}

In [4]:
alg_switcher = {
    "ecdsa ": asym.ec.ECDSA,
}

curve_switcher = {
    "sec192r1": asym.ec.SECP192R1,
    "secp192r1": asym.ec.SECP192R1,    
    "sec224r1": asym.ec.SECP224R1,
    "secp224r1": asym.ec.SECP224R1,    
    "sec256k1": asym.ec.SECP256K1,
    "secp256k1": asym.ec.SECP256K1,
    "sec256r1": asym.ec.SECP256R1,
    "secp256r1": asym.ec.SECP256R1,
    "sec384r1": asym.ec.SECP384R1,
    "secp384r1": asym.ec.SECP384R1,
    "sec521r1": asym.ec.SECP521R1,
    "secp521r1": asym.ec.SECP521R1,
    "sec163k1": asym.ec.SECT163K1,
    "sect163k1": asym.ec.SECT163K1,
    "sec163r2": asym.ec.SECT163R2,
    "sect163r2": asym.ec.SECT163R2,
    "sec233k1": asym.ec.SECT233K1,
    "sect233k1": asym.ec.SECT233K1,
}

hash_switcher = {
    "sha256": hashes.SHA256,
}

header_to_alias_mapper = {
    "ecdsa": "ecc", 
}

def get_switch_case(switcher, key):
    return switcher.get(key)

In [5]:
class PKFernet(object):
    def __init__(self, priv_keyring={}, public_keyrings={}):
        assert type(priv_keyring) is dict and type(public_keyrings) is dict, "Invalid parameter types, please pass JSON keyrings"
        self.priv_keyring = priv_keyring
        self.pub_keyrings = public_keyrings
    
    def url_safe_pem(self, pem_key):
        key = re.findall("KEY-----\n(.*)==\n-----END", pem_key, re.DOTALL)[0]
#         key = re.findall("-----BEGIN EC PRIVATE KEY-----\n(.*)==\n-----END EC PRIVATE KEY-----\n", pem_key, re.DOTALL)[0]
        key.replace("+", "-").replace("/", "_")
        return key
    
    def restore_pem(self, url_safe_pem):
        return url_safe_pem.replace("-", "+").replace("_", "/")
    
    def export_pub_keys(self, key_alias_list=[]):
        if not key_alias_list:
            return self.pub_keyrings
        res = {}
        for k in key_alias_list:
            res[k] = self.pub_keyrings[k]
        return res
        
    def import_pub_keys(self, receiver_name, receiver_public_keyring, overwrite=False):
        if overwrite is False:
            assert receiver_name not in self.pub_keyrings, "A public keyring already exists for this user, pass overwrite=True to update existing keyring"
        self.pub_keyrings[receiver_name] = receiver_public_keyring
      
    def parse_header(self, header):
        arr = header.split('_')
        new_header = arr[0] + '.' + arr[2]
        # returns alg, hash, key_param, ver, usage, key_type
        return new_header.split('.')
    
    def header_to_alias(self, header):
        alg, _, key_param, ver, usage, key_type = self.parse_header(header)
        alg = get_switch_case(header_to_alias_mapper, alg)
        if alg == "ecc":
            key_param = key_param[:3] + key_param[4:]
    
        alias = '.'.join([alg, key_param, ver, usage, key_type])
        return alias
    
    def encrypt(self, msg, receiver_name, receiver_enc_pub_key_alias, sender_sign_header, adata='', sign_also=True):
        # ensure keyrings are populated with sender and receiver
#         and sender_sign_header in self.priv_keyring
        assert receiver_name in self.pub_keyrings, "Keys must exist for both the sender and receiver to encrypt."
        
        alg, key_param, ver, usage, key_type = receiver_enc_pub_key_alias.split(".")
#         alg_send, key_param_send, ver_send, usage_send, key_type_send = sender_sign_priv_key_alias.split(".")
#         TODO: ASSERT SENDER AND RECEIVER CAN AGREE UPON ALGORITHM
#         assert alg_rec == alg_send and key_param_rec == key_param_send and ver_rec == ver_send, "Sender and Receiver Exchange is not compatible."
        
        # generate ephemeral private key based on the given algorithm
        if alg == "ecc":
            curve = get_switch_case(curve_switcher, key_param)
            ephem_priv_key = asym.ec.generate_private_key(curve(), default_backend())
        elif alg == "dsa":
            ephem_priv_key = asym.dsa.generate_private_key(key_size=key_param, backend=default_backend())
        elif alg == "rsa":
            ephem_priv_key = asym.rsa.generate_private_key(public_exponent=65537, key_size=key_param, backend=default_backend())

        rec_pub_key = load_pem_public_key(self.pub_keyrings[receiver_name][receiver_enc_pub_key_alias], backend=default_backend())
        ephem_public_key = ephem_priv_key.public_key()
        rpk = ephem_public_key.public_bytes(encoding.PEM, public_format.SubjectPublicKeyInfo)
        
        # PARSE HEADER TO DECIDE WHICH ALG TO USE HERE
        # exchange using ephemeral private key and the receiver's public key   
        shared_key = ephem_priv_key.exchange(asym.ec.ECDH(), rec_pub_key)
        
        # sign if necessary using static sender's private key
        if sign_also:
            send_priv_alias = self.header_to_alias(sender_sign_header)
            send_priv_key = load_pem_private_key(self.priv_keyring[send_priv_alias], password=None, backend=default_backend())
            # TODO: PARSE HEADER TO DECIDE WHICH ALG TO USE HERE
            signer = send_priv_key.signer(asym.ec.ECDSA(hashes.SHA256()))
            signer.update(msg)
            signature = signer.finalize()
            msg += "|" + signature
        
        fern = Fernet2(urlsafe_b64encode(shared_key))
        f2_ctxt = fern.encrypt(msg, adata) # MUST DECODE ADATA?
        
        # TODO: MAKE URL SAFE, base 64 encode alg    self.url_safe_pem(rpk)
        return adata + "|" + urlsafe_b64encode(alg) + "|" + rpk + "|" + f2_ctxt

    def decrypt(self, ctx, sender_name, verfiy_also=True):
        # use static private key for decryption
        adata, alg, rpk, f2_ctxt = ctx.split('|')
        
        rec_priv_key = load_pem_private_key(self.priv_keyring[alg])
        
        # TODO: parse header and decide which algorithm, double check shared key and fernet isntantiations
        shared_key = rec_priv_key.exchange(asym.ec.ECDH(), rpk)

        fern = Fernet2(shared_key)
        msg = fern.decrypt(ctx, associated_data=adata)
        
        if verify_also:
            msg, sig = msg.split('|')
            
            send_pub_keyring = self.pub_keyrings[sender_name]
            
            for pub_alias in send_pub_keyring:
                if alg in pub_alias and "sig" in pub_alias:
                    send_pub_key = load_pem_public_key(send_pub_keyring[pub_alias], backend=default_backend())
                    
            # ASSERT HERE
            send_pub_key.verifier(sig, asym.ec.ECDSA(hashes.SHA256()))
            verifier.update(msg)
            verifier.verify() # throws error if verify is false
            
        return msg
    
        '''
        enc_priv_key.decrypt(
            ctx,
            padding.OAEP(
                mgf=padding.MGF1(algorithm=hashes.SHA1()),
                algorithm=hashes.SHA1(),
                label=None
            )
         )
        
        public_key = private_key.public_key()
        verifier = public_key.verifier(
            signature,
            padding.PSS(
                mgf=padding.MGF1(hashes.SHA256()),
                salt_length=padding.PSS.MAX_LENGTH
            ),
            hashes.SHA256()
        )
        verifier.update(message)
        verifier.verify()
        '''
        
# TODO: HANDLE PADDING
#     def pad(self, enc_priv_key):
#         enc_priv_key.decrypt(
#             ctx,
#             padding.OAEP(
#                 mgf=padding.MGF1(algorithm=hashes.SHA1()),
#                 algorithm=hashes.SHA1(),
#                 label=None
#             )
#          )

In [6]:
pf = PKFernet(private_keyring, public_keyrings)

In [7]:
pf.encrypt("MESSAGE", "asheesh_teja", "ecc.secp224r1.1.enc.pub", "ecdsa_with_sha256.secp224r1.1.sig.priv", "ADATA")

'ADATA|ZWNj|-----BEGIN PUBLIC KEY-----\nME4wEAYHKoZIzj0CAQYFK4EEACEDOgAEUX7Ko8cjwDlLorUre+pEGOus6S1VaSff\nPafQJZ0arTbhL/PobkFrLQ0XMrx9Zc0m4UoBNI8D+IE=\n-----END PUBLIC KEY-----\n|gSewybQhMzuzb0Ej49rdiV0DUcSEjLjgyxVdCCsoLslGjkFKVeFZ5F_-bbN4q6eMXWDzbKB6blaeFa6kq6gkbH9nVabZ8NHVb6sCeC0kdQAluFkJtLSRaapHAZ47vGD9z5hqCmgxH1anLqahfLugx5v3TsLakqzPJk-5IAeKZ4Fk'

In [8]:
#             asym.rsa.RSAPrivateKeyWithSerialization.private_bytes(encoding, format, encryption_algorithm)
#             private_bytes(encoding, format, encryption_algorithm)
#             asymmetric.rsa.RSAPrivateKeyWithSerialization[source]