# Key switch and Mod Switch

In [9]:
import numpy as np
from note_include.elem.Ring  import Ring
from note_include.elem.LWE   import LWE
from note_include.elem.RLWE  import RLWE
from note_include.elem.RLWEp import RLWEp
from note_include.elem.RGSW  import RGSW
from note_include.utils.types import RGSWctxt, RLWEctxt, RLWEpctxt, LWEctxt

np.set_printoptions(threshold=np.inf, linewidth=np.inf)



In [10]:
q    = 7681 
q    = 2048 
n    = 16
# Q    = 12289 # For N = 1024
# N    = 1024  
B    = 2
B_ks = 256
ddd  = int(np.ceil(np.log(q) / np.log(B)))
std  = 3.2

In [11]:
def crange(coeffs, q):
    """
        For Modulus range [-q/2, q/2]
    """
    coeffs = np.where((coeffs >= 0) & (coeffs <= q//2),
                      coeffs,
                      coeffs - q)

    return coeffs

---
## KeySwitch

In [12]:
std = 0

lwe  = LWE(n, q, std)
rlwe = RLWE(n, q, std)
rgsw = RGSW(n, q, std, B, ddd)
s, _ = rlwe.keygen()

monomial = np.zeros(n)
u = 5
monomial[(n//2) - u] = 1
monomial = Ring(n, q, monomial)

rgsw_key = rgsw.encrypt(monomial, s)

m1 = [i for i in range(8)]
m2 = [-i for i in range(8)]
m  = np.concatenate((m1, m2))

poly_acc  = Ring(n, q, m)
poly_zero = Ring(n, q, np.zeros(n))

result = rgsw.mult_rlwe((poly_zero, poly_acc), rgsw_key)
print(result)

a, b = result

ptxt = rlwe.decrypt(result, s)
print(crange(np.array(ptxt.coeffs), q))

a.coeffs = [a.coeffs[0]] + [-x for x in reversed(a.coeffs[1:])]
a.coeffs = np.array(a.coeffs, dtype=int)
s.coeffs = np.array(s.coeffs, dtype=int)
b = int(b.coeffs[0])

msg = lwe.decrypt((a.coeffs, b), s.coeffs)
print(msg)

[R(n=16, q=2048, coeffs= 100.0 + 601.0x + 1206.0x^2 + 1212.0x^3 + 980.0x^4 + 1066.0x^5 + 114.0x^6 + 729.0x^7 + 156.0x^8 + 1868.0x^9 + 622.0x^10 + 113.0x^11 + 6.0x^12 + 1876.0x^13 + 47.0x^14 + 960.0x^15), R(n=16, q=2048, coeffs= 876.0 + 1179.0x + 210.0x^2 + 1464.0x^3 + 693.0x^4 + 857.0x^5 + 223.0x^6 + 224.0x^7 + 990.0x^8 + 1516.0x^9 + 1128.0x^10 + 661.0x^11 + 1464.0x^12 + 978.0x^13 + 815.0x^14 + 1508.0x^15)]
[ 5.  6.  7.  0.  1.  2.  3.  4.  5.  6.  7.  0. -1. -2. -3. -4.]
5


In [13]:
def KSKgen(lwe_sk:list[int], rlwe_sk:Ring, B:int, lweCC:LWE):
    qq = rlwe_sk.q
    dd = int(np.ceil(np.log(qq) / np.log(B)))

    ksk = []
    for _s in rlwe_sk.coeffs:
        matrix = []
        for v in range(B):
            row = []
            for j in range(dd):
                row.append(lweCC.encrypt((v * (B**j) * _s) % qq, lwe_sk))
            matrix.append(row)
        ksk.append(matrix)    
    return ksk

def KeySwitch(ctxt:RLWEctxt, ksk:list[list[LWEctxt]], B:int, lweCC:LWE) -> LWEctxt:
    a, b = ctxt
    dd = int(np.ceil(np.log(a.q) / np.log(B)))

    b_0      = int(b.coeffs[0])                                             # extract constant term
    a_coeffs = [a.coeffs[0]] + [(-x) % a.q for x in reversed(a.coeffs[1:])] # suit for negacyclic works

    tmp_ct = (np.zeros(a.n), b_0)
    for i, _a in enumerate(a_coeffs):
        tmpa = int(_a)
        for j in range(dd):
            v  = tmpa % B
            tmpa = tmpa // B

            tmp_ct = lweCC.sub(tmp_ct, ksk[i][v][j])
    
    return tmp_ct

In [14]:
std = 3.2

lwe  = LWE(n, q, std)
rlwe = RLWE(n, q, std)
rgsw = RGSW(n, q, std, B, ddd)

s_rlwe, _ = rlwe.keygen()
s_lwe     = lwe.keygen()

ksk      = KSKgen(s_lwe, s_rlwe, B_ks, lwe)

m1 = [i * 100 for i in range(8)]
m2 = [-i * 100 for i in range(8)]
m  = np.concatenate((m1, m2))

u = 5
monomial = np.zeros(n)
monomial[(n//2) - u] = 1
monomial = Ring(n, q, monomial)
rgsw_key = rgsw.encrypt(monomial, s_rlwe)

poly_acc  = Ring(n, q, m)
poly_zero = Ring(n, q, np.zeros(n))

result = rgsw.mult_rlwe((poly_zero, poly_acc), rgsw_key)

switched_key_result = KeySwitch(result, ksk, B_ks, lwe)
ptxt = lwe.decrypt(switched_key_result, s_lwe)
print("Ideal result         : ", u * 100)
print("Key switching result : ", ptxt)

# print(len(ksk))          # s_i
# print(len(ksk[0]))       # v
# print(len(ksk[0][0]))    # j
# print(len(ksk[0][0][0])) # tuple

Ideal result         :  500
Key switching result :  431.0


---
## ModSwitch

It occurs quite big noise

In [26]:
def ModSwitch(ctxt:RLWEctxt, Q:int, q:int) -> RLWEctxt:
    a, b = ctxt
    a_coeffs = [(np.round(coef*q/Q)) % q for coef in a.coeffs]
    b_coeffs = [(np.round(coef*q/Q)) % q for coef in b.coeffs]

    return (Ring(a.n, q, a_coeffs), Ring(b.n, q, b_coeffs))

def skModSwitch(sk:Ring, Q:int, q:int) -> Ring:
    s_coeffs = [(np.round(coef*q/Q)) % q for coef in sk.coeffs]

    return (Ring(s.n, q, s_coeffs))

In [42]:
std  = 3.2
# Q    = 12289 # For N = 1024
Q    = q * 2 # For N = 1024
N    = 1024  

rlwe = RLWE(N, Q, std)
s, _ = rlwe.keygen()
msg  = [i for i in range(N)]
poly = Ring(N, Q, msg)
ctxt = rlwe.encrypt(poly, s)

mod_switched_ctxt = ModSwitch(ctxt, Q, q)
result = rlwe.decrypt(mod_switched_ctxt, s)

print(msg)
print(crange(np.array(result.coeffs), q))



[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221,