In [11]:
from cryptography.hazmat.primitives.ciphers import (
        Cipher, algorithms, modes
    )
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
from cryptography.exceptions import InvalidTag
from Crypto.Cipher import AES
from Crypto.Random import get_random_bytes
from Crypto.Util.number import long_to_bytes, bytes_to_long
from bitstring import BitArray, Bits
import binascii
import sys

import time
import base64

ALL_ZEROS = b'\x00'*16
GCM_BITS_PER_BLOCK = 128


def check_correctness(keyset, nonce, ct):
    flag = True

    for i in range(len(keyset)):
        aesgcm = AESGCM(keyset[i])
        try:
            aesgcm.decrypt(nonce, ct, None)
        except InvalidTag:
            print('key %s failed' % i)
            flag = False

    if flag:
        print("All %s keys decrypted the ciphertext" % len(keyset))



def pad(a):
    if len(a) < GCM_BITS_PER_BLOCK:
        diff = GCM_BITS_PER_BLOCK - len(a)
        zeros = ['0'] * diff
        a = a + zeros
    return a



def bytes_to_element(val, field, a):
    bits = BitArray(val)
    result = field._cache.fetch_int(0)
    for i in range(len(bits)):
        if bits[i]:
            result += a^i
    return result



def multi_collide_gcm(keyset, nonce, tag, first_block=None, use_magma=False):

    # initialize matrix and vector spaces
    P.<x> = PolynomialRing(GF(2))
    p = x^128 + x^7 + x^2 + x + 1
    GFghash.<a> = GF(2^128,'x',modulus=p)
    if use_magma:
        t = "p:=IrreducibleLowTermGF2Polynomial(128); GFghash<a> := ext<GF(2) | p>;"
        magma.eval(t)
    else:
        R = PolynomialRing(GFghash, 'x')

    # encode length as lens
    if first_block is not None:
        ctbitlen = (len(keyset) + 1) * GCM_BITS_PER_BLOCK
    else:
        ctbitlen = len(keyset) * GCM_BITS_PER_BLOCK
    adbitlen = 0
    lens = (adbitlen << 64) | ctbitlen
    lens_byte = int(lens).to_bytes(16,byteorder='big')
    lens_bf = bytes_to_element(lens_byte, GFghash, a)

    # increment nonce
    nonce_plus = int((int.from_bytes(nonce,'big') << 32) | 1).to_bytes(16,'big')

    # encode fixed ciphertext block and tag
    if first_block is not None:
        block_bf = bytes_to_element(first_block, GFghash, a)
    tag_bf = bytes_to_element(tag, GFghash, a)
    keyset_len = len(keyset)

    if use_magma:
        I = []
        V = []
    else:
        pairs = []

    for k in keyset:
        # compute H
        aes = AES.new(k, AES.MODE_ECB)
        H = aes.encrypt(ALL_ZEROS)
        h_bf = bytes_to_element(H, GFghash, a)

        # compute P
        P = aes.encrypt(nonce_plus)
        p_bf = bytes_to_element(P, GFghash, a)

        if first_block is not None:
            # assign (lens * H) + P + T + (C1 * H^{k+2}) to b
            b = (lens_bf * h_bf) + p_bf + tag_bf + (block_bf * h_bf^(keyset_len+2))
        else:
            # assign (lens * H) + P + T to b
            b = (lens_bf * h_bf) + p_bf + tag_bf

        # get pair (H, b*(H^-2))
        y =  b * h_bf^-2
        if use_magma:
            I.append(h_bf)
            V.append(y)
        else:
            pairs.append((h_bf, y))

    # compute Lagrange interpolation
    if use_magma:
        f = magma("Interpolation(%s,%s)" % (I,V)).sage()
    else:
        f = R.lagrange_polynomial(pairs)
    coeffs = f.list()
    coeffs.reverse()

    # get ciphertext
    if first_block is not None:
        ct = list(map(str, block_bf.polynomial().list()))
        ct_pad = pad(ct)
        ct = Bits(bin=''.join(ct_pad))
    else:
        ct = ''
    
    for i in range(len(coeffs)):
        ct_i = list(map(str, coeffs[i].polynomial().list()))
        ct_pad = pad(ct_i)
        ct_i = Bits(bin=''.join(ct_pad))
        ct += ct_i
    ct = ct.bytes
    
    return ct+tag



if __name__ == '__main__':
    print("Search starting")
    startTime = time.time()
    
    UPPERCASE = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
    keyLst = [f"{a+b+c}".encode('utf-8').ljust(16, b'\0') for a in UPPERCASE for b in UPPERCASE for c in UPPERCASE]

    chunkSize = 1100 # 16 chunks of 1100 each, log2(1100)~11 binary search guesses < 30 guess cap
    keysets = [keyLst[i:i+chunkSize] for i in range(0, len(keyLst), chunkSize)]

    first_block = b'\x01'
    nonce = b'\x01'*12
    tag = b'\x00'*16

    f = open('ct.txt', 'w')
    ctLst = []
    
    for i, keyset in enumerate(keysets):
        blockTime = time.time()
        ct = multi_collide_gcm(keyset, nonce, tag, first_block=first_block)
        check_correctness(keyset, nonce, ct)
        ct = base64.b64encode(ct)
        ctLst.append(ct)
        print(f"Block {i} found in {time.time() - blockTime} seconds")

    f.write('\n'.join([str(ct) for ct in ctLst]))
    f.close()
    
    print(f"Code executed in {time.time() - startTime} seconds")

Search starting
All 1100 keys decrypted the ciphertext
Block 0 found in 49.691373348236084 seconds
All 1100 keys decrypted the ciphertext
Block 1 found in 47.659080505371094 seconds
All 1100 keys decrypted the ciphertext
Block 2 found in 48.33413743972778 seconds
All 1100 keys decrypted the ciphertext
Block 3 found in 49.359739542007446 seconds
All 1100 keys decrypted the ciphertext
Block 4 found in 44.085745334625244 seconds
All 1100 keys decrypted the ciphertext
Block 5 found in 48.49023175239563 seconds
All 1100 keys decrypted the ciphertext
Block 6 found in 46.744627475738525 seconds
All 1100 keys decrypted the ciphertext
Block 7 found in 48.1559100151062 seconds
All 1100 keys decrypted the ciphertext
Block 8 found in 45.3368866443634 seconds
All 1100 keys decrypted the ciphertext
Block 9 found in 44.80050539970398 seconds
All 1100 keys decrypted the ciphertext
Block 10 found in 46.19463062286377 seconds
All 1100 keys decrypted the ciphertext
Block 11 found in 48.20659422874451 sec

In [13]:
# Get the strings to copy in
finalStr = ''
for i, ct in enumerate(ctLst):
    finalStr += f'{i} {base64.b64encode(nonce).decode('utf-8')},{ct.decode('utf-8')}\n\n\n\n\n\n\n\n\n\n'

f = open('ctFormat.txt', 'w')
f.write(finalStr)
f.close()

In [16]:
# binary search
n = 4 # replace with the keyset for which the ct decrypted correctly under
newKeyset = keysets[n]

# Split into half and generate ct that works for either block
newKeysets = [newKeyset[:len(newKeyset)//2], newKeyset[len(newKeyset)//2:]]

print("Search starting")
for i, keyset in enumerate(newKeysets):
    blockTime = time.time()
    ct = multi_collide_gcm(keyset, nonce, tag, first_block=first_block)
    check_correctness(keyset, nonce, ct)
    ct = base64.b64encode(ct)
    print(f'{i} {base64.b64encode(nonce).decode('utf-8')},{ct.decode('utf-8')}\n\n\n\n\n\n\n\n\n\n')
    

Search starting
All 550 keys decrypted the ciphertext
0 AQEBAQEBAQEBAQEB,AQAAAAAAAAAAAAAAAAAAAAG/rRp0iiQHdKMg7dOzJn+L+2jvVv8jJ0ox3Q06px3ga3Hhcuj9cYt4LOSFQTBh5fx7fFwfed5sQFxlhtJFkvAph6UUhZ9HXVkWjRamYGksL89hyoMek348fdLSzaEbqtndNK2soBzd3cXVKDUwRLgJKqmdg/1Z3iLMYWoUMIX01jZTKyGRbB5Y11Tx9AtPoijftLePHzjYHbriE4d6ui838o+Edwvq8MAobfxf1bZWQp5b1QxbPzjQPeP5O5jYOKH8ZDpwqFrEJWKJHy3cW4YJctiQ/3r+FexLHy/g+6XZh51eGEauXcFJqPVMf0HQm9GBb8NBSCzGFBk+GaJLSrvkSY2CHX5vSxQtBftd2OEIhI11kO31ONvlhSUJC5RlHPUKzydHs+/aWeZHFF8f8wBINS5tjx9uB2Jdi9zzKuxB5y5OuzNpYUMJoiE61mwo1zJik4zJ2VO+wT69S14EwpUeEcSXI7poobRGyN/QGbe2l23AyWQi6MrT6dj0zpD3DGT2f42AwIWsl7QQ2kotLczs7EkpQcqcylzt1aDLgqAlPxLznBR6Y3oj3f9ymiv0qZFavKvwSi27uWUBac7yEEMC7rno3VfhlE3qrZnUHv/6nV5QT0h3mPSCVxoFgDCmz2UN+W/9yLk9WShHUD8wav+FNbEdrnGmgSFwfwa0M/TKI8cpbDQXNmrElXr2eEC797tZCoR+0SEm1TAVPpY4BcNuGSKBt+jVasiFGFf+zwH8HJ0s2nF+TNGdzOxZp/zNV1rgLCe3zmUxOr4ZXdlOzmhMNQ8f981k3HkMC+6h/jyx9gsI96Ka+z4h+OvvNkoOjMHEbSaWBVfoDQqxHK6r/mThgDvNw0fCvtamOb+zKsQsJyBOehXT6FE2MjCV0m4AfgIjzaS9a1z

In [25]:
# binary search (with new keyset)
n = 1 # replace with the keyset for which the ct decrypted correctly under
newKeyset = newKeysets[n]

# Split into half and generate ct that works for either block
newKeysets = [newKeyset[:len(newKeyset)//2], newKeyset[len(newKeyset)//2:]]
print(newKeysets) # print it only at the last guess so you don't get spam output

print("Search starting")
for i, keyset in enumerate(newKeysets):
    blockTime = time.time()
    ct = multi_collide_gcm(keyset, nonce, tag, first_block=first_block)
    check_correctness(keyset, nonce, ct)
    ct = base64.b64encode(ct)
    print(f'{i} {base64.b64encode(nonce).decode('utf-8')},{ct.decode('utf-8')}\n\n\n\n\n\n\n\n\n\n')
    

[[b'HCD\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00'], [b'HCE\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00']]
Search starting
All 1 keys decrypted the ciphertext
0 AQEBAQEBAQEBAQEB,AQAAAAAAAAAAAAAAAAAAAGHKK/B09XQvK209XaTWSz0AAAAAAAAAAAAAAAAAAAAA










All 1 keys decrypted the ciphertext
1 AQEBAQEBAQEBAQEB,AQAAAAAAAAAAAAAAAAAAADBN5hYoe5chgjucAiMYKUcAAAAAAAAAAAAAAAAAAAAA












In [2]:
pip install cryptography
pip install pycryptodome
pip install bitstring
pip install base64
pip install timeit

Collecting pycryptodome
  Downloading pycryptodome-3.22.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.4 kB)
Downloading pycryptodome-3.22.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.3/2.3 MB[0m [31m3.7 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hInstalling collected packages: pycryptodome
Successfully installed pycryptodome-3.22.0

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.2[0m[39;49m -> [0m[32;49m25.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.
