# DST40 Brute Force version GPU

Auteur: [Kaci Amaouche](mailto:amaouchekaci28@gmail.com)

Dans ce [Jupyter](https://jupyter.org/) notebook,nous présentons une implémentation d'une attaque type bruteforce de DST40 sur GPU en utilisant cuda numba. Si vous n'êtes pas familier avec Jupyter, vous pouvez jeter un coup d'œil rapide à la documentation ou aux tutoriels disponibles[Notebook Basics](https://jupyter-notebook.readthedocs.io/en/stable/examples/Notebook/Notebook%20Basics.html) guide (~5min).

Ce notebook comporte:

* [Configuration de l'environnement]
* [Implémentation]

## 1 - Configuration de l'environnement

Pour pouvoir exécuter ces scripts, il va falloir:
1. Ce notebook (the `.ipynb` file)
1. Python >= 3.8


## 2 - Implémentation

On commence par importer les librairies nécessaires

In [None]:
import numpy as np
from numba import cuda
import math
from numba import njit, prange, config
from typing import Callable, List

Maintenant, nous définissons l'algorithme de chiffrement, qui est presque identique à la version CPU, sauf qu'ici, nous devons effectuer quelques manipulations pour obtenir un code équivalent, car CUDA ne prend pas en charge les listes.

In [None]:
@njit(parallel=True)
def bit(number, position):
    # Return the bit at the specified position
    return (number >> position) & 1

@njit(parallel=True)
def dst40_round1(hash: int, key: int) -> int:
    """
    Compute 2 bits of dst-40 keystream
    :param hash: 1st LFSR state (the one initialized with the challenge)
    :param key: 2nd LFSR state (the one initialized with the secret key)
    """

    key_bit0, key_bit1, key_bit2, key_bit3, key_bit4, key_bit5, key_bit6, key_bit7, key_bit8, key_bit9, key_bit10, key_bit11, key_bit12, key_bit13, key_bit14, key_bit15, key_bit16, key_bit17, key_bit18, key_bit19, key_bit20, key_bit21, key_bit22, key_bit23, key_bit24, key_bit25, key_bit26, key_bit27, key_bit28, key_bit29, key_bit30, key_bit31, key_bit32, key_bit33, key_bit34, key_bit35, key_bit36, key_bit37, key_bit38, key_bit39=bit(key,0),bit(key,1),bit(key,2),bit(key,3),bit(key,4),bit(key,5),bit(key,6),bit(key,7),bit(key,8),bit(key,9),bit(key,10),bit(key,11),bit(key,12),bit(key,13),bit(key,14),bit(key,15),bit(key,16),bit(key,17),bit(key,18),bit(key,19),bit(key,20),bit(key,21),bit(key,22),bit(key,23),bit(key,24),bit(key,25),bit(key,26),bit(key,27),bit(key,28),bit(key,29),bit(key,30),bit(key,31),bit(key,32),bit(key,33),bit(key,34),bit(key,35),bit(key,36),bit(key,37),bit(key,38),bit(key,39)
    hash_bit0, hash_bit1, hash_bit2, hash_bit3, hash_bit4, hash_bit5, hash_bit6, hash_bit7, hash_bit8, hash_bit9, hash_bit10, hash_bit11, hash_bit12, hash_bit13, hash_bit14, hash_bit15, hash_bit16, hash_bit17, hash_bit18, hash_bit19, hash_bit20, hash_bit21, hash_bit22, hash_bit23, hash_bit24, hash_bit25, hash_bit26, hash_bit27, hash_bit28, hash_bit29, hash_bit30, hash_bit31, hash_bit32, hash_bit33, hash_bit34, hash_bit35, hash_bit36, hash_bit37, hash_bit38, hash_bit39=bit(hash,0),bit(hash,1),bit(hash,2),bit(hash,3),bit(hash,4),bit(hash,5),bit(hash,6),bit(hash,7),bit(hash,8),bit(hash,9),bit(hash,10),bit(hash,11),bit(hash,12),bit(hash,13),bit(hash,14),bit(hash,15),bit(hash,16),bit(hash,17),bit(hash,18),bit(hash,19),bit(hash,20),bit(hash,21),bit(hash,22),bit(hash,23),bit(hash,24),bit(hash,25),bit(hash,26),bit(hash,27),bit(hash,28),bit(hash,29),bit(hash,30),bit(hash,31),bit(hash,32),bit(hash,33),bit(hash,34),bit(hash,35),bit(hash,36),bit(hash,37),bit(hash,38),bit(hash,39)
    fa11= (key_bit39 << 4) | (key_bit31 << 3) | (hash_bit39 <<2) | (hash_bit31 <<1) | (hash_bit23)
    fb22= (key_bit38 << 4) | (key_bit30 << 3) | (hash_bit38 <<2) | (hash_bit30 <<1) | (hash_bit22)
    fc33= (key_bit23 << 4) | (key_bit15 << 3) | (key_bit7 <<2) | (hash_bit15 <<1) | (hash_bit7)
    fd44= (key_bit22 << 4) | (key_bit14 << 3) | (key_bit6 <<2) | (hash_bit14 <<1) | (hash_bit6)

    fa55=(key_bit37 << 4) | (key_bit29 << 3) | (hash_bit37 <<2) | (hash_bit29 <<1) | (hash_bit21)
    fb66=(key_bit36 << 4) | (key_bit28 << 3) | (hash_bit36 <<2) | (hash_bit28 <<1) | (hash_bit20)
    fc77=(key_bit21 << 4) | (key_bit13 << 3) | (key_bit5 <<2) | (hash_bit13 <<1) | (hash_bit5)
    fd88=(key_bit20 << 4) | (key_bit12 << 3) | (key_bit4 <<2) | (hash_bit12 <<1) | (hash_bit4)

    fa99=(key_bit35 << 4) | (key_bit27 << 3) | (hash_bit35 <<2) | (hash_bit27 <<1) | (hash_bit19)
    fb1010=(key_bit34 << 4) | (key_bit26 << 3) | (hash_bit34 <<2) | (hash_bit26 <<1) | (hash_bit18)
    fc1111=(key_bit19 << 4) | (key_bit11 << 3) | (key_bit3 <<2) | (hash_bit11 <<1) | (hash_bit3)
    fd1212=(key_bit18 << 4) | (key_bit10 << 3) | (key_bit2 <<2) | (hash_bit10 <<1) | (hash_bit2)

    fa1313=(key_bit33 << 4) | (key_bit25 << 3) | (hash_bit33 <<2) | (hash_bit25 <<1) | (hash_bit17)
    fb1414=(key_bit32 << 4) | (key_bit24 << 3) | (hash_bit32 <<2) | (hash_bit24 <<1) | (hash_bit16)
    fe1515=(key_bit17 << 3) | (key_bit9 << 2) | (key_bit1 <<1) | (hash_bit9) 
    fe1616=(key_bit16 << 3) | (key_bit8 << 2) | (key_bit0 <<1) | (hash_bit8) 

    if fa11== 1 or fa11== 3 or fa11== 4 or fa11==6 or fa11==7 or fa11==8 or fa11==9 or fa11==12 or fa11==18 or fa11==19 or fa11==23 or fa11==24 or fa11==26 or fa11==28 or fa11==29 or fa11==31:
         fa1=1
    else:
         fa1=0
    if fb22==1 or fb22==2 or fb22==9 or fb22==10 or fb22==12 or fb22==13 or fb22==14 or fb22==15 or fb22==16 or fb22==17 or fb22==18 or fb22==19 or fb22==21 or fb22==22 or fb22==29 or fb22==30:
         fb2=1
    else:
         fb2=0
    if fc33==2 or fc33==4 or fc33==5 or fc33==6 or fc33==8 or fc33==10 or fc33==11 or fc33==12 or fc33==18 or fc33==19 or fc33==20 or fc33==21 or fc33==25 or fc33==27 or fc33==29 or fc33==31:
         fc3=1
    else:
         fc3=0
    if fd44==1 or fd44==3 or fd44==4 or fd44==5 or fd44==10 or fd44==11 or fd44==12 or fd44==14 or fd44==18 or fd44==19 or fd44==20 or fd44==22 or fd44==25 or fd44==27 or fd44==28 or fd44==29:
         fd4=1
    else:
         fd4=0
    if fa55== 1 or fa55== 3 or fa55== 4 or fa55==6 or fa55==7 or fa55==8 or fa55==9 or fa55==12 or fa55==18 or fa55==19 or fa55==23 or fa55==24 or fa55==26 or fa55==28 or fa55==29 or fa55==31:
        fa5=1
    else:
         fa5=0
    if fb66==1 or fb66==2 or fb66==9 or fb66==10 or fb66==12 or fb66==13 or fb66==14 or fb66==15 or fb66==16 or fb66==17 or fb66==18 or fb66==19 or fb66==21 or fb66==22 or fb66==29 or fb66==30:
        fb6=1
    else:
         fb6=0
    if fc77==2 or fc77==4 or fc77==5 or fc77==6 or fc77==8 or fc77==10 or fc77==11 or fc77==12 or fc77==18 or fc77==19 or fc77==20 or fc77==21 or fc77==25 or fc77==27 or fc77==29 or fc77==31:
         fc7=1
    else:
         fc7=0
    if fd88==1 or fd88==3 or fd88==4 or fd88==5 or fd88==10 or fd88==11 or fd88==12 or fd88==14 or fd88==18 or fd88==19 or fd88==20 or fd88==22 or fd88==25 or fd88==27 or fd88==28 or fd88==29:
         fd8=1
    else:
         fd8=0
    if fa99== 1 or fa99== 3 or fa99== 4 or fa99==6 or fa99==7 or fa99==8 or fa99==9 or fa99==12 or fa99==18 or fa99==19 or fa99==23 or fa99==24 or fa99==26 or fa99==28 or fa99==29 or fa99==31:
         fa9=1
    else:
         fa9=0
    if fb1010==1 or fb1010==2 or fb1010==9 or fb1010==10 or fb1010==12 or fb1010==13 or fb1010==14 or fb1010==15 or fb1010==16 or fb1010==17 or fb1010==18 or fb1010==19 or fb1010==21 or fb1010==22 or fb1010==29 or fb1010==30:
         fb10=1
    else:
         fb10=0
    if fc1111==2 or fc1111==4 or fc1111==5 or fc1111==6 or fc1111==8 or fc1111==10 or fc1111==11 or fc1111==12 or fc1111==18 or fc1111==19 or fc1111==20 or fc1111==21 or fc1111==25 or fc1111==27 or fc1111==29 or fc1111==31:
         fc11=1
    else:
         fc11=0
    if fd1212==1 or fd1212==3 or fd1212==4 or fd1212==5 or fd1212==10 or fd1212==11 or fd1212==12 or fd1212==14 or fd1212==18 or fd1212==19 or fd1212==20 or fd1212==22 or fd1212==25 or fd1212==27 or fd1212==28 or fd1212==29:
         fd12=1
    else:
         fd12=0
    if fa1313== 1 or fa1313== 3 or fa1313== 4 or fa1313==6 or fa1313==7 or fa1313==8 or fa1313==9 or fa1313==12 or fa1313==18 or fa1313==19 or fa1313==23 or fa1313==24 or fa1313==26 or fa1313==28 or fa1313==29 or fa1313==31:
         fa13=1
    else:
         fa13=0
    if fb1414==1 or fb1414==2 or fb1414==9 or fb1414==10 or fb1414==12 or fb1414==13 or fb1414==14 or fb1414==15 or fb1414==16 or fb1414==17 or fb1414==18 or fb1414==19 or fb1414==21 or fb1414==22 or fb1414==29 or fb1414==30:
         fb14=1
    else:
         fb14=0
    if fe1515==1 or fe1515==3 or fe1515==6 or fe1515==7 or fe1515==8 or fe1515==9 or fe1515==12 or fe1515==14:
         fe15=1
    else:
         fe15=0
    if fe1616==1 or fe1616==3 or fe1616==6 or fe1616==7 or fe1616==8 or fe1616==9 or fe1616==12 or fe1616==14:
         fe16=1
    else:
         fe16=0
    fg11=(fa1 << 3) | (fb2 << 2) | (fc3 << 1) | fd4
    fg22=(fa5 << 3) | (fb6 << 2) | (fc7 << 1) | fd8
    fg33= (fa9 << 3) | (fb10 << 2) | (fc11 << 1) | fd12
    fg44=(fa13 << 3) | (fb14 << 2) | (fe15 << 1) | fe16 

    if fg11== 1 or fg11==2 or fg11==3 or fg11==6 or fg11==9 or fg11==12 or fg11==13 or fg11==14:
         fg1=1
    else:
         fg1=0
    if fg22== 1 or fg22==2 or fg22==3 or fg22==6 or fg22==9 or fg22==12 or fg22==13 or fg22==14:
         fg2=1
    else:
         fg2=0
    if fg33== 1 or fg33==2 or fg33==3 or fg33==6 or fg33==9 or fg33==12 or fg33==13 or fg33==14:
         fg3=1
    else:
         fg3=0
    if fg44== 1 or fg44==2 or fg44==3 or fg44==6 or fg44==9 or fg44==12 or fg44==13 or fg44==14:
         fg4=1
    else:
         fg4=0
    
    fh11=(fg1 <<3) | (fg2 << 2) | (fg3 << 1) | fg4

    if fh11==0 or fh11==1 or fh11==14 or fh11==15:
        fh1=0
    if fh11==2 or fh11==6 or fh11==9 or fh11==13:
        fh1=2
    if fh11==3 or fh11==4 or fh11==11 or fh11==12:
        fh1=3
    if fh11==5 or fh11==7 or fh11==8 or fh11==10:
        fh1=1
    res = fh1 ^((hash_bit1<<1) | hash_bit0 )
    return res

@njit(parallel=True)
def dst40_encode(challenge, key) :
	"""
	DST-40 encryption
	:param challenge: 40-bits challenge
	:param key: 40-bits key
	:return: keystream in int
	"""

	hash40 = challenge
	key40 = key

	cnt = 0
	for _ in range(192):
		tmp = 0
		hash40 = (dst40_round1(hash40, key40) << 38) | (hash40 >> 2)

		if cnt == 1: # every three cycles (counter begin at 2) we shift the register (the one initialized with the secret key)
			tmp = key40
			key40  = ((bit(tmp, 0) ^bit(tmp, 2) ^bit(tmp, 19) ^ bit(tmp, 21)) << 39) | (key40 >> 1)

		cnt += 1
		if cnt == 3 :
			cnt = 0

	return hash40 >> 16

Enfin, nous pouvons définir le noyau (kernel) et la fonction qui l'exécute. Veuillez noter que dans le cadre d'une attaque réelle, MAX_THREAD doit être mis à 2^40. Cependant, pour l'exemple d'exécution avec une clé de 32 bits, je l'ai réglé sur 2^32.

In [None]:
MAX_THREAD=2**32
KEY_SIZE=2**48
@cuda.jit
def kernel_dst_40_bits_key(keystream1,keystream2, key_found, iv1,iv2, progress, call_number):
    # Increment number of thread passed
    cuda.atomic.add(progress, 0, 1)
    
    # Get the current ID of the thread + MAX_NUMBER of key generate by one kernel * the call number
    tested_key = (cuda.blockIdx.x * cuda.blockDim.x + cuda.threadIdx.x)

    # If the keystream result is equal to the keystream captured for both of the iv then the key is found
    if dst40_encode(iv1,tested_key)==keystream1:
        if dst40_encode(iv2,tested_key)==keystream2:
            cuda.atomic.add(key_found, 0, tested_key)


            


def gpu_dst_40_brute_force_40_bits_keys(initialization_vector: List[int], keystream: List[int]) -> bool:
  
    # Data preprocessing
    np_initialization_vector = np.array(initialization_vector)
    np_keystream = np.array(keystream)

    # Copy data on the device
    d_key = cuda.managed_array(1, dtype=np.uint64, strides=None, order='C', stream=0, attach_global=True)
    progress = cuda.managed_array(1, dtype=np.uint64, strides=None, order='C', stream=0, attach_global=True)

    # Compute number of block and thread
    threads_per_block = 1024
    blocks_per_grid = int(MAX_THREAD / threads_per_block)
    number_call_kernel = math.ceil((KEY_SIZE / MAX_THREAD))

    # According to the number of key possibility, the kernel is calling X time to reach the number of possibility
    # For example: if the key is 2*33 we call the kernel twice if the number of thread = 1024 and the number of block = 4194304 (1024 * 4194304 = 2**32)
    for call in range(2**40):
        d_keystream = cuda.to_device(np_keystream)
        d_initialization_vector = cuda.to_device(np_initialization_vector)
        d_kernel_call_number = cuda.to_device([call])
        kernel_dst_40_bits_key[blocks_per_grid, threads_per_block](d_keystream[0],d_keystream[1], d_key,
                                                                     d_initialization_vector[0],d_initialization_vector[1], progress,
                                                                     d_kernel_call_number)
        cuda.synchronize()
        return d_key[0]

In [None]:
P1,P2=2**37-1235478,2**38-14257531
key=2**32-175
C1,C2=dst40_encode(P1,key),dst40_encode(P2,key)
gpu_dst_40_brute_force_40_bits_keys([P1,P2],[C1,C2])

Pour une clé de 32 bits, l'algorithme prend environ 47 secondes pour trouver la bonne clé, tandis qu'il prendrait environ 6 heures pour une clé de 40 bits, car la complexité est exponentielle. Il est important de noter que ce test de performances a été réalisé sur Google Colab, et il serait beaucoup plus lent si exécuté sur une carte Nvidia puissante