In [73]:
import os
import hashlib
import base64
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric import dh
import pyDH
from typing import Tuple
from Crypto.Cipher import AES
import hkdf

In [75]:
def get_dh_obj():
    return pyDH.DiffieHellman()

def b64enc(b):
    return base64.b64encode(b).decode("utf-8")

def b64dec(b):
    return base64.b64decode(b)

def conv_bytes_append(a:int, b:int) -> bytes:
    return bytes([a]) + bytes([b])

def conv_bytes(a:int) -> bytes:
    num_bytes = (a.bit_length() + 7) // 8
    return a.to_bytes(num_bytes, 'big')

def str_to_bytes(s:str) -> bytes:
    return s.encode('utf-8')

def bytes_to_str(b:bytes) -> str:
    return b.decode('utf-8')

def kdf_rk(rk:bytes, dh_output:str) -> Tuple[bytes, bytes]:
    prk = hkdf.Hkdf(rk, b"DoubleRatchet", hashlib.sha256)
    output = prk.expand(str_to_bytes(dh_output), 64)
    return (output[:32], output[32:])

def kdf_ck(ck:bytes) -> Tuple[bytes, bytes]:
    prk = hkdf.Hkdf(ck, b"DoubleRatchet", hashlib.sha256)
    output = prk.expand(b"ChainKey", 64)
    return (output[:32], output[32:])


In [76]:

class DoubleRatchet():
    
    def __init__(self, our_dh_obj, their_public_key):
        
        self.our_dh_obj = our_dh_obj
        self.their_public_key = their_public_key
        self.root_key = None
        self.recv_key = None
        self.send_key = None
        self.last_done = None
        self.root_chain = None

    def init_session(self, root_key):
        self.root_key = root_key
        self.root_chain = hkdf.Hkdf(self.root_key, b"DoubleRatchet", hashlib.sha256)

        
    def refresh_chains(self):
        self.recv_chain = hkdf.Hkdf(self.recv_key, b"DoubleRatchet", hashlib.sha256)
        self.send_chain = hkdf.Hkdf(self.send_key, b"DoubleRatchet", hashlib.sha256)

    def chain_step(self, chain):
        if chain == "send":
            output = self.send_chain.expand(b"common_key", 64)
            self.send_key = output[:32]
            self.refresh_chains()
            return output[32:]
        elif chain == "receive":
            output = self.recv_chain.expand(b"common_key", 64)
            self.recv_key = output[:32]
            self.refresh_chains()
            return output[32:]
        else:
            raise Exception("Invalid chain")
    
    def update_key_pair(self):
        self.our_dh_obj = get_dh_obj()
        return self.our_dh_obj.gen_public_key()

    def update_root(self):
        self.root_key, output = kdf_ck(self.root_key)
        self.root_chain = hkdf.Hkdf(self.root_key, b"DoubleRatchet", hashlib.sha256)
        return output
    
    def send(self, message):
        if self.last_done == None:
            self.last_done = "send"
            dh_output = self.our_dh_obj.gen_shared_key(self.their_public_key)
            (self.root_key, self.send_key) = kdf_rk(self.root_key, dh_output)
            self.refresh_chains()
            output = self.chain_step("send")
            # print("Encryption key: ", output)
            public_key = self.our_dh_obj.gen_public_key()
            # print("Public key: ", public_key)
            return (output, public_key)
        
        elif self.last_done == "recv":
            self.last_done = "send"
            new_pub = self.update_key_pair()
            dh_output = self.our_dh_obj.gen_shared_key(self.their_public_key)
            (self.root_key, self.send_key) = kdf_rk(self.root_key, dh_output)
            self.refresh_chains()
            key = self.chain_step("send")
            return (key, new_pub) #(the key to encrypt the data, new public key)
        
        elif self.last_done == "send":
            self.last_done = "send"
            output = self.chain_step("send")
            
            return (output, None)
        
    
    def recv(self, message, public_key):
        if public_key != None:
            self.their_public_key = public_key
        if self.last_done == None:
            self.last_done = "recv"
            dh_output = self.our_dh_obj.gen_shared_key(self.their_public_key)
            self.root_key, self.recv_key = kdf_rk(self.root_key, dh_output)
            self.refresh_chains()
            output = self.chain_step("receive")
            # print("Decryption key: ", output)
            return output
        elif self.last_done == "send":
            self.last_done = "recv"
            dh_output = self.our_dh_obj.gen_shared_key(self.their_public_key)
            self.root_key, self.recv_key = kdf_rk(self.root_key, dh_output)
            self.refresh_chains()
            output = self.chain_step("receive")
            # print("Decryption key: ", output)
            return output
        elif self.last_done == "recv":
            self.last_done = "recv"
            output = self.chain_step("receive")
            # print("Decryption key: ", output)
            return output
    
    def aes_encrypt(self, message, key):
        cipher = AES.new(key, AES.MODE_EAX)
        nonce = cipher.nonce
        ciphertext, _ = cipher.encrypt_and_digest(message)
        return ciphertext
    
    def aes_decrypt(self, message, key):
        cipher = AES.new(key, AES.MODE_EAX)
        nonce = cipher.nonce
        plaintext = cipher.decrypt(message)
        return plaintext

In [77]:
def test():
    root_key = os.urandom(32)
    alice_obj = get_dh_obj()
    bob_obj = get_dh_obj()
    alice_pub = alice_obj.gen_public_key()
    bob_pub = bob_obj.gen_public_key()
    alice = DoubleRatchet(alice_obj, bob_pub)
    (enc, bob_pub) = alice.send(b"Hello")
    print("encrypted message: ", b64enc(enc))
    print("bob's public key: ", bob_pub)
    bob = DoubleRatchet(bob_obj, alice_pub)
    print("decrypted message: ", b64enc(bob.recv(enc, None)))

# test()

In [79]:
def main():
    root_key = os.urandom(32)
    alice_obj = get_dh_obj()
    bob_obj = get_dh_obj()
    alice_public_key = alice_obj.gen_public_key()
    print("Type of alice_public_key: ", type(alice_public_key))
    bob_public_key = bob_obj.gen_public_key()
    alice = DoubleRatchet(alice_obj, bob_public_key)
    bob = DoubleRatchet(bob_obj, alice_public_key)
    alice.init_session(root_key)
    bob.init_session(root_key)
    alice_send = alice.send("yo")
    alice_send_enc = alice_send[0]
    alice_send_pub = alice_send[1]
    bob_recv = bob.recv("yo", alice_send_pub)
    
    bob_recv_dec = bob_recv
    print("Alice send enc: ", b64enc(alice_send_enc))
    print("Alice send pub: ", alice_send_pub)
    print("Bob recv dec: ", b64enc(bob_recv_dec))

    bob_send = bob.send("yo")
    bob_send_enc = bob_send[0]
    bob_send_pub = bob_send[1]
    alice_recv = alice.recv(bob_send, bob_send_pub)
    alice_recv_dec = alice_recv
    print("Bob send enc: ", b64enc(bob_send_enc))
    print("Bob send pub: ", bob_send_pub)
    print("Alice recv dec: ", b64enc(alice_recv_dec))

    bob_send = bob.send("yo")
    bob_send_enc = bob_send[0]
    bob_send_pub = bob_send[1]
    alice_recv = alice.recv(bob_send, bob_send_pub)
    alice_recv_dec = alice_recv
    print("Bob send enc: ", b64enc(bob_send_enc))
    print("Bob send pub: ", bob_send_pub)
    print("Alice recv dec: ", b64enc(alice_recv_dec))
    
    alice_obj = get_dh_obj()
    bob_obj = get_dh_obj()
    alice_public_key = alice_obj.gen_public_key()
    bob_public_key = bob_obj.gen_public_key()
    alice_shared = alice_obj.gen_shared_key(bob_public_key)
    bob_shared = bob_obj.gen_shared_key(alice_public_key)

    root_key = os.urandom(32)

    kdf = hkdf.Hkdf(root_key, b"DoubleRatchet", hashlib.sha256)
    alice_recv_key = kdf.expand(str_to_bytes(alice_shared), 32)

    print("alice_recv_key: ", b64enc(alice_recv_key))

    # kdf2 = hkdf.Hkdf(root_key, b"DoubleRatchet", hashlib.sha256)
    # bob_send_key = kdf2.expand(str_to_bytes(bob_shared), 32)

    # print("bob_send_key: ", b64enc(bob_send_key))

    # new = kdf2.expand(bob_send_key, 32)
    # print("new: ", b64enc(new))

    # assert alice_recv_key == bob_send_key
    

if __name__ == "__main__":
    main()
    

Type of alice_public_key:  <class 'int'>
Alice send enc:  ZxPm6PKC4uNNTO1aEbylRsD3ebJoWCe3aCEBTI6NbyM=
Alice send pub:  759829537399857337082628991244613382442231488805646333041603530464074264997881642182972233628605194309198273844612269833836446194657376378811929174635474111849678011964165040992536182387716843552936204132626514807477392911498520469545472300988171612915783205627294025887218297407475329889388214000650413339711790202564612670784766668357927974423210623318489122391480262780249900222418713482731971581795320692890746795271317528926869447008835883165335226824605638660170692969317968329524722482227022662968906702307852334965724437153254410527209161915531390329007806798413537032674019783551866598218091831892839748299
Bob recv dec:  ZxPm6PKC4uNNTO1aEbylRsD3ebJoWCe3aCEBTI6NbyM=
Bob send enc:  7Ts4Y9dW1hHe10CWuOBXBI4CKUGq5zh2Ev+o1yJf1hM=
Bob send pub:  276727632125382789666903181604098108112232990187315406626382785822658847521585799895974460569101624468995267586228099519753554246