Implementation of Deep net with 2 hidden layers that learns with encrypted data

In [1]:
!pip install tenseal

Collecting tenseal
  Downloading tenseal-0.3.16-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (8.4 kB)
Downloading tenseal-0.3.16-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (4.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.8/4.8 MB[0m [31m46.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: tenseal
Successfully installed tenseal-0.3.16


In [2]:
import torch
import tenseal as ts
import pandas as pd
import random
from time import time

# those are optional and are not necessary for training
import numpy as np
import matplotlib.pyplot as plt

In [3]:
torch.random.manual_seed(73)
random.seed(73)


def split_train_test(x, y, test_ratio=0.3):
    idxs = [i for i in range(len(x))]
    random.shuffle(idxs)
    # delimiter between test and train data
    delim = int(len(x) * test_ratio)
    test_idxs, train_idxs = idxs[:delim], idxs[delim:]
    return x[train_idxs], y[train_idxs], x[test_idxs], y[test_idxs]


def heart_disease_data():
    data = pd.read_csv("/kaggle/input/trydataset/framingham.csv")
    # drop rows with missing values
    data = data.dropna()
    # drop some features
    data = data.drop(columns=["education", "currentSmoker", "BPMeds", "diabetes", "diaBP", "BMI"])
    # balance data
    grouped = data.groupby('TenYearCHD')
    data = grouped.apply(lambda x: x.sample(grouped.size().min(), random_state=73).reset_index(drop=True))
    # extract labels
    y = torch.tensor(data["TenYearCHD"].values).float().unsqueeze(1)
    #data = data.drop("TenYearCHD",'columns')
    # standardize data
    data = (data - data.mean()) / data.std()
    x = torch.tensor(data.values).float()
    return split_train_test(x, y)


def random_data(m=1024, n=2):
    # data separable by the line `y = x`
    x_train = torch.randn(m, n)
    x_test = torch.randn(m // 2, n)
    y_train = (x_train[:, 0] >= x_train[:, 1]).float().unsqueeze(0).t()
    y_test = (x_test[:, 0] >= x_test[:, 1]).float().unsqueeze(0).t()
    return x_train, y_train, x_test, y_test


# You can use whatever data you want without modification to the tutorial
# x_train, y_train, x_test, y_test = random_data()
x_train, y_train, x_test, y_test = heart_disease_data()

print("############# Data summary #############")
print(f"x_train has shape: {x_train.shape}")
print(f"y_train has shape: {y_train.shape}")
print(f"x_test has shape: {x_test.shape}")
print(f"y_test has shape: {y_test.shape}")
print("#######################################")

############# Data summary #############
x_train has shape: torch.Size([780, 10])
y_train has shape: torch.Size([780, 1])
x_test has shape: torch.Size([334, 10])
y_test has shape: torch.Size([334, 1])
#######################################


  data = grouped.apply(lambda x: x.sample(grouped.size().min(), random_state=73).reset_index(drop=True))


In [4]:
# parameters
poly_mod_degree = 8192
coeff_mod_bit_sizes = [40, 21, 21, 21, 21, 21, 21, 40]
# create TenSEALContext
ctx_training = ts.context(ts.SCHEME_TYPE.CKKS, poly_mod_degree, -1, coeff_mod_bit_sizes)
ctx_training.global_scale = 2 ** 21
ctx_training.generate_galois_keys()

In [5]:
#t_start = time()
#enc_x_train = ts.ckks_tensor(ctx_training, x_train[0:200])
#enc_y_train = ts.ckks_tensor(ctx_training, y_train[0:200])
#t_end = time()
#print(f"Encryption of the training_set took {int(t_end - t_start)} seconds")
#print(enc_x_train[0])

In [6]:
class EncryptedDL2layer() :
  def __init__(self) -> None:
      self.weight1=np.random.rand(10,6)*0.01
      self.bias1=np.random.rand(6)*0.01
      self.weight2=np.random.rand(6, 3)*0.01
      self.bias2=np.random.rand(3)*0.01
      self.weight3=np.random.rand(3,1)*0.01
      self.bias3=np.random.rand(1)*0.01
      self.dw1=0
      self.db1=0
      self.dw2=0
      self.db2=0
      self.dw3=0
      self.db3=0
  def bootstrapping(enc,ctx_training) :
    return ts.ckks_tensor(ctx_training,enc.decrypt())
  def swiss(enc_x):
    return enc_x.polyval([0.6931,0.5,0.125,0])
  def swiss_derv(enc_x):
    return enc_x.polyval([0.5, 0.197, 0, -0.004])
  def forward(self,enc_x_train,ctx_training) :
    z11=enc_x_train.mm(self.weight1)
    z1=z11.add(self.bias1)
    a1=EncryptedDL2layer.swiss(z1)
    y=EncryptedDL2layer.bootstrapping(a1,ctx_training)
    z21=y.mm(self.weight2)
    z2=z21.add(self.bias2)
    a2=EncryptedDL2layer.swiss(z2)
    y1=EncryptedDL2layer.bootstrapping(a2,ctx_training)
    z31=y1.mm(self.weight3)
    z3=z31.add(self.bias3)
    a3=EncryptedDL2layer.swiss(z3)
    a3=EncryptedDL2layer.bootstrapping(a3,ctx_training)
    return a3,z3,a2,z2,a1,z1
  def backward(self,a3,z3,a2,z2,a1,z1,enc_y_train,ctx_training) :
    #calculating the output at the layer 2
    error=a3-enc_y_train
    der=EncryptedDL2layer.swiss_derv(z3)
    #finding delta3
    delta3=error.mul(der)
    #using bootstrapping
    delta3=EncryptedDL2layer.bootstrapping(delta3,ctx_training)

    del2=delta3.mm(self.weight3.transpose())
    del2=EncryptedDL2layer.bootstrapping(del2,ctx_training)
    #finding der2
    der2=EncryptedDL2layer.swiss_derv(z2)
    #using bootstrapping
    der2=EncryptedDL2layer.bootstrapping(der2,ctx_training)
    #finding delta2
    delta2=del2.mul(der2)
    del1=delta2.mm(self.weight2.transpose())
    del1=EncryptedDL2layer.bootstrapping(del1,ctx_training)
    #finding der1
    der1=EncryptedDL2layer.swiss_derv(z1)
    #using bootstrapping
    der1=EncryptedDL2layer.bootstrapping(der1,ctx_training)
    #finding delta1
    delta1=del1.mul(der1)
    #finding the gradients
    #for weight3 and bias3
    self.dw3=a2.transpose().mm(delta3)
    self.db3=delta3.sum()
    #for weight2 and bias2
    self.dw2=a1.transpose().mm(delta2)
    self.db2=delta2.sum()
    #for weight1 and bias1
    self.dw1=enc_x_train.transpose().mm(delta1)
    self.db1=delta1.sum()
  def update_params(self):
    self.weight3=self.weight3-0.01*self.dw3
    self.bias3=self.bias3-0.01*self.db3
    self.weight2=self.weight2-0.01*self.dw2
    self.bias2=self.bias2-0.01*self.db2
    self.weight1=self.weight1-0.01*self.dw1
    self.bias1=self.bias1-0.01*self.db1
  def encrypt(self, context):
    self.weight1 = ts.ckks_tensor(context, self.weight1)
    self.bias1 = ts.ckks_tensor(context, self.bias1)
    self.weight2 = ts.ckks_tensor(context, self.weight2)
    self.bias2 = ts.ckks_tensor(context, self.bias2)
    self.weight3 = ts.ckks_tensor(context, self.weight3)
    self.bias3 = ts.ckks_tensor(context, self.bias3)
  def decrypt(self):
    self.weight1 = self.weight1.decrypt()
    self.bias1 = self.bias1.decrypt()
    self.weight2 = self.weight2.decrypt()
    self.bias2 = self.bias2.decrypt()
    self.weight3 = self.weight3.decrypt()
    self.bias3 = self.bias3.decrypt()
  def accuracy(self,x_test,y_test):
    #self.decrypt()
    w1 = torch.tensor(self.weight1)
    b1 = torch.tensor(self.bias1)
    out1 = torch.swiss(x_test.matmul(w1) + b1).reshape(-1, 1)
    w2 = torch.tensor(self.weight2)
    b2 = torch.tensor(self.bias2)
    out2 = torch.swiss(out1.matmul(w2) + b2).reshape(-1, 1)
    correct = torch.abs(y_test - out2) < 0.5
    return correct.float().mean()

  def __call__(self, *args, **kwargs):
    return self.forward(*args, **kwargs)


In [7]:
EPOCHS = 5
eelr = EncryptedDL2layer()
#accuracy = eelr.accuracy(x_test, y_test)
#print(f"Accuracy at epoch #0 is {accuracy}")
times = []
for epoch in range(EPOCHS):
    print("Epoch : ",epoch)
#print(enc_x_train[0])
    for i in range(0,780,60):
        t_start = time()
        enc_x_train = ts.ckks_tensor(ctx_training, x_train[i:i+60])
        enc_y_train = ts.ckks_tensor(ctx_training, y_train[i:i+60])
        t_end = time()
        print(f"Encryption of the training_set took {int(t_end - t_start)} seconds")
        eelr.encrypt(ctx_training)
    
        # if you want to keep an eye on the distribution to make sure
        # the function approximation is still working fine
        # WARNING: this operation is time consuming
        # encrypted_out_distribution(eelr, enc_x_train)
    
        t_start = time()
        #for enc_x, enc_y in zip(enc_x_train, enc_y_train):
        a3,z3,a2,z2,a1,z1 = eelr.forward(enc_x_train,ctx_training)
        eelr.backward(a3,z3,a2,z2,a1,z1,enc_y_train,ctx_training)
        eelr.update_params()
        t_end = time()
        
        times.append(t_end - t_start)
    
        eelr.decrypt()
        print("\n")
        #accuracy = eelr.accuracy(x_test, y_test)
        #print(f"Accuracy at epoch #{epoch + 1} is {accuracy}")
    
    a3,z3,a2,z2,a1,z1 = eelr.forward(enc_x_train,ctx_training)
    data=torch.tensor(a3.decrypt().tolist())
    loss_fn = torch.nn.BCEWithLogitsLoss()
    loss = loss_fn(data, y_train[i:i+60])
    print("Loss at epoch ",epoch,loss.data)
print(f"\nAverage time per epoch: {int(sum(times) / len(times))} seconds")
    #print(f"Final accuracy is {accuracy}")
print("Final weight1 ",eelr.weight1.tolist())
print("Final bias 1",eelr.bias1.tolist())
print("Final weight 2",eelr.weight2.tolist())
print("Final bias 2",eelr.bias2.tolist())

Epoch :  0
Encryption of the training_set took 4 seconds


Encryption of the training_set took 3 seconds


Encryption of the training_set took 3 seconds


Encryption of the training_set took 3 seconds


Encryption of the training_set took 3 seconds


Encryption of the training_set took 3 seconds


Encryption of the training_set took 3 seconds


Encryption of the training_set took 3 seconds


Encryption of the training_set took 3 seconds


Encryption of the training_set took 3 seconds


Encryption of the training_set took 3 seconds


Encryption of the training_set took 3 seconds


Encryption of the training_set took 3 seconds


Loss at epoch  0 tensor(0.7164)
Epoch :  1
Encryption of the training_set took 3 seconds


Encryption of the training_set took 3 seconds


Encryption of the training_set took 3 seconds


Encryption of the training_set took 3 seconds


Encryption of the training_set took 3 seconds


Encryption of the training_set took 3 seconds


Encryption of the training_set too

In [8]:
#eelr.encrypt(ctx_training)
m,n,j,k,l,r=eelr.forward(enc_x_train,ctx_training)
enc_y_train.sub(m).decrypt().tolist()

[[0.49765461144326867],
 [-0.5012574513138099],
 [0.49457448449772085],
 [0.5000566180718824],
 [-0.49891315898969585],
 [0.49653265037187116],
 [0.49637387046043935],
 [-0.5000167228326923],
 [0.4951859849630046],
 [0.4996706969471751],
 [0.49510404001480296],
 [-0.5008814106707833],
 [-0.5015868742355932],
 [-0.5002692464943138],
 [-0.4995073880782557],
 [-0.49648289020544406],
 [-0.5002457083530673],
 [0.4983637544371586],
 [-0.4983984536667237],
 [0.49987677861992535],
 [-0.49864785617955404],
 [-0.5009571340602285],
 [0.4955348840536473],
 [-0.4995605438429126],
 [-0.5012994088510192],
 [-0.4989360371872128],
 [-0.49803725254636055],
 [-0.5024458051223881],
 [0.49904252654399583],
 [0.4971678388544877],
 [0.4950213357322076],
 [-0.5026423389102296],
 [0.4958060341902231],
 [-0.5007496081183153],
 [-0.500365334927794],
 [-0.5001437499325808],
 [0.49422598372630094],
 [0.49644081619978103],
 [0.49935173392364046],
 [0.4978335273925357],
 [0.4958331658055589],
 [-0.5008853787960602],

In [9]:
out=torch.tensor(enc_y_train.sub(m).decrypt().tolist())
correct=torch.abs(out)<0.5
print(correct.float().mean())

tensor(0.6833)


In [10]:
enc_x_test = ts.ckks_tensor(ctx_training, x_test)
enc_y_test = ts.ckks_tensor(ctx_training, y_test)

In [11]:
m,n,j,k,l,r=eelr.forward(enc_x_test,ctx_training)
print(m.decrypt().tolist())
out=torch.tensor(enc_y_test.sub(m).decrypt().tolist())
correct=torch.abs(out)<0.5
print(correct.float().mean())

[[0.4997688139868569], [0.5017065939646869], [0.5026236379084004], [0.5023130323581544], [0.5026679155911686], [0.5050289039915008], [0.502952826413671], [0.49946688874103407], [0.49952414995939626], [0.4977790022641204], [0.5024343003172412], [0.4994973518754163], [0.5040135379551463], [0.5044453540692885], [0.5022320102343802], [0.5034052815018412], [0.5048025495885101], [0.5035138173772442], [0.5049154370634495], [0.49800514744913854], [0.5038196821597637], [0.5000562043157648], [0.5005343080312649], [0.5005230133447297], [0.49945005243083324], [0.5003003893893944], [0.5016561328748383], [0.5045889128924305], [0.5034977317307896], [0.5064453760326468], [0.5035213846543506], [0.5019949795481553], [0.5039442553769202], [0.500967297027289], [0.49987151929355417], [0.5019362649744885], [0.49968773479983136], [0.4981951755493187], [0.5035798932668486], [0.4996639254529071], [0.5069167518295357], [0.5054405567779101], [0.5017353463680044], [0.49855723305558103], [0.5005596851266139], [0.5