Below is the list of required files that need to uploaded for this script to run:

- `nn_utils.py`
- `utils.py`
- `nn_relu_utils.py`
- 'diabetes_binary_5050split_health_indicators_BRFSS2015.csv`

In [14]:
pip install tenseal



In [15]:
import torch
import random
import tenseal as ts
from time import time
from utils import load_diabetes_data_5050,create_dataloader, print_metrics
from nn_utils import train, evaluate_model
from nn_relu_utils import NeuralNet_Relu1

torch.random.manual_seed(73)
random.seed(73)

# Training Neural Network on Unencrypted Data

In [16]:
#Load BRFSS dataset with 50/50 split
x_train, x_test, y_train, y_test = load_diabetes_data_5050()

In [17]:
train_dl = create_dataloader(x_train, y_train)
test_dl = create_dataloader(x_test, y_test)

In [18]:
model = NeuralNet_Relu1()
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
model = train(model, train_dl, criterion, optimizer, 10)

Epoch: 1 	Training Loss: 0.516054
Epoch: 2 	Training Loss: 0.506100
Epoch: 3 	Training Loss: 0.504309
Epoch: 4 	Training Loss: 0.503408
Epoch: 5 	Training Loss: 0.502888
Epoch: 6 	Training Loss: 0.502523
Epoch: 7 	Training Loss: 0.502285
Epoch: 8 	Training Loss: 0.502063
Epoch: 9 	Training Loss: 0.501891
Epoch: 10 	Training Loss: 0.501775


In [19]:
# Evaluate the model
accuracy, precision, recall, f1, confusion = evaluate_model(model, test_dl)

# Print the evaluation metrics
print_metrics(accuracy, precision, recall, f1, confusion)

Evaluated test_set of 14139 entries in 0 seconds
Accuracy: 0.7566
Precision: 0.7260
Recall: 0.8254
F1 Score: 0.7725
Confusion Matrix:
 [[4852 2206]
 [1236 5845]]


# Evaluating NN on encrypted data

In [20]:
## Encryption Parameters

# controls precision of the fractional part
bits_scale = 26

# Create TenSEAL context
context = ts.context(
    ts.SCHEME_TYPE.CKKS,
    poly_modulus_degree=8192,
    coeff_mod_bit_sizes=[31, bits_scale, bits_scale, bits_scale, bits_scale, 31]
)

# set the scale
context.global_scale = pow(2, bits_scale)

# galois keys are required to do ciphertext rotations
context.generate_galois_keys()

In [21]:

t_start = time()
enc_x_test = [ts.ckks_vector(context, x.tolist()) for x in x_test]
t_end = time()
print(f"Encryption of the test-set took {int(t_end - t_start)} seconds")

Encryption of the test-set took 112 seconds


In [22]:
def relu_approx(enc_x):
      return enc_x.polyval([0.563059, 0.5, 0.078047])

class EncConvNet:
    def __init__(self, torch_nn):

        self.fc1_weight = torch_nn.fc1.weight.T.data.tolist()
        self.fc1_bias = torch_nn.fc1.bias.data.tolist()

        self.fc2_weight = torch_nn.fc2.weight.T.data.tolist()
        self.fc2_bias = torch_nn.fc2.bias.data.tolist()


    def forward(self, enc_x):
        # fc1 layer
        enc_x = enc_x.mm(self.fc1_weight) + self.fc1_bias
        # relu approximation
        enc_x = relu_approx(enc_x)
        # fc2 layer
        enc_x = enc_x.mm(self.fc2_weight) + self.fc2_bias
        return enc_x

    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)


In [23]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix

def encrypted_evaluation(model, enc_x_test, y_test):
  t_start = time()
  y_true = []
  y_pred = []

  for enc_x, y in zip(enc_x_test, y_test):
        # encrypted evaluation
        enc_out = model(enc_x)
        # plain comparison
        output = enc_out.decrypt()
        output = torch.tensor(output)
        output = torch.sigmoid(output)
        predicted = output >= 0.5
        y_true.extend(y.view(-1).tolist())
        y_pred.extend(predicted.view(-1).tolist())

  t_end = time()
  print(f"Evaluated test_set of {len(y_test)} entries in {int(t_end - t_start)} seconds")

  # Calculate metrics
  accuracy = accuracy_score(y_true, y_pred)
  precision = precision_score(y_true, y_pred)
  recall = recall_score(y_true, y_pred)
  f1 = f1_score(y_true, y_pred)
  confusion = confusion_matrix(y_true, y_pred)

  return accuracy, precision, recall, f1, confusion

In [24]:
enc_model = EncConvNet(model)
accuracy, precision, recall, f1, confusion = encrypted_evaluation(enc_model, enc_x_test, y_test)
print_metrics(accuracy, precision, recall, f1, confusion)

Evaluated test_set of 14139 entries in 2076 seconds
Accuracy: 0.7557
Precision: 0.7367
Recall: 0.7971
F1 Score: 0.7657
Confusion Matrix:
 [[5041 2017]
 [1437 5644]]
