In [7]:
import struct
import binascii
from typing import List

class F2_32:
    def __init__(self, val: int):
        assert isinstance(val, int)
        self.val = val
        
    def __add__(self, other):
        return F2_32((self.val + other.val) & 0xffffffff)
    
    def __xor__(self, other):
        return F2_32(self.val ^ other.val)
    
    def __lshift__(self, nbit: int):
        left  = (self.val << nbit%32) & 0xffffffff
        right = (self.val & 0xffffffff) >> (32-(nbit%32))
        return F2_32(left | right)
    
    def __repr__(self):
        return hex(self.val)
    
    def __int__(self):
        return int(self.val)

def quarter_round(a: F2_32, b: F2_32, c: F2_32, d: F2_32):
    a += b 
    d ^= a
    d <<= 16
    c += d 
    b ^= c
    b <<= 12
    a += b 
    d ^= a 
    d <<= 8
    c += d 
    b ^= c 
    b <<= 7
    return a, b, c, d

def Qround(state: List[F2_32], idx1, idx2, idx3, idx4):
    state[idx1], state[idx2], state[idx3], state[idx4] = quarter_round(state[idx1], state[idx2], state[idx3], state[idx4])

def inner_block(state: List[F2_32]):
    Qround(state, 0, 4, 8, 12)
    Qround(state, 1, 5, 9, 13)
    Qround(state, 2, 6, 10, 14)
    Qround(state, 3, 7, 11, 15)
    Qround(state, 0, 5, 10, 15)
    Qround(state, 1, 6, 11, 12)
    Qround(state, 2, 7, 8, 13)
    Qround(state, 3, 4, 9, 14)
    return state

def serialize(state: List[F2_32]) -> List[bytes]:
    return b''.join([ struct.pack('<I', int(s)) for s in state ])

def xor(x: bytes, y: bytes):
    return bytes(a ^ b for a, b in zip(x, y))


In [8]:
import math
import binascii

def clamp(r: int) -> int:
    return r & 0x0ffffffc0ffffffc0ffffffc0fffffff

def le_bytes_to_num(byte) -> int:
    res = 0
    for i in range(len(byte) - 1, -1, -1):
        res <<= 8
        res += byte[i]
    return res

def num_to_16_le_bytes(num: int) -> bytes:
    res = []
    for i in range(16):
        res.append(num & 0xff)
        num >>= 8
    return bytearray(res)


In [9]:
def chacha20_block(key: bytes, counter: int, nonce: bytes) -> bytes:
    constants = [F2_32(x) for x in struct.unpack('<IIII', b'expand 32-byte k')]
    key       = [F2_32(x) for x in struct.unpack('<IIIIIIII', key)]
    counter   = [F2_32(counter)]
    nonce     = [F2_32(x) for x in struct.unpack('<III', nonce)]
    state = constants + key + counter + nonce
    initial_state = state[:]
    for i in range(10):
        state = inner_block(state)
    state = [ s + init_s for s, init_s in zip(state, initial_state) ]
    return serialize(state)

def chacha20_encrypt(key: bytes, counter: int, nonce: bytes, plaintext: bytes):
    encrypted_message = bytearray(0)

    for j in range(len(plaintext) // 64):
        key_stream = chacha20_block(key, counter + j, nonce)
        block = plaintext[j*64 : (j+1)*64]
        encrypted_message += xor(block, key_stream)

    if len(plaintext) % 64 != 0:
        j = len(plaintext) // 64
        key_stream = chacha20_block(key, counter + j, nonce)
        block = plaintext[j*64 : ]
        encrypted_message += xor(block, key_stream)

    return encrypted_message

def poly1305_mac(msg: bytes, key: bytes) -> bytes:
    r = le_bytes_to_num(key[0:16])
    r = clamp(r)
    s = le_bytes_to_num(key[16:32])
    a = 0  # a is the accumulator
    p = (1<<130) - 5
    for i in range(1, math.ceil(len(msg)/16) + 1):
        n = le_bytes_to_num(msg[(i-1)*16 : i*16] + b'\x01')
        a += n
        a = (r * a) % p
    a += s
    return num_to_16_le_bytes(a)



In [10]:
import struct
import binascii

def poly1305_key_gen(key: bytes, nonce: bytes) -> bytes:
    counter = 0
    block = chacha20_block(key, counter, nonce)
    return block[0:32]

def pad16(x: bytes) -> bytes:
    if len(x) % 16 == 0: return b''
    return b'\x00' * (16 - (len(x) % 16))

def num_to_8_le_bytes(num: int) -> bytes:
    return struct.pack('<Q', num)

def chacha20_aead_encrypt(aad: bytes, key: bytes, nonce: bytes, plaintext: bytes):
    otk = poly1305_key_gen(key, nonce)
    ciphertext = chacha20_encrypt(key, 1, nonce, plaintext)
    mac_data = aad + pad16(aad)
    mac_data += ciphertext + pad16(ciphertext)
    mac_data += num_to_8_le_bytes(len(aad))
    mac_data += num_to_8_le_bytes(len(ciphertext))
    tag = poly1305_mac(mac_data, otk)
    return (ciphertext, tag)

def chacha20_aead_decrypt(aad: bytes, key: bytes, nonce: bytes, ciphertext: bytes):
    otk = poly1305_key_gen(key, nonce)
    plaintext = chacha20_encrypt(key, 1, nonce, ciphertext)
    mac_data = aad + pad16(aad)
    mac_data += ciphertext + pad16(ciphertext)
    mac_data += num_to_8_le_bytes(len(aad))
    mac_data += num_to_8_le_bytes(len(ciphertext))
    tag = poly1305_mac(mac_data, otk)
    return (plaintext, tag)


def compare_const_time(a, b):
    if len(a) != len(b): return False
    result = 0
    for x, y in zip(a, b):
        result |= x ^ y
    return result == 0



In [11]:
def encrypt_and_tag(key, nonce, plaintext, aad):
    return chacha20_aead_encrypt(key=key, nonce=nonce, plaintext=plaintext, aad=aad)

def decrypt_and_verify(key, nonce, ciphertext, mac, aad):
    plaintext, tag = \
        chacha20_aead_decrypt(key=key, nonce=nonce, ciphertext=ciphertext, aad=aad)

    if not compare_const_time(tag, mac):
        return Exception('bad tag!')

    return plaintext

In [21]:
import time
import os, psutil

key = b"A" * 32
nonce = b"B" * 12
aad = b"C" * 12
plaintext = b'\xd9\x31\x32\x25\xf8\x84\x06\xe5' + \
                b'\xa5\x59\x09\xc5\xaf\xf5\x26\x9a' + \
                b'\x86\xa7\xa9\x53\x15\x34\xf7\xda' + \
                b'\x2e\x4c\x30\x3d\x8a\x31\x8a\x72' + \
                b'\x1c\x3c\x0c\x95\x95\x68\x09\x53' + \
                b'\x2f\xcf\x0e\x24\x49\xa6\xb5\x25' + \
                b'\xb1\x6a\xed\xf5\xaa\x0d\xe6\x57'
    
for i in range(1000):
    plaintext+=b'\xb1\x6a\xed\xf5\xaa\x0d\xe6\x57'
plaintext+=b'\xba\x63\x7b\x39'
    
start = time.time()
ciphertext, tag = encrypt_and_tag(key=key, nonce=nonce, plaintext=plaintext, aad=aad)
end = time.time()
print('Encryption time: ',(end-start)*1000)

start = time.time()
ciphertext = b's'
plaintext2 = decrypt_and_verify(key=key, nonce=nonce, ciphertext=ciphertext, aad=aad, mac=tag)
end = time.time()
print('Decryption time: ',(end-start)*1000)
process = psutil.Process(os.getpid())

print("Memory in MB:", process.memory_info().rss / 1024 ** 2)

Encryption time:  199.46622848510742
Decryption time:  1.8928050994873047
Memory in MB: 54.23828125


In [18]:
from Crypto.Cipher import AES
import scrypt, os, binascii
import time
import os, psutil

def encrypt_AES_GCM(msg, password):
    kdfSalt = os.urandom(16)
    secretKey = scrypt.hash(password, kdfSalt, N=16384, r=8, p=1, buflen=32)
    aesCipher = AES.new(secretKey, AES.MODE_GCM)
    ciphertext, authTag = aesCipher.encrypt_and_digest(msg)
    return (kdfSalt, ciphertext, aesCipher.nonce, authTag)

def decrypt_AES_GCM(encryptedMsg, password):
    (kdfSalt, ciphertext, nonce, authTag) = encryptedMsg
    secretKey = scrypt.hash(password, kdfSalt, N=16384, r=8, p=1, buflen=32)
    aesCipher = AES.new(secretKey, AES.MODE_GCM, nonce)
    plaintext = aesCipher.decrypt_and_verify(ciphertext, authTag)
    return plaintext

msg =   b'\xd9\x31\x32\x25\xf8\x84\x06\xe5' + \
        b'\xa5\x59\x09\xc5\xaf\xf5\x26\x9a' + \
        b'\x86\xa7\xa9\x53\x15\x34\xf7\xda' + \
        b'\x2e\x4c\x30\x3d\x8a\x31\x8a\x72' + \
        b'\x1c\x3c\x0c\x95\x95\x68\x09\x53' + \
        b'\x2f\xcf\x0e\x24\x49\xa6\xb5\x25' + \
        b'\xb1\x6a\xed\xf5\xaa\x0d\xe6\x57' 
    
for i in range(1000):
    msg+=b'\xb1\x6a\xed\xf5\xaa\x0d\xe6\x57'
msg+=b'\xba\x63\x7b\x39'
    
password = 'feffe9928665731c6d6a8f9467308308'
password = bytes.fromhex(password)

start = time.time()
encryptedMsg = encrypt_AES_GCM(msg, password)
end = time.time()
print('Encryption time: ',(end-start)*1000)

start = time.time()
decryptedMsg = decrypt_AES_GCM(encryptedMsg, password)
end = time.time()
print('Decryption time: ',(end-start)*1000)

process = psutil.Process(os.getpid())

print("Memory in MB:", process.memory_info().rss / 1024 ** 2)

Encryption time:  109.70497131347656
Decryption time:  58.83479118347168
Memory in MB: 54.19921875
