In [1]:
from __future__ import print_function

import os, sys
module_path = os.path.abspath(os.path.join('../../..'))
sys.path.append(module_path)

import numpy as np
import math
import copy
import pandas as pd
from keras.utils import np_utils
from keras.datasets import mnist
import time
import pickle

from pycrcnn.he.he import TFHEnuFHE
from pycrcnn.he.tfhe_value import TFHEValue
from pycrcnn.he.alu import *

2023-01-25 09:58:00.255721: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-01-25 09:58:00.531991: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-01-25 09:58:01.645055: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2023-01-25 09:58:01.645134: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] 

## HE Init

In [2]:
HE_client = TFHEnuFHE(22)

with open("res/keys/secret_key", "rb") as f:
    HE_client.secret_key = HE_client.ctx.load_secret_key(f)
    
with open("res/keys/cloud_key", "rb") as f:
    HE_client.cloud_key = HE_client.ctx.load_cloud_key(f)

cloud_key = HE_client.cloud_key
HE_client.generate_vm(cloud_key)

In [3]:
num1 = HE_client.encrypt(1)
num2 = HE_client.encode(6)
sum = num1+num2
mul = num1*num2

## EncNet Architecture

In [4]:
SHRT_MAX = 32767
SHRT_MIN = (-SHRT_MAX - 1 )

# Int square root
def isqrt(n):
    x = n
    y = (x + 1) // 2
    while y < x:
        x = y
        y = (x + n // x) // 2
    return x

In [5]:
# Encrypted PLA tanh Activation function
def encrypted_tanh(act_in, in_dim, out_dim):
    y_max, y_min = HE_client.encode(128), HE_client.encode(-127)
    intervals = HE_client.encode_matrix([128, 75, 32, -31, -74, -127])
    slopes_inv = HE_client.encode_matrix([128, 8, 2, 1, 2, 8, 128])
    act_out, act_grad_inv = np.full((act_in.shape[0], out_dim), y_max), np.full((act_in.shape[0], out_dim), slopes_inv[0])

    for i in range(len(act_in)):
        for j in range(len(act_in[i].squeeze())):
            val = act_in[i].squeeze()[j] / ((1 << 8) * in_dim)

            lt0 = val < intervals[0]
            act_out[i][j] = TFHEValue(HE_client.vm.gate_mux(lt0, (val / 4).value, act_out[i][j].value), val.vm, val.n_bits)
            act_grad_inv[i][j] = TFHEValue(HE_client.vm.gate_mux(lt0, slopes_inv[1].value, act_grad_inv[i][j].value), val.vm, val.n_bits)

            lt1 = val < intervals[1]
            act_out[i][j] = TFHEValue(HE_client.vm.gate_mux(lt1, val.value, act_out[i][j].value), val.vm, val.n_bits)
            act_grad_inv[i][j] = TFHEValue(HE_client.vm.gate_mux(lt1, slopes_inv[2].value, act_grad_inv[i][j].value), val.vm, val.n_bits)

            lt2 = val < intervals[2]
            act_out[i][j] = TFHEValue(HE_client.vm.gate_mux(lt2, (val * 2).value, act_out[i][j].value), val.vm, val.n_bits)
            act_grad_inv[i][j] = TFHEValue(HE_client.vm.gate_mux(lt2, slopes_inv[3].value, act_grad_inv[i][j].value), val.vm, val.n_bits)

            lt3 = val < intervals[3]
            act_out[i][j] = TFHEValue(HE_client.vm.gate_mux(lt3, val.value, act_out[i][j].value), val.vm, val.n_bits)
            act_grad_inv[i][j] = TFHEValue(HE_client.vm.gate_mux(lt3, slopes_inv[4].value, act_grad_inv[i][j].value), val.vm, val.n_bits)

            lt4 = val < intervals[4]
            act_out[i][j] = TFHEValue(HE_client.vm.gate_mux(lt4, (val / 4).value, act_out[i][j].value), val.vm, val.n_bits)
            act_grad_inv[i][j] = TFHEValue(HE_client.vm.gate_mux(lt4, slopes_inv[5].value, act_grad_inv[i][j].value), val.vm, val.n_bits)

            lt5 = val < intervals[5]
            act_out[i][j] = TFHEValue(HE_client.vm.gate_mux(lt5, y_min.value, act_out[i][j].value), val.vm, val.n_bits)
            act_grad_inv[i][j] = TFHEValue(HE_client.vm.gate_mux(lt5, slopes_inv[6].value, act_grad_inv[i][j].value), val.vm, val.n_bits)
        
    return act_out, act_grad_inv

In [6]:
# Encrypted L2 Loss Function
def encrypted_L2(y_true, net_out):
    loss = np.full((y_true.shape[0], y_true.shape[1]), HE_client.encode(0))
    for i in range(len(y_true)):
        for j in range(len(y_true[i])):
            loss[i][j] = net_out[i].squeeze()[j] - y_true[i][j]
    return loss

In [7]:
# Encrypted MaxPool Layer
class EncryptedMaxPoolLayer:
    def __init__(self, kernel_size, stride=(1, 1)):
        self.kernel_size = kernel_size
        self.stride = stride

    def forward(self, batch):
        return np.array([_max(image, self.kernel_size, self.stride) for image in batch])

    def backward(self, loss, lr_inv):
        return loss

def _max(image, kernel_size, stride):
    x_s = stride[1]
    y_s = stride[0]

    x_k = kernel_size[1]
    y_k = kernel_size[0]

    # print(image)
    x_d = len(image[0])
    y_d = len(image)

    x_o = ((x_d - x_k) // x_s) + 1
    y_o = ((y_d - y_k) // y_s) + 1

    def get_submatrix(matrix, x, y):
        index_row = y * y_s
        index_column = x * x_s
        return matrix[index_row: index_row + y_k, index_column: index_column + x_k]

    return [[encrypted_max(get_submatrix(image, x, y).flatten()) for x in range(0, x_o)] for y in range(0, y_o)]

In [8]:
# Encrypted Flatten Layer
class EncryptedFlattenLayer:
    def __init__(self):
        pass

    def forward(self, flatten_in):
        return flatten_in.reshape(flatten_in.shape[0], flatten_in.shape[1]*flatten_in.shape[2])

    def backward(self, loss, lr_inv):
        return loss

In [9]:
# Encrypted FC Layer
class EncryptedFCLayer:
    def __init__(self, in_dim, out_dim, last_layer = False):
        self.in_dim, self.out_dim = in_dim, out_dim
        self.last_layer = last_layer
        self.weights = np.zeros((in_dim, out_dim)).astype(int)
        self.bias = np.zeros((1, out_dim)).astype(int)
        self.DFA_weights = np.zeros((1, 1)).astype(int)

    def forward(self, fc_in):
        self.input = fc_in
        dot = (self.input @ self.weights) + self.bias
        output, self.act_grad_inv = encrypted_tanh(dot, self.in_dim, self.out_dim)
        return output

    def backward(self, loss, lr_inv):
        d_DFA = self.compute_dDFA(loss, lr_inv)

        weights_update = self.input.T @ d_DFA
        weights_update = weights_update / lr_inv
        weights_update = weights_update.reshape(self.in_dim, self.out_dim)

        if type(self.weights.squeeze()[0][0]) is not TFHEValue:
            self.weights = HE_client.encode_matrix(self.weights)

        self.weights -= weights_update

        ones = np.ones((len(d_DFA), 1)).astype(int)
        bias_update = d_DFA.T @ ones
        bias_update = bias_update.T / lr_inv

        if type(self.bias.squeeze()[0]) is not TFHEValue:
            self.bias = HE_client.encode_matrix(self.bias)

        self.bias -= bias_update

    def compute_dDFA(self, loss, lr_inv):
        if self.last_layer:
            d_DFA = np.divide(loss, self.act_grad_inv)
        else:
            if self.DFA_weights.shape[0] != loss.shape[1] and self.DFA_weights.shape[1] != self.weights.shape[1]: # 0 rows, 1 cols
                print("DFA not initialized!")
            dot = loss @ self.DFA_weights
            d_DFA = np.divide(dot, self.act_grad_inv)
        return d_DFA

In [10]:
# Encrypted Network
class EncryptedNetwork:
    def __init__(self):
        self.layers = []
    
    # Add layer to network
    def add(self, layer):
        self.layers.append(layer)
    
    # Serialize the network
    def serialize(self):
        for l in self.layers:
            if hasattr(l, "weights"):
                l.weights = HE_client.serialize_matrix(l.weights)
                l.bias = HE_client.serialize_matrix(l.bias)
                l.act_grad_inv = None
                l.input = None
                if not l.last_layer:
                    l.DFA_weights = HE_client.serialize_matrix(l.DFA_weights)
    
    # Deserialize the network
    def deserialize(self):
        for l in self.layers:
            if hasattr(l, "weights"):
                l.weights = HE_client.deserialize_matrix(l.weights)
                l.bias = HE_client.deserialize_matrix(l.bias)
                if not l.last_layer:
                    l.DFA_weights = HE_client.deserialize_matrix(l.DFA_weights)
    
    # Test
    def test(self, x_test, y_test):
        corr = HE_client.encode(0)
        enc_x = HE_client.encrypt_matrix(x_test)
        enc_y = HE_client.encrypt_matrix(y_test)

        for j in range(len(x_test)):
            pred = self.predict(enc_x[j])
            corr = TFHEValue(HE_client.vm.gate_mux(pred == enc_y[j][0], (corr + 1).value, corr.value), corr.vm, corr.n_bits)
        return corr
    
    # Predict output
    def predict(self, input_data):
        output = np.expand_dims(input_data, axis=0)
        for layer in self.layers:
            output = layer.forward(output)
        return encrypted_argmax(output.squeeze())

    # Train the network
    def fit(self, x_train, y_train, epochs, mini_batch_size, lr_inv):
        for i in range(epochs):
            for j in range(int(len(x_train)/mini_batch_size)):
                idx_start = j * mini_batch_size
                idx_end = idx_start + mini_batch_size

                batch_in = HE_client.encrypt_matrix(x_train[idx_start:idx_end])
                batch_target = HE_client.encrypt_matrix(y_train[idx_start:idx_end])

                start_time = time.time()

                # Forward propagation
                for layer in self.layers:
                  batch_in = layer.forward(batch_in)
                fwd_out = batch_in

                end_time = time.time()

                print("End forward batch: " + repr(j))
                print("Computation time: ")
                hours, rem = divmod(end_time-start_time, 3600)
                minutes, seconds = divmod(rem, 60)
                print("{:0>2}:{:0>2}:{:05.2f}".format(int(hours),int(minutes),seconds))
                print("")

                # Loss
                loss = encrypted_L2(batch_target, fwd_out)
               
                start_time = time.time()

                # Backward propagation
                for layer in reversed(self.layers):
                    layer.backward(loss, lr_inv)

                end_time = time.time()

                print("End backward batch: " + repr(j))
                print("Computation time: ")
                hours, rem = divmod(end_time-start_time, 3600)
                minutes, seconds = divmod(rem, 60)
                print("{:0>2}:{:0>2}:{:05.2f}".format(int(hours),int(minutes),seconds))
                print("")

            print("End epoch: " + repr(i))
            print("")

## Experiments

In [11]:
## UPLOAD DFA WEIGHTS
DFA_weights1 = np.load("res/dfa/DFA_weights_L1.npy")
DFA_weights2 = np.load("res/dfa/DFA_weights_L2.npy")

### Aggregation

In [12]:
# Load encrypted trained TFHE-NN-1
with open("out/model1/enc_net.pkl", "rb") as f:
    enc_net1 = pickle.load(f)
    enc_net1.deserialize()

# Load encrypted trained TFHE-NN-2
with open("out/model2/enc_net.pkl", "rb") as f:
    enc_net2 = pickle.load(f)
    enc_net2.deserialize()

In [13]:
# Group encrypted weights
W2 = [enc_net1.layers[2].weights, enc_net2.layers[2].weights]
B2 = [enc_net1.layers[2].bias, enc_net2.layers[2].bias]

W3 = [enc_net1.layers[3].weights, enc_net2.layers[3].weights]
B3 = [enc_net1.layers[3].bias, enc_net2.layers[3].bias]

W4 = [enc_net1.layers[4].weights, enc_net2.layers[4].weights]
B4 = [enc_net1.layers[4].bias, enc_net2.layers[4].bias]

In [14]:
%%time
# Aggragation of encrypted weights
average_weights2 = encrypted_mean_matrix(W2)
average_bias2 = encrypted_mean_matrix(B2)
average_weights3 = encrypted_mean_matrix(W3)
average_bias3 = encrypted_mean_matrix(B3)
average_weights4 = encrypted_mean_matrix(W4)
average_bias4 = encrypted_mean_matrix(B4)

CPU times: user 11min 3s, sys: 306 ms, total: 11min 3s
Wall time: 29min 13s


### Serialization

In [15]:
# Save decrypted trained weights
with open("res/aggregated_weights.pkl", "wb") as f:
    pickle.dump(HE_client.decrypt_matrix(average_weights2), f)
    pickle.dump(HE_client.decrypt_matrix(average_bias2), f)
    pickle.dump(HE_client.decrypt_matrix(average_weights3), f)
    pickle.dump(HE_client.decrypt_matrix(average_bias3), f)
    pickle.dump(HE_client.decrypt_matrix(average_weights4), f)
    pickle.dump(HE_client.decrypt_matrix(average_bias4), f)

In [16]:
# Save serialized net
aggr_net = EncryptedNetwork()
aggr_net.add(EncryptedMaxPoolLayer((4, 4), stride=(4, 4)))
aggr_net.add(EncryptedFlattenLayer())
aggr_net.add(EncryptedFCLayer(16, 4))
aggr_net.add(EncryptedFCLayer(4, 2))
aggr_net.add(EncryptedFCLayer(2, 3, last_layer=True))

aggr_net.layers[2].DFA_weights = HE_client.encode_matrix(DFA_weights1)
aggr_net.layers[3].DFA_weights = HE_client.encode_matrix(DFA_weights2)

aggr_net.layers[2].weights = average_weights2
aggr_net.layers[2].bias = average_bias2
aggr_net.layers[3].weights = average_weights3
aggr_net.layers[3].bias = average_bias3
aggr_net.layers[4].weights = average_weights4
aggr_net.layers[4].bias = average_bias4

aggr_net.serialize()

with open("out/aggregation/enc_net.pkl", "wb") as f:
  pickle.dump(aggr_net, f)