# Uncompressing Dilithium's public key
-------------------------

## Preliminaries

### Libraries and functions used

In [None]:
# Loading the functions used for the attack
%run -i ../Helper_functions.py

In [None]:
# Importing useful libraries
import time

from collections import Counter

# Only to display markdown tables
from IPython.display import display, Markdown, Latex

### Compiling the codes

If necessary we can produce the NIST KATs.  
These keys will be used for the attack.  
Alternatively you can use those provided in the folder `../nistkat/`.

In [None]:
# Path to Dilithium code
path_to_dilithium = f"../../dilithium/ref/"

In [None]:
# Rule to make the NIST KAT file
nistkat_rule = f"nistkat/PQCgenKAT_sign{dilithium.MODE}"

In [None]:
%%bash -s "$path_to_dilithium" "$nistkat_rule"
cd $1
make $2
./$2

We compile the code used to generate the desired number of random signatures for a given key.  
Later we can execute it like this:
```bash
./sign_rdm_msg_and_save2 0 10000
```
Here, we produced 10000 signatures for the key 0 of the KAT of Dilithium2.

In [None]:
# Path to where the added functions are
path_to_c = f"../C_functions/"

In [None]:
# Rule to make the script that signs random messages and saves them 
sign_rule = f"sign_rdm_msg_and_save{dilithium.MODE}"

Some Warnings may appear.

In [None]:
%%bash -s "$path_to_c" "$sign_rule"
cd $1
make $2

We compile the code used to formulate and solve LP instances.  
Later we can execute it like this:
```bash
./build_solve_t0_lp2 0 10000 100
```
Here we formulated an LP problem with 10000 inequalities for the key 0 of the KAT of Dilithium2 with a radius C=100.

In [None]:
# Path to where the lpsolve55 lib is located
# by default should be:
lplib_path = f"../.."

# # if different installation, enter your own:
# lplib_path = 

In [None]:
# Rule to make the script that builds and solves the LP instances 
atk_rule   = f"build_solve_t0_lp{dilithium.MODE}"

Some Warnings may appear.

In [None]:
%%bash -s "$path_to_c" "$atk_rule" "$lplib_path"
cd $1
make LPLIB_PATH=$3 $2

## Main Attack

### Parameters setup

In [None]:
# The number of keys to target from a 'NIST KAT' file
# by default 1 <= NB_Keys_tested <= 100
NB_Keys_tested = 1

# Set this option to True if the sk is included
known_sk = True

In [None]:
# Store the NB_Keys_tested targeted
keys_file_path = f'{path_to_dilithium}PQCsignKAT_Dilithium{dilithium.MODE}.rsp'
if known_sk:
    PK, SK = open_keys(NB_Keys_tested, 
                       include_sk = known_sk,
                       keys_file_name = keys_file_path)
else:
    PK = open_keys(NB_Keys_tested, 
                   include_sk = known_sk,
                   keys_file_name = keys_file_path)

In [None]:
# If we want to collect signatures
Collect_signs = True

# Set the number of signatures to generate
Nb_signs = 300000

### Sub-function to generate useful inequations and solve corresponding LP

In [None]:
def build_solve_lp(process_, key_index, nb_ineq, *args):
    """
    This function calls the C executable that builds and solves the LP problem with the corresponding parameters
  
    Parameters
    ----------
    process_     str: The executable, here build_solve_t0_lp{dilithium.MODE}
    key_index    int: The key targeted 
    nb_ineq      int: Number of inequations to collect 
    C_low        int: optionnal if the interval is not centered in 0
    C            int: radius     
    
    Returns
    ----------
    output list(str): Output from the executable  
    """

    # Difference between first non-centered radius and the others
    if len(args) == 1:
        C = args[0]
    elif len(args) == 2:
        C_low = args[0]
        C_up  = args[1]
        C = C_up
    else:
        raise ValueError("Too many C, either one or two")
        
    print(f"Solving LP for {C: 4.1f} and {Nb_ineq: 6d} ineq", end = " ")
    
    if len(args) == 1:
        start_time = time.time()
        output = run_my_process(process_, f"{key_index}", f"{nb_ineq}", f"{C}")
        total_seconds = time.time() - start_time
    elif len(args) == 2:
        start_time = time.time()
        output = run_my_process(process_, f"{key_index}", f"{nb_ineq}", f"{C_low}", f"{C_up}")
        total_seconds = time.time() - start_time
        
    time_delta = timedelta(seconds = total_seconds)
    hours, remainder = divmod(time_delta.seconds, 3600)
    minutes, seconds = divmod(remainder, 60)
    time_str = f"  Time: {hours} h {minutes} min {seconds} sec\n"
    print(f"({hours}h{minutes}m{seconds}s)")
    output.append(time_str)
    output.append(f"Nb Inequalities: {Nb_ineq}\n")  
    
    return output

### Algorithm 6: Heuristically recovering $\mathbf{t}_0$

In [None]:
# Here we set the list of radii 
List_C         = [pow(2, dilithium.D-1)/pow(2, i) for i in range(dilithium.D)]
# Associated with each radius we set the corresponding number of inaqualities
Nb_Inequations = [50000 for C_ in List_C]

For a list of radii and number of inequations, this part will:
- produce the desired number of signatures and store them in: `Attack_t0/signs/Dilithium{dilithium.mode}/key{key_targeted}`
- build the LP problems for each polynomial of $\mathbf{t}_0$ and store it in: `Attack_t0/lps/Dilithium{dilithium.mode}/key{key_targeted}`
- solve the LP problems for each polynomial of $\mathbf{t}_0$
- update the $\mathbf{t}_0$  guess stored in: `Attack_t0/guess/Dilithium{dilithium.mode}/key{key_targeted}`
- Repeat untill all the radius and inequations are collected
- Display a sum up `.md` file and save it in: `Attack_t0/sum_ups/Dilithium{dilithium.mode}/key{key_targeted}`


In [None]:
all_keys_results = []
for key_targeted in range(NB_Keys_tested):
    print(f">>> Uncompressing t0 for key#{key_targeted}:")
    if Collect_signs:
        print(f"Signing {Nb_signs} random messages:", end = " ")
        output = run_my_process(f"{path_to_c}{sign_rule}", f"{key_targeted}",f"{Nb_signs}")
        print(u"\u2705")
    
    # Opening corresponding pk/sk
    if known_sk:
        pk, sk = PK[key_targeted], SK[key_targeted]
        _, Key, tr, s1, s2, t0 = dilithium.unpack_sk(sk)
    else:
        pk = PK[key_targeted]
        
    rho, t1 = dilithium.unpack_pk(pk)
    Antt = dilithium.polyvec_matrix_expand(rho)
    A = Antt2Aintt(Antt)
    
    key_results = []

    t0_guess_file_path = f"../Guess/Dilithium{dilithium.MODE}/key{key_targeted}/t0_guess_file.bin"
    os.makedirs(os.path.dirname(t0_guess_file_path), exist_ok = True)
    
    # We start with a guess with all the coeffs to 0
    t0_guess = np.zeros((dilithium.K, dilithium.N)).astype(np.float64)
    t0_guess_C = np.array(t0_guess).astype(np.float64)
    with open(t0_guess_file_path, 'wb') as t0_guess_file:
        t0_guess_file.write(t0_guess_C.tobytes())
    
    for iteration in range(len(List_C)):
        C = List_C[iteration]
        Nb_ineq = Nb_Inequations[iteration] 
        if iteration == 0:
            output = build_solve_lp(f"{path_to_c}{atk_rule}", key_targeted, Nb_ineq, C-1, C)
        else:
            output = build_solve_lp(f"{path_to_c}{atk_rule}", key_targeted, Nb_ineq, C)
        
        if known_sk:
            t0_guess_updated = np.fromfile(t0_guess_file_path, dtype=np.float64)
            t0_guess_updated = t0_guess_updated.reshape((dilithium.K, dilithium.N))

            ERROR = [[np.round(t0[i][j] - t0_guess_updated[i][j]) for j in range(dilithium.N)] for i in range(dilithium.K)]
            norm_inf_error = np.linalg.norm(ERROR, axis = 1, ord = np.inf)
            print("  Inf Norm between t0 and t0_found: ", norm_inf_error)

            output.append(f"Min error: {np.min(norm_inf_error)}\n")
            output.append(f"Max error: {np.max(norm_inf_error)}\n")         

        key_results.append(output)

    sum_up_file_path = f"../Sum_ups/Dilithium{dilithium.MODE}/key{key_targeted}/results.md"
    temps_test, signs_test, mkdwn_test = format_results(key_results, include_sk = known_sk)
    os.makedirs(os.path.dirname(sum_up_file_path), exist_ok = True)
    with open(sum_up_file_path, "w") as sum_up_file:
        sum_up_file.write(mkdwn_test)
    all_keys_results.append(key_results)
    display(Markdown(mkdwn_test))