# Solving the set of inequations for ILP solving based on good signatures

## (given the knowledge of $t_0$ )
-------------------------

## Libraries Used
------------------------

Below choose the solver to use between `scipy` and `lpsolve`    

In [None]:
%run -i ../Helper_functions.py

In [None]:
# Set the number of keys tested for the attack , the same as for ./PQCgenKAT_Sign_Modified
NB_Keys_tested = 1

# Set the number of keys tested for the attack , the same as for ./PQCgenSign_keyKAT
TOTAL_NB_Signs = 1250000

# Folder where the lp files are stored
dir_lps = f"{os.path.abspath(os.path.join(__file__ ,'..'))}/Lps"

# Adapt file type according to solver used 
if SOLVER == "scipy":
    bounds = [(- dilithium.ETA, dilithium.ETA) for _ in range(dilithium.N)]
    extension_ = "npz"
elif SOLVER == "lpsolve":
    extension_ = "lp"
    # Maximum runtime for lpsolve for each polynomial (in sec.)
    MAX_TIMEOUT    = 30*60
else:
    print("Wrong solver found, verify it before solving")

In [None]:
PK, SK = open_pk_sk(NB_Keys_tested)

In [None]:
ALL_S2_FOUND = {i: 0 for i in range(NB_Keys_tested)}

In [None]:
%%time

for key_targeted in range(NB_Keys_tested):
    # Open corresponding pk/sk 
    pk, sk = PK[key_targeted], SK[key_targeted]
    
    rho, t1 = dilithium.unpack_pk(pk)

    Antt = dilithium.polyvec_matrix_expand(rho)
    A = dilithium.Antt2Aintt(Antt)

    # Just to compare to the correct values, we unpack the sk
    _, Key, tr, s1, s2, t0 = dilithium.unpack_sk(sk)
    s2_found = []
    
    vec_success = True
    
    for poly_targeted in range(0, dilithium.K):
        poly_success  = True
        lps_file_name = f"{dir_lps}/Dilithium{dilithium.MODE}/sk_{key_targeted}_poly{poly_targeted}.{extension_}"
        
        if SOLVER == "scipy":
            np_file = print_np_file_infos(f"{lps_file_name}")
            res = linprog(c = np_file["c"], A_ub = np_file["A"], b_ub = np_file["b"], bounds = bounds)
            potential_s2 = np.round(res.x).astype(int)
            if not np.array_equal(s2[poly_targeted], potential_s2):
                print(f"\n>>> At least one wrong coefficient found of s2[{poly_targeted}] :(")
                poly_success = False
                vec_success = False
                break
        elif SOLVER == "lpsolve":
            ordered_potential_s2 = [0 for _ in range(dilithium.N)]
            print(f"Loading: {lps_file_name}")
            lp_handle = lps.lpsolve(b'read_LP', lps_file_name.encode())
            lps.lpsolve(b'set_verbose', lp_handle, lps.IMPORTANT)
            lps.lpsolve(b"set_timeout", lp_handle, MAX_TIMEOUT)
            lps.lpsolve(b'solve',lp_handle)
            potential_s2 = lps.lpsolve(b"get_variables", lp_handle)[0]
            coeff_names = lps.lpsolve(b'get_col_name', lp_handle)

            for i in range(dilithium.N):
                coeff_name = coeff_names[i]
                index = int(coeff_name[1:])
                
                if potential_s2[i] != s2[poly_targeted][index]:
                    print(f"\n>>> Wrong coefficient found s2[{poly_targeted}][{index}] = {potential_s2[i]} vs. real s2[{poly_targeted}][{index}] = {s2[poly_targeted][index]}")
                    poly_success = False
                    vec_success = False
                    break
                else:
                    ordered_potential_s2[index] = np.round(potential_s2[i]).astype(int)
                    
            if poly_success == False:
                break             
        else:
            raise ValueError(f"Solver {SOLVER}, used is incorrect")
        s2_found.append(ordered_potential_s2)
    ALL_S2_FOUND[key_targeted] = s2_found
    if vec_success:
        print(f"\n>>> For key#{key_targeted}, the {dilithium.K} polynomials of s2 are correctly found!")
    else:
        print(f"\n>>> For key#{key_targeted}, at least one of the {dilithium.K} polynomials of s2 are not correct!")

In [None]:
if SOLVER == "lpsolve":
    free_all_lps()