In [1]:
from Crypto.Cipher import AES
import os
from hashlib import sha256
import hmac
from tqdm import tqdm

In [2]:
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.backends import default_backend

In [10]:
class AuthenticEncryptor:
    def __init__(self, is_encrypt, k_e, k_i):
        self.is_encrypt = is_encrypt
        self.k_e = k_e
        self.k_i = k_i
        self.aes_block_size = 16
        self.tag_size = 32
        self.backend = default_backend()
        
        
    def add_block(self, block, is_final=False):            
        if self.is_encrypt:
            update_part = self.aes_ctr.update(block)
            self.tag.update(update_part)
            
            if is_final:
                update_part += self.aes_ctr.finalize() + self.tag.digest()
            return update_part 
        
        else:
            plaintext = self.aes_ctr.update(block)
            if is_final:
                plaintext += self.aes_ctr.finalize()
            return plaintext
    
    
    def process_data(self, data):
        if self.is_encrypt:
            iv = os.urandom(self.aes_block_size)
            self.tag = hmac.new(self.k_i, iv, digestmod=sha256)
            self.ciphertext = iv
            self.aes_ctr = Cipher(algorithms.AES(self.k_e), modes.CTR(iv), 
                                  backend=self.backend).encryptor()
            
            for i in range(0, len(data), self.aes_block_size):
                self.ciphertext += self.add_block(data[i:i+self.aes_block_size], i + self.aes_block_size >= len(data))
    
            return self.ciphertext
        
        else:
            tag = data[-self.tag_size:]
            data = data[:-self.tag_size]
            if tag != hmac.new(self.k_i, data, sha256).digest():
                raise ValueError('Unsuccessful MAC check')
                
            self.aes_ctr = Cipher(algorithms.AES(self.k_e), modes.CTR(data[:self.aes_block_size]), 
                                  backend=self.backend).decryptor()
            
            data = data[self.aes_block_size:]
            self.plaintext = bytearray()
            for i in range(0, len(data), self.aes_block_size):
                self.plaintext += self.add_block(data[i:i+self.aes_block_size], i + self.aes_block_size >= len(data))
            
            return self.plaintext

In [11]:
k_e = os.urandom(16)
k_i = os.urandom(32)

In [12]:
aut_e = AuthenticEncryptor(True, k_e, k_i)

In [13]:
# 100 MB
data = os.urandom(10 ** 8)

In [14]:
cipher = aut_e.process_data(data)

In [15]:
aut_d = AuthenticEncryptor(False, k_e, k_i)

In [16]:
assert aut_d.process_data(cipher) == data