In [1]:
%load_ext autoreload
%autoreload 2

In [1]:
import numpy as np
from math import sqrt, ceil

from hemul.cipher import *
from hemul.ciphertext import Plaintext
from hemul.scheme import *
from hemul.algorithms import Algorithms 
from hemul.context import set_all

In [2]:
context, ev, encoder, encryptor, decryptor = set_all(30, 450, 8)
nslots = context.params.nslots
coeff_modulus = [60,30,30,30,30,30,60]


근데 이게.. SEAL-specific한가? HEAAN에는 적용 X. OpenFHE에는?   
HECATE에서 다루는 문제도 비슷한데, HECATE는 arbitrary scale로 내리는 계산을 따로 만듦 (Downscale).   

> 중요한건 문제의 타입과 문제 해결에 필요한 정보




### Example 1) 기존 코드


In [3]:
x = np.array([1,2,3,4,5,6,7,8])
y = np.array([9,8,7,6,5,4,3,2])

def fun(x,y):
    return (x**2 + y**2)**3

print(fun(x,y))

[551368 314432 195112 140608 125000 140608 195112 314432]


### Example 1) FHE 버전

```Python
x = encryptor.encrypt([1,2,3,4,5,6,7,8])
y = encryptor.encrypt([9,8,7,6,5,4,3,2])

@compile_fhe
def fun(x:Ciphertext, y:Ciphertext):
    return (x**2 + y**2)**3
```

In [4]:
x = encryptor.encrypt([1,2,3,4,5,6,7,8])

def fun_v1(x):
    x2 = ev.square(x)
    ev.rescale_next(x2)
    ev.mod_down_to(x, x2.logq)
    return ev.add(x, x2)

def fun_v2(x):
    x2 = ev.square(x)
    p = Plaintext([1,1,1,1,1,1,1,1], logp=x.logp, logn=x.logn)
    x = ev.mult_by_plain(x, p)
    return ev.add(x, x2)


In [62]:
# 결과 확인
ev._counter.reset()
res1 = fun_v1(x)


In [63]:
cnt1 = ev._counter.get()
cnt1

{'multp': [],
 'multc': [{'logq': 420, 'logp': 30, 'logn': 8, 'ntt': 1}],
 'rot': [],
 'bootstrap': [],
 'mod_switch': [{'logq': 390, 'logp': 30, 'logn': 8, 'ntt': 1}],
 'rescale': [{'logq': 390, 'logp': 30, 'logn': 8, 'ntt': 0}],
 'ntt_switch': []}

In [64]:
ev._counter.reset()
res2 = fun_v2(x)
cnt2 = ev._counter.get()

In [65]:
cnt2

{'multp': [{'logq': 390, 'logp': 30, 'logn': 8, 'ntt': 1}],
 'multc': [{'logq': 390, 'logp': 30, 'logn': 8, 'ntt': 1}],
 'rot': [],
 'bootstrap': [],
 'mod_switch': [],
 'rescale': [],
 'ntt_switch': []}

In [68]:
from hemul.utils import Call_counter
#import sympy as sp 
#p,q,n = sp.symbols('p,q,n')

cost_functions = {
    "multp": lambda p,q,n: 3*q*n, # p is ignored, but to have a uniform interface.
    "multc": lambda p,q,n: p*q*n,
    "rot": lambda p,q,n: 0.5*p*q*n,
    "bootstrap": lambda p,q,n: p**1.5*q*n,  
    "mod_switch": lambda p,q,n: 5*q*n,
    "rescale": lambda p,q,n: 6*q*n,
    "ntt_switch": lambda p,q,n: 7*q*n,
    }

class FHECostEvaluator():
    """Total cost evaluator with arbitrary cost functions"""
    def __init__(self, cost_functions):
        self.cfs = cost_functions

    def total_cost(self, counter:Call_counter):
        summary = []
        tot_cost = 0
        for key in counter.keys():
            this_cost = self.eval_sum(key, counter.get(key))
            summary.append((key, this_cost))
            tot_cost += this_cost
            #print(f"{key} cost: {this_cost}")
        
        return tot_cost, summary

    def eval(self, op, cnt):
        return self.cfs[op](cnt['logp'], cnt['logq'], cnt['logn'])

    def eval_sum(self, op, cnt_list):
        return np.sum([self.eval(op, cnt) for cnt in cnt_list])


In [69]:
CEV = FHECostEvaluator(cost_functions)

In [70]:
tot_cnt1, summary1 = CEV.total_cost(cnt1)
tot_cnt2, summary2 = CEV.total_cost(cnt2)

In [71]:
tot_cnt1 

135120.0

In [72]:
tot_cnt2

102960.0

### Example 2 loop 돌아서 곱하기 많은 케이스

In [6]:
def newton_raphson_inv(a, number=1e-6, n_iters = 20):
    for i in range(n_iters): 
        q = (2-a*number)
        number = number*q
    return number

In [9]:
algo = Algorithms(ev, encoder)

# optimize 덜 된 버전
def newton_raphson_inv_fhe(ctxt, number = 1e-4, n_iters = 20):
    
    two = algo.encode_repeat(2) # [2,2,2,2,2,2,2,2,2,...]
    number = algo.encode_repeat(number)    
    
    q_ = ev.mult_by_plain(ctxt, number, inplace=False)
    ev.rescale_next(q_)
    sub_ = ev.add_plain(q_, two, inplace=False)
    number_ = ev.mult_by_plain(sub_, number, inplace=False)
    ev.rescale_next(number_)

    for i in range(1, n_iters):
        print(i, "1", number_.logp, number_.logq)
        tmp = ev.negate(number_, inplace=False)
        print(i, "1-1", tmp.logp, tmp.logq)
        print(i, "1-1-1", ctxt.logp, ctxt.logq)
        ev.match_mod(ctxt, tmp)
        q_ = ev.mult(ctxt, tmp, inplace=False)
        print(i, "1-2", q_.logp, q_.logq)
        ev.rescale_next(q_)
        
        sub_ = ev.add_plain(q_, two, inplace=False)
        ev.match_mod(number_, sub_)
        print(i, "2", number_.logp, number_.logq)
        ev.mult(number_, sub_, inplace=True)
        ev.rescale_next(number_)
        print(i, "3", number_.logp, number_.logq)
        if number_.logq < 2*number_.logp:
            ev.bootstrap(number_)
            print("Bootstrapping...")
    return number_

### 문제: scale이 바뀌는 모든 method 이전에 logq를 체크하지 않으면 자동으로 bootstrapping할 방법이 없음. 
1. 원래 OpenFHE도 그렇게 하는 걸까?
2. 컴파일러로 하려면 SSA인 상태로 변수를 계속 따라가는 것? 
3. 컴파일러로 하면 장점:
   1. 정해진 logq에 따라 bootstrapping을 하는 것이 아니고 bootstrapping 개수를 최적화하는 적당한 logq를 정할 수 있음 

In [11]:
#  컴파일러가 필요하겠어...! 
def newton_raphson_inv_fhe(ctxt, number = 1e-4, n_iters = 20):
    two = algo.encode_repeat(2) # [2,2,2,2,2,2,2,2,2,...]
    number = algo.encode_repeat(number)    
    
    q_ = ev.mult_by_plain(ctxt, number, inplace=False)
    ev.rescale_next(q_)
    sub_ = ev.add_plain(q_, two, inplace=False)
    number_ = ev.mult_by_plain(sub_, number, inplace=False)
    ev.rescale_next(number_)

    for i in range(1, n_iters):
        print(i, "1", number_.logp, number_.logq)
        tmp = ev.negate(number_, inplace=False)
        print(i, "1-1", tmp.logp, tmp.logq)
        print(i, "1-1-1", ctxt.logp, ctxt.logq)
        ctxt_ = ev.copy(ctxt)
        ev.match_mod(ctxt_, tmp)
        q_ = ev.mult(ctxt_, tmp, inplace=False)
        print(i, "1-2", q_.logp, q_.logq)
        ev.rescale_next(q_)
        
        sub_ = ev.add_plain(q_, two, inplace=False)
        ev.match_mod(number_, sub_)
        print(i, "2", number_.logp, number_.logq)
        ev.mult(number_, sub_, inplace=True)
        ev.rescale_next(number_)
        print(i, "3", number_.logp, number_.logq)
        if number_.logq < 2*number_.logp:
            ev.bootstrap(number_)
            print("Bootstrapping...")
    return number_

In [12]:
xarr = np.array([10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60])
result = newton_raphson_inv(xarr)

ctxt = encryptor.encrypt(xarr)
result_he = newton_raphson_inv_fhe(ctxt)

1 1 30 390
1 1-1 30 390
1 1-1-1 30 450
1 1-2 60 390
1 2 30 360
1 3 30 330
2 1 30 330
2 1-1 30 330
2 1-1-1 30 450
2 1-2 60 330
2 2 30 300
2 3 30 270
3 1 30 270
3 1-1 30 270
3 1-1-1 30 450
3 1-2 60 270
3 2 30 240
3 3 30 210
4 1 30 210
4 1-1 30 210
4 1-1-1 30 450
4 1-2 60 210
4 2 30 180
4 3 30 150
5 1 30 150
5 1-1 30 150
5 1-1-1 30 450
5 1-2 60 150
5 2 30 120
5 3 30 90
6 1 30 90
6 1-1 30 90
6 1-1-1 30 450
6 1-2 60 90
6 2 30 60
6 3 30 30
Bootstrapping...
7 1 30 420
7 1-1 30 420
7 1-1-1 30 450
7 1-2 60 420
7 2 30 390
7 3 30 360
8 1 30 360
8 1-1 30 360
8 1-1-1 30 450
8 1-2 60 360
8 2 30 330
8 3 30 300
9 1 30 300
9 1-1 30 300
9 1-1-1 30 450
9 1-2 60 300
9 2 30 270
9 3 30 240
10 1 30 240
10 1-1 30 240
10 1-1-1 30 450
10 1-2 60 240
10 2 30 210
10 3 30 180
11 1 30 180
11 1-1 30 180
11 1-1-1 30 450
11 1-2 60 180
11 2 30 150
11 3 30 120
12 1 30 120
12 1-1 30 120
12 1-1-1 30 450
12 1-2 60 120
12 2 30 90
12 3 30 60
13 1 30 60
13 1-1 30 60
13 1-1-1 30 450
13 1-2 60 60
13 2 30 30


AssertionError: no more noise budget! do bootstrapping

In [11]:
np.isclose(result_he._arr[:8], result[:8], atol=1e-5)

array([ True,  True,  True,  True,  True,  True,  True,  True])

### 3. rotation과 NTT 변환 억제

matrix 곱하기? 