In [None]:
import numpy as np
import math

class GSW:
    def __init__(self, lambda_val: int, L: int):
        """
        Setup function for LWE-based encryption scheme.
        
        Parameters:
        - lambda_val: Security parameter (λ)
        - L: Circuit depth (depth of homomorphic operations)

        """
        self.n = lambda_val * L                     # Lattice dimension n based on lambda and L
        
        kappa = lambda_val + L                      # A basic relation, could be tuned depending on context
        self.q = 2 ** kappa

        # Define the error distribution chi as a Gaussian with mean 0 and standard deviation appropriate for LWE
        # Here we use a simple Gaussian; in practice, a discrete Gaussian or centered Gaussian would be used
        chi_mean = 0
        chi_std = self.n ** 1/L
        self.chi = lambda size: np.random.normal(chi_mean, 0.2, size)

        # Calculate the number of samples m, typically proportional to n * log2(q)
        self.m = int(self.n * kappa)                 # number of samples of LWE problem, typically proportional to n * log2(q)

        # Define ell and N based on q and n
        self.l = kappa + 1                          # Bits needed to represent elements in Z_q
        self.N = (self.n) * self.l              # Total dimension for storage
        print(f"n: {self.n}, q: {self.q}, m: {self.m}, l: {self.l}, N: {self.N}")

    def mod_operation(self, value:int):
        """
        value: int
        return: int -> Z_q
        """
        return value % self.q

    def bit_decomp_vec(self, values: np.ndarray[np.array]):
        """
        value: np.array -> Z_q ^ Nxm
        return: np.array -> {0, 1}^NxN
        """
        values = values.astype(int)
        values = self.mod_operation(values)
        return np.array([value >> i & 1 for value in values for i in range(self.l)])

    def bit_decomp(self, values: np.ndarray[np.array]):
        """
        value: np.array -> Z_q ^ Nxm
        return: np.array -> {0, 1}^NxN
        """
        return np.array([self.bit_decomp_vec(line) for line in values])
    
    def bit_comp_vec(self, values:np.ndarray[np.array], n = None):
        """
        value: np.array -> {0, 1}^NxN
        return: np.array -> Z_q^Nxm
        """
        if n is None:
            n = self.n
            
        values = values.astype(int)
        values = self.mod_operation(values)
        return self.mod_operation(np.array([
            sum([values[l] << l-(k*self.l)  for l in range(k*self.l, self.l+(k*self.l))]) 
            for k in range(n)
        ]))

    def bit_comp(self, values:np.ndarray[np.array]):
        """
        value: np.array -> {0, 1}^NxN
        return: np.array -> Z_q^Nxm
        """        
        return np.array([self.bit_comp_vec(line) for line in values])
    
    def flatten(self, value:np.array):
        """
        value: np.array -> {d+}^N
        return: np.array -> {0, 1}^N
        """
        return self.bit_decomp(self.bit_comp(value))
    
    def powersof2(self, values:np.array):
        """
        value: np.array -> Z_q^n
        return: np.array -> Z_q^N
        """
        values = self.mod_operation(values)
        return np.array([
            self.mod_operation(values[n] <<l) for n in range(values.size) for l in range(self.l)
        ])
        
    def keys_gen(self):
        """
        return: Z_q^n, Z_q^n, np.array -> Z_q^mxn
        """
        t = np.random.randint(0, self.q, self.n-1)
        s = np.insert(-t, 0, 1)
        v = self.powersof2(s)    
        B = np.random.randint(0, self.q, (self.m, self.n-1))  
        e = np.around(self.chi((self.m))).astype(int)
        b = B @ t.T  + e
        A = self.mod_operation(np.column_stack((b, B)))
        # print(f"e: {e.shape} -> {e}")
        
        return s, v, A, B, b
    
    def encrypt(self, pk:np.ndarray[np.array], u:int):
        In = np.eye(self.N)
        R = self.binaryRandomMatrix()
        RA = R @ pk
        RA_dec = self.bit_decomp(RA)
        C_z = u * In + RA_dec
        C = self.flatten(C_z)

        return C
    
    def bin_to_int(self, values:np.array):
        """
        value: np.array -> {0, 1}^N
        return: int
        """
        values = values.astype(int)
        return sum([values[i] << i for i in range(values.size)])
        
    def decript(self, sk:np.array, C:np.ndarray[np.array]):

        Cv = self.mod_operation(C @ sk.T)[0:self.l-1]
        g = np.array([1<<i for i in range(self.l-1)])
        bit = 0
        bits = np.array([])
        j = 0 
        
        for i in range(self.l-2, -1, -1):
            bit = int(round((Cv[i] - bit) / g[i])) >> j & 1
            bits = np.append(bits, bit)
            bit = bit << j
            j += 1
        # print("Bits: ", bits)
        
        # print(f"Cv: {Cv.shape} -> {Cv}\ng: {g.shape} -> {g}")
        
        return self.bin_to_int(bits)
        
    
    def binaryRandomMatrix(self):
        """
        return: np.array -> {0, 1}^NxN
        """
        return np.random.randint(0, 2, (self.N, self.m))
    
    def generateRandomMatrix(self):
        """
        return: np.array -> Z_q^NxN
        """
        return np.random.randint(0, self.q, (self.N, self.n))
    
    
gsw = GSW(8, 5)
s, v, A, B, b = gsw.keys_gen()
# print(f"s: {s.shape} -> {s}\nv: {v.shape} -> {v}\nA: {A.shape} -> \n{A}\nB: {B.shape} -> \n{B}\nb: {b.shape} -> {b}")
C1 = gsw.encrypt(A, 5)
C2 = gsw.encrypt(A, 3)
C3 = gsw.encrypt(A, 2)
m1 = gsw.decript(v, C1)
m2 = gsw.decript(v, C2)
print(f"m1 + m2 = {m1} + {m2} = {m1+m2} = {gsw.decript(v, C1 + C2)}")
print(f"m1 x m2 = {m1} x {m2} = {m1*m2} = {gsw.decript(v, C1 @ C2)}")




    

n: 40, q: 8192, m: 520, l: 14, N: 560
m1 + m2 = 5 + 3 = 8 = 8
m1 x m2 = 5 x 3 = 15 = 19


### Testing Matrix decomposition

In [99]:
gsw = GSW(5, 4)
matrix = gsw.generateRandomMatrix()
print(matrix)
print(matrix.shape)

decomp = gsw.bit_decomp(matrix)
print(decomp)
print(decomp.shape)

comp = gsw.bit_comp(decomp)
print(comp)
print(comp.shape)

n: 20, q: 512, m: 180, l: 10, N: 200
[[272  41 221 ... 365 130 270]
 [407 171 206 ...  45 292 398]
 [136 430 319 ... 132 497 328]
 ...
 [446 202  98 ... 271 128 169]
 [511 284  57 ... 336 181 398]
 [ 79 100 429 ... 166  57 101]]
(200, 20)
[[0 0 0 ... 0 1 0]
 [1 1 1 ... 1 1 0]
 [0 0 0 ... 0 1 0]
 ...
 [0 1 1 ... 1 0 0]
 [1 1 1 ... 1 1 0]
 [1 1 1 ... 0 0 0]]
(200, 200)
[[272  41 221 ... 365 130 270]
 [407 171 206 ...  45 292 398]
 [136 430 319 ... 132 497 328]
 ...
 [446 202  98 ... 271 128 169]
 [511 284  57 ... 336 181 398]
 [ 79 100 429 ... 166  57 101]]
(200, 20)
