In [None]:
from __future__ import print_function

import os, sys
module_path = os.path.abspath(os.path.join('../../..'))
sys.path.append(module_path)

import numpy as np
import math
import copy
import pandas as pd
from keras.utils import np_utils
from keras.datasets import mnist
import time
import pickle

from pycrcnn.he.he import TFHEnuFHE
from pycrcnn.he.tfhe_value import TFHEValue
from pycrcnn.he.alu import *

## Dataset

In [None]:
# Min-Max quantization function
def quantize_tensor(x, num_bits, min_val=None, max_val=None):
    if not min_val and not max_val: 
        min_val, max_val = x.min(), x.max()
    qmin = -2.**(num_bits-1)
    qmax = 2.**(num_bits-1) - 1.
   
    x = x - min_val          # Allineo tutto l'array in modo che parta da 0
    x /= (max_val - min_val) # Lo scalo tra 0 e 1    
    x *= (qmax - qmin)       # Lo scalo tra 0 e 16
    x -= qmax                # Lo sfaso tra -8 e 7
    q_x = x.astype(float).round().astype(int)
    
    return q_x

In [None]:
# Prepare Penguins dataset
penguins = pd.read_csv('../penguins_size.csv')
penguins = penguins.sample(frac=1, random_state=2)
penguins = penguins.dropna()

# Feature selection
x_train, y_train = penguins.loc[:, ["island", "culmen_length_mm", "flipper_length_mm", "body_mass_g"]].values, penguins.iloc[:, :1].values

# Encode labels
for i in range(len(y_train)):
    if y_train[i][0] == "Adelie":
        y_train[i][0] = 0
    elif y_train[i][0] == "Gentoo":
        y_train[i][0] = 1
    else:
        y_train[i][0] = 2

island = {}
countI = 0
for i in range(len(x_train)):
  # Island
  if x_train[i][0] in island:
      x_train[i][0] = island[x_train[i][0]]
  else:
      island[x_train[i][0]] = countI
      x_train[i][0] = countI
      countI += 1

# Quantize tensors with 4 bits
x_train[:, 1] = quantize_tensor(x_train[:, 1], 4)
x_train[:, 2] = quantize_tensor(x_train[:, 2], 4)
x_train[:, 3] = quantize_tensor(x_train[:, 3], 4)

# Split Train-Validation set
train, val = 150, 64
x_val, y_val = x_train[train:train+val], y_train[train:train+val]
x_test, y_test = x_train[train+val:], y_train[train+val:]
y_train = np_utils.to_categorical(y_train).astype(int)*8
x_train, y_train = x_train[:train], y_train[:train]

## HE Init

In [None]:
HE_client = TFHEnuFHE(16)

with open("res/secret_key", "rb") as f:
    HE_client.secret_key = HE_client.ctx.load_secret_key(f)
    
with open("res/cloud_key", "rb") as f:
    HE_client.cloud_key = HE_client.ctx.load_cloud_key(f)

cloud_key = HE_client.cloud_key
HE_client.generate_vm(cloud_key)

In [None]:
num1 = HE_client.encrypt(1)
num2 = HE_client.encode(6)
sum = num1+num2
mul = num1*num2

## EncNet Architecture

In [None]:
SHRT_MAX = 32767
SHRT_MIN = (-SHRT_MAX - 1 )

# Int square root
def isqrt(n):
    x = n
    y = (x + 1) // 2
    while y < x:
        x = y
        y = (x + n // x) // 2
    return x

In [None]:
# Encrypted PLA tanh Activation function
def encrypted_tanh(act_in, in_dim, out_dim):
    y_max, y_min = HE_client.encode(128), HE_client.encode(-127)
    intervals = HE_client.encode_matrix([128, 75, 32, -31, -74, -127])
    slopes_inv = HE_client.encode_matrix([128, 8, 2, 1, 2, 8, 128])
    act_out, act_grad_inv = np.full((act_in.shape[0], out_dim), y_max), np.full((act_in.shape[0], out_dim), slopes_inv[0])

    for i in range(len(act_in)):
      for j in range(len(act_in[i].squeeze())):
        val = act_in[i].squeeze()[j] / ((1 << 8) * in_dim)

        lt0 = val < intervals[0]
        act_out[i][j] = TFHEValue(HE_client.vm.gate_mux(lt0, (val / 4).value, act_out[i][j].value), val.vm, val.n_bits)
        act_grad_inv[i][j] = TFHEValue(HE_client.vm.gate_mux(lt0, slopes_inv[1].value, act_grad_inv[i][j].value), val.vm, val.n_bits)

        lt1 = val < intervals[1]
        act_out[i][j] = TFHEValue(HE_client.vm.gate_mux(lt1, val.value, act_out[i][j].value), val.vm, val.n_bits)
        act_grad_inv[i][j] = TFHEValue(HE_client.vm.gate_mux(lt1, slopes_inv[2].value, act_grad_inv[i][j].value), val.vm, val.n_bits)

        lt2 = val < intervals[2]
        act_out[i][j] = TFHEValue(HE_client.vm.gate_mux(lt2, (val * 2).value, act_out[i][j].value), val.vm, val.n_bits)
        act_grad_inv[i][j] = TFHEValue(HE_client.vm.gate_mux(lt2, slopes_inv[3].value, act_grad_inv[i][j].value), val.vm, val.n_bits)

        lt3 = val < intervals[3]
        act_out[i][j] = TFHEValue(HE_client.vm.gate_mux(lt3, val.value, act_out[i][j].value), val.vm, val.n_bits)
        act_grad_inv[i][j] = TFHEValue(HE_client.vm.gate_mux(lt3, slopes_inv[4].value, act_grad_inv[i][j].value), val.vm, val.n_bits)

        lt4 = val < intervals[4]
        act_out[i][j] = TFHEValue(HE_client.vm.gate_mux(lt4, (val / 4).value, act_out[i][j].value), val.vm, val.n_bits)
        act_grad_inv[i][j] = TFHEValue(HE_client.vm.gate_mux(lt4, slopes_inv[5].value, act_grad_inv[i][j].value), val.vm, val.n_bits)

        lt5 = val < intervals[5]
        act_out[i][j] = TFHEValue(HE_client.vm.gate_mux(lt5, y_min.value, act_out[i][j].value), val.vm, val.n_bits)
        act_grad_inv[i][j] = TFHEValue(HE_client.vm.gate_mux(lt5, slopes_inv[6].value, act_grad_inv[i][j].value), val.vm, val.n_bits)
        
    return act_out, act_grad_inv

In [None]:
# Encrypted L2 Loss Function
def encrypted_L2(y_true, net_out):
    loss = np.full((y_true.shape[0], y_true.shape[1]), HE_client.encode(0))
    for i in range(len(y_true)):
        for j in range(len(y_true[i])):
            loss[i][j] = net_out[i].squeeze()[j] - y_true[i][j]
    return loss

In [None]:
# Encrypted FC Layer
class EncryptedFCLayer:
    def __init__(self, in_dim, out_dim, last_layer = False):
        self.in_dim, self.out_dim = in_dim, out_dim
        self.last_layer = last_layer
        self.weights = np.zeros((in_dim, out_dim)).astype(int)
        self.bias = np.zeros((1, out_dim)).astype(int)
        self.DFA_weights = np.zeros((1, 1)).astype(int)

    def forward(self, fc_in):
        self.input = fc_in
        dot = (self.input @ self.weights) + self.bias
        output, self.act_grad_inv = encrypted_tanh(dot, self.in_dim, self.out_dim)
        return output

    def backward(self, loss, lr_inv):
        d_DFA = self.compute_dDFA(loss, lr_inv)

        weights_update = self.input.T @ d_DFA
        weights_update = weights_update / lr_inv
        weights_update = weights_update.reshape(self.in_dim, self.out_dim)

        if type(self.weights.squeeze()[0][0]) is not TFHEValue:
            self.weights = HE_client.encode_matrix(self.weights)

        self.weights -= weights_update

        ones = np.ones((len(d_DFA), 1)).astype(int)
        bias_update = d_DFA.T @ ones
        bias_update = bias_update.T / lr_inv

        if type(self.bias.squeeze()[0]) is not TFHEValue:
            self.bias = HE_client.encode_matrix(self.bias)

        self.bias -= bias_update

    def compute_dDFA(self, loss, lr_inv):
        if self.last_layer:
            d_DFA = np.divide(loss, self.act_grad_inv)
        else:
            if self.DFA_weights.shape[0] != loss.shape[1] and self.DFA_weights.shape[1] != self.weights.shape[1]: # 0 rows, 1 cols
                print("DFA not initialized!")
            dot = loss @ self.DFA_weights
            d_DFA = np.divide(dot, self.act_grad_inv)
        return d_DFA

In [None]:
# Encrypted TFHE-NN Network
class EncryptedNetwork:
    def __init__(self):
        self.layers = []
    
    # Add layer to network
    def add(self, layer):
        self.layers.append(layer)
    
    # Serialize the network
    def serialize(self):
        for l in self.layers:
            if hasattr(l, "weights"):
                l.weights = HE_client.serialize_matrix(l.weights)
                l.bias = HE_client.serialize_matrix(l.bias)
                l.act_grad_inv = None
                l.input = None
                if not l.last_layer:
                    l.DFA_weights = HE_client.serialize_matrix(l.DFA_weights)
    
    # Deserialize the network
    def deserialize(self):
        for l in self.layers:
            if hasattr(l, "weights"):
                l.weights = HE_client.deserialize_matrix(l.weights)
                l.bias = HE_client.deserialize_matrix(l.bias)
                if not l.last_layer:
                    l.DFA_weights = HE_client.deserialize_matrix(l.DFA_weights)
    
    # Test
    def test(self, x_test, y_test):
        corr = HE_client.encode(0)
        enc_x = HE_client.encrypt_matrix(x_test)
        enc_y = HE_client.encrypt_matrix(y_test)

        for j in range(len(x_test)):
            pred = self.predict(enc_x[j])
            corr = TFHEValue(HE_client.vm.gate_mux(pred == enc_y[j][0], (corr + 1).value, corr.value), corr.vm, corr.n_bits)
        return corr
    
    # Predict output
    def predict(self, input_data):
        output = np.expand_dims(input_data, axis=0)
        for layer in self.layers:
            output = layer.forward(output)
        return encrypted_argmax(output.squeeze())

    # Train the network
    def fit(self, x_train, y_train, epochs, mini_batch_size, lr_inv):
        for i in range(epochs):
            for j in range(int(len(x_train)/mini_batch_size)):
                idx_start = j * mini_batch_size
                idx_end = idx_start + mini_batch_size

                batch_in = HE_client.encrypt_matrix(x_train[idx_start:idx_end])
                batch_target = HE_client.encrypt_matrix(y_train[idx_start:idx_end])

                start_time = time.time()

                # Forward propagation
                for layer in self.layers:
                  batch_in = layer.forward(batch_in)
                fwd_out = batch_in

                end_time = time.time()

                print("End forward batch: " + repr(j))
                print("Computation time: ")
                hours, rem = divmod(end_time-start_time, 3600)
                minutes, seconds = divmod(rem, 60)
                print("{:0>2}:{:0>2}:{:05.2f}".format(int(hours),int(minutes),seconds))
                print("")

                # Loss
                loss = encrypted_L2(batch_target, fwd_out)
               
                start_time = time.time()

                # Backward propagation
                for layer in reversed(self.layers):
                    layer.backward(loss, lr_inv)

                end_time = time.time()

                print("End backward batch: " + repr(j))
                print("Computation time: ")
                hours, rem = divmod(end_time-start_time, 3600)
                minutes, seconds = divmod(rem, 60)
                print("{:0>2}:{:0>2}:{:05.2f}".format(int(hours),int(minutes),seconds))
                print("")

            print("End epoch: " + repr(i))
            print("")

## Experiments

In [None]:
## UPLOADE DFA WEIGHTS
DFA_weights = np.load("res/DFA_weights1.npy")

### Cross Validation

In [None]:
# Load TFHE-NN-1
with open("res/enc_net1.pkl", "rb") as f:
  enc_net1 = pickle.load(f)
  enc_corr1 = HE_client.deserialize(pickle.load(f))
  enc_net1.deserialize()

# Load TFHE-NN-2
with open("res/enc_net2.pkl", "rb") as f:
  enc_net2 = pickle.load(f)
  enc_corr2 = HE_client.deserialize(pickle.load(f))
  enc_net2.deserialize()

# Load TFHE-NN-3
with open("res/enc_net3.pkl", "rb") as f:
  enc_net3 = pickle.load(f)
  enc_corr3 = HE_client.deserialize(pickle.load(f))
  enc_net3.deserialize()

# Load TFHE-NN-4
with open("res/enc_net4.pkl", "rb") as f:
  enc_net4 = pickle.load(f)
  enc_corr4 = HE_client.deserialize(pickle.load(f))
  enc_net4.deserialize()

In [None]:
%%time
# Compute argmax
argmax = encrypted_argmax([enc_corr1, enc_corr2, enc_corr3, enc_corr4])

# Prepare control signal of MUX
control_signal = [HE_client.encode(0), HE_client.encode(0), HE_client.encode(0), HE_client.encode(0)]
for i in range(len(control_signal)):
  enc_i = HE_client.encode(i)
  control_signal[i] = TFHEValue(HE_client.vm.gate_mux(enc_i == argmax, HE_client.encode(1).value, control_signal[i].value), control_signal[i].vm, control_signal[i].n_bits)

In [None]:
# Group encrypted weights
W1 = [enc_net1.layers[0].weights, enc_net2.layers[0].weights, enc_net3.layers[0].weights, enc_net4.layers[0].weights]
B1 = [enc_net1.layers[0].bias, enc_net2.layers[0].bias, enc_net3.layers[0].bias, enc_net4.layers[0].bias]

W2 = [enc_net1.layers[1].weights, enc_net2.layers[1].weights, enc_net3.layers[1].weights, enc_net4.layers[1].weights]
B2 = [enc_net1.layers[1].bias, enc_net2.layers[1].bias, enc_net3.layers[1].bias, enc_net4.layers[1].bias]

In [None]:
%%time
# Encrypted Cross Validation
res_weights1 = np.full(W1[0].shape, HE_client.encode(0))
res_bias1 = np.full(B1[0].shape, HE_client.encode(0))
res_weights2 = np.full(W2[0].shape, HE_client.encode(0))
res_bias2 = np.full(B2[0].shape, HE_client.encode(0))

for i in range(len(control_signal)):
  res_weights1 = encrypted_mux_matrix(control_signal[i] == HE_client.encode(1), W1[i], res_weights1)
  res_bias1 = encrypted_mux_matrix(control_signal[i] == HE_client.encode(1), B1[i], res_bias1)
  res_weights2 = encrypted_mux_matrix(control_signal[i] == HE_client.encode(1), W2[i], res_weights2)
  res_bias2 = encrypted_mux_matrix(control_signal[i] == HE_client.encode(1), B2[i], res_bias2)

### Serialization

In [None]:
# Save decrypted weights
with open("res/cross_validated_weights.pkl", "wb") as f:
    pickle.dump(HE_client.decrypt_matrix(res_weights1), f)
    pickle.dump(HE_client.decrypt_matrix(res_bias1), f)
    pickle.dump(HE_client.decrypt_matrix(res_weights2), f)
    pickle.dump(HE_client.decrypt_matrix(res_bias2), f)

In [None]:
# Save serialized net
enc_CV_net = EncryptedNetwork()
enc_CV_net.add(EncryptedFCLayer(4, 2))
enc_CV_net.add(EncryptedFCLayer(2, 3, last_layer=True))

enc_CV_net.layers[0].DFA_weights = HE_client.encode_matrix(DFA_weights)

enc_CV_net.layers[0].weights = HE_client.encrypt_matrix(res_weights1)
enc_CV_net.layers[0].bias = HE_client.encrypt_matrix(res_bias1)
enc_CV_net.layers[1].weights = HE_client.encrypt_matrix(res_weights2)
enc_CV_net.layers[1].bias = HE_client.encrypt_matrix(res_bias2)

enc_CV_net.serialize()

with open("res/enc_cv_net.pkl", "wb") as f:
  pickle.dump(enc_CV_net, f)