In [1]:
import torch
import tenseal as ts
import pandas as pd
import random
from time import time
import numpy as np
import matplotlib.pyplot as plt

In [16]:
def split_train_test(x, y, ratio=0.3):
    idxs = list(range(len(x)))
    random.shuffle(idxs)
    split_idx = int(len(x)*ratio)
    test_idxs, train_idxs = idxs[:split_idx], idxs[split_idx:]
    return x[train_idxs], y[train_idxs], x[test_idxs], y[test_idxs]

In [37]:
def prepare_heart_disease_data():
    data = pd.read_csv("../data/framingham.csv")
    # Drop target columns
    X = data.drop(['TenYearCHD'], axis=1, inplace=False)
    Y = data['TenYearCHD']
    X = X.apply(lambda x: x.fillna(x.mean()),axis=0)
    # Standardize data
    X = (X - X.mean()) / X.std()
    X = torch.tensor(X.values).float()
    Y = torch.tensor(Y.values).float().unsqueeze(1)
    
    return split_train_test(X, Y)

In [38]:
x_train, y_train, x_test, y_test = prepare_heart_disease_data()

print("############# Data summary #############")
print(f"x_train has shape: {x_train.shape}")
print(f"y_train has shape: {y_train.shape}")
print(f"x_test has shape: {x_test.shape}")
print(f"y_test has shape: {y_test.shape}")
print("#######################################")

############# Data summary #############
x_train has shape: torch.Size([2967, 15])
y_train has shape: torch.Size([2967, 1])
x_test has shape: torch.Size([1271, 15])
y_test has shape: torch.Size([1271, 1])
#######################################


In [39]:
class LR(torch.nn.Module):

    def __init__(self, n_features):
        super(LR, self).__init__()
        self.lr = torch.nn.Linear(n_features, 1)
        
    def forward(self, x):
        out = torch.sigmoid(self.lr(x))
        return out

In [40]:
n_features = x_train.shape[1]
model = LR(n_features)
# use gradient descent with a learning_rate=1
optim = torch.optim.SGD(model.parameters(), lr=0.01)
# use Binary Cross Entropy Loss
criterion = torch.nn.BCELoss()

In [45]:
# define the number of epochs for both plain and encrypted training
EPOCHS = 1000
torch.random.manual_seed(0)
random.seed(0)
def train(model, optim, criterion, x, y, epochs=EPOCHS):
    for e in range(1, epochs + 1):
        optim.zero_grad()
        out = model(x)
        loss = criterion(out, y)
        loss.backward()
        optim.step()
        print(f"Loss at epoch {e}: {loss.data}")
    return model

model = train(model, optim, criterion, x_train, y_train)

Loss at epoch 1: 0.5273855328559875
Loss at epoch 2: 0.5268250107765198
Loss at epoch 3: 0.5262671113014221
Loss at epoch 4: 0.5257118940353394
Loss at epoch 5: 0.5251592993736267
Loss at epoch 6: 0.5246093273162842
Loss at epoch 7: 0.5240618586540222
Loss at epoch 8: 0.5235170125961304
Loss at epoch 9: 0.5229747891426086
Loss at epoch 10: 0.5224349498748779
Loss at epoch 11: 0.5218977928161621
Loss at epoch 12: 0.5213630795478821
Loss at epoch 13: 0.5208309292793274
Loss at epoch 14: 0.5203012824058533
Loss at epoch 15: 0.5197740793228149
Loss at epoch 16: 0.5192493200302124
Loss at epoch 17: 0.5187271237373352
Loss at epoch 18: 0.5182072520256042
Loss at epoch 19: 0.5176898837089539
Loss at epoch 20: 0.5171748995780945
Loss at epoch 21: 0.5166622400283813
Loss at epoch 22: 0.5161521434783936
Loss at epoch 23: 0.5156442523002625
Loss at epoch 24: 0.5151388049125671
Loss at epoch 25: 0.5146357417106628
Loss at epoch 26: 0.51413494348526
Loss at epoch 27: 0.5136364698410034
Loss at epoc

Loss at epoch 240: 0.4450227618217468
Loss at epoch 241: 0.4448230266571045
Loss at epoch 242: 0.4446239769458771
Loss at epoch 243: 0.44442570209503174
Loss at epoch 244: 0.44422823190689087
Loss at epoch 245: 0.4440315365791321
Loss at epoch 246: 0.4438355565071106
Loss at epoch 247: 0.4436403512954712
Loss at epoch 248: 0.44344595074653625
Loss at epoch 249: 0.44325220584869385
Loss at epoch 250: 0.4430592656135559
Loss at epoch 251: 0.4428670108318329
Loss at epoch 252: 0.44267553091049194
Loss at epoch 253: 0.44248485565185547
Loss at epoch 254: 0.4422948360443115
Loss at epoch 255: 0.44210559129714966
Loss at epoch 256: 0.4419170022010803
Loss at epoch 257: 0.44172918796539307
Loss at epoch 258: 0.4415420591831207
Loss at epoch 259: 0.4413556456565857
Loss at epoch 260: 0.44117000699043274
Loss at epoch 261: 0.4409849941730499
Loss at epoch 262: 0.4408007264137268
Loss at epoch 263: 0.440617173910141
Loss at epoch 264: 0.4404342472553253
Loss at epoch 265: 0.44025206565856934
Los

Loss at epoch 557: 0.40713468194007874
Loss at epoch 558: 0.4070669710636139
Loss at epoch 559: 0.40699946880340576
Loss at epoch 560: 0.40693214535713196
Loss at epoch 561: 0.40686503052711487
Loss at epoch 562: 0.4067980945110321
Loss at epoch 563: 0.40673139691352844
Loss at epoch 564: 0.4066648781299591
Loss at epoch 565: 0.40659859776496887
Loss at epoch 566: 0.40653252601623535
Loss at epoch 567: 0.4064665734767914
Loss at epoch 568: 0.4064008593559265
Loss at epoch 569: 0.40633538365364075
Loss at epoch 570: 0.40627002716064453
Loss at epoch 571: 0.4062049090862274
Loss at epoch 572: 0.40613994002342224
Loss at epoch 573: 0.40607523918151855
Loss at epoch 574: 0.40601062774658203
Loss at epoch 575: 0.405946284532547
Loss at epoch 576: 0.4058821499347687
Loss at epoch 577: 0.40581807494163513
Loss at epoch 578: 0.40575435757637024
Loss at epoch 579: 0.4056906998157501
Loss at epoch 580: 0.4056273400783539
Loss at epoch 581: 0.4055640995502472
Loss at epoch 582: 0.4055010676383972

Loss at epoch 899: 0.39223170280456543
Loss at epoch 900: 0.39220476150512695
Loss at epoch 901: 0.392177939414978
Loss at epoch 902: 0.3921511471271515
Loss at epoch 903: 0.3921244144439697
Loss at epoch 904: 0.39209774136543274
Loss at epoch 905: 0.3920712172985077
Loss at epoch 906: 0.39204463362693787
Loss at epoch 907: 0.3920181691646576
Loss at epoch 908: 0.3919917941093445
Loss at epoch 909: 0.39196547865867615
Loss at epoch 910: 0.3919391930103302
Loss at epoch 911: 0.39191296696662903
Loss at epoch 912: 0.3918868601322174
Loss at epoch 913: 0.3918607532978058
Loss at epoch 914: 0.3918347656726837
Loss at epoch 915: 0.39180874824523926
Loss at epoch 916: 0.39178287982940674
Loss at epoch 917: 0.391757071018219
Loss at epoch 918: 0.391731321811676
Loss at epoch 919: 0.39170557260513306
Loss at epoch 920: 0.39167994260787964
Loss at epoch 921: 0.3916544020175934
Loss at epoch 922: 0.39162883162498474
Loss at epoch 923: 0.39160338044166565
Loss at epoch 924: 0.3915780186653137
Los

In [46]:
def accuracy(model, x, y):
    out = model(x)
    correct = torch.abs(y - out) < 0.5
    return correct.float().mean()

plain_accuracy = accuracy(model, x_test, y_test)
print(f"Accuracy on plain test_set: {plain_accuracy}")

Accuracy on plain test_set: 0.8473643064498901


In [50]:
class EncryptedLR:
    
    def __init__(self, torch_lr):
        # TenSEAL processes lists and not torch tensors,
        # so we take out the parameters from the PyTorch model
        self.weight = torch_lr.lr.weight.data.tolist()[0]
        self.bias = torch_lr.lr.bias.data.tolist()
        
    def forward(self, enc_x):
        # We don't need to perform sigmoid as this model
        # will only be used for evaluation, and the label
        # can be deduced without applying sigmoid
        enc_out = enc_x.dot(self.weight) + self.bias
        return enc_out
    
    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)
        
    ################################################
    ## You can use the functions below to perform ##
    ## the evaluation with an encrypted model     ##
    ################################################
    
    def encrypt(self, context):
        self.weight = ts.ckks_vector(context, self.weight)
        self.bias = ts.ckks_vector(context, self.bias)
        
    def decrypt(self, context):
        self.weight = self.weight.decrypt()
        self.bias = self.bias.decrypt()
        

eelr = EncryptedLR(model)

In [47]:
# parameters
poly_mod_degree = 4096
coeff_mod_bit_sizes = [40, 20, 40]
# create TenSEALContext
ctx_eval = ts.context(ts.SCHEME_TYPE.CKKS, poly_mod_degree, -1, coeff_mod_bit_sizes)
# scale of ciphertext to use
ctx_eval.global_scale = 2 ** 20
# this key is needed for doing dot-product operations
ctx_eval.generate_galois_keys()

In [48]:
t_start = time()
enc_x_test = [ts.ckks_vector(ctx_eval, x.tolist()) for x in x_test]
t_end = time()
print(f"Encryption of the test-set took {int(t_end - t_start)} seconds")

Encryption of the test-set took 2 seconds


In [51]:
def encrypted_evaluation(model, enc_x_test, y_test):
    t_start = time()
    
    correct = 0
    for enc_x, y in zip(enc_x_test, y_test):
        # encrypted evaluation
        enc_out = model(enc_x)
        # plain comparison
        out = enc_out.decrypt()
        out = torch.tensor(out)
        out = torch.sigmoid(out)
        if torch.abs(out - y) < 0.5:
            correct += 1
    
    t_end = time()
    print(f"Evaluated test_set of {len(x_test)} entries in {int(t_end - t_start)} seconds")
    print(f"Accuracy: {correct}/{len(x_test)} = {correct / len(x_test)}")
    return correct / len(x_test)
    

encrypted_accuracy = encrypted_evaluation(eelr, enc_x_test, y_test)
diff_accuracy = plain_accuracy - encrypted_accuracy
print(f"Difference between plain and encrypted accuracies: {diff_accuracy}")
if diff_accuracy < 0:
    print("Oh! We got a better accuracy on the encrypted test-set! The noise was on our side...")

Evaluated test_set of 1271 entries in 5 seconds
Accuracy: 1081/1271 = 0.8505114083398898
Difference between plain and encrypted accuracies: -0.003147125244140625
Oh! We got a better accuracy on the encrypted test-set! The noise was on our side...
