# Decrypting and Decompressing a Network Stream in One Go

In [None]:
%load_ext cython

In [None]:
import asyncio

In [None]:
def run_async(coro):
    loop = asyncio.get_event_loop()
    result = loop.run_until_complete(coro)
    return result

In [None]:
import os
import zlib
import hashlib
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.backends import default_backend
backend = default_backend()

async def py_unpack(stream, receive_callback, key, iv, tag=None, decompress=True):
    cipher = Cipher(algorithms.AES(key), modes.GCM(iv, tag), backend=backend)
    decryptor = cipher.decryptor()
    decrypt = decryptor.update

    if not decompress:
        # really simple case
        async for data in stream:
            receive_callback(decrypt(data))
        rest = decryptor.finalize()
    else:
        decompressor = zlib.decompressobj()
        decompress = decompressor.decompress

        async for data in stream:
            receive_callback(decompress(decrypt(data)))

        rest = decompress(decryptor.finalize())
        if rest:
            receive_callback(rest)
        rest = decompressor.flush()

    if rest:
        receive_callback(rest)

In [None]:
%%cython -3 -a
## COPY PYTHON CODE FROM ABOVE

In [None]:
# distutils: library_dirs=['/usr/local/Cellar/openssl/1.0.2h/lib']
# distutils: include_dirs=['/usr/local/Cellar/openssl/1.0.2h/include']

In [None]:
%%cython
# cython: c_string_encoding=ascii
# cython: language_level=3
# cython: profile=True
# cython: binding=True
# distutils: libraries=['ssl', 'crypto', 'z']
# distutils: library_dirs=['/usr/local/Cellar/openssl/1.0.2h/lib']
# distutils: include_dirs=['/usr/local/Cellar/openssl/1.0.2h/include']

DEF CHUNK_SIZE = 2048
DEF COMP_BUFFER_SIZE = CHUNK_SIZE * 10

from libc.string cimport memset
from cpython.mem cimport PyMem_Malloc, PyMem_Free


### openssl declarations

cdef extern from "openssl/evp.h" nogil:
    ctypedef struct EVP_CIPHER_CTX
    ctypedef struct EVP_CIPHER
    ctypedef struct ENGINE
    enum:
        EVP_CTRL_GCM_SET_TAG
        EVP_CTRL_GCM_SET_IVLEN

    void OpenSSL_add_all_algorithms()

    const EVP_CIPHER* EVP_aes_256_gcm()
    EVP_CIPHER_CTX* EVP_CIPHER_CTX_new()
    void EVP_CIPHER_CTX_free(EVP_CIPHER_CTX *a)
    int EVP_CIPHER_CTX_ctrl(EVP_CIPHER_CTX *ctx, int type, int arg, void *ptr)

    int EVP_DecryptInit_ex(EVP_CIPHER_CTX *ctx, const EVP_CIPHER *cipher, ENGINE *impl,
        const unsigned char *key, const unsigned char *iv)
    int EVP_DecryptUpdate(EVP_CIPHER_CTX *ctx, unsigned char *out,
        int *outl, const unsigned char *inptr, int inl)
    int EVP_DecryptFinal(EVP_CIPHER_CTX *ctx, unsigned char *outm, int *outl)
    int EVP_DecryptFinal_ex(EVP_CIPHER_CTX *ctx, unsigned char *outm, int *outl)


### zlib declarations

cdef extern from "zlib.h":
    ctypedef struct z_stream:
        unsigned char    *next_in;  #/* next input byte */
        unsigned int     avail_in;  #/* number of bytes available at next_in */
        unsigned long    total_in;  #/* total number of input bytes read so far */

        unsigned char    *next_out; #/* next output byte should be put there */
        unsigned int     avail_out; #/* remaining free space at next_out */
        unsigned long    total_out; #/* total number of bytes output so far */

        const char *msg;    #/* last error message, NULL if no error */

        int     data_type;  #/* best guess about the data type: binary or text */
        unsigned long   adler;      #/* adler32 value of the uncompressed data */

    enum:
        Z_SYNC_FLUSH
        Z_NO_FLUSH
        Z_OK
        Z_STREAM_END
    int inflateInit(z_stream* strm)
    int inflate(z_stream* strm, int flush)
    int inflateEnd(z_stream* strm)


### decryption setup helper

cdef int init_decryption(EVP_CIPHER_CTX **_ctx,
                         bytes key, bytes iv, bytes aad=None) except -1:
    cdef int offset = 0

    # Create and initialise the context
    cdef EVP_CIPHER_CTX *ctx = EVP_CIPHER_CTX_new()
    _ctx[0] = ctx
    if not ctx:
        raise MemoryError()

    try:
        # Initialise the decryption operation.
        if not EVP_DecryptInit_ex(ctx, EVP_aes_256_gcm(), NULL, NULL, NULL):
            raise RuntimeError()
        # Set IV length. Not necessary if this is 12 bytes (96 bits)
        if not EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_SET_IVLEN, 16, NULL):
            raise RuntimeError()
        # Initialise key and IV
        if not EVP_DecryptInit_ex(ctx, NULL, NULL, key, iv):
            raise RuntimeError()
        # Provide any AAD data. This can be called zero or more times as required
        if aad:
            if not EVP_DecryptUpdate(ctx, NULL, &offset, aad, len(aad)):
                raise RuntimeError()
    except:
        EVP_CIPHER_CTX_free(ctx)
        _ctx[0] = NULL
        raise

    assert offset >= 0
    return offset


### decompression helper

cdef int decompress_chunk(z_stream *zstream, unsigned char *in_buffer,
                          unsigned char* out_buffer, receive_callback,
                          int in_length, int out_length) except -1:
    zstream.next_in = in_buffer
    zstream.avail_in = in_length

    zstream.avail_out = 0
    while zstream.avail_out == 0:
        zstream.next_out = out_buffer
        zstream.avail_out = out_length
        ret = inflate(zstream, Z_SYNC_FLUSH)
        if ret != Z_OK and ret != Z_STREAM_END:
            raise RuntimeError("Decompression failed with error %s: %s" % (
                ret, (zstream.msg.decode('utf8') if zstream.msg is not NULL else 'unknown error')))
        # pass decompressed data on to receiver
        receive_callback(out_buffer[:out_length - zstream.avail_out])


# global OpenSSL init at module load time
OpenSSL_add_all_algorithms()


### coroutine to decrypt + decompress

async def cy_unpack(stream, receive_callback, bytes key, bytes iv, bytes tag=None, bint decompress=True):
    cdef EVP_CIPHER_CTX *ctx = NULL;
    cdef int length, plaintext_len = 0;
    cdef Py_ssize_t data_length, n
    cdef bytes data
    cdef z_stream zstream
    cdef unsigned char *c_data
    cdef unsigned char *mem_buffers = NULL
    cdef unsigned char *aes_buffer
    cdef unsigned char *zlib_buffer

    memset(&zstream, 0, sizeof(zstream))
    length = init_decryption(&ctx, key, iv)
    try:
        mem_buffers = <unsigned char*>PyMem_Malloc(COMP_BUFFER_SIZE + CHUNK_SIZE + 32)
        if not mem_buffers:
            raise MemoryError()
        zlib_buffer = mem_buffers
        aes_buffer = mem_buffers + COMP_BUFFER_SIZE

        assert length == 0  # not using AAD
        if decompress and inflateInit(&zstream) != Z_OK:
            raise MemoryError()

        # process stream one chunk at a time
        async for data in stream:

            # decrypt data in CHUNK_SIZE bytes chunks
            data_length = len(data)
            c_data = data

            while data_length > 0:
                length = 0
                # decrypt
                if not EVP_DecryptUpdate(
                        ctx, aes_buffer, &length, c_data,
                        CHUNK_SIZE if data_length > CHUNK_SIZE else data_length):
                    raise RuntimeError("Decryption failed: chunk")
                c_data += CHUNK_SIZE
                data_length -= CHUNK_SIZE
                if decompress:
                    decompress_chunk(&zstream, aes_buffer, zlib_buffer, receive_callback, length, COMP_BUFFER_SIZE)
                else:
                    receive_callback(aes_buffer[:length])

        if tag:
            if not EVP_CIPHER_CTX_ctrl(
                    ctx, EVP_CTRL_GCM_SET_TAG, 16, <char*>tag):
                raise RuntimeError("Decryption failed: tag")
        length = 0
        ret = EVP_DecryptFinal_ex(ctx, aes_buffer, &length)
        if ret < 0:
            raise RuntimeError("Decryption failed: final")

        if length > 0:
            if decompress:
                # decompress final chunk
                decompress_chunk(
                    &zstream, aes_buffer, zlib_buffer, receive_callback, length, COMP_BUFFER_SIZE)
            else:
                receive_callback(aes_buffer[:length])

    finally:
        if decompress:
            inflateEnd(&zstream)
        EVP_CIPHER_CTX_free(ctx)
        PyMem_Free(mem_buffers)


In [None]:
# generate encrypted data
import os
import zlib
import hashlib
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.backends import default_backend
backend = default_backend()

key = os.urandom(32)
iv = os.urandom(16)
cipher = Cipher(algorithms.AES(key), modes.GCM(iv), backend=backend)

message = b''.join(os.urandom(17) * 11 for _ in range(56789))

encryptor = cipher.encryptor()
ct = encryptor.update(zlib.compress(message, 9)) + encryptor.finalize()
c_enc_tag = encryptor.tag

encryptor = cipher.encryptor()
et = encryptor.update(message) + encryptor.finalize()
e_enc_tag = encryptor.tag

print("Input message length:", len(message))
print("Input message MD5sum:", hashlib.md5(message).hexdigest())
print("Crypto message length:", len(et))
print("Compressed crypto message length:", len(ct))

In [None]:
# statistics
import hashlib
from collections import Counter

def chunk_counter():
    sizes = Counter()
    md5 = hashlib.md5()
    update_hash = md5.update

    def count_chunk(chunk):
        sizes[len(chunk)] += 1
        update_hash(chunk)

    def report():
        size_sum = sum(k*v for k, v in sizes.items())
        chunk_count = sum(sizes.values())
        print("Output message chunks:", chunk_count)
        print("Output message length:", size_sum)
        print("Average size per chunk:", size_sum // chunk_count)
        print("Chunk size distribution:", ', '.join('%d x %d B' % t[::-1] for t in sizes.most_common(6)))
        print("Output message MD5sum:", md5.hexdigest())

    return count_chunk, report

In [None]:
def as_chunks(data, chunk_size=1024):
    for pos in range(0, len(data), chunk_size):
        yield data[pos: pos+chunk_size]

class DataIter:
    def __init__(self, chunks):
        self._data = iter(chunks)

    def __aiter__(self):
        # NOTE: used to be "async def" in Py3.5.0, changed in Python 3.5.2 / Cython 0.24.1
        return self

    async def __anext__(self):
        try:
            return next(self._data)
        except StopIteration:
            raise StopAsyncIteration

In [None]:
c_chunks = list(as_chunks(ct, chunk_size=512))
print("Compressed input chunks:", len(c_chunks))
e_chunks = list(as_chunks(et, chunk_size=1024))
print("Uncompressed input chunks:", len(e_chunks))

In [None]:
print("Compressed input length:", len(ct))
count_chunk, report = chunk_counter()
run_async(cy_unpack(DataIter(c_chunks), count_chunk, key, iv, c_enc_tag, decompress=True))
report()

In [None]:
print("Compressed input length:", len(ct))
count_chunk, finalize = chunk_counter()
run_async(py_unpack(DataIter(c_chunks), count_chunk, key, iv, c_enc_tag, decompress=True))
report()

In [None]:
print("Compressed input length:", len(et))
count_chunk, report = chunk_counter()
run_async(cy_unpack(DataIter(e_chunks), count_chunk, key, iv, e_enc_tag, decompress=False))
report()

In [None]:
print("Compressed input length:", len(et))
count_chunk, report = chunk_counter()
run_async(py_unpack(DataIter(e_chunks), count_chunk, key, iv, e_enc_tag, decompress=False))
report()

In [None]:
%timeit run_async(py_unpack(DataIter(e_chunks), len, key, iv, e_enc_tag, decompress=False))

In [None]:
%timeit run_async(cy_unpack(DataIter(e_chunks), len, key, iv, e_enc_tag, decompress=False))

In [None]:
19.6 / 66.5

In [None]:
%timeit run_async(py_unpack(DataIter(c_chunks), len, key, iv, c_enc_tag, decompress=True))

In [None]:
%timeit run_async(cy_unpack(DataIter(c_chunks), len, key, iv, c_enc_tag, decompress=True))

In [None]:
16.2/28.1

In [None]:
%%prun -s time
run_async(cy_unpack(DataIter(e_chunks), len, key, iv, e_enc_tag, decompress=False))

In [None]:
%%prun -s time
run_async(py_unpack(DataIter(e_chunks), len, key, iv, e_enc_tag, decompress=False))