In [297]:
from Crypto.Cipher import AES
from Crypto.Util.Padding import pad, unpad
from Crypto.Random import get_random_bytes
import base64
import random

In [298]:
def key_generator_AES(kp):
    """
    generer des clefs 16 bytes avec kp bytes significatifs
    """
    return get_random_bytes(kp)+bytes((16 - kp))

def bytes_to_bin(by):
    """
    convertir des bytes en bits 
    quand l'entier representant le byte a moins de 8 bits,
    on le bourre de 0 a gauche
    """
    l = list(by)
    res = '0b'
    for i in l:
        res += bin(i)[2:].zfill(8)
    return res

In [225]:
def double_AES(msg, key1, key2,  mode = 'MODE_ECB'):
    """
    faire le double chiffrement AES
    en passant les 2 clefs en parametres
    """
    enc = simple_enc_AES(msg, key1)
    enc = simple_enc_AES(enc, key2)
    return enc

    
def simple_enc_AES(msg, key, mode = 'MODE_ECB'):
    """
    chiffrer AES
    """
    if not (isinstance(msg,bytes)):
        msg = msg.encode()
        
    cipher = AES.new(key, AES.MODE_ECB)
    enc = cipher.encrypt(pad(msg, len(key)))
    return enc

def simple_dec_AES(msg, key, mode = 'MODE_ECB'):
    """
    dechiffrer AES
    """
    cipher = AES.new(key, AES.MODE_ECB)
    plaintext = cipher.decrypt(msg)
    
    # error de padding si clef incorrevt
    #return unpad(plaintext, len(key))  
    return plaintext

In [299]:
kp = 4
key1 = key_generator_AES(kp)
key2 = key_generator_AES(kp)
M1 = "Voici le message 1"
C1 = double_AES(M1,key1,key2)
M2 = "Voici le message 2"
C2 = double_AES(M2,key1,key2)

In [300]:
def trail(f, msg, kp, l):
    """
    retourne un triplet (x0, xd, d) 
    f : fonction chiffrement OU dechiffrement
    msg : message clair OU chiffré deux fois  
    kp :nb de byte significatif de la clef
    l : nb de bit a 0 (pour la condition d'arret)
    """
    x0 = key_generator_AES(kp)
    tmp = x0

    d = 0
    max_it = (20/ (1/(2**l))) // 3  # diviser par 3 car sinon ca prends 
                                    # tres longtemps
    mask_l = 2**l - 1  
    
    while True:
        c_temp = f(msg, tmp)
        
        if d == max_it:
            #print("Risque de cycle ")
            return None
        
        tmp_binary = bytes_to_bin(tmp[:kp])
        if int(tmp_binary,2) & mask_l == 0:  # condition d'arret
            xd = tmp
            break
            
        d += 1
        tmp = c_temp[:kp]+bytes((16 - kp))
        
    return (x0, xd, d)


In [301]:
kp = 4
l=5
print(trail(simple_enc_AES, M1, kp, l ))

(b'\x85\xc6,\xf4\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00', b'2[\xc1 \x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00', 1)


In [302]:
def F(b):
    """
    Choisir uniformement une fonction
    b : 0 ou 1
    """
    if b == 0:
        return simple_enc_AES
    if b == 1:
        return simple_dec_AES

In [303]:
def new_step(f, M, kp, x):
    """
    passer d'un xi au suivant
    """
    c = f(M, x)
    return c[:kp]+bytes((16 - kp))

In [272]:
def remonter (F, A, B, M, C, kp, b):
    """
    returne (x,y) tq x != y et f(x) == f(y)
    A, B : triplet (x0, xd, d)
    b : 0 ou 1
    """       
    couple = [M,C]
    if A[2] >= B[2]: 

        x = A[0]
        for _ in range(A[2]-B[2]):
            x =  new_step(F(b), couple[b], kp, x)
        y = B[0]

        if x == y : 
            print('pb : x==y et fhash(x)==fhash(y)')
            return None

        while True:
            if x == y :
                break
            tmp1 = x            
            tmp2 = y             # anciennes valeurs
            x =  new_step(F(b), couple[b], kp, tmp1)
            y = new_step(F(1-b), couple[1-b], kp, tmp2)

        #return ( (tmp1, F(b).__name__ ) , (tmp2, F(1-b).__name__ ))
        return (   (tmp1, F(b)) , (tmp2, F(1-b))    )
        
    else:    # A[2] < B[2] mais on fait la meme chose
        y = B[0]
        for _ in range(B[2]-A[2]):
            y =  new_step(F(1-b), couple[1-b], kp, y)
        x = A[0]

        if x == y : 
            print('pb : x==y et fhash(x)==fhash(y)')
            return None

        while True:
            if x == y :
                break

            tmp1 = x            
            tmp2 = y             # anciennes valeurs
            x = new_step(F(b), couple[b], kp, tmp1)
            y = new_step(F(1-b), couple[1-b], kp, tmp2)

        #return ( (tmp1, F(b).__name__ ) , (tmp2, F(1-b).__name__ ))
        return (   (tmp1, F(b)) , (tmp2, F(1-b))    )

In [273]:
def collision_detection(F, M, C, kp, l):
    """
    detecte une seule collision
    retourne le couple de triplet ( (x0,xd,d), (x0',xd,d') )
    """
    dico = {}
    couple = [M,C]
    while True: 
        b = random.randint(0,1) 
        res = trail(F(b),couple[b], kp, l)
        if res == None:
            continue
            
        x0, xd, d = res

        if (xd,1-b) in dico:  
            #print("Collision found")
            A = (x0,xd,d)                                  # b
            B = (dico[(xd,1-b)][0], xd,dico[(xd,1-b)][1])  # 1-b
            return remonter(F, A , B, M, C, kp, b )
            #return A, B
        dico[(xd,b)] = (x0, d)


In [274]:
collision_detection(F, M1, C1, kp, l)

((b'S$\x0c\x12\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00',
  <function __main__.simple_enc_AES(msg, key, mode='MODE_ECB')>),
 (b'RHg\xe3\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00',
  <function __main__.simple_dec_AES(msg, key, mode='MODE_ECB')>))

In [275]:
def collision_detection_multiple(F,M, C, kp, l, nb_col):
    """
    detecte plusieurs collision 
    nb_col : nb de collison qu'on veut obtenir
    """
    liste = []
    i = 0
    while i<nb_col :
        tmp = collision_detection(F, M, C, kp, l)
        if tmp == None:
            continue
        if tmp in liste or (tmp[1],tmp[0]) in liste: # si collision deja trouvé 
            continue 
        liste.append(tmp)
        i += 1
    return liste

In [277]:
import time
l= 5
col = 3
t1 = time.time()
collisions = collision_detection_multiple(F, M1, C1, kp, l, col )
print(collisions)
t2 = time.time()
print(t2-t1)

25.613345861434937


In [294]:
def verification(M,C, list_col_keys):
    """
    trouver la golden collision
    """
    for i in list_col_keys:
        try: 
            if i[0][1].__name__ == "simple_enc_AES":
                tmp1= i[0][1](M, i[0][0])
                tmp2 = unpad(i[1][1](C, i[1][0]), 16)
            else:
                tmp1= unpad(i[0][1](C, i[0][0]),16)
                tmp2 = i[1][1](M, i[1][0])

            if( tmp1 == tmp2):
                return True
        except ValueError:
            pass
        else : 
            print(tmp1+'\n'+tmp2+'\n\n')
    

In [295]:
verification(M2,C2, collisions)