<h1> Encrypted Inference-Linear Regression</h1>

In this tutorial, we train a Linear regression model in plaintext on Boston Housing Dataset. Then we use the model for performing inference on test data. This tutorial uses protocol Falcon for 3 parties and SPDZ for 3 and 5 parties. It depicts how you can perform inference on a dataset with nearly the same accuracy. 

In [1]:
import pandas as pd
import numpy as np

In [2]:
import torch
import torch.nn as nn
import torch.utils.data as data_utils

In [3]:
#Set a manual seed to maintain consistency
torch.manual_seed(0)

<torch._C.Generator at 0x7fd8fc4422d0>

<h2>Data Loading and Processing</h1>

In [4]:
#Improt dataset
dataset=pd.read_csv("dataset/Boston.csv")
dataset=dataset.drop("Unnamed: 0",axis=1)

In [5]:
#Visualize and look at columns and rows of dataset
dataset.head()

Unnamed: 0,crim,zn,indus,chas,nox,rm,age,dis,rad,tax,ptratio,black,lstat,medv
0,0.00632,18.0,2.31,0,0.538,6.575,65.2,4.09,1,296,15.3,396.9,4.98,24.0
1,0.02731,0.0,7.07,0,0.469,6.421,78.9,4.9671,2,242,17.8,396.9,9.14,21.6
2,0.02729,0.0,7.07,0,0.469,7.185,61.1,4.9671,2,242,17.8,392.83,4.03,34.7
3,0.03237,0.0,2.18,0,0.458,6.998,45.8,6.0622,3,222,18.7,394.63,2.94,33.4
4,0.06905,0.0,2.18,0,0.458,7.147,54.2,6.0622,3,222,18.7,396.9,5.33,36.2


In [6]:
X_data = dataset.drop("medv",axis=1)
y_data = dataset["medv"]

In [7]:
X_data = X_data.apply(
    lambda x: (x - x.mean()) / x.std()
)

In [8]:
y_data

0      24.0
1      21.6
2      34.7
3      33.4
4      36.2
       ... 
501    22.4
502    20.6
503    23.9
504    22.0
505    11.9
Name: medv, Length: 506, dtype: float64

In [9]:
X_data

Unnamed: 0,crim,zn,indus,chas,nox,rm,age,dis,rad,tax,ptratio,black,lstat
0,-0.419367,0.284548,-1.286636,-0.272329,-0.144075,0.413263,-0.119895,0.140075,-0.981871,-0.665949,-1.457558,0.440616,-1.074499
1,-0.416927,-0.487240,-0.592794,-0.272329,-0.739530,0.194082,0.366803,0.556609,-0.867024,-0.986353,-0.302794,0.440616,-0.491953
2,-0.416929,-0.487240,-0.592794,-0.272329,-0.739530,1.281446,-0.265549,0.556609,-0.867024,-0.986353,-0.302794,0.396035,-1.207532
3,-0.416338,-0.487240,-1.305586,-0.272329,-0.834458,1.015298,-0.809088,1.076671,-0.752178,-1.105022,0.112920,0.415751,-1.360171
4,-0.412074,-0.487240,-1.305586,-0.272329,-0.834458,1.227362,-0.510674,1.076671,-0.752178,-1.105022,0.112920,0.440616,-1.025487
...,...,...,...,...,...,...,...,...,...,...,...,...,...
501,-0.412820,-0.487240,0.115624,-0.272329,0.157968,0.438881,0.018654,-0.625178,-0.981871,-0.802418,1.175303,0.386834,-0.417734
502,-0.414839,-0.487240,0.115624,-0.272329,0.157968,-0.234316,0.288648,-0.715931,-0.981871,-0.802418,1.175303,0.440616,-0.500355
503,-0.413038,-0.487240,0.115624,-0.272329,0.157968,0.983986,0.796661,-0.772919,-0.981871,-0.802418,1.175303,0.440616,-0.982076
504,-0.407361,-0.487240,0.115624,-0.272329,0.157968,0.724955,0.736268,-0.667776,-0.981871,-0.802418,1.175303,0.402826,-0.864446


In [10]:
features = torch.tensor(X_data.values.astype(np.float16).astype(np.float32)) 
targets = torch.tensor(y_data.values.astype(np.float16).astype(np.float32))

In [11]:
features

tensor([[-0.4194,  0.2847, -1.2871,  ..., -1.4580,  0.4407, -1.0742],
        [-0.4170, -0.4873, -0.5928,  ..., -0.3027,  0.4407, -0.4919],
        [-0.4170, -0.4873, -0.5928,  ..., -0.3027,  0.3960, -1.2080],
        ...,
        [-0.4131, -0.4873,  0.1156,  ...,  1.1758,  0.4407, -0.9819],
        [-0.4075, -0.4873,  0.1156,  ...,  1.1758,  0.4028, -0.8643],
        [-0.4146, -0.4873,  0.1156,  ...,  1.1758,  0.4407, -0.6685]])

In [12]:
# Arguments of projects
batch_size = 16
epochs = 500
train_test_split = 0.8
lr = 0.0001

In [13]:
train_x = features[:int(len(features)*train_test_split)]
train_y = targets[:int(len(features)*train_test_split)]

test_x = features[int(len(features)*train_test_split)+1:]
test_y = targets[int(len(features)*train_test_split)+1:]

In [14]:
def get_batches(X,y):
    batches = []
    
    for index in range(0,len(train_x)+1,batch_size):
        batches.append((X[index:index+batch_size],y[index:index+batch_size]))
    
    return batches

In [15]:
train_batches=get_batches(train_x,train_y)

<h1>Plaintext Training</h1>

In [16]:
import syft as sy

In [17]:
class LinearSyNet(sy.Module):
    def __init__(self, torch_ref):
        super(LinearSyNet, self).__init__(torch_ref=torch_ref)
        self.fc1 = self.torch_ref.nn.Linear(13,1)

    def forward(self, x):
        x = self.fc1(x)
        return x

In [18]:
model = LinearSyNet(torch)
criterion = torch.nn.MSELoss(reduction='mean') 
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

In [19]:
for epoch in range(epochs):
  running_loss = 0.0
  for index in range(0,len(train_batches)):
    # Clear gradient buffers because we don't want any gradient from previous epoch to carry forward, dont want to cummulate gradients
    optimizer.zero_grad()

    # get output from the model, given the inputs
    outputs = model(train_batches[index][0]).reshape([-1])

    # get loss for the predicted output
    loss = criterion(outputs,train_batches[index][1])
    running_loss += loss
    # get gradients w.r.t to parameters
    loss.backward()

    # update parameters
    optimizer.step()
    
  test_accuracy = outputs = criterion(model(test_x).reshape([-1]),test_y)
  print(f"Epoch {epoch}/{epochs}  Running Loss : {running_loss.item()/batch_size} and testacc : {test_accuracy.item()}")

Epoch 0/500  Running Loss : 978.0725708007812 and testacc : 367.93927001953125
Epoch 1/500  Running Loss : 831.0911865234375 and testacc : 411.168212890625
Epoch 2/500  Running Loss : 720.55322265625 and testacc : 440.28497314453125
Epoch 3/500  Running Loss : 633.6892700195312 and testacc : 457.1828308105469
Epoch 4/500  Running Loss : 562.962646484375 and testacc : 464.29931640625
Epoch 5/500  Running Loss : 503.80291748046875 and testacc : 463.9801330566406
Epoch 6/500  Running Loss : 453.3403015136719 and testacc : 458.2357177734375
Epoch 7/500  Running Loss : 409.6958312988281 and testacc : 448.6855773925781
Epoch 8/500  Running Loss : 371.5797424316406 and testacc : 436.5837707519531
Epoch 9/500  Running Loss : 338.0628967285156 and testacc : 422.87548828125
Epoch 10/500  Running Loss : 308.4455261230469 and testacc : 408.2576904296875
Epoch 11/500  Running Loss : 282.1790466308594 and testacc : 393.2343444824219
Epoch 12/500  Running Loss : 258.81988525390625 and testacc : 378.1

Epoch 123/500  Running Loss : 41.93108367919922 and testacc : 37.597877502441406
Epoch 124/500  Running Loss : 41.893070220947266 and testacc : 37.24527359008789
Epoch 125/500  Running Loss : 41.85566711425781 and testacc : 36.89934539794922
Epoch 126/500  Running Loss : 41.818870544433594 and testacc : 36.5599479675293
Epoch 127/500  Running Loss : 41.78266525268555 and testacc : 36.22693634033203
Epoch 128/500  Running Loss : 41.74702453613281 and testacc : 35.90016174316406
Epoch 129/500  Running Loss : 41.71194839477539 and testacc : 35.579463958740234
Epoch 130/500  Running Loss : 41.67741775512695 and testacc : 35.264793395996094
Epoch 131/500  Running Loss : 41.64341354370117 and testacc : 34.95596694946289
Epoch 132/500  Running Loss : 41.609928131103516 and testacc : 34.652870178222656
Epoch 133/500  Running Loss : 41.576942443847656 and testacc : 34.35540008544922
Epoch 134/500  Running Loss : 41.544464111328125 and testacc : 34.063453674316406
Epoch 135/500  Running Loss : 4

Epoch 250/500  Running Loss : 39.58168411254883 and testacc : 20.645917892456055
Epoch 251/500  Running Loss : 39.573429107666016 and testacc : 20.62322998046875
Epoch 252/500  Running Loss : 39.56526565551758 and testacc : 20.601301193237305
Epoch 253/500  Running Loss : 39.55717086791992 and testacc : 20.580106735229492
Epoch 254/500  Running Loss : 39.54916763305664 and testacc : 20.559650421142578
Epoch 255/500  Running Loss : 39.54125213623047 and testacc : 20.53989601135254
Epoch 256/500  Running Loss : 39.53340148925781 and testacc : 20.52086639404297
Epoch 257/500  Running Loss : 39.525630950927734 and testacc : 20.502532958984375
Epoch 258/500  Running Loss : 39.5179443359375 and testacc : 20.484880447387695
Epoch 259/500  Running Loss : 39.51033020019531 and testacc : 20.467897415161133
Epoch 260/500  Running Loss : 39.50278854370117 and testacc : 20.45157814025879
Epoch 261/500  Running Loss : 39.495323181152344 and testacc : 20.4359073638916
Epoch 262/500  Running Loss : 39

Epoch 376/500  Running Loss : 38.97275161743164 and testacc : 21.0404109954834
Epoch 377/500  Running Loss : 38.970130920410156 and testacc : 21.05553436279297
Epoch 378/500  Running Loss : 38.9675407409668 and testacc : 21.070709228515625
Epoch 379/500  Running Loss : 38.9649772644043 and testacc : 21.0859375
Epoch 380/500  Running Loss : 38.96242904663086 and testacc : 21.10120964050293
Epoch 381/500  Running Loss : 38.959896087646484 and testacc : 21.116533279418945
Epoch 382/500  Running Loss : 38.9573860168457 and testacc : 21.131898880004883
Epoch 383/500  Running Loss : 38.95490646362305 and testacc : 21.14729881286621
Epoch 384/500  Running Loss : 38.95243453979492 and testacc : 21.162734985351562
Epoch 385/500  Running Loss : 38.94998550415039 and testacc : 21.178211212158203
Epoch 386/500  Running Loss : 38.94756317138672 and testacc : 21.193729400634766
Epoch 387/500  Running Loss : 38.945152282714844 and testacc : 21.209272384643555
Epoch 388/500  Running Loss : 38.94276809

<h1>Plaintext Inference</h1>

In [20]:
plaintext_predictions = model(test_x)

In [21]:
print("MSE Loss: ",criterion(plaintext_predictions,test_y).item())

MSE Loss:  63.399288177490234


  return F.mse_loss(input, target, reduction=self.reduction)


<h1>Encrypted Inference</h1>

In [22]:
import syft as sy
from sympc.module.nn import mse_loss
import sympc
from sympc.session import Session
from sympc.session import SessionManager
from sympc.tensor import MPCTensor
from sympc.optim import SGD
from sympc.config import Config
from sympc.protocol import Falcon,FSS
import time

In [23]:
def get_clients(n_parties):
  # Define the virtual machines that would be use in the computation
  parties=[]

  for index in range(n_parties): 
      parties.append(sy.VirtualMachine(name = "worker"+str(index)).get_root_client())

  return parties

In [24]:
def inference(n_clients,protocol):

  parties=get_clients(n_clients)

  # Setup the session for the computation
  session = Session(parties = parties,protocol = protocol)
  SessionManager.setup_mpc(session)

  mpc_model = model.share(session)

  test_data=MPCTensor(secret=test_x, session = session)

  start_time = time.time()
  enc_results = mpc_model(test_data)
  end_time = time.time()

  print(f"Time for inference: {end_time-start_time}s")

  predictions = enc_results.reconstruct()
    
  return predictions

In [25]:
predictions=inference(3,Falcon("semi-honest"))

Time for inference: 0.033750057220458984s


In [26]:
print("MSE Loss: ",criterion(predictions,test_y).item())

MSE Loss:  tensor(63.3991)


We can see that the prediction values and mean squared error values are almost the same as final model. Small differences are due to precision loss.

In [27]:
for index in range(0,10):
    print(f"Index {index}")
    print(f"Encrypted Prediction Output {predictions[index].item()}")
    print(f"Plaintext Prediction Output {plaintext_predictions[index].item()}")
    print(f"Expected Prediction: {test_y[index]}")
    print("\n")

Index 0
Encrypted Prediction Output 1.8402252197265625
Plaintext Prediction Output 1.8402048349380493
Expected Prediction: 5.0


Index 1
Encrypted Prediction Output 5.19158935546875
Plaintext Prediction Output 5.191580772399902
Expected Prediction: 11.8984375


Index 2
Encrypted Prediction Output 19.786148071289062
Plaintext Prediction Output 19.786155700683594
Expected Prediction: 27.90625


Index 3
Encrypted Prediction Output 13.1734619140625
Plaintext Prediction Output 13.173490524291992
Expected Prediction: 17.203125


Index 4
Encrypted Prediction Output 20.987045288085938
Plaintext Prediction Output 20.987075805664062
Expected Prediction: 27.5


Index 5
Encrypted Prediction Output 14.210540771484375
Plaintext Prediction Output 14.210532188415527
Expected Prediction: 15.0


Index 6
Encrypted Prediction Output 19.272994995117188
Plaintext Prediction Output 19.273014068603516
Expected Prediction: 17.203125


Index 7
Encrypted Prediction Output 1.7180023193359375
Plaintext Prediction 

Falcon can also work with malicious security guarantee but at a large inference time. 

In [28]:
predictions=inference(3,Falcon("malicious"))

Time for inference: 0.5756688117980957s


In [29]:
predictions=inference(3,FSS())

[2021-07-13T17:54:24.721347+0530][CRITICAL][logger]][31441] <class 'syft.core.store.store_memory.MemoryStore'> __delitem__ error <UID: e43609e4f3b24af1bab46d6a3463c6cf>.
[2021-07-13T17:54:24.724769+0530][CRITICAL][logger]][31441] <class 'syft.core.store.store_memory.MemoryStore'> __delitem__ error <UID: 433768ff3ded40c586a6b517735d8271>.
[2021-07-13T17:54:24.726637+0530][CRITICAL][logger]][31441] <class 'syft.core.store.store_memory.MemoryStore'> __delitem__ error <UID: 47ca2a8b0f55424b8c7dee7dc0a0ccb0>.


Time for inference: 0.5407030582427979s


In [30]:
predictions=inference(5,FSS())

[2021-07-13T17:54:26.588460+0530][CRITICAL][logger]][31441] <class 'syft.core.store.store_memory.MemoryStore'> __delitem__ error <UID: 384b45903a294c2e85c2b480de81ce4d>.
[2021-07-13T17:54:26.595719+0530][CRITICAL][logger]][31441] <class 'syft.core.store.store_memory.MemoryStore'> __delitem__ error <UID: 35ffa19d3ebf42ce9a3c5a0a619d59a9>.
[2021-07-13T17:54:26.597316+0530][CRITICAL][logger]][31441] <class 'syft.core.store.store_memory.MemoryStore'> __delitem__ error <UID: 728a7f55e01c4ff98525a9b84f6a1cff>.


Time for inference: 0.9237020015716553s



<center> <h3> Comparison </h3> </center>

| Protocol | Security Type| Parties | Inference Time (s) |
| --- | --- | --- | --- |
| Falcon | Semi-honest | 3 | 0.03391 |
| Falcon | Malicious | 3 | 0.61146 |
| FSS| Semi-honest | 3 | 0.56047 |
| FSS | Semi-honest | 5 | 0.91995|

Falcon works for 3 parties and semi honest security setting provides significant security guarantee. For, N number of parties you can Functional Secret Sharing protocol (FSS). 

<h3>What's next?</h3>

SyMPC is still under development! We will add here more features as soon they are stable enough, stay tuned! 🕺

If you enjoyed this tutorial, show your support by Starring SyMPC! 🙏