# Secret Key Regev Ecncryption and SimplePIR

In this tutorial we will first build secret key regev encryption scheme using learning with errors (LWE) problem. Then we will use secret key regev encryption scheme to build a very simple private information retrieval scheme. 

This tutorial is part of [presentation](https://docs.google.com/presentation/d/1ESZ2xZeyBYyzc-AWvZwc8o4ioOGGZtK9KlnsplzsfQ0/edit?usp=sharing) that I gave at [Zuzalu](https://zuzalu.city/)

In [785]:
import random

class Matrix:

    def __init__(self, mod, rows, cols, mat):
        self.mod = mod
        self.rows = rows
        self.cols = cols
        self.mat = mat
        
    def scale(self, numerator, denominator, new_q):
        mat = [0] * (self.rows * self.cols)
        for index in range(len(self.mat)):
            mat[index] = round((numerator * self.mat[index])/denominator) % new_q
        return Matrix(new_q, self.rows, self.cols, mat)

    def set_at(self, row, col, val):
        self.mat[row*self.cols + col] = val
        
    def get_at(self, row, col):
        return self.mat[row*self.cols + col]

    def __mul__(self, other):
        assert (self.cols == other.rows)
        mat = [0] * (self.rows * other.cols)
        for i in range(self.rows):
            for j in range(self.cols):
                for k in range(other.cols):
                    mat[i*other.cols + k] = (mat[i*other.cols + k] + (self.mat[i*self.cols + j] * other.mat[j*other.cols + k])) % self.mod
        return Matrix(self.mod, self.rows, other.cols, mat)
        

    def __add__(self, other):
        assert (self.rows == other.rows and self.cols == other.cols)
        for index in range(len(self.mat)):
            self.mat[index] = (self.mat[index] + other.mat[index]) % self.mod
        return self

    def __sub__(self, other):
        assert (self.rows == other.rows and self.cols == other.cols)
        for index in range(len(self.mat)):
            self.mat[index] = (self.mat[index] - other.mat[index]) % self.mod
        return self
    
    def __eq__(self, other):
        return self.rows == other.rows and self.cols == other.cols and self.mat == other.mat

    def __repr__(self):
        return "Rows:"+ str(self.rows) + " Cols:" + str(self.cols) + " \nMatrix(" + str(self.mat) + ")"

def zero_matrix(mod, rows, cols): 
    mat = [0 for _ in range(rows * cols)]
    return Matrix(mod, rows, cols, mat)
    
def random_matrix(mod, rows, cols): 
    mat = [random.randint(0, mod - 1) for _ in range(rows * cols)]
    return Matrix(mod, rows, cols, mat)

def sample_error(bound):
    return random.randint(-bound, +bound)

def sample_error_matrix(bound, mod, rows, cols): 
    mat = [sample_error(bound) % mod for _ in range(cols * rows)]
    return Matrix(mod, rows, cols, mat)

# Secret Key Regev Encryption

LWE parameters n, m, q, p, $\sigma$ should be chosen such that solving LWE is hard. 

In practice you will use tools like [lattice estimator](https://github.com/malb/lattice-estimator) to figure out the secruity of your chosen parameters. 

For the purposes of the tutorial we will working with "bound" instead of sampling errors from $\sigma$. This means we will randomly sample error from the range \[-bound, +bound\]

Warning: None of LWE parameters use throughout the tutorial are secure. 

## Attempt 1

Idea behind attempt one is easy. Since B = As+e, let's just add message vector m0 to As+e to encrypt. 

In [786]:
# LWE parameters
n = 10
m = 100
p = 50
bound = 3
q = 1000

In [787]:
s = random_matrix(q, n, 1)
m0 = random_matrix(q, m, 1)

### Encrypt 

In [788]:
A = random_matrix(q, m, n)
e = sample_error_matrix(bound, q, m, 1)
B = (A * s) + e + m0 
# ciphertext `c` consists of two components, B and A. 
c = (B, A)

### Decrypt

In [789]:
# Given c and secret s
B = c[0]
A = c[1]
m1 = B - (A * s)

# m1 != m0 
# This is because message vector m0 got jumbled up with error vector e
assert(m1 != m0)

# However, if the decryptor knws `e` (which isn't the case) they can extract the message
# Since m1 = m0 + e, m1 - e == m0 
assert((m1 - e) == m0)

## Attempt 2

Since the message vector got jumbled up with error vector in last attemp, how about scaling the message vector by a factor before adding it to As + e and scaling back result of B - (A * s) during decryption?

By scaling the message vector we assure that it lives in most significant bits of the ciphertext and does not gets jumbled up with error bits that live in least significant bits. 

However, due to scaling, we will have to reduce the bits in message vector to be smaller than bits in ciphertext. So instead of each value in message vector being modulo q, we will make it modulo p and p < q. We will set the scaling factor $\Delta$ to q/p.

## Encrypt

In [790]:
delta = q / p
m0 = random_matrix(p, m, 1) 
# scale message vector by delta
m0_scaled = m0.scale(q, p, q)
A = random_matrix(q, m, n)
e = sample_error_matrix(bound, q, m, 1)
B = (A * s) + e + m0_scaled 
c = (B, A)

## Decrypt

In [791]:
B = c[0]
A = c[1]
m1 = B - (A * s)
# scale messge by 1/delta
m0_r = m1.scale(p, q, p)

In [792]:
# decryption works!
assert(m0_r == m0)


# Additive homomorphism

Secret Key Regev Encryption is additive homomorphic. 

Let's say c0 ecnrypts m0 and c1 encrypts m1 under secret vector s. If I add c0 + c1 and set that as c2, decrypting c2 will produce m2 such that m2 = m0+m1

In [793]:
delta = q / p
m0 = random_matrix(p, m, 1) 
m0_scaled = m0.scale(q, p, q)
A0 = random_matrix(q, m, n)
e0 = sample_error_matrix(bound, q, m, 1)
B0 = (A0 * s) + e0 + m0_scaled 
c0 = (B0, A0)

m1 = random_matrix(p, m, 1) 
m1_scaled = m1.scale(q, p, q)
A1 = random_matrix(q, m, n)
e1 = sample_error_matrix(bound, q, m, 1)
B1 = (A1 * s) + e1 + m1_scaled 
c1 = (B1, A1)

In [794]:
# c2 = c0 + c1
c2 = (c0[0] + c1[0], c0[1] + c1[1])

In [795]:
# decrypt
B = c2[0]
A = c2[1]
m0_r = B - (A * s)
# scale messge by 1/delta
m0_r = m0_r.scale(p, q, p)

In [796]:
assert(m0_r == (m0 + m1))

# Plaintext Inner product

Let's say we have a plaintext vector k and a ciphertext c0 that encrypts plaintext vector m0. We can transform c0 to c1 such that it encrypts inner product of k and m0. 


In [797]:
# NOTICE that we redine LWE parameter values
n = 10
m = 100
p = 4
bound = 2
q = 1000

s = random_matrix(p, n, 1) 

delta = q / p
m0 = random_matrix(p, m, 1) 
m0_scaled = m0.scale(q, p, q)
A0 = random_matrix(q, m, n)
e0 = sample_error_matrix(bound, q, m, 1)
B0 = (A0 * s) + e0 + m0_scaled 
c0 = (B0, A0)

# k transposed
k_T = random_matrix(p, 1, m)

# Switch k_T from modulus p to modulus q
k_T = k_T.scale(1, 1, q)

In [798]:
c1 = (k_T * c0[0], k_T * c0[1])

In [799]:
# decrypt
B = c1[0]
A = c1[1]
m0_r = B - (A * s)
# scale messge by 1/delta
m0_r = m0_r.scale(p, q, p)

In [800]:
# Switch k_T from modulus q back to p 
k_T = k_T.scale(1, 1, p)

assert(m0_r == k_T * m0)

# Noise growth

Notice that to make plaintext inner product work we changed LWE parameters and increased q = 1000 and decreased p = 4. This is to make sure that error growth, that is $k^{T}$$e0$, remains smaller than $\Delta/2$. If $k^{T}e0 > \Delta/2$ then decryption of c1 will not result to $k^{T}m0$. 

In [801]:
# Let's change LWE parameters again
n = 10
m = 1000
p = 25
bound = 3
q = 1000

# encrypt
s = random_matrix(p, n, 1) 
delta = q / p
m0 = random_matrix(p, m, 1) 
m0_scaled = m0.scale(q, p, q)
A0 = random_matrix(q, m, n)
e0 = sample_error_matrix(bound, q, m, 1)
B0 = (A0 * s) + e0 + m0_scaled 
c0 = (B0, A0)

# k transposed
k_T = random_matrix(p, 1, m)

# Switch k_T from modulus p to modulus q
k_T = k_T.scale(1, 1, q)

# if noise growth is greater than delta, then decryption fails
noise_growth = k_T * e0

# inner product
c1 = (k_T * c0[0], k_T * c0[1])


# decrypt
B = c1[0]
A = c1[1]
m0_r = B - (A * s)
# scale messge by 1/delta
m0_r = m0_r.scale(p, q, p)


# Switch k_T from modulus q back to p 
k_T = k_T.scale(1, 1, p)

# Decryption fails whenever noise growth > delta/2 
print("Correct decryption:" ,m0_r == k_T * m0)
print("delta/2:", (q / p)/2)
print("Noise growth:", noise_growth.mat[0])

Correct decryption: False
delta/2: 20.0
Noise growth: 467


# SimplePIR

We have all the tools to build a very simple [PIR scheme](https://eprint.iacr.org/2022/949). But before building real PIR, let's try building fake PIR to build the intuition of how a real PIR scheme works.

In [802]:
# Structure the db as 2 dimensional matrix. 
db = random_matrix(50, 10, 10)

# let's query value at row = 5 and col = 10
q_r = 5
q_c = 5

# create a hot vector that is 1 at index corresponding to q_c (ie the column in which desired value lives) and 0 otherwise
qu = zero_matrix(50, 10, 1)
qu.set_at(q_c, 0, 1)

# compute db * qu
res = db * qu

# res is the `q_c`th column from which you can extract `q_r`th row to find the desired value
assert(res.get_at(q_r, 0) == db.get_at(q_r, q_c))


Idea for real PIR: Let's do the same thing but this time encrypt the query vector using secret key regev ecnryption. 

In [803]:
# LWE parameters
n = 10
m = 100
p = 4
bound = 2
q = 1000

# db columns must match `m`. Moreover, LWE parameters must be chosen such that inner product results in correct decryption. 
# For simplicity we set db as a square matrix
db = random_matrix(p, m, m)

# let's query value at row = 5 and col = 10
q_r = 5
q_c = 5

# create query vector
qu = zero_matrix(q, m, 1)
qu.set_at(q_c, 0, 1)

# encrypt query vector
s = random_matrix(p, n, 1) 
qu_scaled = qu.scale(q, p, q)
A = random_matrix(q, m, n)
e = sample_error_matrix(bound, q, m, 1)
B = (A * s) + e + qu_scaled 
# encrypted query c_qu
c_qu = (B, A)

# Compute db * c_qu
# change db from modulo p to modulo q
db = db.scale(1, 1, q)
c_res = (db * c_qu[0], db * c_qu[1])

# decrypt c_res to find the `q_c`th column
B = c_res[0]
A = c_res[1]
m_r = B - (A * s)
# scale messge by 1/delta
m_r = m_r.scale(p, q, p)

# yay! you were able to retrieve correct value while with encrypted query. SimplePIR works!
assert(m_r.get_at(q_r, 0) == db.get_at(q_r, q_c))
