## Levels


### parms_id (MS-SEAL-specific)
256-bit has to identify a set of parameters.

### modulus switching chain (MS-SEAL-specific)
On creating a SEALContext, a 'modulus switching chain' is also generated. It's a chain of parameters where the size of **coefficient modulus is decreasing** going down the chain.  
계산의 validity를 확인하기 위한 도구. 계산이 진행됨에 따라 chain을 타고 내려가다가 parameter set이 invalid해지면 멈추는 용도.  
각각의 chain은 chain index가 있고 또 고유의 parms_id가 있음.  
Chain의 앞에 있으면 = Chain index가 크면 = high level이라고 부르기로 정함. :: level == chain index

In [1]:
from fase import seal
import numpy as np

In [2]:
parms = seal.EncryptionParameters(seal.scheme_type.bfv)
poly_modulus_degree = 8192
parms.set_poly_modulus_degree(poly_modulus_degree)
coeff_modulus_bits = [50, 30, 30, 50, 50]
parms.set_coeff_modulus(seal.CoeffModulus.Create(poly_modulus_degree, coeff_modulus_bits))

parms.set_plain_modulus(seal.PlainModulus.Batching(poly_modulus_degree, 20))

context = seal.SEALContext(parms)

In [3]:
print(f"Max bit count = {seal.CoeffModulus.MaxBitCount(8192)} > {np.sum(coeff_modulus_bits)}")

Max bit count = 218 > 210


coeff_modulus 순서가 매우 중요함. 

              special prime +---------+
                                      |
                                      v
    coeff_modulus: { 50, 30, 30, 50, 50 }  +---+  Level 4 (all keys; `key level')
                                               |
                                               |
        coeff_modulus: { 50, 30, 30, 50 }  +---+  Level 3 (highest `data level')
                                               |
                                               |
            coeff_modulus: { 50, 30, 30 }  +---+  Level 2
                                               |
                                               |
                coeff_modulus: { 50, 30 }  +---+  Level 1
                                               |
                                               |
                    coeff_modulus: { 50 }  +---+  Level 0 (lowest level)

Special prime은 key들 계산할 때 쓰이며, coeff_modulus 중에서 가장 크게 잡는 것을 권장.  
Ciphertext들은 그 다음 level의 coeff_modulus부터 사용함. 

In [4]:
context_data = context.key_context_data()
print("----> Level (chain index): ",context_data.chain_index())
print(" ...... key_context_data()")
print("      parms_id: ", context_data.parms_id())
print("      coeff_modulus primes: ")

----> Level (chain index):  4
 ...... key_context_data()
      parms_id:  [2796926214238341906, 7385196832706630708, 1778331432907072121, 9574839751679865602]
      coeff_modulus primes: 


In [7]:
while context_data:
    print("----> Level (chain index): ",context_data.chain_index())
    if context_data.parms_id() == context.first_parms_id():
        print(".... first_context_data()")
    elif context_data.parms_id() == context.last_parms_id():
        print(".... last_context_data()")
    else:
        print("      parms_id: ", context_data.parms_id())
        print("      coeff_modulus primes: ")
        
    context_data = context_data.next_context_data()
print("End of chanin reached")

----> Level (chain index):  4
      parms_id:  [2796926214238341906, 7385196832706630708, 1778331432907072121, 9574839751679865602]
      coeff_modulus primes: 
----> Level (chain index):  3
.... first_context_data()
----> Level (chain index):  2
      parms_id:  [9611362035343820607, 16250750482064473934, 11666325943188447289, 11197906540159540193]
      coeff_modulus primes: 
----> Level (chain index):  1
      parms_id:  [8338375294373729721, 6255090713888968186, 9221042086196212239, 13262945982911515329]
      coeff_modulus primes: 
----> Level (chain index):  0
.... last_context_data()
End of chanin reached


In [8]:
keygen = seal.KeyGenerator(context)

secret_key = keygen.secret_key()
public_key = keygen.create_public_key()
relin_keys = keygen.create_relin_keys()

In [9]:
print(secret_key.parms_id(), public_key.parms_id(), relin_keys.parms_id())

[2796926214238341906, 7385196832706630708, 1778331432907072121, 9574839751679865602] [2796926214238341906, 7385196832706630708, 1778331432907072121, 9574839751679865602] [2796926214238341906, 7385196832706630708, 1778331432907072121, 9574839751679865602]


4 level이 있으므로, key도 4개씩 

In [10]:
encryptor = seal.Encryptor(context, public_key)
evaluator = seal.Evaluator(context)
decryptor = seal.Decryptor(context, secret_key)

In [12]:
ptxt = seal.Plaintext("1x^3 + 2x^2 + 3x^1 + 4") # 이렇게 polynomial 식으로 넣어도 됨.
enc = encryptor.encrypt(ptxt)

print(ptxt.parms_id(), "Not set in BFV")
print(enc.parms_id())

[0, 0, 0, 0] Not set in BFV
[2386594185047272216, 3177129986462177089, 5264335169243394227, 3608211254040463884]


## Modulus switching

위에서 확인한 여러 level의 parameter chain을 활용.
#### <modulus switching을 쓰는 이유>
ctxt의 size를 줄여서 계산을 빠르게 함... modulus switch는 BFV보다 CKKS에서 훨씬 중요... Tutorial 4 참고. 

`Evaluator.mod_switch_to_next()`

In [14]:
context_data = context.first_context_data()

while context_data.next_context_data():
    print("----> Level (chain index): ",context_data.chain_index())
    print("      parms_id: ", context_data.parms_id())
    print("      Noise budget at this level", decryptor.invariant_noise_budget(enc))
    
    print("Modulus switching")
    evaluator.mod_switch_to_next_inplace(enc)
    context_data = context_data.next_context_data()
    

print("----> Level (chain index): ",context_data.chain_index())
print("      parms_id: ", context_data.parms_id())
print("      Noise budget at this level", decryptor.invariant_noise_budget(enc))
print("End of chanin reached")

----> Level (chain index):  3
      parms_id:  [2386594185047272216, 3177129986462177089, 5264335169243394227, 3608211254040463884]
      Noise budget at this level 132
Modulus switching
----> Level (chain index):  2
      parms_id:  [9611362035343820607, 16250750482064473934, 11666325943188447289, 11197906540159540193]
      Noise budget at this level 82
Modulus switching
----> Level (chain index):  1
      parms_id:  [8338375294373729721, 6255090713888968186, 9221042086196212239, 13262945982911515329]
      Noise budget at this level 52
Modulus switching
----> Level (chain index):  0
      parms_id:  [12645946865612918007, 3410116064097512307, 265341692546732382, 11895432757488484810]
      Noise budget at this level 22
End of chanin reached


Noise budget이 (앞으로 계산이 필요한 만큼) 점진적으로 줄어들었음. ctxt의 크기가 coefficient modulus에 linearly depend하므로 그만큼 noise consumption을 줄일 수 있음.   

In [19]:
enc_x = encryptor.encrypt(ptxt)
print("Compute the 8th power")
print("Noise budget fresh:", decryptor.invariant_noise_budget(enc_x))
evaluator.square_inplace(enc_x)
evaluator.relinearize_inplace(enc_x, relin_keys)
print("Noise budget of the 2nd power:", decryptor.invariant_noise_budget(enc_x))
evaluator.square_inplace(enc_x)
evaluator.relinearize_inplace(enc_x, relin_keys)
print("Noise budget of the 4nd power:", decryptor.invariant_noise_budget(enc_x))

Compute the 8th power
Noise budget fresh: 132
Noise budget of the 2nd power: 100
Noise budget of the 4nd power: 67


몇 번의 계산 후에는 modulus swith해도 noise budget이 더 안 내려감. (67 -> 67)  
Q:'몇 번의' 계산이란 무엇인가..? 

In [20]:
evaluator.mod_switch_to_next_inplace(enc_x)
print("Noise budget afte modulus switching:", decryptor.invariant_noise_budget(enc_x))

Noise budget afte modulus switching: 67


In [21]:
evaluator.square_inplace(enc_x)
evaluator.relinearize_inplace(enc_x, relin_keys)
print("Noise budget of the 8th power", decryptor.invariant_noise_budget(enc_x))
evaluator.mod_switch_to_next_inplace(enc_x)
print("Noise budget after modulus switching", decryptor.invariant_noise_budget(enc_x))


Noise budget of the 8th power 34
Noise budget after modulus switching 34


In [22]:
recovered = decryptor.decrypt(enc_x)
print("result (hex):", recovered.to_string())
# 맞겠지 뭐...

result (hex): 1x^24 + 10x^23 + 88x^22 + 330x^21 + EFCx^20 + 3A30x^19 + C0B8x^18 + 22BB0x^17 + 58666x^16 + C88D0x^15 + 9C377x^14 + F4C0Ex^13 + E8B38x^12 + 5EE89x^11 + F8BFFx^10 + 30304x^9 + 5B9D4x^8 + 12653x^7 + 4DFB5x^6 + 879F8x^5 + 825FBx^4 + F1FFEx^3 + 3FFFFx^2 + 60000x^1 + 10000


마지막으로, modulus chain을 안 만들고싶으면   
`context = SEALContext(parms, False)`


끝.