In [132]:
import base64

from cryptography.hazmat.primitives import serialization,hashes
from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey
from cryptography.hazmat.primitives.asymmetric.ed25519 import \
        Ed25519PublicKey, Ed25519PrivateKey
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from cryptography.hazmat.backends import default_backend


In [133]:
from Cryptodome.Cipher import AES


In [134]:
def b64(msg):
    #base64 encoding helper function
    return base64.encodebytes(msg).decode('utf-8').strip()

In [135]:
def hkdf(inp,len):
    #using hashed message authenication code(HMAC) based key derivation function(KDF) to obtain the key from an input
    hkdf = HKDF(algorithm=hashes.SHA256(),length=len,salt=b'',info=b'',backend=default_backend())
    return hkdf.derive(inp)

In [136]:
class SymmRatchet(object):
    def __init__(self,key):
        self.state = key

    def next(self,inp=b''):
        output = hkdf(self.state+inp,80)
        self.state = output[:32]
        outkey,iv=output[32:64],output[64:]
        return outkey,iv

In [137]:
class Bob(object):
    def __init__(self):
        self.IKb = X25519PrivateKey.generate()
        self.SPKb = X25519PrivateKey.generate()
        self.OPKb = X25519PrivateKey.generate()
        self.dhratchet = X25519PrivateKey.generate()

    def x3dh(self, alice):
        dh1 = self.SPKb.exchange(alice.IKa.public_key())
        dh2 = self.IKb.exchange(alice.EKa.public_key())
        dh3 = self.OPKb.exchange(alice.EKa.public_key())
        dh4 = self.SPKb.exchange(alice.EKa.public_key())
        self.sk = hkdf(dh1 + dh2 + dh3 + dh4, 32)
        print('[Bob] shared key:', b64(self.sk))

    def init_ratchet(self):
        self.root_ratchet = SymmRatchet(self.sk)
        self.recv_ratchet = SymmRatchet(self.root_ratchet.next()[0])
        self.send_ratchet = SymmRatchet(self.root_ratchet.next()[0])

    def dh_ratchet(self, alice_public):
        dh_recv = self.dhratchet.exchange(alice_public)
        shared_recv = self.root_ratchet.next(dh_recv)[0]
        self.recv_ratchet = SymmRatchet(shared_recv)
        print('[Bob] shared recv seed:', b64(shared_recv))

        self.dhratchet = X25519PrivateKey.generate()
        dh_send = self.dhratchet.exchange(alice_public)
        shared_send = self.root_ratchet.next(dh_send)[0]
        self.send_ratchet = SymmRatchet(shared_send)
        print('[Bob] shared send seed:', b64(shared_send))

    def send(self, alice, msg):
        key, iv = self.send_ratchet.next()
        cipher = AES.new(key, AES.MODE_CBC, iv).encrypt(pad(msg))
        print('[Bob] Sending cipher text to Alice:', b64(cipher))
        alice.recv(cipher, self.dhratchet.public_key())

    def recv(self, cipher, alice_public_key):
        self.dh_ratchet(alice_public_key)
        key, iv = self.recv_ratchet.next()
        msg = unpad(AES.new(key, AES.MODE_CBC, iv).decrypt(cipher))
        print('[Bob] Decrypted message:', msg)


In [138]:
class Alice(object):
    def __init__(self):
        self.IKa = X25519PrivateKey.generate()
        self.EKa = X25519PrivateKey.generate()
        self.dhratchet = None

    def x3dh(self, bob):
        dh1 = self.IKa.exchange(bob.SPKb.public_key())
        dh2 = self.EKa.exchange(bob.IKb.public_key())
        dh3 = self.EKa.exchange(bob.OPKb.public_key())
        dh4 = self.EKa.exchange(bob.SPKb.public_key())
        self.sk = hkdf(dh1 + dh2 + dh3 + dh4, 32)
        print('[Alice] shared key:', b64(self.sk))

    def init_ratchet(self):
        self.root_ratchet = SymmRatchet(self.sk)
        self.send_ratchet = SymmRatchet(self.root_ratchet.next()[0])
        self.recv_ratchet = SymmRatchet(self.root_ratchet.next()[0])

    def dh_ratchet(self, bob_public):
        if self.dhratchet is not None:
            dh_recv = self.dhratchet.exchange(bob_public)
            shared_recv = self.root_ratchet.next(dh_recv)[0]
            self.recv_ratchet = SymmRatchet(shared_recv)
            print('[Alice] shared recv seed:', b64(shared_recv))

        self.dhratchet = X25519PrivateKey.generate()
        dh_send = self.dhratchet.exchange(bob_public)
        shared_send = self.root_ratchet.next(dh_send)[0]
        self.send_ratchet = SymmRatchet(shared_send)
        print('[Alice] shared send seed:', b64(shared_send))

    def send(self, bob, msg):
        key, iv = self.send_ratchet.next()
        cipher = AES.new(key, AES.MODE_CBC, iv).encrypt(pad(msg))
        print('[Alice] Sending cipher text to Bob:', b64(cipher))
        bob.recv(cipher, self.dhratchet.public_key())

    def recv(self, cipher, bob_public_key):
        self.dh_ratchet(bob_public_key)
        key, iv = self.recv_ratchet.next()
        msg = unpad(AES.new(key, AES.MODE_CBC, iv).decrypt(cipher))
        print('[Alice] Decrypted message:', msg)

In [139]:
def pad(msg):
    num = (16 - (len(msg) % 16))
    return msg + bytes([num] * num)


In [140]:
def unpad(cipher):
    return cipher[:-cipher[-1]]

In [141]:
alice = Alice()
bob = Bob()

alice.x3dh(bob)
bob.x3dh(alice)

alice.init_ratchet()
bob.init_ratchet()

alice.dh_ratchet(bob.dhratchet.public_key())

[Alice] shared key: DOhPp2QWyRGya9dcvVqa2iMMAIah3eGz1vMPSzorlbg=
[Bob] shared key: DOhPp2QWyRGya9dcvVqa2iMMAIah3eGz1vMPSzorlbg=
[Alice] shared send seed: G0/T1hHCr5DSn95vnr8Pvd1L9KLrMdX62poWKhPm8Ps=


In [142]:
alice.send(bob,b"Bob!")

bob.send(alice,b"Hello Alice")

[Alice] Sending cipher text to Bob: WObd+LTGi4YgFTacINAlDA==
[Bob] shared recv seed: G0/T1hHCr5DSn95vnr8Pvd1L9KLrMdX62poWKhPm8Ps=
[Bob] shared send seed: ZlGFR46Uq7lQFk1NMdeQD1mijP5lb1VbDStCXliIyYk=
[Bob] Decrypted message: b'Hello Bob!'
[Bob] Sending cipher text to Alice: +rEUyKZ10YSWQCxMQyaO3upNNk/p3hCqiD2DceniAVk=
[Alice] shared recv seed: ZlGFR46Uq7lQFk1NMdeQD1mijP5lb1VbDStCXliIyYk=
[Alice] shared send seed: Mlm32NXi1hSAh7GZEhWXoC/OHbj48gLF/D9HWOrpnqQ=
[Alice] Decrypted message: b'Hello Alice!! How are you ?'
