In [1]:
%load_ext autoreload
%autoreload 2

In [17]:
import numpy as np
import pandas as pd

from hemul.cipher import *
from hemul.ciphertext import Plaintext
from hemul.scheme import *

from hemul.algorithms import Algorithms 
from hemul.stats import Statistics
from hemul.context import set_all

from math import sqrt
from numpy import polynomial as P
def _approx_sign(n):
    """
    Approxiate sign function in [-1,1]
    """
    p_t1 = P.Polynomial([1,0,-1])
    p_x  = P.Polynomial([0,1])

    def c_(i: int):
        return 1/4**i * math.comb(2*i,i)

    def term_(i: int):
        return c_(i) * p_x * p_t1**i

    poly = term_(0)
    for nn in range(1,n+1):
        poly += term_(nn)
    return poly

def approx_comp(x,v):
    return np.array(x > v).astype(float)

def get_bincount(arr, nnn, bmin, bmax):
    n_larger = [nnn]
    for i, bb in enumerate(range(bmin,bmax+1)):
        n_larger.append(np.sum(approx_comp(arr,bb+0.5)))

    bincount = np.array(n_larger)[:-1] - np.array(n_larger)[1:]
    return bincount

In [3]:
context, ev, encoder, encryptor, decryptor = set_all(30, 150, 12)
nslots = context.params.nslots

### Load data

빈 슬롯 == 0 이므로, category encoding할 때 1부터 시작해야함.   
Yes / No 라고 1/0으로 하지 말 것. 

In [4]:
data = pd.read_csv("./stat/WA_Fn-UseC_-HR-Employee-Attrition.csv")
new_label = {"Attrition": {"Yes":1, "No":0}}
data.replace(new_label , inplace = True)

ng1 = data.Attrition.nunique()
ng2 = data.JobSatisfaction.nunique()

ntot = len(data)
print(f"Enough slots? {ntot} < {context.params.nslots}")
V1 = encryptor.encrypt(data.Attrition.values+1) 
V2 = encryptor.encrypt(data.JobSatisfaction.values)

Enough slots? 1470 < 4096


In [5]:
bc1 = get_bincount(V1._arr, ntot, 1, 2)
bc2 = get_bincount(V2._arr, ntot, 1, 4)

print(bc1, bc2)

[1233.  237.] [289. 280. 442. 459.]


category 조합을 4 * 2 = 8 가지의 int로 표현

1. V1 = [0,1];  V2 = [0,1,2,3] 일 경우,  
4*v1 + v2 = [0,...,7]  :  empty 값인 0과 구별 어려움

2. V1 = [0,1];  V2 = [1,2,3,4] 일 경우,  
4*v1 + v2 = [1,...,8]  :  결과는 1 ~ 8로 딱 좋으나, V1 bincount때 여전히 0과 구별 어려움. 

3. V1 = [1,2];  V2 = [1,2,3,4] 일 경우,  
  4*v1 + v2 = [5,...,12]  : 분류 가능. 대신 bin의 min과 max를 따로 제공해야함. 


In [6]:
p_ng2 = Plaintext(np.repeat(ng2, ntot), nslots=nslots, logp=V1.logp)
ev.mult_by_plain(V1,p_ng2, inplace=True) ## V1 inplace로 변경됨. 
ev.rescale_next(V1)
print(V1._arr)
print(V2._arr)
cat_comb = ev.add(V1, V2)

[8. 4. 8. ... 0. 0. 0.]
[4. 2. 3. ... 0. 0. 0.]


각 category에 해당하는 값을 count 

대소비교

In [54]:
algo = Algorithms(ev, encoder)

if 불가하니 bin 갯수만큼의 질문을 반복하는 것 외에 다른 방법은 없는 듯 

In [59]:
contin = encryptor.encrypt(get_bincount(cat_comb._arr, ntot, 5, 12))
print(contin._arr[:10])

[223. 234. 369. 407.  66.  46.  73.  52.   0.   0.]


[a,b,c,d] * [A,B]

```
[a,b,c,d,#,#,#,#]  
        *  
[A,A,A,A,0,0,0,0]   
        +  
[#,#,#,#,a,b,c,d]  
        *  
[0,0,0,0,B,B,B,B]  
```

### col for multiplication

In [58]:
contin._arr[:10]

array([223., 234., 369., 407.,  66.,  46.,  73.,  52.,   0.,   0.])

## m x n array. 

[[a,b,c,d],
 [e,f,g,h],
 [i,j,k,l],
 ...]
 
최종적으로 필요한 것은 [a+e+i+..., b+f+j+..., ...] * n



```
  A  B  C  A  B  C  
+    A  B  C  A  B  C  
+       A  B  C  A  B  C  
==
     [a1+c1+b1, a2+b2+c2, a3+b3+c3]
```       
          

co



In [73]:
ctxt = ev.copy(contin)

for ii in range(1, ng1):
    ev.lrot(ctxt, ng2*ii, inplace=True)
    col_sum = ev.add(contin, ctxt) 
    print(col_sum._arr[:10])

mask_c = np.zeros(nslots)
mask_c[:ng2] =1 # Only first ng2 slots are valid
ev.mult_by_plain(col_sum, encoder.encode(mask_c), inplace=True)
col_tmp = ev.lrot(col_sum, -ng2, inplace=False)
ev.add(col_sum, col_tmp, inplace=True)

[289. 280. 442. 459.  66.  46.  73.  52.   0.   0.]
[223. 234. 369. 407.  66.  46.  73.  52.   0.   0.]


In [72]:
ctxt = ev.copy(contin)

ev.lrot(ctxt, ng2, inplace=True)
col_sum = ev.add(contin, ctxt) 
print(col_sum._arr[:10])

mask_c = np.zeros(nslots)
mask_c[:ng2] =1 # Only first ng2 slots are valid
ev.mult_by_plain(col_sum, encoder.encode(mask_c), inplace=True)
col_tmp = ev.lrot(col_sum, -ng2, inplace=False)
ev.add(col_sum, col_tmp, inplace=True)

[66. 46. 73. 52.  0.  0.  0.  0.  0.  0.]
[289. 280. 442. 459.  66.  46.  73.  52.   0.   0.]


In [64]:
col_sum._arr[:10]

array([289., 280., 442., 459., 289., 280., 442., 459.,   0.,   0.])

### row for multiplication

In [65]:
def _row_for_mult(row_sum, ntot, ng1, ng2, nslots, debug=False):
    # ng1 < ng2
    each_element=[]
    for i in range(ng1):
        mask = np.zeros(nslots)
        mask[i*ng2] = 1/ntot
        row_tmp = ev.mult_by_plain(row_sum, encoder.encode(mask), inplace=False)
        each_element.append(row_tmp)

    # check
    if debug:
        for ee in each_element:
            print(ee._arr[:10])

    result = each_element[0]
    for ee in each_element[1:]:
        ev.add(result, ee, inplace=True)

    # check
    if debug: print(result._arr[:10])

    for i in range(int(sqrt(ng2))):
        r_tmp = ev.copy(result)
        ev.lrot(r_tmp, -2**i)
        ev.add(result, r_tmp, inplace=True)
        # check
        if debug: print(result._arr[:10])
    return result

In [66]:
row_sum = algo.sum_reduce(contin, nsum=ng2) # 0-th and 4-th slots are valid
print(row_sum._arr[:10])
print(row_sum._arr[0], row_sum._arr[ng2])

[1233. 1076.  888.  592.  237.  171.  125.   52.    0.    0.]
1233.0 237.0


In [67]:
row_mult = _row_for_mult(row_sum, ntot, ng1, ng2, nslots)

In [68]:
ev.rescale_next(row_mult)
ev.rescale_next(col_sum)
E = ev.mult(row_mult, col_sum)
print(E.logp)
ev.rescale_next(E)

print(E._arr[:10])
print(contin._arr[:10])

60
[242.40612245 234.85714286 370.73877551 384.99795918  46.59387755
  45.14285714  71.26122449  74.00204082   0.           0.        ]
[223. 234. 369. 407.  66.  46.  73.  52.   0.   0.]


In [69]:
E._arr[:10]

array([242.40612245, 234.85714286, 370.73877551, 384.99795918,
        46.59387755,  45.14285714,  71.26122449,  74.00204082,
         0.        ,   0.        ])

In [70]:
#chi2
sub = ev.sub(contin, E, inplace=False)
sqr = ev.square(sub, inplace=False)
div = algo.divide(sqr,E)
chi2 = algo.sum_reduce(div, ng1*ng2)


In [71]:
chi2._arr

array([17.50507701,         nan,         nan, ...,         nan,
               nan,         nan])