In [1]:
import pqlattice as pq
import numpy as np
import math
import sys
import os

sys.path.append(os.path.abspath("../.."))

In [2]:
from sagemath import sage_client
sage = sage_client.connect()

Connected to Sage server


In [3]:
def construct_lwe_basis(A, q):
    m, n = A.shape
    Im = q * pq.as_integer(np.identity(m))
    G = np.vstack((A.T, Im))
    H = sage.hnf(G)
    rows, cols = H.shape
    
    return H[:m]

def construct_sis_basis(A, q):
    B_primal = construct_lwe_basis(A, q)
    B_inv = np.linalg.inv(B_primal.astype(float))
    B_dual = np.round(q * B_inv.T).astype(int)
    return B_dual

In [7]:
n = 14
sigma = 2
q = 1000
m = 50
secret_dist = "ternary"

possible_values = []
if secret_dist == "binary":
    possible_values = [0, 1]
elif secret_dist == "ternary":
    possible_values = [-1, 0, 1]

In [9]:
lwe = pq.random.LWE(n, q, sigma, secret_dist, 80)
secret = lwe.secret
A, b = lwe.sample_matrix(m)

In [10]:
recovered_secret = []

print(f"Recovering secret's components")
for i in range(n):
    A_punctured = np.delete(A, i, axis=1)
    G_dual = construct_sis_basis(A_punctured, q)
    target_column = A[:, i]
    v = sage.shortest_vector(G_dual)
    interaction = np.dot(v, target_column)
    
    best_guess = 0
    max_score = -float('inf')
    
    projection = np.dot(v, b)
    
    for guess in possible_values:
        
        correction = guess * interaction
        z = (projection - correction) % q
        angle = (2 * math.pi * z) / q
        score = math.cos(angle)
        
        if score > max_score:
            max_score = score
            best_guess = guess
            
    recovered_secret.append(best_guess)
    
    is_correct = (best_guess == secret[i])
    status = "OK" if is_correct else "FAIL"
    print(f"s[{i:2}]: Guessed {best_guess:>2} (Score: {max_score:.4f}) -> {status}")

print()
print(f"Real secret:")
print(f"{pq.as_integer(secret)}")
print(f"Recovered secret:")
print(f"{pq.as_integer(recovered_secret)}")
accuracy = (pq.as_integer(recovered_secret) == pq.as_integer(secret)).sum() / len(secret)
print()
print(f"accuracy: {accuracy*100:.2f}")

Recovering secret's components
s[ 0]: Guessed  1 (Score: 0.9913) -> OK
s[ 1]: Guessed  0 (Score: 0.9471) -> OK
s[ 2]: Guessed  1 (Score: 0.9620) -> OK
s[ 3]: Guessed  1 (Score: 0.9967) -> OK
s[ 4]: Guessed  1 (Score: 0.9867) -> OK
s[ 5]: Guessed  0 (Score: 0.9997) -> OK
s[ 6]: Guessed -1 (Score: 0.9921) -> OK
s[ 7]: Guessed  0 (Score: 0.9980) -> OK
s[ 8]: Guessed  1 (Score: 0.9980) -> OK
s[ 9]: Guessed  0 (Score: 0.9990) -> OK
s[10]: Guessed  1 (Score: 0.9972) -> OK
s[11]: Guessed  1 (Score: 1.0000) -> OK
s[12]: Guessed  1 (Score: 0.9967) -> OK
s[13]: Guessed  0 (Score: 0.9956) -> OK

Real secret:
[1 0 1 1 1 0 -1 0 1 0 1 1 1 0]
Recovered secret:
[1 0 1 1 1 0 -1 0 1 0 1 1 1 0]

accuracy: 100.00
