In [None]:
import numpy as np
from Pyfhel.util import ENCODING_t
from Pyfhel import PyCtxt, Pyfhel, PyPtxt
import torch
from torch import nn
import pickle

In [None]:
pyfhel = Pyfhel()
pyfhel.contextGen(p=65537, m=16384, base=2, sec=128, intDigits=64, fracDigits = 32)
pyfhel.keyGen()

In [None]:
pyfhel.relinKeyGen(32,5)

In [None]:
class ApproximateSigmoid(nn.Module):
    def forward(self,x,pyfhel=None):
        if type(x) == list or type(x) == PyCtxt:
            with torch.no_grad():
                return self.forward_encrypted(x,pyfhel)
        else:
            # f3(x) = 0.5 + 1.20096(x/8) - 0.81562(x/8)^3
            return 0.5 + 1.20096*(x/8) - 0.81562*((x/8)**3)
    
    def forward_encrypted(self,x,pyfhel):
        assert type(pyfhel) == Pyfhel
        if type(x) == list:
            return [self.forward_encrypred_single(a,pyfhel) for a in x]
        else:
            return self.forward_encrypred_single(x,pyfhel)
    
    def forward_encrypred_single(self,a,pyfhel):
        term_1 = pyfhel.encode(0.5)
        a_by_8 = pyfhel.multiply_plain(a,pyfhel.encode(0.125),in_new_ctxt=True)
        term_2 = pyfhel.multiply_plain(a_by_8,pyfhel.encode(1.20096),in_new_ctxt=True)
        a_by_8_cube = pyfhel.power(a_by_8,3,in_new_ctxt=True)
        pyfhel.relinearize(a_by_8_cube)
        term_3 = pyfhel.multiply_plain(a_by_8_cube,pyfhel.encode(- 0.81562),in_new_ctxt=True)
        pyfhel.add(term_2,term_3)
        pyfhel.add_plain(term_2,term_1)
        return term_2

class ModifiedLinear(nn.Linear):
    def forward(self,x,pyfhel=None):
        if type(x) == list or type(x) == PyCtxt:
            with torch.no_grad():
                return self.forward_encrypted(x,pyfhel)
        else:
            return super().forward(x)
    
    def forward_encrypted(self,x,pyfhel):
        assert type(pyfhel) == Pyfhel
        assert type(x) == list # no support for batching right now 
        outputs = []
        weights = self.weight.detach().numpy()
        biases = self.bias.detach().numpy()
        for row,bias in zip(weights,biases):
            out = pyfhel.encryptFrac(bias)
            assert len(row) == len(x)
            for weight_element,input_element in zip(row,x):
                weight_input_product = pyfhel.multiply_plain(input_element,pyfhel.encode(weight_element),in_new_ctxt=True)
                pyfhel.relinearize(weight_input_product)
                pyfhel.add(out,weight_input_product)
            outputs.append(out)
        return outputs

In [None]:
class ModifiedSequential(nn.Sequential):
    def forward(self,x,pyfhel=None):
        if type(x) == torch.Tensor:
            return super().forward(x)
        return self.forward_encrypted(x,pyfhel)
    
    def forward_encrypted(self,x,pyfhel):
        for child in self.children():
            x = child(x,pyfhel)
        return x

In [91]:
INPUT_SIZE = 4
HIDDEN_SIZE = 2
OUTPUT_SIZE = 1

# creating a neural network which can handle 
encrpyted_neural_processor = ModifiedSequential(
    ModifiedLinear(INPUT_SIZE,HIDDEN_SIZE),
    ApproximateSigmoid(),
    ModifiedLinear(HIDDEN_SIZE,OUTPUT_SIZE),
    ApproximateSigmoid()
)

# correspoing pytorch model initialization would look like this:
traditional_neural_network = nn.Sequential(
    nn.Linear(INPUT_SIZE,HIDDEN_SIZE),
    nn.Sigmoid(),
    nn.Linear(HIDDEN_SIZE,OUTPUT_SIZE),
    nn.Sigmoid()
)


In [92]:
# train the perceptron on IRIS dataset
from sklearn import datasets
from sklearn.model_selection import train_test_split

iris = datasets.load_iris()
X = iris['data']
y = iris['target']

mask = y<=1

X=X[mask]
y=y[mask]

# need this to normalize inputs later
maxes = np.max(X,axis=0)

X = X/maxes
y = y.reshape(-1,1)

X = torch.tensor(X,dtype=torch.float)
y = torch.tensor(y,dtype=torch.float)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.20)

In [93]:
labels = {0:'setosa', 1:'versicolor'}

In [94]:
loss = torch.nn.BCELoss()
optimizer = torch.optim.Adam(encrpyted_neural_processor.parameters(), lr=0.01)

## The model trains on unencrypted data but may predict on both: encrypted and unencrypted data!

In [95]:
EPOCHS = 45

for i in range(EPOCHS):
    total_loss = 0
    for a_x,a_y in zip(X_train,y_train):
        a_out = encrpyted_neural_processor(a_x)
        a_out.clip_(0,1)
        a_loss = loss(a_out,a_y)
        total_loss += a_loss.item()
        optimizer.zero_grad()
        a_loss.backward()
        optimizer.step()
    if i%5 == 0:
        print(f'loss after {i} epochs is {total_loss}')

loss after 0 epochs is 55.58208787441254
loss after 5 epochs is 43.32298055291176
loss after 10 epochs is 16.029682844877243
loss after 15 epochs is 2.0014456705539487
loss after 20 epochs is 0.8141750950017013
loss after 25 epochs is 0.4686754490248859
loss after 30 epochs is 0.281995673525671
loss after 35 epochs is 0.18029425293207169
loss after 40 epochs is 0.10111747030168772


In [96]:
# the model is now trained. Let's try out a test with both: encrypted and unencrypted data

# first, lets prepare unencrypted input
unencrypted_input = (X_test/maxes)[1].float()

print(unencrypted_input)

tensor([0.1020, 0.1860, 0.0538, 0.0617])


In [97]:
# Lets also prepare encrypted input
encrpyted_input = [pyfhel.encryptFrac(a_inp.item()) for a_inp in unencrypted_input]

print(encrpyted_input)

[<Pyfhel Ciphertext at 0x7f17e2036f80, encoding=FRACTIONAL, size=2/2, noiseBudget=410>, <Pyfhel Ciphertext at 0x7f17e2038b80, encoding=FRACTIONAL, size=2/2, noiseBudget=410>, <Pyfhel Ciphertext at 0x7f17e2e91900, encoding=FRACTIONAL, size=2/2, noiseBudget=410>, <Pyfhel Ciphertext at 0x7f17e20339c0, encoding=FRACTIONAL, size=2/2, noiseBudget=410>]


In [98]:
encrpyted_neural_processor(unencrypted_input)

tensor([-0.0330], grad_fn=<SubBackward0>)

In [99]:
encrpyted_output = encrpyted_neural_processor(encrpyted_input,pyfhel)

In [100]:
pyfhel.decryptFrac(encrpyted_output[0])

-0.033422748325392604