In [105]:
from typing import List, Tuple
from typing import Optional
import math

import torch
import torch.nn as nn

import unittest
import numpy as np

from time import time

In [19]:
def create_binary_list_from_int(number: int) -> List[int]:
    if number < 0 or type(number) is not int:
        raise ValueError("Only Positive integers are allowed")

    return [int(x) for x in list(bin(number))[2:]]

In [20]:
def generate_even_data(max_int: int, batch_size: int=16)-> Tuple[List[int], List[List[int]]]:
    # Get the number of binary places needed to represent the maximum number
    max_length = int(math.log(max_int, 2))

    # Sample batch_size number of integers in range 0-max_int
    sampled_integers = np.random.randint(0, int(max_int / 2), batch_size)

    # create a list of labels all ones because all numbers are even
    labels = [1] * batch_size

    # Generate a list of binary numbers for training.
    data = [create_binary_list_from_int(int(x * 2)) for x in sampled_integers]
    data = [([0] * (max_length - len(x))) + x for x in data]

    return labels, data

In [21]:
class Generator(nn.Module):

    def __init__(self, input_length: int):
        super(Generator, self).__init__()
        self.dense_layer = nn.Linear(int(input_length), int(input_length))
        self.activation = nn.Sigmoid()

    def forward(self, x):
        return self.activation(self.dense_layer(x))

In [22]:
class Discriminator(nn.Module):
    def __init__(self, input_length: int):
        super(Discriminator, self).__init__()
        self.dense = nn.Linear(int(input_length), 1);
        self.activation = nn.Sigmoid()

    def forward(self, x):
        return self.activation(self.dense(x))


In [48]:
import tenseal as ts
context = ts.context(
            ts.SCHEME_TYPE.CKKS,
            poly_modulus_degree=8192,
            coeff_mod_bit_sizes=[60, 40, 40, 60]
          )
context.generate_galois_keys()
context.global_scale = 2**40


In [337]:
# parameters
poly_mod_degree = 4096
coeff_mod_bit_sizes = [40, 20, 40]
# create TenSEALContext
ctx_eval = ts.context(ts.SCHEME_TYPE.CKKS, poly_mod_degree, -1, coeff_mod_bit_sizes)
# scale of ciphertext to use
ctx_eval.global_scale = 2 ** 20
# this key is needed for doing dot-product operations
ctx_eval.generate_galois_keys()


# parameters
poly_mod_degree = 8192
coeff_mod_bit_sizes = [40, 21, 21, 21, 21, 21, 21, 40]
# create TenSEALContext
ctx_training = ts.context(ts.SCHEME_TYPE.CKKS, poly_mod_degree, -1, coeff_mod_bit_sizes)
ctx_training.global_scale = 2 ** 21
ctx_training.generate_galois_keys()

In [381]:
def train(max_int: int = 128, batch_size: int = 16, training_steps: int = 500):
    input_length = int(math.log(max_int, 2))

    # Models
    generator = Generator(input_length)
    discriminator = Discriminator(input_length)

    # Optimizers
    generator_optimizer = torch.optim.Adam(generator.parameters(), lr=0.001)
    discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.001)

    # loss
    loss = nn.BCELoss()

    for i in range(training_steps):
        # zero the gradients on each iteration
        generator_optimizer.zero_grad()

        # Create noisy input for generator
        # Need float type instead of int
        noise = torch.randint(0, 2, size=(batch_size, input_length)).float()
        generated_data = generator(noise)

        # Generate examples of even real data
        true_labels, true_data = generate_even_data(max_int, batch_size=batch_size)
        true_labels = torch.tensor(true_labels).float()
        true_labels = true_labels.unsqueeze(1)
        true_data = torch.tensor(true_data).float()
        
        # Train the generator
        # We invert the labels here and don't train the discriminator because we want the generator
        # to make things the discriminator classifies as true.
        generator_discriminator_out = discriminator(generated_data)
        generator_loss = loss(generator_discriminator_out, true_labels)
        generator_loss_vector = [generator_loss.item()]
        print(generator_loss_vector)
        enc_generator_loss = ts.ckks_vector(ctx_training, generator_loss_vector)
        print(enc_generator_loss)
        generator_loss.backward()
        generator_optimizer.step()

        # Train the discriminator on the true/generated data
        discriminator_optimizer.zero_grad()
        true_discriminator_out = discriminator(true_data)
        true_discriminator_loss = loss(true_discriminator_out, true_labels)
        true_discriminator_loss_vector = [true_discriminator_loss.item()]
        print(true_discriminator_loss_vector)
        enc_true_discriminator_loss = ts.ckks_vector(ctx_training, true_discriminator_loss_vector)
        print(enc_true_discriminator_loss)

        # add .detach() here think about this
        generator_discriminator_out = discriminator(generated_data.detach())
        #generator_discriminator_loss = loss(generator_discriminator_out, torch.zeros(batch_size))
        generator_discriminator_loss = loss(generator_discriminator_out, torch.zeros(batch_size).unsqueeze(1))
        discriminator_loss = (true_discriminator_loss + generator_discriminator_loss) / 2
        discriminator_loss.backward()
        discriminator_optimizer.step()

In [382]:
train()

[0.6957530379295349]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f6d0>
[0.6755548715591431]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403ffd0>
[0.6958860158920288]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f610>
[0.707151472568512]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f210>
[0.6996229887008667]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403fa10>
[0.7388465404510498]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f050>
[0.6939811706542969]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f6d0>
[0.7246197462081909]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f210>
[0.6993811726570129]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f990>
[0.6933154463768005]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f610>
[0.6982495188713074]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403ffd0>
[0.6514685153961182]
<tenseal.tensors.ckksvector.CKKSVe

<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403fd10>
[0.6694324016571045]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f8d0>
[0.7399413585662842]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f210>
[0.6692819595336914]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f6d0>
[0.791102409362793]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403fa10>
[0.675933301448822]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f050>
[0.7231051921844482]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f750>
[0.6706147789955139]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403fa10>
[0.8107243180274963]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f610>
[0.6684812307357788]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f8d0>
[0.6562548875808716]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403fd10>
[0.6682192087173462]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2

<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f8d0>
[0.7017472386360168]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f990>
[0.6682980060577393]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f610>
[0.7375670671463013]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f8d0>
[0.6635792255401611]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f110>
[0.7131248712539673]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f050>
[0.669019877910614]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403fd10>
[0.7232786417007446]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f8d0>
[0.6681989431381226]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f750>
[0.7206715941429138]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f610>
[0.6651473045349121]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f8d0>
[0.7696061730384827]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc

<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f6d0>
[0.7135911583900452]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403ffd0>
[0.670712411403656]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f990>
[0.7324883937835693]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f950>
[0.6806483864784241]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f210>
[0.684090256690979]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f050>
[0.6718769073486328]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f950>
[0.7575572729110718]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403ffd0>
[0.6702632904052734]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f610>
[0.7328557372093201]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f050>
[0.6656960248947144]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403fa10>
[0.7033036351203918]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2

<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403ffd0>
[0.6865097880363464]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f210>
[0.7361537218093872]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f990>
[0.6790741682052612]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f750>
[0.7248713374137878]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f050>
[0.6808481812477112]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403fa10>
[0.7567507028579712]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403fd10>
[0.6951119899749756]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f050>
[0.7560464143753052]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f110>
[0.6819345951080322]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f990>
[0.7299984693527222]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403fd10>
[0.6880450248718262]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7ef

<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f050>
[0.7169535160064697]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f610>
[0.7365039587020874]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f290>
[0.7085827589035034]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f750>
[0.7203186750411987]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f210>
[0.7040977478027344]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f290>
[0.718166172504425]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f110>
[0.7048399448394775]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403fd10>
[0.7084414958953857]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f210>
[0.7029199004173279]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f110>
[0.7340415716171265]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f750>
[0.7046116590499878]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc

<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f950>
[0.7079766392707825]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403fa10>
[0.7288232445716858]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f6d0>
[0.6871038675308228]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403fd10>
[0.7281314134597778]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f990>
[0.7050472497940063]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403ffd0>
[0.7240503430366516]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f750>
[0.7077516317367554]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f950>
[0.731705904006958]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f6d0>
[0.7029246687889099]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f8d0>
[0.7320269346237183]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403ffd0>
[0.7173206210136414]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc

<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f990>
[0.6868869066238403]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f750>
[0.7442229986190796]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f8d0>
[0.7022984027862549]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f210>
[0.741701602935791]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f290>
[0.697770893573761]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f050>
[0.7452989816665649]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f210>
[0.6988235116004944]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f290>
[0.7438440918922424]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f050>
[0.6679560542106628]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f8d0>
[0.744310200214386]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f6d0>
[0.7005505561828613]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc24

<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403fa10>
[0.6999778151512146]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f110>
[0.747442364692688]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f8d0>
[0.6783338785171509]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f290>
[0.7446959018707275]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f610>
[0.6618216037750244]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f110>
[0.7386131882667542]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f990>
[0.6828246712684631]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403ffd0>
[0.7444108128547668]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403f050>
[0.6830739378929138]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403fd10>
[0.745570182800293]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2403ffd0>
[0.6734238862991333]
<tenseal.tensors.ckksvector.CKKSVector object at 0x7efc2

In [264]:
y_train, x_train = generate_even_data(max_int=128, batch_size=16)
print(y_train)
y_train
print(x_train)
x_test = [115, 90, 94, 122, 122, 122, 95, 115, 83, 123, 114, 123, 94, 89, 91, 86]
y_test = [0,1,1,1,1,1,0,0,0,0,1,0,1,0,0,1]

[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
[[1, 1, 0, 0, 0, 1, 0], [1, 0, 1, 0, 1, 1, 0], [1, 1, 0, 0, 1, 1, 0], [1, 1, 0, 0, 1, 0, 0], [0, 1, 1, 1, 1, 0, 0], [0, 1, 0, 1, 1, 1, 0], [1, 0, 1, 0, 1, 1, 0], [1, 1, 1, 1, 1, 1, 0], [0, 0, 1, 0, 0, 1, 0], [1, 1, 0, 1, 1, 1, 0], [0, 0, 1, 0, 1, 0, 0], [0, 1, 0, 0, 0, 0, 0], [1, 0, 0, 0, 1, 0, 0], [0, 0, 1, 0, 0, 0, 0], [0, 1, 1, 0, 0, 0, 0], [1, 0, 0, 0, 1, 0, 0]]


In [265]:
class LR(torch.nn.Module):

    def __init__(self, n_features):
        super(LR, self).__init__()
        self.lr = torch.nn.Linear(n_features, 1)
        
    def forward(self, x):
        out = torch.sigmoid(self.lr(x))
        return out

In [266]:
#n_features = x_train.shape[1]
n_features = 7
model = LR(n_features)
# use gradient descent with a learning_rate=1
optim = torch.optim.SGD(model.parameters(), lr=1)
# use Binary Cross Entropy Loss
criterion = torch.nn.BCELoss()

In [267]:
class EncryptedLR:
    
    def __init__(self, torch_lr):
        self.weight = torch_lr.lr.weight.data.tolist()[0]
        self.bias = torch_lr.lr.bias.data.tolist()
        # we accumulate gradients and counts the number of iterations
        self._delta_w = 0
        self._delta_b = 0
        self._count = 0
        
    def forward(self, enc_x):
        enc_out = enc_x.dot(self.weight) + self.bias
        enc_out = EncryptedLR.sigmoid(enc_out)
        return enc_out
    
    def backward(self, enc_x, enc_out, enc_y):
        out_minus_y = (enc_out - enc_y)
        self._delta_w += enc_x * out_minus_y
        self._delta_b += out_minus_y
        self._count += 1
        
    def update_parameters(self):
        if self._count == 0:
            raise RuntimeError("You should at least run one forward iteration")
        # update weights
        # We use a small regularization term to keep the output
        # of the linear layer in the range of the sigmoid approximation
        self.weight -= self._delta_w * (1 / self._count) + self.weight * 0.05
        self.bias -= self._delta_b * (1 / self._count)
        # reset gradient accumulators and iterations count
        self._delta_w = 0
        self._delta_b = 0
        self._count = 0
    
    @staticmethod
    def sigmoid(enc_x):
        # We use the polynomial approximation of degree 3
        # sigmoid(x) = 0.5 + 0.197 * x - 0.004 * x^3
        # from https://eprint.iacr.org/2018/462.pdf
        # which fits the function pretty well in the range [-5,5]
        return enc_x.polyval([0.5, 0.197, 0, -0.004])
    
    def plain_accuracy(self, x_test, y_test):
        # evaluate accuracy of the model on
        # the plain (x_test, y_test) dataset
        w = torch.tensor(self.weight)
        b = torch.tensor(self.bias)
        out = torch.sigmoid(x_test.matmul(w) + b).reshape(-1, 1)
        correct = torch.abs(y_test - out) < 0.5
        return correct.float().mean()    
    
    def encrypt(self, context):
        self.weight = ts.ckks_vector(context, self.weight)
        self.bias = ts.ckks_vector(context, self.bias)
        
    def decrypt(self):
        self.weight = self.weight.decrypt()
        self.bias = self.bias.decrypt()
        
    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)

In [268]:
# parameters
poly_mod_degree = 4096
coeff_mod_bit_sizes = [40, 20, 40]
# create TenSEALContext
ctx_eval = ts.context(ts.SCHEME_TYPE.CKKS, poly_mod_degree, -1, coeff_mod_bit_sizes)
# scale of ciphertext to use
ctx_eval.global_scale = 2 ** 20
# this key is needed for doing dot-product operations
ctx_eval.generate_galois_keys()


# parameters
poly_mod_degree = 8192
coeff_mod_bit_sizes = [40, 21, 21, 21, 21, 21, 21, 40]
# create TenSEALContext
ctx_training = ts.context(ts.SCHEME_TYPE.CKKS, poly_mod_degree, -1, coeff_mod_bit_sizes)
ctx_training.global_scale = 2 ** 21
ctx_training.generate_galois_keys()

In [269]:
t_start = time()
#enc_x_train = [ts.ckks_vector(ctx_training, x.tolist()) for x in x_train]
#enc_y_train = [ts.ckks_vector(ctx_training, y.tolist()) for y in y_train]
enc_x_train_array = []
enc_y_train_array = []
for x,y in zip(x_train,y_train):
    enc_x_train = ts.ckks_vector(ctx_training, x)
    enc_x_train_array.append(enc_x_train)
    #print("Encryption of ", x," is ", enc_x_train)
    
    # convert single digit to vector
    y = [int(d) for d in str(y)]
    enc_y_train = ts.ckks_vector(ctx_training, y)
    enc_y_train_array.append(enc_y_train)
    #print("Encryption of ", y," is ", enc_y_train)

    
    
#enc_x_train = [ts.ckks_vector(ctx_training, x_train) for x in x_train]
#enc_y_train = [ts.ckks_vector(ctx_training, y_train) for y in y_train]
                                 
t_end = time()
print(f"Encryption of the training_set took {int(t_end - t_start)} seconds")

Encryption of the training_set took 0 seconds


In [270]:
eelr = EncryptedLR(LR(n_features))
#accuracy = eelr.plain_accuracy(x_test, y_test)
#print(f"Accuracy at epoch #0 is {accuracy}")
EPOCHS = 5

times = []
for epoch in range(EPOCHS):
    eelr.encrypt(ctx_training)
    
    # if you want to keep an eye on the distribution to make sure
    # the function approxiamation is still working fine
    # WARNING: this operation is time consuming
    # encrypted_out_distribution(eelr, enc_x_train)
    
    t_start = time()
    for enc_x, enc_y in zip(enc_x_train_array, enc_y_train_array):
        enc_out = eelr.forward(enc_x)
        eelr.backward(enc_x, enc_out, enc_y)
     
    eelr.update_parameters()
    t_end = time()
    times.append(t_end - t_start)
    
    eelr.decrypt()
    #accuracy = eelr.plain_accuracy(x_test, y_test)
    #print(f"Accuracy at epoch #{epoch + 1} is {accuracy}")


print(f"\nAverage time per epoch: {int(sum(times) / len(times))} seconds")
#print(f"Final accuracy is {accuracy}")

#diff_accuracy = plain_accuracy - accuracy
#print(f"Difference between plain and encrypted accuracies: {diff_accuracy}")
#if diff_accuracy < 0:
#    print("Oh! We got a better accuracy when training on encrypted data! The noise was on our side...")


Average time per epoch: 1 seconds
