In [105]:
from typing import List, Tuple
from typing import Optional
import math

import torch
import torch.nn as nn

import unittest
import numpy as np

from time import time

In [19]:
def create_binary_list_from_int(number: int) -> List[int]:
    if number < 0 or type(number) is not int:
        raise ValueError("Only Positive integers are allowed")

    return [int(x) for x in list(bin(number))[2:]]

In [20]:
def generate_even_data(max_int: int, batch_size: int=16)-> Tuple[List[int], List[List[int]]]:
    # Get the number of binary places needed to represent the maximum number
    max_length = int(math.log(max_int, 2))

    # Sample batch_size number of integers in range 0-max_int
    sampled_integers = np.random.randint(0, int(max_int / 2), batch_size)

    # create a list of labels all ones because all numbers are even
    labels = [1] * batch_size

    # Generate a list of binary numbers for training.
    data = [create_binary_list_from_int(int(x * 2)) for x in sampled_integers]
    data = [([0] * (max_length - len(x))) + x for x in data]

    return labels, data

In [21]:
class Generator(nn.Module):

    def __init__(self, input_length: int):
        super(Generator, self).__init__()
        self.dense_layer = nn.Linear(int(input_length), int(input_length))
        self.activation = nn.Sigmoid()

    def forward(self, x):
        return self.activation(self.dense_layer(x))

In [22]:
class Discriminator(nn.Module):
    def __init__(self, input_length: int):
        super(Discriminator, self).__init__()
        self.dense = nn.Linear(int(input_length), 1);
        self.activation = nn.Sigmoid()

    def forward(self, x):
        return self.activation(self.dense(x))


In [48]:
import tenseal as ts
context = ts.context(
            ts.SCHEME_TYPE.CKKS,
            poly_modulus_degree=8192,
            coeff_mod_bit_sizes=[60, 40, 40, 60]
          )
context.generate_galois_keys()
context.global_scale = 2**40


In [337]:
# parameters
poly_mod_degree = 4096
coeff_mod_bit_sizes = [40, 20, 40]
# create TenSEALContext
ctx_eval = ts.context(ts.SCHEME_TYPE.CKKS, poly_mod_degree, -1, coeff_mod_bit_sizes)
# scale of ciphertext to use
ctx_eval.global_scale = 2 ** 20
# this key is needed for doing dot-product operations
ctx_eval.generate_galois_keys()


# parameters
poly_mod_degree = 8192
coeff_mod_bit_sizes = [40, 21, 21, 21, 21, 21, 21, 40]
# create TenSEALContext
ctx_training = ts.context(ts.SCHEME_TYPE.CKKS, poly_mod_degree, -1, coeff_mod_bit_sizes)
ctx_training.global_scale = 2 ** 21
ctx_training.generate_galois_keys()

In [373]:
def train(max_int: int = 128, batch_size: int = 16, training_steps: int = 500):
    input_length = int(math.log(max_int, 2))

    # Models
    generator = Generator(input_length)
    discriminator = Discriminator(input_length)

    # Optimizers
    generator_optimizer = torch.optim.Adam(generator.parameters(), lr=0.001)
    discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.001)

    # loss
    loss = nn.BCELoss()

    for i in range(training_steps):
        # zero the gradients on each iteration
        generator_optimizer.zero_grad()

        # Create noisy input for generator
        # Need float type instead of int
        noise = torch.randint(0, 2, size=(batch_size, input_length)).float()
        generated_data = generator(noise)

        # Generate examples of even real data
        true_labels, true_data = generate_even_data(max_int, batch_size=batch_size)
        true_labels = torch.tensor(true_labels).float()
        true_labels = true_labels.unsqueeze(1)
        true_data = torch.tensor(true_data).float()
        
        # Train the generator
        # We invert the labels here and don't train the discriminator because we want the generator
        # to make things the discriminator classifies as true.
        generator_discriminator_out = discriminator(generated_data)
        generator_loss = loss(generator_discriminator_out, true_labels)
        generator_loss_vector = [generator_loss.item()]
        print(generator_loss_vector)
        enc_generator_loss = ts.ckks_vector(ctx_training, generator_loss_vector)
        print(enc_generator_loss)
        generator_loss.backward()
        generator_optimizer.step()

        # Train the discriminator on the true/generated data
        discriminator_optimizer.zero_grad()
        true_discriminator_out = discriminator(true_data)
        true_discriminator_loss = loss(true_discriminator_out, true_labels)

        # add .detach() here think about this
        generator_discriminator_out = discriminator(generated_data.detach())
        #generator_discriminator_loss = loss(generator_discriminator_out, torch.zeros(batch_size))
        generator_discriminator_loss = loss(generator_discriminator_out, torch.zeros(batch_size).unsqueeze(1))
        discriminator_loss = (true_discriminator_loss + generator_discriminator_loss) / 2
        discriminator_loss.backward()
        discriminator_optimizer.step()

In [374]:
train()

[0.5542437434196472]
[0.5585272312164307]
[0.5629175305366516]
[0.5583036541938782]
[0.5547553896903992]
[0.5575031638145447]
[0.561896800994873]
[0.5681354999542236]
[0.5544306039810181]
[0.5640419125556946]
[0.5715162754058838]
[0.5620367527008057]
[0.5720278024673462]
[0.5731748938560486]
[0.5752259492874146]
[0.5704964399337769]
[0.5724347829818726]
[0.5706439018249512]
[0.5712560415267944]
[0.5730897188186646]
[0.575793981552124]
[0.5837514400482178]
[0.58234703540802]
[0.5748985409736633]
[0.580812931060791]
[0.5775256752967834]
[0.581829309463501]
[0.5832788944244385]
[0.5819604992866516]
[0.5805007219314575]
[0.5813208222389221]
[0.5842555165290833]
[0.5858339071273804]
[0.5876688957214355]
[0.5919299125671387]
[0.5834435224533081]
[0.5870649814605713]
[0.5863118767738342]
[0.5915297269821167]
[0.5977877974510193]
[0.5833285450935364]
[0.593869149684906]
[0.5963711142539978]
[0.5984539985656738]
[0.6011722683906555]
[0.593683123588562]
[0.6016730070114136]
[0.5964047312736511]


[0.7062872648239136]
[0.7067877650260925]
[0.7042995691299438]
[0.7159066796302795]
[0.7074021100997925]
[0.7121376991271973]
[0.7011553049087524]
[0.7057291269302368]
[0.7110608816146851]
[0.7087165117263794]
[0.7048594951629639]
[0.7086701393127441]
[0.7157617807388306]
[0.7113911509513855]
[0.7015479803085327]
[0.7106163501739502]
[0.7073631286621094]
[0.7150470614433289]
[0.7114810943603516]
[0.7150009870529175]
[0.7139045000076294]
[0.7183828353881836]
[0.7127599120140076]
[0.7078262567520142]
[0.7149680256843567]
[0.7110747694969177]
[0.7137323617935181]
[0.7155201435089111]
[0.7106055617332458]
[0.7155376672744751]
[0.7156375050544739]
[0.7090803980827332]
[0.714834451675415]
[0.7197393774986267]
[0.7169337868690491]
[0.7125425934791565]
[0.7220710515975952]
[0.7208635807037354]
[0.7198382616043091]
[0.717902421951294]
[0.7161601781845093]
[0.7276496291160583]
[0.7137881517410278]
[0.7197854518890381]
[0.7173445224761963]
[0.7186065912246704]
[0.7172107100486755]
[0.725806772708

In [264]:
y_train, x_train = generate_even_data(max_int=128, batch_size=16)
print(y_train)
y_train
print(x_train)
x_test = [115, 90, 94, 122, 122, 122, 95, 115, 83, 123, 114, 123, 94, 89, 91, 86]
y_test = [0,1,1,1,1,1,0,0,0,0,1,0,1,0,0,1]

[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
[[1, 1, 0, 0, 0, 1, 0], [1, 0, 1, 0, 1, 1, 0], [1, 1, 0, 0, 1, 1, 0], [1, 1, 0, 0, 1, 0, 0], [0, 1, 1, 1, 1, 0, 0], [0, 1, 0, 1, 1, 1, 0], [1, 0, 1, 0, 1, 1, 0], [1, 1, 1, 1, 1, 1, 0], [0, 0, 1, 0, 0, 1, 0], [1, 1, 0, 1, 1, 1, 0], [0, 0, 1, 0, 1, 0, 0], [0, 1, 0, 0, 0, 0, 0], [1, 0, 0, 0, 1, 0, 0], [0, 0, 1, 0, 0, 0, 0], [0, 1, 1, 0, 0, 0, 0], [1, 0, 0, 0, 1, 0, 0]]


In [265]:
class LR(torch.nn.Module):

    def __init__(self, n_features):
        super(LR, self).__init__()
        self.lr = torch.nn.Linear(n_features, 1)
        
    def forward(self, x):
        out = torch.sigmoid(self.lr(x))
        return out

In [266]:
#n_features = x_train.shape[1]
n_features = 7
model = LR(n_features)
# use gradient descent with a learning_rate=1
optim = torch.optim.SGD(model.parameters(), lr=1)
# use Binary Cross Entropy Loss
criterion = torch.nn.BCELoss()

In [267]:
class EncryptedLR:
    
    def __init__(self, torch_lr):
        self.weight = torch_lr.lr.weight.data.tolist()[0]
        self.bias = torch_lr.lr.bias.data.tolist()
        # we accumulate gradients and counts the number of iterations
        self._delta_w = 0
        self._delta_b = 0
        self._count = 0
        
    def forward(self, enc_x):
        enc_out = enc_x.dot(self.weight) + self.bias
        enc_out = EncryptedLR.sigmoid(enc_out)
        return enc_out
    
    def backward(self, enc_x, enc_out, enc_y):
        out_minus_y = (enc_out - enc_y)
        self._delta_w += enc_x * out_minus_y
        self._delta_b += out_minus_y
        self._count += 1
        
    def update_parameters(self):
        if self._count == 0:
            raise RuntimeError("You should at least run one forward iteration")
        # update weights
        # We use a small regularization term to keep the output
        # of the linear layer in the range of the sigmoid approximation
        self.weight -= self._delta_w * (1 / self._count) + self.weight * 0.05
        self.bias -= self._delta_b * (1 / self._count)
        # reset gradient accumulators and iterations count
        self._delta_w = 0
        self._delta_b = 0
        self._count = 0
    
    @staticmethod
    def sigmoid(enc_x):
        # We use the polynomial approximation of degree 3
        # sigmoid(x) = 0.5 + 0.197 * x - 0.004 * x^3
        # from https://eprint.iacr.org/2018/462.pdf
        # which fits the function pretty well in the range [-5,5]
        return enc_x.polyval([0.5, 0.197, 0, -0.004])
    
    def plain_accuracy(self, x_test, y_test):
        # evaluate accuracy of the model on
        # the plain (x_test, y_test) dataset
        w = torch.tensor(self.weight)
        b = torch.tensor(self.bias)
        out = torch.sigmoid(x_test.matmul(w) + b).reshape(-1, 1)
        correct = torch.abs(y_test - out) < 0.5
        return correct.float().mean()    
    
    def encrypt(self, context):
        self.weight = ts.ckks_vector(context, self.weight)
        self.bias = ts.ckks_vector(context, self.bias)
        
    def decrypt(self):
        self.weight = self.weight.decrypt()
        self.bias = self.bias.decrypt()
        
    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)

In [268]:
# parameters
poly_mod_degree = 4096
coeff_mod_bit_sizes = [40, 20, 40]
# create TenSEALContext
ctx_eval = ts.context(ts.SCHEME_TYPE.CKKS, poly_mod_degree, -1, coeff_mod_bit_sizes)
# scale of ciphertext to use
ctx_eval.global_scale = 2 ** 20
# this key is needed for doing dot-product operations
ctx_eval.generate_galois_keys()


# parameters
poly_mod_degree = 8192
coeff_mod_bit_sizes = [40, 21, 21, 21, 21, 21, 21, 40]
# create TenSEALContext
ctx_training = ts.context(ts.SCHEME_TYPE.CKKS, poly_mod_degree, -1, coeff_mod_bit_sizes)
ctx_training.global_scale = 2 ** 21
ctx_training.generate_galois_keys()

In [269]:
t_start = time()
#enc_x_train = [ts.ckks_vector(ctx_training, x.tolist()) for x in x_train]
#enc_y_train = [ts.ckks_vector(ctx_training, y.tolist()) for y in y_train]
enc_x_train_array = []
enc_y_train_array = []
for x,y in zip(x_train,y_train):
    enc_x_train = ts.ckks_vector(ctx_training, x)
    enc_x_train_array.append(enc_x_train)
    #print("Encryption of ", x," is ", enc_x_train)
    
    # convert single digit to vector
    y = [int(d) for d in str(y)]
    enc_y_train = ts.ckks_vector(ctx_training, y)
    enc_y_train_array.append(enc_y_train)
    #print("Encryption of ", y," is ", enc_y_train)

    
    
#enc_x_train = [ts.ckks_vector(ctx_training, x_train) for x in x_train]
#enc_y_train = [ts.ckks_vector(ctx_training, y_train) for y in y_train]
                                 
t_end = time()
print(f"Encryption of the training_set took {int(t_end - t_start)} seconds")

Encryption of the training_set took 0 seconds


In [270]:
eelr = EncryptedLR(LR(n_features))
#accuracy = eelr.plain_accuracy(x_test, y_test)
#print(f"Accuracy at epoch #0 is {accuracy}")
EPOCHS = 5

times = []
for epoch in range(EPOCHS):
    eelr.encrypt(ctx_training)
    
    # if you want to keep an eye on the distribution to make sure
    # the function approxiamation is still working fine
    # WARNING: this operation is time consuming
    # encrypted_out_distribution(eelr, enc_x_train)
    
    t_start = time()
    for enc_x, enc_y in zip(enc_x_train_array, enc_y_train_array):
        enc_out = eelr.forward(enc_x)
        eelr.backward(enc_x, enc_out, enc_y)
     
    eelr.update_parameters()
    t_end = time()
    times.append(t_end - t_start)
    
    eelr.decrypt()
    #accuracy = eelr.plain_accuracy(x_test, y_test)
    #print(f"Accuracy at epoch #{epoch + 1} is {accuracy}")


print(f"\nAverage time per epoch: {int(sum(times) / len(times))} seconds")
#print(f"Final accuracy is {accuracy}")

#diff_accuracy = plain_accuracy - accuracy
#print(f"Difference between plain and encrypted accuracies: {diff_accuracy}")
#if diff_accuracy < 0:
#    print("Oh! We got a better accuracy when training on encrypted data! The noise was on our side...")


Average time per epoch: 1 seconds
