In [1]:
%load_ext autoreload
%autoreload 2

In [1]:
from muxcnn.utils import load_params, load_img
import torch
from muxcnn.resnet_muxconv import ResNet_MuxConv

## Torch model

In [2]:
from muxcnn.models.ResNet20 import ResNet, BasicBlock

model = ResNet(BasicBlock,[1,1,1])
model.eval() ########## 필수! 
load_params(model, fn_param="./ResNet8.pt",device='cpu')

In [3]:
img_tensor = load_img("./cute.jpg", hi=32, wi=32)

In [4]:
print(model(img_tensor))

before softmax tensor([[ -9.3005, -11.0586,  -2.9741,  -2.9179,  -9.3500,   4.1474,  -6.8721,
          -3.6999,  -8.9446,  -7.6901]], grad_fn=<AddmmBackward0>)
tensor([[-1.3450e+01, -1.5208e+01, -7.1236e+00, -7.0674e+00, -1.3500e+01,
         -2.0791e-03, -1.1022e+01, -7.8494e+00, -1.3094e+01, -1.1840e+01]],
       grad_fn=<LogSoftmaxBackward0>)


# MuxedCNN

In [5]:
muxed_model = ResNet_MuxConv(model, alpha=12)

functions set
degrees = [15, 15, 15, 15], margin = 0.01, eps = 0.02


In [6]:
from muxcnn.utils import get_channel_last, get_conv_params
from muxcnn.hecnn_par import MultParPack
imgl = get_channel_last(img_tensor[0].detach().numpy())
ki = 1 # initial ki
hi, wi, ch = imgl.shape

# early conv and bn
_, ins0, outs0 = get_conv_params(model.conv1, {'k':ki, 'h':hi, 'w':wi})
ct_a = MultParPack(imgl, ins0)

In [7]:
# Muxed
U, ins1, outs1 = get_conv_params(model.conv1, {'k':ki, 'h':hi, 'w':wi})
ct_a = MultParPack(imgl, ins1)
from muxcnn.hecnn_par import forward_convbn_par
out1, un1 = forward_convbn_par(model.conv1, 
                              model.bn1, ct_a, ins1)

[MultParConv] (hi,wi,ci,ki,ti,pi) =(32,32, 3, 1, 3,  8)
[MultParConv] (ho,wo,co,ko,to,po) =(32,32,16, 1,16,  2)
[MultParConv] q = 2
2.2703 s


# FHE Context

In [8]:
from muxcnn.resnet_fhe import ResNetFHE
import numpy as np

from hemul.cipher import *
from hemul.scheme import *
#from hemul.context import set_all

class Param():
    def __init__(self, n=None, logn=None, logp=None, logq=None, logQboot=None):
        self.n = n
        self.logn = logn
        self.logp = logp
        self.logq = logq 
        self.logQboot = logQboot
        if self.logn == None:
            self.logn = int(np.log2(n))

import hemul.HEAAN as he
def set_all(logp = 30, logq = 540, logn = 15):
    """
    Do I have to return context?? 
    Need to check all required rotations first 
    (running emulator can do the job)

    """
    n = 1*2**logn
    slots = n
    parms = Param(n=n, logp=logp, logq=logq)
    do_reduction=False

    ring = he.Ring()
    secretKey = he.SecretKey(ring)
    scheme = he.Scheme(secretKey, ring, False)

    algo = he.SchemeAlgo(scheme)

    # reduction때는 right rotation N_class개 필요. 
    ######
    ######


    if do_reduction:
        Nclass = Nmodel.head.shape[0]
        scheme.addLeftRotKeys(secretKey)
        for i in range(Nclass):
            scheme.addRightRotKey(secretKey, i+1) # 
    else:
        # reduction 안 하면 하나짜리 rotation만 여러번 반복.
        scheme.addLeftRotKey(secretKey, 1)

    return 


context, ev, encoder, encryptor, decryptor = set_all(30, 900, 15)
nslots = context.params.nslots

cwd =  /home/hoseung/Work/MuxConv/scripts
binding to HEAAN
FHE context is set


In [9]:
fhemodel = ResNetFHE(model)
fhemodel.set_agents(context, ev, encoder, encryptor)

functions set
degrees = [15, 15, 15, 15], margin = 0.01, eps = 0.02


In [12]:
# FHE
ctx_a = encryptor.encrypt(ct_a)
ctxt = fhemodel.forward_convbn_par_fhe(model.conv1, 
                          model.bn1, ctx_a, ins0)

[MultParConv] (hi,wi,ci,ki,ti,pi) =(32,32, 3, 1, 3,  8)
[MultParConv] (ho,wo,co,ko,to,po) =(32,32,16, 1,16,  2)
[MultParConv] q = 2


In [13]:
print(ctxt._arr[2000:2050] == out1[2000:2050])

[ True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  True  True  True  True
  True  True]


## Activation

In [24]:
from hemul.comparator_fhe import ApprRelu_FHE
from hemul.comparator import ApprRelu

In [31]:
appr_relu = ApprRelu(xmin=-10, xmax=10, min_depth=True)
out1 = appr_relu(out1)
print(out1)

functions set
degrees = [15, 15, 15, 15], margin = 0.01, eps = 0.02
[ 1.23083757e+00  4.86967574e-01  2.75344729e-01 ... -5.07712600e-10
  7.09781745e-10  1.15660812e-10]


In [36]:
appr = ApprRelu_FHE(ev, xmin=-10, xmax=10, min_depth=True)
activated = appr(ctxt)
print(activated._arr)

functions set
degrees = [15, 15, 15, 15], margin = 0.01, eps = 0.02
[ 1.23083757e+00  4.86967574e-01  2.75344729e-01 ... -5.07710431e-10
  7.09780806e-10  1.15659451e-10]


In [43]:
np.all(np.isclose(out1,activated._arr))

True

In [None]:
def forward_early(self, img_tensor):
    model = self.torch_model
    imgl = get_channel_last(img_tensor[0].detach().numpy())
    ki = 1 # initial ki
    hi, wi, ch = imgl.shape

    # early conv and bn
    _, ins0, outs0 = get_conv_params(model.conv1, {'k':ki, 'h':hi, 'w':wi})
    ct_a = MultParPack(imgl, ins0)
    ctxt, un1 = forward_convbn_par(model.conv1, 
                                   model.bn1, ct_a, ins0)
    ctxt = self.activation(ctxt)
    return ctxt, outs0 

In [21]:
torch.tensor(result[::64][:10])

tensor([ -9.2678, -10.8150,  -3.1036,  -2.9026,  -9.4616,   4.3280,  -6.9752,
         -3.5619,  -8.8842,  -7.5630], dtype=torch.float64)

In [10]:
torch.tensor(result[::64][:10])

tensor([ -9.2678, -10.8150,  -3.1036,  -2.9026,  -9.4616,   4.3280,  -6.9752,
         -3.5619,  -8.8842,  -7.5630], dtype=torch.float64)