In [None]:
from __future__ import print_function

import os, sys
module_path = os.path.abspath(os.path.join('../../..'))
sys.path.append(module_path)

import numpy as np
import math
import copy
import pandas as pd
from keras.utils import np_utils
from keras.datasets import mnist
import time
import pickle

from pycrcnn.he.HE import TFHEnuFHE
from pycrcnn.he.tfhe_value import TFHEValue
from pycrcnn.he.alu import *

## Dataset

In [None]:
# Prepare TernaryMNIST Dataset
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# Train set
x_train = x_train[:, 6:22, 6:22]

# Create Ternary classification dataset
train_indexes, test_indexes = [], []
for i in range(len(x_train)):
    if y_train[i] == 0 or y_train[i] == 1 or y_train[i] == 2:
        train_indexes.append(i)
for i in range(len(x_test)):
    if y_test[i] == 0 or y_test[i] == 1 or y_test[i] == 2:
        test_indexes.append(i)
x_train = np.subtract(x_train[train_indexes], 128)
x_train.dtype = np.int8
y_train = y_train[train_indexes]

val_images = 5000
idx_train = len(x_train) - val_images
x_train, x_val = x_train[:idx_train], x_train[idx_train:]
y_train, y_val = y_train[:idx_train], y_train[idx_train:]
y_train = np_utils.to_categorical(y_train).astype(int)*16

# Test set
x_test = x_test[:, 6:22, 6:22]
x_test = np.subtract(x_test[test_indexes], 128)
x_test.dtype = np.int8
y_test = y_test[test_indexes]

## HE Init

In [None]:
HE_client = TFHEnuFHE(22)

with open("secret_key", "rb") as f:
    HE_client.secret_key = HE_client.ctx.load_secret_key(f)
    
with open("cloud_key", "rb") as f:
    HE_client.cloud_key = HE_client.ctx.load_cloud_key(f)

cloud_key = HE_client.cloud_key
HE_client.generate_vm(cloud_key)

In [None]:
num1 = HE_client.encrypt(1)
num2 = HE_client.encode(6)
sum = num1+num2
mul = num1*num2

## EncNet Architecture

In [None]:
SHRT_MAX = 32767
SHRT_MIN = (-SHRT_MAX - 1 )

# Int square root
def isqrt(n):
    x = n
    y = (x + 1) // 2
    while y < x:
        x = y
        y = (x + n // x) // 2
    return x

In [None]:
# Encrypted PLA tanh Activation function
def encrypted_tanh(act_in, in_dim, out_dim):
    y_max, y_min = HE_client.encode(128), HE_client.encode(-127)
    intervals = HE_client.encode_matrix([128, 75, 32, -31, -74, -127])
    slopes_inv = HE_client.encode_matrix([128, 8, 2, 1, 2, 8, 128])
    act_out, act_grad_inv = np.full((act_in.shape[0], out_dim), y_max), np.full((act_in.shape[0], out_dim), slopes_inv[0])

    for i in range(len(act_in)):
      for j in range(len(act_in[i].squeeze())):
        val = act_in[i].squeeze()[j] / ((1 << 8) * in_dim)

        lt0 = val < intervals[0]
        act_out[i][j] = TFHEValue(HE_client.vm.gate_mux(lt0, (val / 4).value, act_out[i][j].value), val.vm, val.n_bits)
        act_grad_inv[i][j] = TFHEValue(HE_client.vm.gate_mux(lt0, slopes_inv[1].value, act_grad_inv[i][j].value), val.vm, val.n_bits)

        lt1 = val < intervals[1]
        act_out[i][j] = TFHEValue(HE_client.vm.gate_mux(lt1, val.value, act_out[i][j].value), val.vm, val.n_bits)
        act_grad_inv[i][j] = TFHEValue(HE_client.vm.gate_mux(lt1, slopes_inv[2].value, act_grad_inv[i][j].value), val.vm, val.n_bits)

        lt2 = val < intervals[2]
        act_out[i][j] = TFHEValue(HE_client.vm.gate_mux(lt2, (val * 2).value, act_out[i][j].value), val.vm, val.n_bits)
        act_grad_inv[i][j] = TFHEValue(HE_client.vm.gate_mux(lt2, slopes_inv[3].value, act_grad_inv[i][j].value), val.vm, val.n_bits)

        lt3 = val < intervals[3]
        act_out[i][j] = TFHEValue(HE_client.vm.gate_mux(lt3, val.value, act_out[i][j].value), val.vm, val.n_bits)
        act_grad_inv[i][j] = TFHEValue(HE_client.vm.gate_mux(lt3, slopes_inv[4].value, act_grad_inv[i][j].value), val.vm, val.n_bits)

        lt4 = val < intervals[4]
        act_out[i][j] = TFHEValue(HE_client.vm.gate_mux(lt4, (val / 4).value, act_out[i][j].value), val.vm, val.n_bits)
        act_grad_inv[i][j] = TFHEValue(HE_client.vm.gate_mux(lt4, slopes_inv[5].value, act_grad_inv[i][j].value), val.vm, val.n_bits)

        lt5 = val < intervals[5]
        act_out[i][j] = TFHEValue(HE_client.vm.gate_mux(lt5, y_min.value, act_out[i][j].value), val.vm, val.n_bits)
        act_grad_inv[i][j] = TFHEValue(HE_client.vm.gate_mux(lt5, slopes_inv[6].value, act_grad_inv[i][j].value), val.vm, val.n_bits)
        
    return act_out, act_grad_inv

In [None]:
# Encrypted L2 Loss Function
def encrypted_L2(y_true, net_out):
    loss = np.full((y_true.shape[0], y_true.shape[1]), HE_client.encode(0))
    for i in range(len(y_true)):
        for j in range(len(y_true[i])):
            loss[i][j] = net_out[i].squeeze()[j] - y_true[i][j]
    return loss

In [None]:
# Encrypted MaxPool Layer
class EncryptedMaxPoolLayer:
    def __init__(self, kernel_size, stride=(1, 1)):
        self.kernel_size = kernel_size
        self.stride = stride

    def forward(self, batch):
        return np.array([_max(image, self.kernel_size, self.stride) for image in batch])

    def backward(self, loss, lr_inv):
        return loss

def _max(image, kernel_size, stride):
    x_s = stride[1]
    y_s = stride[0]

    x_k = kernel_size[1]
    y_k = kernel_size[0]

    # print(image)
    x_d = len(image[0])
    y_d = len(image)

    x_o = ((x_d - x_k) // x_s) + 1
    y_o = ((y_d - y_k) // y_s) + 1

    def get_submatrix(matrix, x, y):
        index_row = y * y_s
        index_column = x * x_s
        return matrix[index_row: index_row + y_k, index_column: index_column + x_k]

    return [[encrypted_max(get_submatrix(image, x, y).flatten()) for x in range(0, x_o)] for y in range(0, y_o)]

In [None]:
# Encrypted Flatten Layer
class EncryptedFlattenLayer:
    def __init__(self):
        pass

    def forward(self, flatten_in):
        return flatten_in.reshape(flatten_in.shape[0], flatten_in.shape[1]*flatten_in.shape[2])

    def backward(self, loss, lr_inv):
        return loss

In [None]:
# Encrypted FC Layer
class EncryptedFCLayer:
    def __init__(self, in_dim, out_dim, last_layer = False):
        self.in_dim, self.out_dim = in_dim, out_dim
        self.last_layer = last_layer
        self.weights = np.zeros((in_dim, out_dim)).astype(int)
        self.bias = np.zeros((1, out_dim)).astype(int)
        self.DFA_weights = np.zeros((1, 1)).astype(int)

    def forward(self, fc_in):
        self.input = fc_in
        dot = (self.input @ self.weights) + self.bias
        output, self.act_grad_inv = encrypted_tanh(dot, self.in_dim, self.out_dim)
        return output

    def backward(self, loss, lr_inv):
        d_DFA = self.compute_dDFA(loss, lr_inv)

        weights_update = self.input.T @ d_DFA
        weights_update = weights_update / lr_inv
        weights_update = weights_update.reshape(self.in_dim, self.out_dim)

        if type(self.weights.squeeze()[0][0]) is not TFHEValue:
            self.weights = HE_client.encode_matrix(self.weights)

        self.weights -= weights_update

        ones = np.ones((len(d_DFA), 1)).astype(int)
        bias_update = d_DFA.T @ ones
        bias_update = bias_update.T / lr_inv

        if type(self.bias.squeeze()[0]) is not TFHEValue:
            self.bias = HE_client.encode_matrix(self.bias)

        self.bias -= bias_update

    def compute_dDFA(self, loss, lr_inv):
        if self.last_layer:
            d_DFA = np.divide(loss, self.act_grad_inv)
        else:
            if self.DFA_weights.shape[0] != loss.shape[1] and self.DFA_weights.shape[1] != self.weights.shape[1]: # 0 rows, 1 cols
                print("DFA not initialized!")
        dot = loss @ self.DFA_weights
        d_DFA = np.divide(dot, self.act_grad_inv)
        return d_DFA

In [None]:
# Encrypted Network
class EncryptedNetwork:
    def __init__(self):
        self.layers = []
    
    # Add layer to network
    def add(self, layer):
        self.layers.append(layer)
    
    # Serialize the network
    def serialize(self):
        for l in self.layers:
            if hasattr(l, "weights"):
                l.weights = HE_client.serialize_matrix(l.weights)
                l.bias = HE_client.serialize_matrix(l.bias)
                l.act_grad_inv = None
                l.input = None
                if not l.last_layer:
                    l.DFA_weights = HE_client.serialize_matrix(l.DFA_weights)
    
    # Deserialize the network
    def deserialize(self):
        for l in self.layers:
            if hasattr(l, "weights"):
                l.weights = HE_client.deserialize_matrix(l.weights)
                l.bias = HE_client.deserialize_matrix(l.bias)
                if not l.last_layer:
                    l.DFA_weights = HE_client.deserialize_matrix(l.DFA_weights)
    
    # Test
    def test(self, x_test, y_test):
        corr = HE_client.encode(0)
        enc_x = HE_client.encrypt_matrix(x_test)
        enc_y = HE_client.encrypt_matrix(y_test)

        for j in range(len(x_test)):
            pred = self.predict(enc_x[j])
            corr = TFHEValue(HE_client.vm.gate_mux(pred == enc_y[j][0], (corr + 1).value, corr.value), corr.vm, corr.n_bits)
        return corr
    
    # Predict output
    def predict(self, input_data):
        output = np.expand_dims(input_data, axis=0)
        for layer in self.layers:
            output = layer.forward(output)
        return encrypted_argmax(output.squeeze())

    # Train the network
    def fit(self, x_train, y_train, epochs, batch_size, lr_inv):
        for i in range(epochs):
            for j in range(int(len(x_train)/batch_size)):
                idx_start = j * batch_size
                idx_end = idx_start + batch_size

                batch_in = HE_client.encrypt_matrix(x_train[idx_start:idx_end])
                batch_target = HE_client.encrypt_matrix(y_train[idx_start:idx_end])

                start_time = time.time()

                # Forward propagation
                for layer in self.layers:
                  batch_in = layer.forward(batch_in)
                fwd_out = batch_in

                end_time = time.time()

                print("End forward batch: " + repr(j))
                print("Computation time: ")
                hours, rem = divmod(end_time-start_time, 3600)
                minutes, seconds = divmod(rem, 60)
                print("{:0>2}:{:0>2}:{:05.2f}".format(int(hours),int(minutes),seconds))
                print("")

                # Loss
                loss = encrypted_L2(batch_target, fwd_out)
               
                start_time = time.time()

                # Backward propagation
                for layer in reversed(self.layers):
                    layer.backward(loss, lr_inv)

                end_time = time.time()

                print("End backward batch: " + repr(j))
                print("Computation time: ")
                hours, rem = divmod(end_time-start_time, 3600)
                minutes, seconds = divmod(rem, 60)
                print("{:0>2}:{:0>2}:{:05.2f}".format(int(hours),int(minutes),seconds))
                print("")

            print("End epoch: " + repr(i))
            print("")

## Experiments

In [None]:
## UPLOAD DFA WEIGHTS
DFA_weights1 = np.load("DFAWeights_L1.npy")
DFA_weights2 = np.load("DFAWeights_L2.npy")

In [None]:
# Network Architecture
net = EncryptedNetwork()
net.add(EncryptedMaxPoolLayer((4, 4), stride=(4, 4)))
net.add(EncryptedFlattenLayer())
net.add(EncryptedFCLayer(16, 4))
net.add(EncryptedFCLayer(4, 2))
net.add(EncryptedFCLayer(2, 3, last_layer=True))

net.layers[2].DFA_weights = HE_client.encode_matrix(DFA_weights1)
net.layers[3].DFA_weights = HE_client.encode_matrix(DFA_weights2)

In [None]:
%%time
# Training
net.fit(x_train[100:125], y_train[100:125], epochs=3, batch_size=5, lr_inv=256)

In [None]:
# Decrypt Weights
weights2 = HE_client.decrypt_matrix(net.layers[2].weights)
bias2 = HE_client.decrypt_matrix(net.layers[2].bias)
weights3 = HE_client.decrypt_matrix(net.layers[3].weights)
bias3 = HE_client.decrypt_matrix(net.layers[3].bias)
weights4 = HE_client.decrypt_matrix(net.layers[4].weights)
bias4 = HE_client.decrypt_matrix(net.layers[4].bias)

In [None]:
# Save plain Weights
with open("plain_weights_M1.pkl", "wb") as f:
    pickle.dump(weights2, f)
    pickle.dump(bias2, f)
    pickle.dump(weights3, f)
    pickle.dump(bias3, f)
    pickle.dump(weights4, f)
    pickle.dump(bias4, f)

In [None]:
# Save Net
net.serialize()

with open("encnetM1.pkl", "wb") as f:
    pickle.dump(net, f)